Respair commited on
Commit
59b7eeb
·
verified ·
1 Parent(s): 5711454

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +33 -0
  2. LICENSE +21 -0
  3. README.md +88 -0
  4. UniSpeech/.gitignore +3 -0
  5. UniSpeech/ILS-SSL/README.md +63 -0
  6. UniSpeech/LICENSE +74 -0
  7. UniSpeech/README.md +148 -0
  8. UniSpeech/SECURITY.md +41 -0
  9. UniSpeech/UniSpeech-SAT/README.md +76 -0
  10. UniSpeech/UniSpeech-SAT/UniSpeech_SAT_SUPERB_Results.png +3 -0
  11. UniSpeech/UniSpeech/README.md +56 -0
  12. UniSpeech/WavLM/README.md +125 -0
  13. UniSpeech/WavLM/WavLM.py +743 -0
  14. UniSpeech/WavLM/WavLM_ASR.PNG +3 -0
  15. UniSpeech/WavLM/WavLM_SUPERB_Leaderboard.png +3 -0
  16. UniSpeech/WavLM/WavLM_SUPERB_Results.png +3 -0
  17. UniSpeech/WavLM/modules.py +827 -0
  18. UniSpeech/azure-pipelines.yml +47 -0
  19. UniSpeech/downstreams/speaker_diarization/README.md +34 -0
  20. UniSpeech/downstreams/speaker_diarization/config/infer_est_nspk1.yaml +31 -0
  21. UniSpeech/downstreams/speaker_diarization/config/unispeech_sat.th +0 -0
  22. UniSpeech/downstreams/speaker_diarization/diarization.py +321 -0
  23. UniSpeech/downstreams/speaker_diarization/models/models.py +391 -0
  24. UniSpeech/downstreams/speaker_diarization/models/transformer.py +147 -0
  25. UniSpeech/downstreams/speaker_diarization/models/utils.py +78 -0
  26. UniSpeech/downstreams/speaker_diarization/requirements.txt +136 -0
  27. UniSpeech/downstreams/speaker_diarization/tmp/mix_0000496.wav +1 -0
  28. UniSpeech/downstreams/speaker_diarization/utils/dataset.py +484 -0
  29. UniSpeech/downstreams/speaker_diarization/utils/kaldi_data.py +162 -0
  30. UniSpeech/downstreams/speaker_diarization/utils/parse_options.sh +97 -0
  31. UniSpeech/downstreams/speaker_diarization/utils/utils.py +189 -0
  32. UniSpeech/downstreams/speaker_verification/README.md +49 -0
  33. UniSpeech/downstreams/speaker_verification/config/unispeech_sat.th +0 -0
  34. UniSpeech/downstreams/speaker_verification/models/__init__.py +0 -0
  35. UniSpeech/downstreams/speaker_verification/models/ecapa_tdnn.py +301 -0
  36. UniSpeech/downstreams/speaker_verification/models/utils.py +78 -0
  37. UniSpeech/downstreams/speaker_verification/requirements.txt +85 -0
  38. UniSpeech/downstreams/speaker_verification/verification.py +67 -0
  39. UniSpeech/downstreams/speaker_verification/vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav +3 -0
  40. UniSpeech/downstreams/speaker_verification/vox1_data/David_Faustino/xTOk1Jz-F_g_0000015.wav +3 -0
  41. UniSpeech/downstreams/speaker_verification/vox1_data/Josh_Gad/HXUqYaOwrxA_0000015.wav +3 -0
  42. UniSpeech/downstreams/speaker_verification/vox1_data/Josh_Gad/RFyw7V3SOnQ_0000001.wav +3 -0
  43. UniSpeech/downstreams/speaker_verification/vox1_data/Lea_Thompson/HladKGyKTLM_0000006.wav +3 -0
  44. UniSpeech/downstreams/speaker_verification/vox1_data/Lea_Thompson/mHTAr5dlAgc_0000004.wav +3 -0
  45. UniSpeech/downstreams/speaker_verification/vox1_data/Zulay_Henao/WbB8m9-wlIQ_0000001.wav +3 -0
  46. UniSpeech/downstreams/speaker_verification/vox1_data/Zulay_Henao/gFfcgOVmiO0_0000002.wav +3 -0
  47. UniSpeech/src/CODE_OF_CONDUCT.md +77 -0
  48. UniSpeech/src/CONTRIBUTING.md +28 -0
  49. UniSpeech/src/LICENSE +21 -0
  50. UniSpeech/src/config/config.yaml +111 -0
.gitattributes CHANGED
@@ -33,3 +33,36 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ UniSpeech/UniSpeech-SAT/UniSpeech_SAT_SUPERB_Results.png filter=lfs diff=lfs merge=lfs -text
37
+ UniSpeech/WavLM/WavLM_ASR.PNG filter=lfs diff=lfs merge=lfs -text
38
+ UniSpeech/WavLM/WavLM_SUPERB_Leaderboard.png filter=lfs diff=lfs merge=lfs -text
39
+ UniSpeech/WavLM/WavLM_SUPERB_Results.png filter=lfs diff=lfs merge=lfs -text
40
+ UniSpeech/downstreams/speaker_verification/vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav filter=lfs diff=lfs merge=lfs -text
41
+ UniSpeech/downstreams/speaker_verification/vox1_data/David_Faustino/xTOk1Jz-F_g_0000015.wav filter=lfs diff=lfs merge=lfs -text
42
+ UniSpeech/downstreams/speaker_verification/vox1_data/Josh_Gad/HXUqYaOwrxA_0000015.wav filter=lfs diff=lfs merge=lfs -text
43
+ UniSpeech/downstreams/speaker_verification/vox1_data/Josh_Gad/RFyw7V3SOnQ_0000001.wav filter=lfs diff=lfs merge=lfs -text
44
+ UniSpeech/downstreams/speaker_verification/vox1_data/Lea_Thompson/HladKGyKTLM_0000006.wav filter=lfs diff=lfs merge=lfs -text
45
+ UniSpeech/downstreams/speaker_verification/vox1_data/Lea_Thompson/mHTAr5dlAgc_0000004.wav filter=lfs diff=lfs merge=lfs -text
46
+ UniSpeech/downstreams/speaker_verification/vox1_data/Zulay_Henao/WbB8m9-wlIQ_0000001.wav filter=lfs diff=lfs merge=lfs -text
47
+ UniSpeech/downstreams/speaker_verification/vox1_data/Zulay_Henao/gFfcgOVmiO0_0000002.wav filter=lfs diff=lfs merge=lfs -text
48
+ UniSpeech/src/examples/speaker_verification/vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav filter=lfs diff=lfs merge=lfs -text
49
+ UniSpeech/src/examples/speaker_verification/vox1_data/David_Faustino/xTOk1Jz-F_g_0000015.wav filter=lfs diff=lfs merge=lfs -text
50
+ UniSpeech/src/examples/speaker_verification/vox1_data/Josh_Gad/HXUqYaOwrxA_0000015.wav filter=lfs diff=lfs merge=lfs -text
51
+ UniSpeech/src/examples/speaker_verification/vox1_data/Josh_Gad/RFyw7V3SOnQ_0000001.wav filter=lfs diff=lfs merge=lfs -text
52
+ UniSpeech/src/examples/speaker_verification/vox1_data/Lea_Thompson/HladKGyKTLM_0000006.wav filter=lfs diff=lfs merge=lfs -text
53
+ UniSpeech/src/examples/speaker_verification/vox1_data/Lea_Thompson/mHTAr5dlAgc_0000004.wav filter=lfs diff=lfs merge=lfs -text
54
+ UniSpeech/src/examples/speaker_verification/vox1_data/Zulay_Henao/WbB8m9-wlIQ_0000001.wav filter=lfs diff=lfs merge=lfs -text
55
+ UniSpeech/src/examples/speaker_verification/vox1_data/Zulay_Henao/gFfcgOVmiO0_0000002.wav filter=lfs diff=lfs merge=lfs -text
56
+ UniSpeech/src/examples/unispeech/data/en/pretrain_1350_16k.id filter=lfs diff=lfs merge=lfs -text
57
+ UniSpeech/src/examples/unispeech/data/en/pretrain_1350_16k.tsv filter=lfs diff=lfs merge=lfs -text
58
+ UniSpeech/src/examples/unispeech/data/es/pretrain_168_16k.phone filter=lfs diff=lfs merge=lfs -text
59
+ UniSpeech/src/examples/unispeech/data/es/pretrain_168_16k_sep.id filter=lfs diff=lfs merge=lfs -text
60
+ UniSpeech/src/examples/unispeech/data/es/pretrain_168_16k_share.id filter=lfs diff=lfs merge=lfs -text
61
+ UniSpeech/src/examples/unispeech/data/fr/pretrain_353_16k.phone filter=lfs diff=lfs merge=lfs -text
62
+ UniSpeech/src/examples/unispeech/data/fr/pretrain_353_16k.tsv filter=lfs diff=lfs merge=lfs -text
63
+ UniSpeech/src/examples/unispeech/data/fr/pretrain_353_16k_sep.id filter=lfs diff=lfs merge=lfs -text
64
+ UniSpeech/src/examples/unispeech/data/fr/pretrain_353_16k_share.id filter=lfs diff=lfs merge=lfs -text
65
+ UniSpeech/src/examples/unispeech/data/it/pretrain_90_16k_sep.id filter=lfs diff=lfs merge=lfs -text
66
+ UniSpeech/src/fairseq/data/data_utils_fast.cpython-36m-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
67
+ UniSpeech/src/fairseq/data/data_utils_fast.cpython-37m-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
68
+ audio_high_quality_lmao.txt filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 YE Zhen
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2502.04128)
2
+
3
+ # X-Codec-2.0
4
+ Paper: LLaSA: Scaling Train-time and Inference-time Compute for LLaMA-based Speech Synthesis
5
+
6
+ **Update (2025-02-13):** Add [Llasa finetune instruction](https://github.com/zhenye234/LLaSA_training/tree/main/finetune).
7
+
8
+ **Update (2025-02-07):** Our paper has been released!
9
+
10
+
11
+ ## Directly used on Hugging Face
12
+
13
+ **Codec**: [xcodec2](https://huggingface.co/HKUST-Audio/xcodec2) (Use `xcodec2==0.1.5` for codec inference and llasa fine-tuning. I’ve removed unnecessary dependencies, and it works fine in my testing. However, I’m not sure if other problems may arise. If you prefer more stability, I recommend using `xcodec2==0.1.3` which accurately aligns during my codec training.)
14
+
15
+
16
+ **Llasa-collections**: [Llasa-collections](https://huggingface.co/collections/HKUSTAudio/llasa-679b87dbd06ac556cc0e0f44)
17
+
18
+ ## Features
19
+
20
+ - **Single Vector Quantization**
21
+ - 65536 Codebook Size using Finite Scalar Quantization achieving 99% codebook usage. ( comparable to text tokenizers, LLaMA3 128256)
22
+ - 50x1 Tokens per Second
23
+
24
+ - **Multilingual Speech Semantic Support**
25
+ - Uses Wav2Vec2-BERT, a semantic encoder pre-trained on 4.5M hours of unlabeled audio data covering more than 143 languages.
26
+ - Codec trained on 150k hours of multilingual speech data, including Emilia (En/Zh/De/Fr/Ja/Ko) and MLS (En/Fr/De/Nl/Es/It/Pt/Pl).
27
+
28
+ - **High-Quality Speech Reconstruction**
29
+ - Transformer + Vocos Decoder
30
+ - BigCodec encoder
31
+ - Spec discriminator with FFT sizes {78, 126, 206, 334, 542, 876, 1418, 2296} tailored for transformer decoder. [Details here](https://openreview.net/pdf?id=4YpMrGfldX)
32
+ - Achieving UTMOS 4.13 WER 2.47 (hubert-large-ls960-ft) sim 0.82 (wavlm_large_finetune) stoi 0.92 pesq-nb 3.05 pesq-wb 2.44 on librispeech-test-clean reconstruction (gt: WER 1.96 UTMOS 4.09)
33
+ - Only for 16kHz speech
34
+
35
+
36
+ ## Commandline Usage
37
+ ## Setup
38
+ Code is tested on `python3.9`
39
+
40
+ Please follow the following steps to setup your environment
41
+ 1. Clone this repo
42
+ 2. conda create --name xcodec2 python=3.9
43
+ 3. conda activate xcodec2
44
+ 2. `pip install -r requirements.txt`
45
+ 3. [Download the pretrained checkpoint here](https://huggingface.co/HKUST-Audio/xcodec2/blob/main/ckpt/epoch%3D4-step%3D1400000.ckpt)
46
+
47
+
48
+ ## Inference
49
+ ```bash
50
+ python inference.py
51
+ ```
52
+
53
+ ## Train
54
+ To train a XCodec2, firstly you have to prepare your data
55
+
56
+ 1. Make a file list by:
57
+ ```bash
58
+ python get_tsv.py
59
+ ```
60
+
61
+ 2. Train a X-Codec-2.0 with the default setting by:
62
+
63
+ ```bash
64
+ python train.py log_dir=/path/to/log_dir
65
+ ```
66
+
67
+ ## Large-scale training, Batch inference and large-scale code extracting:
68
+
69
+ Batch inference
70
+ ```bash
71
+ python inference_save_code.py
72
+ ```
73
+ Training
74
+ ```bash
75
+ Sbatch train_slurm.sh
76
+ ```
77
+
78
+ Code extracting
79
+ ```bash
80
+ Sbatch large_scale_save_code.sh
81
+ ```
82
+
83
+ Code will save in output folder with the same subfolder structure for audio file.
84
+
85
+
86
+
87
+ ## Acknowledgement
88
+ I would like to extend a special thanks to authors of BigCodec, since our code base is mainly borrowed from [BigCodec](https://github.com/Aria-K-Alethia/BigCodec).
UniSpeech/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .idea/
2
+
3
+ __pycache__/
UniSpeech/ILS-SSL/README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # ILS-SSL
3
+
4
+ > [**ILS-SSL**](https://arxiv.org/pdf/2112.08778.pdf): Self-Supervised Learning for Speech Recognition with Intermediate Layer Supervision
5
+
6
+ The data preparation and pre-training for the first iteration follow the same pipeline as Hubert. We give example scripts for ILS-Hubert pre-training and fine-tuning in src/examples/hubert/scripts
7
+
8
+ ## Pre-Trained and Fine-tuned Models
9
+ Model | Pretraining Dataset | Finetuning Dataset | Model
10
+ |---|---|---|---
11
+ ILS-Base | 960h LibriSpeech | - | [Download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/el_hubert_4_12/checkpoint_best.pt?st=2022-01-04T08%3A05%3A24Z&se=2024-01-05T08%3A05%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=JI8ZOgBhrrKUY4DE2ommnKpyAUuX6OrHfWgdjAT2Xnc%3D)
12
+ ILS-Large | 60k hrs Libri-Light | - | [Download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/ils_hubert_large/checkpoint_fixed.pt?st=2022-01-04T08%3A24%3A37Z&se=2025-01-05T08%3A24%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=Dv6svAaI7Td%2BZWUTjTFkhChFbpnAAU6xKNjPbPQnIKM%3D)
13
+ ILS-Large | 60k hrs Libri-Light | 960h LibriSpeech | [Download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/ils_hubert_large/checkpoint_ft.pt?st=2022-01-04T08%3A40%3A17Z&se=2025-01-05T08%3A40%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=GKIe%2F1kz%2F1fjGTsQsakJy68jlsFDbKmIVYjH61dhrwA%3D)
14
+
15
+
16
+ ## Results on Librispeech
17
+ Base Model | Finetuning set| LM | test-clean | test-other
18
+ |---|---|---|---|---
19
+ wav2vec2.0 | 1 hour | None | 24.5 | 29.7
20
+ Hubert | 1 hour | None| 20.9 | 27.5
21
+ ILS-SSL | 1 hour | None | 17.9 | 23.1
22
+ wav2vec2.0 | 1 hour | 4-gram | 5.5 | 11.3
23
+ Hubert | 1 hour | 4-gram | 6.1 | 11.3
24
+ ILS-SSL | 1 hour | 4-gram | 5.4 | 10.2
25
+ wav2vec2.0 | 10 hour | None | 11.1 | 17.6
26
+ Hubert | 10 hour | None| 10.1 | 16.8
27
+ ILS-SSL | 10 hour | None | 8.3 | 13.6
28
+ wav2vec2.0 | 10 hour | 4-gram | 4.3 | 9.5
29
+ Hubert | 10 hour | 4-gram | 4.3 | 9.4
30
+ ILS-SSL | 10 hour | 4-gram | 3.8 | 8.1
31
+ wav2vec2.0 | 100 hour | None | 6.1 | 13.3
32
+ Hubert | 100 hour | None| 6.3 | 13.2
33
+ ILS-SSL | 100 hour | None | 4.7 | 10.1
34
+ wav2vec2.0 | 100 hour | 4-gram | 3.4| 8.0
35
+ Hubert | 100 hour | 4-gram | 3.4 | 8.1
36
+ ILS-SSL | 100 hour | 4-gram | 3.0 | 6.9
37
+
38
+ Large Model | Finetuning set| LM | test-clean | test-other
39
+ |---|---|---|---|---
40
+ wav2vec2.0 | 1 hour | None | 17.2 | 20.3
41
+ Hubert | 1 hour | None| 17.4 | 20.3
42
+ ILS-SSL | 1 hour | None | 14.3 | 16.9
43
+ wav2vec2.0 | 1 hour | Transf | 2.9 | 5.8
44
+ Hubert | 1 hour | Transf | 2.9 | 5.4
45
+ ILS-SSL | 1 hour | Transf | 2.8 | 5.3
46
+ wav2vec2.0 | 10 hour | None | 6.3 | 10.0
47
+ Hubert | 10 hour | None | 6.2 | 9.6
48
+ ILS-SSL | 10 hour | None | 6.1 | 9.1
49
+ wav2vec2.0 | 10 hour | Transf | 2.6 | 4.9
50
+ Hubert | 10 hour | Transf | 2.4 | 4.6
51
+ ILS-SSL | 10 hour | Transf | 2.5 | 4.5
52
+ wav2vec2.0 | 100 hour | None | 3.1 | 6.3
53
+ Hubert | 100 hour | None| 2.9 | 6.0
54
+ ILS-SSL | 100 hour | None | 2.9 | 5.8
55
+ wav2vec2.0 | 100 hour | Transf | 2.0 | 4.0
56
+ Hubert | 100 hour | Transf | 2.1 | 3.9
57
+ ILS-SSL | 100 hour | Transf | 2.0 | 4.0
58
+ wav2vec2.0 | 960 hour | None | 2.2 | 4.5
59
+ Hubert | 960 hour | None | 2.1 | 4.3
60
+ ILS-SSL | 960 hour | None | 1.9 | 3.8
61
+ wav2vec2.0 | 960 hour | Transf | 1.8 | 3.3
62
+ Hubert | 960 hour | Transf | 1.9 | 3.3
63
+ ILS-SSL | 960 hour | Transf | 1.8 | 3.2
UniSpeech/LICENSE ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-ShareAlike 3.0 Unported
2
+
3
+ CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE LEGAL SERVICES. DISTRIBUTION OF THIS LICENSE DOES NOT CREATE AN ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES REGARDING THE INFORMATION PROVIDED, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM ITS USE.
4
+
5
+ License
6
+
7
+ THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS CREATIVE COMMONS PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS PROTECTED BY COPYRIGHT AND/OR OTHER APPLICABLE LAW. ANY USE OF THE WORK OTHER THAN AS AUTHORIZED UNDER THIS LICENSE OR COPYRIGHT LAW IS PROHIBITED.
8
+
9
+ BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. TO THE EXTENT THIS LICENSE MAY BE CONSIDERED TO BE A CONTRACT, THE LICENSOR GRANTS YOU THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS.
10
+
11
+ 1. Definitions
12
+
13
+ "Adaptation" means a work based upon the Work, or upon the Work and other pre-existing works, such as a translation, adaptation, derivative work, arrangement of music or other alterations of a literary or artistic work, or phonogram or performance and includes cinematographic adaptations or any other form in which the Work may be recast, transformed, or adapted including in any form recognizably derived from the original, except that a work that constitutes a Collection will not be considered an Adaptation for the purpose of this License. For the avoidance of doubt, where the Work is a musical work, performance or phonogram, the synchronization of the Work in timed-relation with a moving image ("synching") will be considered an Adaptation for the purpose of this License.
14
+ "Collection" means a collection of literary or artistic works, such as encyclopedias and anthologies, or performances, phonograms or broadcasts, or other works or subject matter other than works listed in Section 1(f) below, which, by reason of the selection and arrangement of their contents, constitute intellectual creations, in which the Work is included in its entirety in unmodified form along with one or more other contributions, each constituting separate and independent works in themselves, which together are assembled into a collective whole. A work that constitutes a Collection will not be considered an Adaptation (as defined below) for the purposes of this License.
15
+ "Creative Commons Compatible License" means a license that is listed at https://creativecommons.org/compatiblelicenses that has been approved by Creative Commons as being essentially equivalent to this License, including, at a minimum, because that license: (i) contains terms that have the same purpose, meaning and effect as the License Elements of this License; and, (ii) explicitly permits the relicensing of adaptations of works made available under that license under this License or a Creative Commons jurisdiction license with the same License Elements as this License.
16
+ "Distribute" means to make available to the public the original and copies of the Work or Adaptation, as appropriate, through sale or other transfer of ownership.
17
+ "License Elements" means the following high-level license attributes as selected by Licensor and indicated in the title of this License: Attribution, ShareAlike.
18
+ "Licensor" means the individual, individuals, entity or entities that offer(s) the Work under the terms of this License.
19
+ "Original Author" means, in the case of a literary or artistic work, the individual, individuals, entity or entities who created the Work or if no individual or entity can be identified, the publisher; and in addition (i) in the case of a performance the actors, singers, musicians, dancers, and other persons who act, sing, deliver, declaim, play in, interpret or otherwise perform literary or artistic works or expressions of folklore; (ii) in the case of a phonogram the producer being the person or legal entity who first fixes the sounds of a performance or other sounds; and, (iii) in the case of broadcasts, the organization that transmits the broadcast.
20
+ "Work" means the literary and/or artistic work offered under the terms of this License including without limitation any production in the literary, scientific and artistic domain, whatever may be the mode or form of its expression including digital form, such as a book, pamphlet and other writing; a lecture, address, sermon or other work of the same nature; a dramatic or dramatico-musical work; a choreographic work or entertainment in dumb show; a musical composition with or without words; a cinematographic work to which are assimilated works expressed by a process analogous to cinematography; a work of drawing, painting, architecture, sculpture, engraving or lithography; a photographic work to which are assimilated works expressed by a process analogous to photography; a work of applied art; an illustration, map, plan, sketch or three-dimensional work relative to geography, topography, architecture or science; a performance; a broadcast; a phonogram; a compilation of data to the extent it is protected as a copyrightable work; or a work performed by a variety or circus performer to the extent it is not otherwise considered a literary or artistic work.
21
+ "You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation.
22
+ "Publicly Perform" means to perform public recitations of the Work and to communicate to the public those public recitations, by any means or process, including by wire or wireless means or public digital performances; to make available to the public Works in such a way that members of the public may access these Works from a place and at a place individually chosen by them; to perform the Work to the public by any means or process and the communication to the public of the performances of the Work, including by public digital performance; to broadcast and rebroadcast the Work by any means including signs, sounds or images.
23
+ "Reproduce" means to make copies of the Work by any means including without limitation by sound or visual recordings and the right of fixation and reproducing fixations of the Work, including storage of a protected performance or phonogram in digital form or other electronic medium.
24
+
25
+ 2. Fair Dealing Rights. Nothing in this License is intended to reduce, limit, or restrict any uses free from copyright or rights arising from limitations or exceptions that are provided for in connection with the copyright protection under copyright law or other applicable laws.
26
+
27
+ 3. License Grant. Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below:
28
+
29
+ to Reproduce the Work, to incorporate the Work into one or more Collections, and to Reproduce the Work as incorporated in the Collections;
30
+ to create and Reproduce Adaptations provided that any such Adaptation, including any translation in any medium, takes reasonable steps to clearly label, demarcate or otherwise identify that changes were made to the original Work. For example, a translation could be marked "The original work was translated from English to Spanish," or a modification could indicate "The original work has been modified.";
31
+ to Distribute and Publicly Perform the Work including as incorporated in Collections; and,
32
+ to Distribute and Publicly Perform Adaptations.
33
+
34
+ For the avoidance of doubt:
35
+ Non-waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme cannot be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License;
36
+ Waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme can be waived, the Licensor waives the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; and,
37
+ Voluntary License Schemes. The Licensor waives the right to collect royalties, whether individually or, in the event that the Licensor is a member of a collecting society that administers voluntary licensing schemes, via that society, from any exercise by You of the rights granted under this License.
38
+
39
+ The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. Subject to Section 8(f), all rights not expressly granted by Licensor are hereby reserved.
40
+
41
+ 4. Restrictions. The license granted in Section 3 above is expressly made subject to and limited by the following restrictions:
42
+
43
+ You may Distribute or Publicly Perform the Work only under the terms of this License. You must include a copy of, or the Uniform Resource Identifier (URI) for, this License with every copy of the Work You Distribute or Publicly Perform. You may not offer or impose any terms on the Work that restrict the terms of this License or the ability of the recipient of the Work to exercise the rights granted to that recipient under the terms of the License. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties with every copy of the Work You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Work, You may not impose any effective technological measures on the Work that restrict the ability of a recipient of the Work from You to exercise the rights granted to that recipient under the terms of the License. This Section 4(a) applies to the Work as incorporated in a Collection, but this does not require the Collection apart from the Work itself to be made subject to the terms of this License. If You create a Collection, upon notice from any Licensor You must, to the extent practicable, remove from the Collection any credit as required by Section 4(c), as requested. If You create an Adaptation, upon notice from any Licensor You must, to the extent practicable, remove from the Adaptation any credit as required by Section 4(c), as requested.
44
+ You may Distribute or Publicly Perform an Adaptation only under the terms of: (i) this License; (ii) a later version of this License with the same License Elements as this License; (iii) a Creative Commons jurisdiction license (either this or a later license version) that contains the same License Elements as this License (e.g., Attribution-ShareAlike 3.0 US)); (iv) a Creative Commons Compatible License. If you license the Adaptation under one of the licenses mentioned in (iv), you must comply with the terms of that license. If you license the Adaptation under the terms of any of the licenses mentioned in (i), (ii) or (iii) (the "Applicable License"), you must comply with the terms of the Applicable License generally and the following provisions: (I) You must include a copy of, or the URI for, the Applicable License with every copy of each Adaptation You Distribute or Publicly Perform; (II) You may not offer or impose any terms on the Adaptation that restrict the terms of the Applicable License or the ability of the recipient of the Adaptation to exercise the rights granted to that recipient under the terms of the Applicable License; (III) You must keep intact all notices that refer to the Applicable License and to the disclaimer of warranties with every copy of the Work as included in the Adaptation You Distribute or Publicly Perform; (IV) when You Distribute or Publicly Perform the Adaptation, You may not impose any effective technological measures on the Adaptation that restrict the ability of a recipient of the Adaptation from You to exercise the rights granted to that recipient under the terms of the Applicable License. This Section 4(b) applies to the Adaptation as incorporated in a Collection, but this does not require the Collection apart from the Adaptation itself to be made subject to the terms of the Applicable License.
45
+ If You Distribute, or Publicly Perform the Work or any Adaptations or Collections, You must, unless a request has been made pursuant to Section 4(a), keep intact all copyright notices for the Work and provide, reasonable to the medium or means You are utilizing: (i) the name of the Original Author (or pseudonym, if applicable) if supplied, and/or if the Original Author and/or Licensor designate another party or parties (e.g., a sponsor institute, publishing entity, journal) for attribution ("Attribution Parties") in Licensor's copyright notice, terms of service or by other reasonable means, the name of such party or parties; (ii) the title of the Work if supplied; (iii) to the extent reasonably practicable, the URI, if any, that Licensor specifies to be associated with the Work, unless such URI does not refer to the copyright notice or licensing information for the Work; and (iv) , consistent with Ssection 3(b), in the case of an Adaptation, a credit identifying the use of the Work in the Adaptation (e.g., "French translation of the Work by Original Author," or "Screenplay based on original Work by Original Author"). The credit required by this Section 4(c) may be implemented in any reasonable manner; provided, however, that in the case of a Adaptation or Collection, at a minimum such credit will appear, if a credit for all contributing authors of the Adaptation or Collection appears, then as part of these credits and in a manner at least as prominent as the credits for the other contributing authors. For the avoidance of doubt, You may only use the credit required by this Section for the purpose of attribution in the manner set out above and, by exercising Your rights under this License, You may not implicitly or explicitly assert or imply any connection with, sponsorship or endorsement by the Original Author, Licensor and/or Attribution Parties, as appropriate, of You or Your use of the Work, without the separate, express prior written permission of the Original Author, Licensor and/or Attribution Parties.
46
+ Except as otherwise agreed in writing by the Licensor or as may be otherwise permitted by applicable law, if You Reproduce, Distribute or Publicly Perform the Work either by itself or as part of any Adaptations or Collections, You must not distort, mutilate, modify or take other derogatory action in relation to the Work which would be prejudicial to the Original Author's honor or reputation. Licensor agrees that in those jurisdictions (e.g. Japan), in which any exercise of the right granted in Section 3(b) of this License (the right to make Adaptations) would be deemed to be a distortion, mutilation, modification or other derogatory action prejudicial to the Original Author's honor and reputation, the Licensor will waive or not assert, as appropriate, this Section, to the fullest extent permitted by the applicable national law, to enable You to reasonably exercise Your right under Section 3(b) of this License (right to make Adaptations) but not otherwise.
47
+
48
+ 5. Representations, Warranties and Disclaimer
49
+
50
+ UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTIBILITY, FITNESS FOR A PARTICULAR PURPOSE, NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO SUCH EXCLUSION MAY NOT APPLY TO YOU.
51
+
52
+ 6. Limitation on Liability. EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
53
+
54
+ 7. Termination
55
+
56
+ This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Adaptations or Collections from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License.
57
+ Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above.
58
+
59
+ 8. Miscellaneous
60
+
61
+ Each time You Distribute or Publicly Perform the Work or a Collection, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License.
62
+ Each time You Distribute or Publicly Perform an Adaptation, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License.
63
+ If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable.
64
+ No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent.
65
+ This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You.
66
+ The rights granted under, and the subject matter referenced, in this License were drafted utilizing the terminology of the Berne Convention for the Protection of Literary and Artistic Works (as amended on September 28, 1979), the Rome Convention of 1961, the WIPO Copyright Treaty of 1996, the WIPO Performances and Phonograms Treaty of 1996 and the Universal Copyright Convention (as revised on July 24, 1971). These rights and subject matter take effect in the relevant jurisdiction in which the License terms are sought to be enforced according to the corresponding provisions of the implementation of those treaty provisions in the applicable national law. If the standard suite of rights granted under applicable copyright law includes additional rights not granted under this License, such additional rights are deemed to be included in the License; this License is not intended to restrict the license of any rights under applicable law.
67
+
68
+ Creative Commons Notice
69
+
70
+ Creative Commons is not a party to this License, and makes no warranty whatsoever in connection with the Work. Creative Commons will not be liable to You or any party on any legal theory for any damages whatsoever, including without limitation any general, special, incidental or consequential damages arising in connection to this license. Notwithstanding the foregoing two (2) sentences, if Creative Commons has expressly identified itself as the Licensor hereunder, it shall have all rights and obligations of Licensor.
71
+
72
+ Except for the limited purpose of indicating to the public that the Work is licensed under the CCPL, Creative Commons does not authorize the use by either party of the trademark "Creative Commons" or any related trademark or logo of Creative Commons without the prior written consent of Creative Commons. Any permitted use will be in compliance with Creative Commons' then-current trademark usage guidelines, as may be published on its website or otherwise made available upon request from time to time. For the avoidance of doubt, this trademark restriction does not form part of the License.
73
+
74
+ Creative Commons may be contacted at https://creativecommons.org/.
UniSpeech/README.md ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # UniSpeech
2
+
3
+ <!--**Pre-trained models for speech related tasks**-->
4
+
5
+ The family of UniSpeech:
6
+ > [**WavLM**](https://arxiv.org/pdf/2110.13900.pdf) (```arXiv```): **WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing**
7
+
8
+ > [**UniSpeech**](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech) (```ICML 2021```): **Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR**
9
+
10
+ > [**UniSpeech-SAT**](https://arxiv.org/pdf/2110.05752.pdf) (```ICASSP 2022 Submission```): **Universal Speech Representation Learning with Speaker Aware Pre-Training**
11
+
12
+ > [**ILS-SSL**](https://arxiv.org/pdf/2112.08778.pdf) (```ICASSP 2022 Submission```): **Self-Supervised Learning for Speech Recognition with Intermediate Layer Supervision**
13
+
14
+ Model introductions, evaluation results, and model inference instructions are located in their corresponding folders. The source code is here [https://github.com/microsoft/UniSpeech/tree/main/src].
15
+
16
+ ## Update
17
+ - [HuggingFace Integration] Dec 23, 2021: [**WavLM**](https://huggingface.co/models?other=wavlm) models are on [HuggingFace](https://huggingface.co/models?other=wavlm) .
18
+ - [HuggingFace Integration] Octorber 26, 2021: [**UniSpeech-SAT**](https://huggingface.co/microsoft/unispeech-sat-large) models are on [HuggingFace](https://huggingface.co/models?other=unispeech-sat) .
19
+ - [Model Release] Octorber 13, 2021: [**UniSpeech-SAT**](https://arxiv.org/pdf/2110.05752.pdf) models are releaseed.
20
+ - [HuggingFace Integration] Octorber 11, 2021: [**UniSpeech**](https://huggingface.co/microsoft/unispeech-large-1500h-cv) models are on [HuggingFace](https://huggingface.co/models?other=unispeech) .
21
+ - [Model Release] June, 2021: [**UniSpeech v1**](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech) models are released.
22
+ ## Pre-trained models
23
+ We strongly suggest using our UniSpeech-SAT model for speaker related tasks, since it shows very powerful performance on various speaker related benchmarks.
24
+ Model | Pretraining Dataset | Finetuning Dataset | Model
25
+ |---|---|---|-----
26
+ UniSpeech Large EN | [Labeled: 1350 hrs en](https://commonvoice.mozilla.org/) | - | [download](https://releasemodel.blob.core.windows.net/models/CommonVoicePretrainedModel/CommonVoiceEnglishPretrainedModel/checkpoint_best.pt?sv=2019-12-12&st=2021-07-14T09%3A00%3A07Z&se=2022-07-15T09%3A00%3A00Z&sr=b&sp=r&sig=5sxvEwVRoGtkazNQYkOuFLlPYau8nl5Ng%2FfRJa0Vnc4%3D)
27
+ UniSpeech Large Multilingual | [Labeled: 1350 hrs en + 353 hrs fr + 168 hrs es + 90 hrs it](https://commonvoice.mozilla.org/) | - | [download](https://releasemodel.blob.core.windows.net/models/CommonVoicePretrainedModel/CommonVoiceMultilingualPretrainedModel/checkpoint_best.pt?sv=2019-12-12&st=2021-07-14T09%3A00%3A39Z&se=2022-07-15T09%3A00%3A00Z&sr=b&sp=r&sig=y%2Fd3rqtbyqW0ZCwR7Czho5any90khA%2Ft3w9PTZ6N9vU%3D)
28
+ Unispeech Large+ | [Labeled: 1350 hrs en, Unlabeled: 353 hrs fr](https://commonvoice.mozilla.org/) | - | [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/pt_fr353.large.one2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A44%3A54Z&se=2023-10-26T06%3A44%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=7tYuYMxVFfM2Vgi%2BoqUh%2ByJXD4hSuoafHgBP5VZApw0%3D)
29
+ UniSpeech Large+ | [Labeld: 1350 hrs en, Unlabeled: 168 hrs es](https://commonvoice.mozilla.org/) | - | [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/pt_es168.large.one2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A39%3A37Z&se=2023-10-26T06%3A39%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=T2B5%2BlOI6v64TNdLSe9rdp3R%2B9Q2E35taUOigGW0nsQ%3D)
30
+ UniSpeech Large+ | [Labeled: 1350 hrs en, Unlabeld: 90 hrs it](https://commonvoice.mozilla.org/) | -| [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/pt_it90.large.one2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A52%3A08Z&se=2023-10-26T06%3A52%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=kXsSJXK9r8UEYlUr2LaJxtPf8m9J2G23MfG725k2DBk%3D)
31
+ UniSpeech Large Multilingual | [Labeled: 1350 hrs en + 353 hrs fr + 168 hrs es + 90 hrs it, Unlabeled: 17 hrs ky](https://commonvoice.mozilla.org/) | - | [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/pt_ky17.large.many2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A53%3A00Z&se=2022-10-26T06%3A53%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=oCQecalXzC5daaurLLJGQdFNtfYwsBM6pNQrDAsf5i0%3D)
32
+ UniSpeech Large+ | [Labeled: 1350 hrs en, Unlabeled: 353 hrs fr](https://commonvoice.mozilla.org/) | 1 hr fr | [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/ft_fr-pt_fr353.large.one2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A27%3A53Z&se=2023-10-26T06%3A27%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=9vEa3xqzWu7SYkACn9TQqDtcm%2BKmUcOHhabjbjZuPys%3D)
33
+ UniSpeech Large+ | [Labeld: 1350 hrs en, Unlabeled: 168 hrs es](https://commonvoice.mozilla.org/) | 1 hr es | [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/ft_es-pt_es168.large.one2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A21%3A34Z&se=2024-10-26T06%3A21%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=G%2B0RddgOh653UzXG95Ljuwv7aG3tu9gXtPXn1ixCiug%3D)
34
+ UniSpeech Large+ | [Labeled: 1350 hrs en, Unlabeld: 90 hrs it](https://commonvoice.mozilla.org/) | 1 hr it | [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/ft_it-pt_it90.large.one2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A36%3A17Z&se=2023-10-26T06%3A36%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=e1WD9uOCo9sCAdH%2FPZQ4wCD30aCDpZvvu43kJrqq2HE%3D)
35
+ UniSpeech Large Multilingual | [Labeled: 1350 hrs en + 353 hrs fr + 168 hrs es + 90 hrs it, Unlabeled: 17 hrs ky](https://commonvoice.mozilla.org/) | 1 hr ky | [download](https://msranlcmtteamdrive.blob.core.windows.net/teamdrive/v-chengw/models/pt_ky17.large.many2one_unispeech/checkpoint_best.pt?st=2021-10-25T06%3A54%3A04Z&se=2023-10-26T06%3A54%3A00Z&sp=rl&sv=2018-03-28&sr=b&sig=2K3VjMcsbKfBkLVyDlqGhVpIX%2B2ZcA5DTlMhjdkXo3g%3D)
36
+ UniSpeech-SAT Base | [960 hrs LibriSpeech](http://www.openslr.org/12) | - | [download](https://valle.blob.core.windows.net/share/unispeech-sat/unispeech_repo/UniSpeech-SAT-Base.pt?sv=2021-10-04&st=2024-01-30T06%3A26%3A06Z&se=2094-01-31T06%3A26%3A00Z&sr=b&sp=r&sig=Ts8al%2FPc%2BksI%2BY4tKvDVZhmyw02c9pMhFDxLrPntSd0%3D)
37
+ UniSpeech-SAT Base+ | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main) | - | [download](https://valle.blob.core.windows.net/share/unispeech-sat/unispeech_repo/UniSpeech-SAT-Base+.pt?sv=2021-10-04&st=2024-01-30T06%3A26%3A25Z&se=2094-01-31T06%3A26%3A00Z&sr=b&sp=r&sig=m6XAIXsC4rVNDW%2FFwXPNKjX2A%2BV9zBwmWAV93vAXcvc%3D)
38
+ UniSpeech-SAT Large | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main) | - | [download](https://valle.blob.core.windows.net/share/unispeech-sat/unispeech_repo/UniSpeech-SAT-Large.pt?sv=2021-10-04&st=2024-01-30T06%3A26%3A43Z&se=2094-01-31T06%3A26%3A00Z&sr=b&sp=r&sig=TJXNIfJzB%2Bsja3uh2xbCUxdbvp8gQP0zzlmZvU2si1I%3D)
39
+ WavLM Base | [960 hrs LibriSpeech](http://www.openslr.org/12)| - | [download](https://valle.blob.core.windows.net/share/wavlm/unispeech_repo/WavLM-Base.pt?sv=2021-10-04&st=2024-01-30T06%3A27%3A08Z&se=2094-01-31T06%3A27%3A00Z&sr=b&sp=r&sig=ThlNPycn578KFcON1NTJ7hzkpNLZR%2B3D4ImTgXQR%2B9E%3D)
40
+ WavLM Base+ | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main)| - | [download](https://valle.blob.core.windows.net/share/wavlm/unispeech_repo/WavLM-Base+.pt?sv=2021-10-04&st=2024-01-30T06%3A27%3A23Z&se=2094-01-31T06%3A27%3A00Z&sr=b&sp=r&sig=6XiFiDMiKNRLRYzNNAL9UWm0dAAFuRweRLFZ2h9IYzg%3D)
41
+ WavLM Large | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main)| - | [download](https://valle.blob.core.windows.net/share/wavlm/unispeech_repo/WavLM-Large.pt?sv=2021-10-04&st=2024-01-30T06%3A27%3A39Z&se=2094-01-31T06%3A27%3A00Z&sr=b&sp=r&sig=eIFLuFlZxmBR7a642KnnLen7yoSzv465iLHLKokO7VM%3D)
42
+
43
+ ## Universal Representation Evaluation on SUPERB
44
+ ![alt text](WavLM/WavLM_SUPERB_Results.png)
45
+
46
+ ## Downstream Task Performance
47
+ We also evaluate our models on typical speaker related benchmarks.
48
+ ### Speaker Verification
49
+ Finetune the model with VoxCeleb2 dev data, and evaluate it on the [VoxCeleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/#:~:text=VoxCeleb%20is%20an%20audio%2Dvisual,interview%20videos%20uploaded%20to%20YouTube)
50
+ | Model |Fix pre-train| Vox1-O | Vox1-E | Vox1-H |
51
+ | ------------- |------------- | ---------- | ---------- | ---------- |
52
+ | ECAPA-TDNN | - | 0.87 | 1.12 | 2.12 |
53
+ | HuBERT large | Yes| 0.888 |0.912| 1.853 |
54
+ | Wav2Vec2.0 (XLSR)| Yes | 0.915| 0.945 |1.895|
55
+ | UniSpeech-SAT large | Yes | 0.771 | 0.781| 1.669|
56
+ | WavLM large | Yes | 0.59 | 0.65| 1.328|
57
+ | WavLM large | No | 0.505 | 0.579| 1.176|
58
+ |+Large Margin Finetune and Score Calibration|
59
+ | HuBERT large | No| 0.585| 0.654 |1.342|
60
+ | Wav2Vec2.0 (XLSR) | No| 0.564| 0.605 |1.23|
61
+ | UniSpeech-SAT large | No | 0.564 | 0.561| 1.23 |
62
+ | **WavLM large (New)** | No | **0.33** | **0.477**| **0.984** |
63
+
64
+ [Large-scale Self-Supervised Speech Representation Learning for Automatic Speaker Verification](https://arxiv.org/pdf/2110.05777.pdf)
65
+
66
+
67
+
68
+ ### Speech Separation
69
+
70
+ Evaluation on [LibriCSS](https://github.com/chenzhuo1011/libri_css)
71
+ | Model |0S | 0L | OV10 | OV20 |OV30 |OV40 |
72
+ | ---------------- |------| ------ | ------ | ------ | ------ | ------ |
73
+ | [Conformer](https://ieeexplore.ieee.org/abstract/document/9413423/) (SOTA) | 4.5 | 4.4 |6.2 |8.5| 11 |12.6|
74
+ | UniSpeech-SAT base | 4.4| 4.4 |5.4| 7.2| 9.2 |10.5|
75
+ | UniSpeech-SAT large | 4.3| 4.2 |5.0 |6.3| 8.2| 8.8|
76
+ | WavLM base+ | 4.5| 4.4 |5.6| 7.5| 9.4 |10.9|
77
+ | **WavLM large** | 4.2| 4.1 | 4.8 | 5.8 | 7.4| 8.5|
78
+
79
+
80
+ ### Speaker Diarization
81
+
82
+ Evaluation on CALLHOME
83
+ | Model |spk_2 |spk_3| spk_4| spk_5| spk_6| spk_all |
84
+ | ---------------- |------| ------ | ------ | ------ | ------ | ------ |
85
+ | [EEND-vector clustering](https://arxiv.org/pdf/2105.09040.pdf) | 7.96| 11.93 |16.38| 21.21| 23.1 |12.49||
86
+ | [EEND-EDA clustering](https://arxiv.org/abs/2107.01545) (SOTA) | 7.11| 11.88 |14.37| 25.95| 21.95 |11.84||
87
+ | UniSpeech-SAT large | 5.93| 10.66| 12.9 |16.48| 23.25| 10.92|
88
+ | WavLM Base| 6.99| 11.12| 15.20 |16.48| 21.61| 11.75|
89
+ | **WavLm large** | 6.46| 10.69| 11.84 |12.89| 20.70| 10.35|
90
+
91
+
92
+ ## License
93
+ This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
94
+ Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) project.
95
+
96
+ [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
97
+
98
+
99
+ ### Reference
100
+ If you find our work is useful in your research, please cite the following paper:
101
+ ``` latex
102
+ @inproceedings{Wang2021UniSpeech,
103
+ author = {Chengyi Wang and Yu Wu and Yao Qian and Kenichi Kumatani and Shujie Liu and Furu Wei and Michael Zeng and Xuedong Huang},
104
+ editor = {Marina Meila and Tong Zhang},
105
+ title = {UniSpeech: Unified Speech Representation Learning with Labeled and
106
+ Unlabeled Data},
107
+ booktitle = {Proceedings of the 38th International Conference on Machine Learning,
108
+ {ICML} 2021, 18-24 July 2021, Virtual Event},
109
+ series = {Proceedings of Machine Learning Research},
110
+ volume = {139},
111
+ pages = {10937--10947},
112
+ publisher = {{PMLR}},
113
+ year = {2021},
114
+ url = {http://proceedings.mlr.press/v139/wang21y.html},
115
+ timestamp = {Thu, 21 Oct 2021 16:06:12 +0200},
116
+ biburl = {https://dblp.org/rec/conf/icml/0002WQK0WZ021.bib},
117
+ bibsource = {dblp computer science bibliography, https://dblp.org}
118
+ }
119
+ ```
120
+
121
+ ``` latex
122
+ @article{Chen2021WavLM,
123
+ title = {WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing},
124
+ author = {Sanyuan Chen and Chengyi Wang and Zhengyang Chen and Yu Wu and Shujie Liu and Zhuo Chen and Jinyu Li and Naoyuki Kanda and Takuya Yoshioka and Xiong Xiao and Jian Wu and Long Zhou and Shuo Ren and Yanmin Qian and Yao Qian and Jian Wu and Michael Zeng and Furu Wei},
125
+ eprint={2110.13900},
126
+ archivePrefix={arXiv},
127
+ primaryClass={cs.CL},
128
+ year={2021}
129
+ }
130
+ ```
131
+
132
+ ``` latex
133
+ @article{Chen2021UniSpeechSAT,
134
+ title = {UniSpeech-SAT: Universal Speech Representation Learning with Speaker Aware Pre-Training},
135
+ author = {Sanyuan Chen and Yu Wu and Chengyi Wang and Zhengyang Chen and Zhuo Chen and Shujie Liu and Jian Wu and Yao Qian and Furu Wei and Jinyu Li and Xiangzhan Yu},
136
+ eprint={2110.05752},
137
+ archivePrefix={arXiv},
138
+ primaryClass={cs.CL},
139
+ year={2021}
140
+ }
141
+ ```
142
+
143
+
144
+ ### Contact Information
145
+
146
+ For help or issues using UniSpeech models, please submit a GitHub issue.
147
+
148
+ For other communications related to UniSpeech, please contact Yu Wu (`yuwu1@microsoft.com`).
UniSpeech/SECURITY.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- BEGIN MICROSOFT SECURITY.MD V0.0.8 BLOCK -->
2
+
3
+ ## Security
4
+
5
+ Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6
+
7
+ If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8
+
9
+ ## Reporting Security Issues
10
+
11
+ **Please do not report security vulnerabilities through public GitHub issues.**
12
+
13
+ Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14
+
15
+ If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16
+
17
+ You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18
+
19
+ Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20
+
21
+ * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22
+ * Full paths of source file(s) related to the manifestation of the issue
23
+ * The location of the affected source code (tag/branch/commit or direct URL)
24
+ * Any special configuration required to reproduce the issue
25
+ * Step-by-step instructions to reproduce the issue
26
+ * Proof-of-concept or exploit code (if possible)
27
+ * Impact of the issue, including how an attacker might exploit the issue
28
+
29
+ This information will help us triage your report more quickly.
30
+
31
+ If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32
+
33
+ ## Preferred Languages
34
+
35
+ We prefer all communications to be in English.
36
+
37
+ ## Policy
38
+
39
+ Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40
+
41
+ <!-- END MICROSOFT SECURITY.MD BLOCK -->
UniSpeech/UniSpeech-SAT/README.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # UniSpeech-SAT
2
+
3
+ > [**UniSpeech-SAT**](https://arxiv.org/pdf/2110.05752.pdf) (```ICASSP 2022 Submission```): **Universal Speech Representation Learning with Speaker Aware Pre-Training**
4
+
5
+ ## Universal Representation Evaluation on SUPERB
6
+ ![alt text](UniSpeech_SAT_SUPERB_Results.png)
7
+
8
+ ## Downstream Task Performance
9
+ We also evaluate our models on typical speaker related benchmarks.
10
+ ### Speaker Verification
11
+ | Model |Fix pre-train| Vox1-O | Vox1-E | Vox1-H |
12
+ | ------------- |------------- | ---------- | ---------- | ---------- |
13
+ | ECAPA-TDNN | - | 0.87 | 1.12 | 2.12 |
14
+ | HuBERT large | Yes| 0.888 |0.912| 1.853 |
15
+ | Wav2Vec2.0 (XLSR)| Yes | 0.915| 0.945 |1.895|
16
+ | UniSpeech-SAT large | Yes | 0.771 | 0.781| 1.669|
17
+ | HuBERT large | No| 0.585| 0.654 |1.342|
18
+ | Wav2Vec2.0 (XLSR) | No| **0.564**| 0.605 |**1.23**|
19
+ | **UniSpeech-SAT large** | No | **0.564** | **0.561** | **1.23** |
20
+
21
+ [Our paper for verification](https://arxiv.org/pdf/2110.05777.pdf)
22
+
23
+
24
+
25
+ ### Speech Separation
26
+
27
+ Evaluation on [LibriCSS](https://github.com/chenzhuo1011/libri_css)
28
+
29
+ | Model |0S | 0L | OV10 | OV20 |OV30 |OV40 |
30
+ | ---------------- |------| ------ | ------ | ------ | ------ | ------ |
31
+ | [Conformer](https://ieeexplore.ieee.org/abstract/document/9413423/) (SOTA) | 4.5 | 4.4 |6.2 |8.5| 11 |12.6|
32
+ | HuBERT base | 4.7| 4.6 | 6.1 | 7.9| 10.6| 12.3|
33
+ | UniSpeech-SAT base+ | 4.4| 4.4 |5.4| 7.2| 9.2 |10.5|
34
+ | **UniSpeech-SAT large** | **4.3**| **4.2** |**5.0** |**6.3**| **8.2**| **8.8**|
35
+
36
+
37
+ ### Speaker Diarization
38
+
39
+ Evaluation on CALLHOME
40
+
41
+ | Model |spk_2 |spk_3| spk_4| spk_5| spk_6| spk_all |
42
+ | ---------------- |------| ------ | ------ | ------ | ------ | ------ |
43
+ | [EEND-vector clustering](https://arxiv.org/pdf/2105.09040.pdf) | 7.96| 11.93 |16.38| 21.21| 23.1 |12.49||
44
+ | [EEND-EDA clustering](https://arxiv.org/abs/2107.01545) (SOTA) | 7.11| 11.88 |14.37| 25.95| 21.95 |11.84||
45
+ | HuBERT base| 7.93|12.07| 15.21 |19.59| 23.32| 12.63|
46
+ | HuBERT large| 7.39| 11.97| 15.76 |19.82| 22.10| 12.40|
47
+ | **UniSpeech-SAT large** | **5.93**| **10.66**| **12.9** |**16.48**| **23.25**| **10.92**|
48
+
49
+
50
+ ## License
51
+ This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
52
+ Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) project.
53
+
54
+ [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
55
+
56
+
57
+ ### Reference
58
+ If you find our work is useful in your research, please cite the following paper:
59
+
60
+ ``` latex
61
+ @article{Chen2021UniSpeechSAT,
62
+ title = {UniSpeech-SAT: Universal Speech Representation Learning with Speaker Aware Pre-Training},
63
+ author = {Sanyuan Chen and Yu Wu and Chengyi Wang and Zhengyang Chen and Zhuo Chen and Shujie Liu and Jian Wu and Yao Qian and Furu Wei and Jinyu Li and Xiangzhan Yu},
64
+ eprint={2110.05752},
65
+ archivePrefix={arXiv},
66
+ primaryClass={cs.CL},
67
+ year={2021}
68
+ }
69
+ ```
70
+
71
+
72
+ ### Contact Information
73
+
74
+ For help or issues using UniSpeech models, please submit a GitHub issue.
75
+
76
+ For other communications related to UniSpeech, please contact Yu Wu (`yuwu1@microsoft.com`).
UniSpeech/UniSpeech-SAT/UniSpeech_SAT_SUPERB_Results.png ADDED

Git LFS Details

  • SHA256: 65fa914aabddd7261bb06d347a07865956b65edc1a2fa3d20a00f521173b83b8
  • Pointer size: 131 Bytes
  • Size of remote file: 339 kB
UniSpeech/UniSpeech/README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # UniSpeech
2
+
3
+ This is the official implementation of paper "[UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597)". The implementation mainly based on [fairseq](https://github.com/pytorch/fairseq) codebase. We release the training recipes on CommonVoice dataset.
4
+
5
+ ## Requirements and Installation
6
+
7
+ - Pytorch >= 1.6.0
8
+ - python version >= 3.6
9
+ ``` bash
10
+ cd src
11
+ pip install soundfile
12
+ pip install librosa
13
+ pip install pydub
14
+ pip install --editable ./
15
+ ```
16
+ ## Data Preparation
17
+ Download pretraining audio data from [here](https://commonvoice.mozilla.org/datasets). (We use the June 2020 release version in our paper).
18
+ Get the wav list and the transcription for each dataset by run:
19
+ ```
20
+ python examples/unispeech/unispeech_manifest.py input_meta_file --dest examples/unispeech/data/LANG
21
+ ```
22
+
23
+ Then convert the audio files in common voices to 16k HZ using the commond:
24
+ ```
25
+ python examples/unispeech/adjust_sample_rate.py --wav-path /path/to/wav/ --dest-path /path/to/16kwav/ --input examples/unispeech/data/LANG/*.tsv --output examples/unispeech/data/LANG/*_16k.tsv
26
+ ```
27
+ For the finetuning data, our train/val/test splits are following [this](https://dl.fbaipublicfiles.com/cpc_audio/common_voices_splits.tar.gz).
28
+ The phoneme transcriptions are generated by [phonemizer](https://github.com/bootphon/phonemizer) to convert texts to phonemes. Then we create .id files using different vocabularies. All our pre-processed data as well as the dictionaries can be downloaded from [here].
29
+
30
+ ## Pretraining
31
+
32
+ We give the training examples for large model here.
33
+ ### Stage 1. Pretraining UniSpeech with labeled data.
34
+ The following script can be used to pre-train an English model:
35
+ ```
36
+ bash examples/unispeech/scripts/one2one_large_pretrain_en1350.sh
37
+ ```
38
+ To train a multilingual model:
39
+ ```
40
+ bash examples/unispeech/scripts/multilingual_large_pretrain.sh
41
+ ```
42
+
43
+ ### Stage 2. Continue pre-training with low-resource unlabeled data. (Optional)
44
+ After stage 1, you can continue pre-training the UniSpeech model with only contrastive loss:
45
+ ```
46
+ bash examples/unispeech/scripts/continue_pretran.sh
47
+ ```
48
+
49
+ ### Stage 3. Finetuning with low-resource labeled data.
50
+ Finally, fint-tune the model with 1 hour labeled data.
51
+ For multilingual models, you can choose to use separate vocabulary (examples/unispeech/data/en/vocab_sep.json) or shared vocabulary (examples/unispeech/data/en/vocab_share.json)
52
+ ```
53
+ bash examples/unispeech/scripts/finetune.sh
54
+ ```
55
+
56
+
UniSpeech/WavLM/README.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # WavLM
3
+
4
+ <!--**Pre-trained models for speech related tasks**-->
5
+
6
+
7
+ [**WavLM**](https://arxiv.org/pdf/2110.13900.pdf) : **WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing**
8
+
9
+ Official PyTorch implementation and pretrained models of WavLM
10
+
11
+ - Oct 2021: release preprint in [arXiv](https://arxiv.org/pdf/2110.13900.pdf)
12
+
13
+ ## Pre-Trained Models
14
+ Model | Pre-training Dataset | Fine-tuning Dataset | Model
15
+ |---|---|---|---
16
+ WavLM Base | [960 hrs LibriSpeech](http://www.openslr.org/12)| - | [Azure Storage](https://valle.blob.core.windows.net/share/wavlm/WavLM-Base.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) <br> [Google Drive](https://drive.google.com/file/d/1BhTPLUkfN6e2xkqR8LEm9lByXbLY1IYd/view?usp=share_link)
17
+ WavLM Base+ | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main)| - | [Azure Storage](https://valle.blob.core.windows.net/share/wavlm/WavLM-Base+.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) <br> [Google Drive](https://drive.google.com/file/d/1-zlAj2SyVJVsbhifwpTlAfrgc9qu-HDb/view?usp=share_link)
18
+ WavLM Large | [60k hrs Libri-Light](https://github.com/facebookresearch/libri-light) + [10k hrs GigaSpeech](https://github.com/SpeechColab/GigaSpeech) + [24k hrs VoxPopuli](https://github.com/facebookresearch/voxpopuli/tree/main)| - | [Azure Storage](https://valle.blob.core.windows.net/share/wavlm/WavLM-Large.pt?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D) <br> [Google Drive](https://drive.google.com/file/d/12-cB34qCTvByWT-QtOcZaqwwO21FLSqU/view?usp=share_link)
19
+
20
+ ## Load Pre-Trained Models for Inference
21
+
22
+ ```python
23
+ import torch
24
+ from WavLM import WavLM, WavLMConfig
25
+
26
+ # load the pre-trained checkpoints
27
+ checkpoint = torch.load('/path/to/wavlm.pt')
28
+ cfg = WavLMConfig(checkpoint['cfg'])
29
+ model = WavLM(cfg)
30
+ model.load_state_dict(checkpoint['model'])
31
+ model.eval()
32
+
33
+ # extract the the representation of last layer
34
+ wav_input_16khz = torch.randn(1,10000)
35
+ rep = model.extract_features(wav_input_16khz)[0]
36
+
37
+ # extract the the representation of each layer
38
+ wav_input_16khz = torch.randn(1,10000)
39
+ rep, layer_results = model.extract_features(wav_input_16khz, output_layer=model.cfg.encoder_layers, ret_layer_results=True)[0]
40
+ layer_reps = [x.transpose(0, 1) for x, _ in layer_results]
41
+ ```
42
+
43
+
44
+ ## Universal Representation Evaluation on SUPERB
45
+ ![alt text](WavLM_SUPERB_Results.png)
46
+
47
+ ![alt text](WavLM_SUPERB_Leaderboard.png)
48
+ ## Downstream Task Performance
49
+ We also evaluate our models on typical speech processing benchmarks.
50
+ ### Speaker Verification
51
+
52
+ Evaluate on the [VoxCeleb](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/#:~:text=VoxCeleb%20is%20an%20audio%2Dvisual,interview%20videos%20uploaded%20to%20YouTube)
53
+
54
+ | Model |Fix pre-train| Vox1-O | Vox1-E | Vox1-H |
55
+ | ------------- |------------- | ---------- | ---------- | ---------- |
56
+ | ECAPA-TDNN | - | 0.87 | 1.12 | 2.12 |
57
+ | HuBERT large | Yes| 0.888 |0.912| 1.853 |
58
+ | Wav2Vec2.0 (XLSR)| Yes | 0.915| 0.945 |1.895|
59
+ | UniSpeech-SAT large | Yes | 0.771 | 0.781| 1.669|
60
+ | WavLM large | Yes | 0.638 | 0.687| 1.457|
61
+ | HuBERT large | No| 0.585| 0.654 |1.342|
62
+ | Wav2Vec2.0 (XLSR) | No| 0.564| 0.605 |1.23|
63
+ | UniSpeech-SAT large | No | 0.564 | 0.561| 1.23 |
64
+ | **WavLM large** | No | **0.431** | **0.538**| **1.154** |
65
+
66
+
67
+
68
+ ### Speech Separation
69
+
70
+ Evaluation on the [LibriCSS](https://github.com/chenzhuo1011/libri_css)
71
+ | Model |0S | 0L | OV10 | OV20 |OV30 |OV40 |
72
+ | ---------------- |------| ------ | ------ | ------ | ------ | ------ |
73
+ | [Conformer](https://ieeexplore.ieee.org/abstract/document/9413423/) (SOTA) | 4.5 | 4.4 |6.2 |8.5| 11 |12.6|
74
+ | HuBERT base | 4.7| 4.6 | 6.1 | 7.9| 10.6| 12.3|
75
+ | UniSpeech-SAT base | 4.4| 4.4 |5.4| 7.2| 9.2 |10.5|
76
+ | UniSpeech-SAT large | 4.3| 4.2 |5.0 |6.3| 8.2| 8.8|
77
+ | WavLM base+ | 4.5| 4.4 |5.6| 7.5| 9.4 |10.9|
78
+ | **WavLM large** | 4.2| 4.1 | 4.8 | 5.8 | 7.4| 8.5|
79
+
80
+
81
+ ### Speaker Diarization
82
+
83
+ Evaluation on the [CALLHOME](https://arxiv.org/pdf/1909.06247.pdf)
84
+ | Model |spk_2 |spk_3| spk_4| spk_5| spk_6| spk_all |
85
+ | ---------------- |------| ------ | ------ | ------ | ------ | ------ |
86
+ | [EEND-vector clustering](https://arxiv.org/pdf/2105.09040.pdf) | 7.96| 11.93 |16.38| 21.21| 23.1 |12.49||
87
+ | [EEND-EDA clustering](https://arxiv.org/abs/2107.01545) (SOTA) | 7.11| 11.88 |14.37| 25.95| 21.95 |11.84||
88
+ | HuBERT base| 7.93|12.07| 15.21 |19.59| 23.32| 12.63|
89
+ | HuBERT large| 7.39| 11.97| 15.76 |19.82| 22.10| 12.40|
90
+ | UniSpeech-SAT large| 5.93| 10.66| 12.9 |16.48| 23.25| 10.92|
91
+ | WavLM Base| 6.99| 11.12| 15.20 |16.48| 21.61| 11.75|
92
+ | **WavLm large** | 6.46| 10.69| 11.84 |12.89| 20.70| 10.35|
93
+
94
+ ### Speech Recogntion
95
+ Evaluate on the [LibriSpeech](https://www.openslr.org/12)
96
+
97
+ ![alt text](WavLM_ASR.PNG)
98
+
99
+
100
+ ## License
101
+ This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
102
+ Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq) project.
103
+
104
+ [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
105
+
106
+
107
+ ### Reference
108
+ If you find our work is useful in your research, please cite the following paper:
109
+ ``` latex
110
+ @article{Chen2021WavLM,
111
+ title = {WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing},
112
+ author = {Sanyuan Chen and Chengyi Wang and Zhengyang Chen and Yu Wu and Shujie Liu and Zhuo Chen and Jinyu Li and Naoyuki Kanda and Takuya Yoshioka and Xiong Xiao and Jian Wu and Long Zhou and Shuo Ren and Yanmin Qian and Yao Qian and Jian Wu and Micheal Zeng and Furu Wei},
113
+ eprint={2110.13900},
114
+ archivePrefix={arXiv},
115
+ primaryClass={cs.CL},
116
+ year={2021}
117
+ }
118
+ ```
119
+
120
+
121
+ ### Contact Information
122
+
123
+ For help or issues using WavLM models, please submit a GitHub issue.
124
+
125
+ For other communications related to WavLM, please contact Yu Wu (`yuwu1@microsoft.com`).
UniSpeech/WavLM/WavLM.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import logging
12
+ from typing import List, Optional, Tuple
13
+
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.nn import LayerNorm
20
+ from modules import (
21
+ Fp32GroupNorm,
22
+ Fp32LayerNorm,
23
+ GradMultiply,
24
+ MultiheadAttention,
25
+ SamePad,
26
+ init_bert_params,
27
+ get_activation_fn,
28
+ TransposeLast,
29
+ GLU_Linear,
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def compute_mask_indices(
36
+ shape: Tuple[int, int],
37
+ padding_mask: Optional[torch.Tensor],
38
+ mask_prob: float,
39
+ mask_length: int,
40
+ mask_type: str = "static",
41
+ mask_other: float = 0.0,
42
+ min_masks: int = 0,
43
+ no_overlap: bool = False,
44
+ min_space: int = 0,
45
+ ) -> np.ndarray:
46
+ """
47
+ Computes random mask spans for a given shape
48
+
49
+ Args:
50
+ shape: the the shape for which to compute masks.
51
+ should be of size 2 where first element is batch size and 2nd is timesteps
52
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
53
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
54
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
55
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
56
+ mask_type: how to compute mask lengths
57
+ static = fixed size
58
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
59
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
60
+ poisson = sample from possion distribution with lambda = mask length
61
+ min_masks: minimum number of masked spans
62
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
63
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
64
+ """
65
+
66
+ bsz, all_sz = shape
67
+ mask = np.full((bsz, all_sz), False)
68
+
69
+ all_num_mask = int(
70
+ # add a random number for probabilistic rounding
71
+ mask_prob * all_sz / float(mask_length)
72
+ + np.random.rand()
73
+ )
74
+
75
+ all_num_mask = max(min_masks, all_num_mask)
76
+
77
+ mask_idcs = []
78
+ for i in range(bsz):
79
+ if padding_mask is not None:
80
+ sz = all_sz - padding_mask[i].long().sum().item()
81
+ num_mask = int(
82
+ # add a random number for probabilistic rounding
83
+ mask_prob * sz / float(mask_length)
84
+ + np.random.rand()
85
+ )
86
+ num_mask = max(min_masks, num_mask)
87
+ else:
88
+ sz = all_sz
89
+ num_mask = all_num_mask
90
+
91
+ if mask_type == "static":
92
+ lengths = np.full(num_mask, mask_length)
93
+ elif mask_type == "uniform":
94
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
95
+ elif mask_type == "normal":
96
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
97
+ lengths = [max(1, int(round(x))) for x in lengths]
98
+ elif mask_type == "poisson":
99
+ lengths = np.random.poisson(mask_length, size=num_mask)
100
+ lengths = [int(round(x)) for x in lengths]
101
+ else:
102
+ raise Exception("unknown mask selection " + mask_type)
103
+
104
+ if sum(lengths) == 0:
105
+ lengths[0] = min(mask_length, sz - 1)
106
+
107
+ if no_overlap:
108
+ mask_idc = []
109
+
110
+ def arrange(s, e, length, keep_length):
111
+ span_start = np.random.randint(s, e - length)
112
+ mask_idc.extend(span_start + i for i in range(length))
113
+
114
+ new_parts = []
115
+ if span_start - s - min_space >= keep_length:
116
+ new_parts.append((s, span_start - min_space + 1))
117
+ if e - span_start - keep_length - min_space > keep_length:
118
+ new_parts.append((span_start + length + min_space, e))
119
+ return new_parts
120
+
121
+ parts = [(0, sz)]
122
+ min_length = min(lengths)
123
+ for length in sorted(lengths, reverse=True):
124
+ lens = np.fromiter(
125
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
126
+ np.int,
127
+ )
128
+ l_sum = np.sum(lens)
129
+ if l_sum == 0:
130
+ break
131
+ probs = lens / np.sum(lens)
132
+ c = np.random.choice(len(parts), p=probs)
133
+ s, e = parts.pop(c)
134
+ parts.extend(arrange(s, e, length, min_length))
135
+ mask_idc = np.asarray(mask_idc)
136
+ else:
137
+ min_len = min(lengths)
138
+ if sz - min_len <= num_mask:
139
+ min_len = sz - num_mask - 1
140
+
141
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
142
+
143
+ mask_idc = np.asarray(
144
+ [
145
+ mask_idc[j] + offset
146
+ for j in range(len(mask_idc))
147
+ for offset in range(lengths[j])
148
+ ]
149
+ )
150
+
151
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
152
+
153
+ min_len = min([len(m) for m in mask_idcs])
154
+ for i, mask_idc in enumerate(mask_idcs):
155
+ if len(mask_idc) > min_len:
156
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
157
+ mask[i, mask_idc] = True
158
+
159
+ return mask
160
+
161
+
162
+ class WavLMConfig:
163
+ def __init__(self, cfg=None):
164
+ self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
165
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
166
+
167
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
168
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
169
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
170
+ self.activation_fn: str = "gelu" # activation function to use
171
+
172
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
173
+ self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
174
+ self.conv_bias: bool = False # include bias in conv encoder
175
+ self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
176
+
177
+ self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
178
+
179
+ # dropouts
180
+ self.dropout: float = 0.1 # dropout probability for the transformer
181
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
182
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
183
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
184
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
185
+ self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
186
+
187
+ # masking
188
+ self.mask_length: int = 10 # mask length)
189
+ self.mask_prob: float = 0.65 # probability of replacing a token with mask
190
+ self.mask_selection: str = "static" # how to choose mask length
191
+ self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
192
+ self.no_mask_overlap: bool = False # whether to allow masks to overlap
193
+ self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
194
+
195
+ # channel masking
196
+ self.mask_channel_length: int = 10 # length of the mask for features (channels)
197
+ self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
198
+ self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
199
+ self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
200
+ self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
201
+ self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
202
+
203
+ # positional embeddings
204
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
205
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
206
+
207
+ # relative position embedding
208
+ self.relative_position_embedding: bool = False # apply relative position embedding
209
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
210
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
211
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
212
+
213
+ if cfg is not None:
214
+ self.update(cfg)
215
+
216
+ def update(self, cfg: dict):
217
+ self.__dict__.update(cfg)
218
+
219
+
220
+ class WavLM(nn.Module):
221
+ def __init__(
222
+ self,
223
+ cfg: WavLMConfig,
224
+ ) -> None:
225
+ super().__init__()
226
+ logger.info(f"WavLM Config: {cfg.__dict__}")
227
+
228
+ self.cfg = cfg
229
+ feature_enc_layers = eval(cfg.conv_feature_layers)
230
+ self.embed = feature_enc_layers[-1][0]
231
+
232
+ self.feature_extractor = ConvFeatureExtractionModel(
233
+ conv_layers=feature_enc_layers,
234
+ dropout=0.0,
235
+ mode=cfg.extractor_mode,
236
+ conv_bias=cfg.conv_bias,
237
+ )
238
+
239
+ self.post_extract_proj = (
240
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
241
+ if self.embed != cfg.encoder_embed_dim
242
+ else None
243
+ )
244
+
245
+ self.mask_prob = cfg.mask_prob
246
+ self.mask_selection = cfg.mask_selection
247
+ self.mask_other = cfg.mask_other
248
+ self.mask_length = cfg.mask_length
249
+ self.no_mask_overlap = cfg.no_mask_overlap
250
+ self.mask_min_space = cfg.mask_min_space
251
+
252
+ self.mask_channel_prob = cfg.mask_channel_prob
253
+ self.mask_channel_selection = cfg.mask_channel_selection
254
+ self.mask_channel_other = cfg.mask_channel_other
255
+ self.mask_channel_length = cfg.mask_channel_length
256
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
257
+ self.mask_channel_min_space = cfg.mask_channel_min_space
258
+
259
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
260
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
261
+
262
+ self.feature_grad_mult = cfg.feature_grad_mult
263
+
264
+ self.mask_emb = nn.Parameter(
265
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
266
+ )
267
+
268
+ self.encoder = TransformerEncoder(cfg)
269
+ self.layer_norm = LayerNorm(self.embed)
270
+
271
+ def apply_mask(self, x, padding_mask):
272
+ B, T, C = x.shape
273
+ if self.mask_prob > 0:
274
+ mask_indices = compute_mask_indices(
275
+ (B, T),
276
+ padding_mask,
277
+ self.mask_prob,
278
+ self.mask_length,
279
+ self.mask_selection,
280
+ self.mask_other,
281
+ min_masks=2,
282
+ no_overlap=self.no_mask_overlap,
283
+ min_space=self.mask_min_space,
284
+ )
285
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
286
+ x[mask_indices] = self.mask_emb
287
+ else:
288
+ mask_indices = None
289
+
290
+ if self.mask_channel_prob > 0:
291
+ mask_channel_indices = compute_mask_indices(
292
+ (B, C),
293
+ None,
294
+ self.mask_channel_prob,
295
+ self.mask_channel_length,
296
+ self.mask_channel_selection,
297
+ self.mask_channel_other,
298
+ no_overlap=self.no_mask_channel_overlap,
299
+ min_space=self.mask_channel_min_space,
300
+ )
301
+ mask_channel_indices = (
302
+ torch.from_numpy(mask_channel_indices)
303
+ .to(x.device)
304
+ .unsqueeze(1)
305
+ .expand(-1, T, -1)
306
+ )
307
+ x[mask_channel_indices] = 0
308
+
309
+ return x, mask_indices
310
+
311
+ def forward_padding_mask(
312
+ self, features: torch.Tensor, padding_mask: torch.Tensor,
313
+ ) -> torch.Tensor:
314
+ extra = padding_mask.size(1) % features.size(1)
315
+ if extra > 0:
316
+ padding_mask = padding_mask[:, :-extra]
317
+ padding_mask = padding_mask.view(
318
+ padding_mask.size(0), features.size(1), -1
319
+ )
320
+ padding_mask = padding_mask.all(-1)
321
+ return padding_mask
322
+
323
+ def extract_features(
324
+ self,
325
+ source: torch.Tensor,
326
+ padding_mask: Optional[torch.Tensor] = None,
327
+ mask: bool = False,
328
+ ret_conv: bool = False,
329
+ output_layer: Optional[int] = None,
330
+ ret_layer_results: bool = False,
331
+ ):
332
+
333
+ if self.feature_grad_mult > 0:
334
+ features = self.feature_extractor(source)
335
+ if self.feature_grad_mult != 1.0:
336
+ features = GradMultiply.apply(features, self.feature_grad_mult)
337
+ else:
338
+ with torch.no_grad():
339
+ features = self.feature_extractor(source)
340
+
341
+ features = features.transpose(1, 2)
342
+ features = self.layer_norm(features)
343
+
344
+ if padding_mask is not None:
345
+ padding_mask = self.forward_padding_mask(features, padding_mask)
346
+
347
+ if self.post_extract_proj is not None:
348
+ features = self.post_extract_proj(features)
349
+
350
+ features = self.dropout_input(features)
351
+
352
+ if mask:
353
+ x, mask_indices = self.apply_mask(
354
+ features, padding_mask
355
+ )
356
+ else:
357
+ x = features
358
+
359
+ # feature: (B, T, D), float
360
+ # target: (B, T), long
361
+ # x: (B, T, D), float
362
+ # padding_mask: (B, T), bool
363
+ # mask_indices: (B, T), bool
364
+ x, layer_results = self.encoder(
365
+ x,
366
+ padding_mask=padding_mask,
367
+ layer=None if output_layer is None else output_layer - 1
368
+ )
369
+
370
+ res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
371
+
372
+ feature = res["features"] if ret_conv else res["x"]
373
+ if ret_layer_results:
374
+ feature = (feature, res["layer_results"])
375
+ return feature, res["padding_mask"]
376
+
377
+
378
+ class ConvFeatureExtractionModel(nn.Module):
379
+ def __init__(
380
+ self,
381
+ conv_layers: List[Tuple[int, int, int]],
382
+ dropout: float = 0.0,
383
+ mode: str = "default",
384
+ conv_bias: bool = False,
385
+ conv_type: str = "default"
386
+ ):
387
+ super().__init__()
388
+
389
+ assert mode in {"default", "layer_norm"}
390
+
391
+ def block(
392
+ n_in,
393
+ n_out,
394
+ k,
395
+ stride,
396
+ is_layer_norm=False,
397
+ is_group_norm=False,
398
+ conv_bias=False,
399
+ ):
400
+ def make_conv():
401
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
402
+ nn.init.kaiming_normal_(conv.weight)
403
+ return conv
404
+
405
+ assert (
406
+ is_layer_norm and is_group_norm
407
+ ) == False, "layer norm and group norm are exclusive"
408
+
409
+ if is_layer_norm:
410
+ return nn.Sequential(
411
+ make_conv(),
412
+ nn.Dropout(p=dropout),
413
+ nn.Sequential(
414
+ TransposeLast(),
415
+ Fp32LayerNorm(dim, elementwise_affine=True),
416
+ TransposeLast(),
417
+ ),
418
+ nn.GELU(),
419
+ )
420
+ elif is_group_norm:
421
+ return nn.Sequential(
422
+ make_conv(),
423
+ nn.Dropout(p=dropout),
424
+ Fp32GroupNorm(dim, dim, affine=True),
425
+ nn.GELU(),
426
+ )
427
+ else:
428
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
429
+
430
+ self.conv_type = conv_type
431
+ if self.conv_type == "default":
432
+ in_d = 1
433
+ self.conv_layers = nn.ModuleList()
434
+ for i, cl in enumerate(conv_layers):
435
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
436
+ (dim, k, stride) = cl
437
+
438
+ self.conv_layers.append(
439
+ block(
440
+ in_d,
441
+ dim,
442
+ k,
443
+ stride,
444
+ is_layer_norm=mode == "layer_norm",
445
+ is_group_norm=mode == "default" and i == 0,
446
+ conv_bias=conv_bias,
447
+ )
448
+ )
449
+ in_d = dim
450
+ elif self.conv_type == "conv2d":
451
+ in_d = 1
452
+ self.conv_layers = nn.ModuleList()
453
+ for i, cl in enumerate(conv_layers):
454
+ assert len(cl) == 3
455
+ (dim, k, stride) = cl
456
+
457
+ self.conv_layers.append(
458
+ torch.nn.Conv2d(in_d, dim, k, stride)
459
+ )
460
+ self.conv_layers.append(torch.nn.ReLU())
461
+ in_d = dim
462
+ elif self.conv_type == "custom":
463
+ in_d = 1
464
+ idim = 80
465
+ self.conv_layers = nn.ModuleList()
466
+ for i, cl in enumerate(conv_layers):
467
+ assert len(cl) == 3
468
+ (dim, k, stride) = cl
469
+ self.conv_layers.append(
470
+ torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
471
+ )
472
+ self.conv_layers.append(
473
+ torch.nn.LayerNorm([dim, idim])
474
+ )
475
+ self.conv_layers.append(torch.nn.ReLU())
476
+ in_d = dim
477
+ if (i + 1) % 2 == 0:
478
+ self.conv_layers.append(
479
+ torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
480
+ )
481
+ idim = int(math.ceil(idim / 2))
482
+ else:
483
+ pass
484
+
485
+ def forward(self, x, mask=None):
486
+
487
+ # BxT -> BxCxT
488
+ x = x.unsqueeze(1)
489
+ if self.conv_type == "custom":
490
+ for conv in self.conv_layers:
491
+ if isinstance(conv, nn.LayerNorm):
492
+ x = x.transpose(1, 2)
493
+ x = conv(x).transpose(1, 2)
494
+ else:
495
+ x = conv(x)
496
+ x = x.transpose(2, 3).contiguous()
497
+ x = x.view(x.size(0), -1, x.size(-1))
498
+ else:
499
+ for conv in self.conv_layers:
500
+ x = conv(x)
501
+ if self.conv_type == "conv2d":
502
+ b, c, t, f = x.size()
503
+ x = x.transpose(2, 3).contiguous().view(b, c * f, t)
504
+ return x
505
+
506
+
507
+ class TransformerEncoder(nn.Module):
508
+ def __init__(self, args):
509
+ super().__init__()
510
+
511
+ self.dropout = args.dropout
512
+ self.embedding_dim = args.encoder_embed_dim
513
+
514
+ self.pos_conv = nn.Conv1d(
515
+ self.embedding_dim,
516
+ self.embedding_dim,
517
+ kernel_size=args.conv_pos,
518
+ padding=args.conv_pos // 2,
519
+ groups=args.conv_pos_groups,
520
+ )
521
+ dropout = 0
522
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
523
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
524
+ nn.init.constant_(self.pos_conv.bias, 0)
525
+
526
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
527
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
528
+
529
+ if hasattr(args, "relative_position_embedding"):
530
+ self.relative_position_embedding = args.relative_position_embedding
531
+ self.num_buckets = args.num_buckets
532
+ self.max_distance = args.max_distance
533
+ else:
534
+ self.relative_position_embedding = False
535
+ self.num_buckets = 0
536
+ self.max_distance = 0
537
+
538
+ self.layers = nn.ModuleList(
539
+ [
540
+ TransformerSentenceEncoderLayer(
541
+ embedding_dim=self.embedding_dim,
542
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
543
+ num_attention_heads=args.encoder_attention_heads,
544
+ dropout=self.dropout,
545
+ attention_dropout=args.attention_dropout,
546
+ activation_dropout=args.activation_dropout,
547
+ activation_fn=args.activation_fn,
548
+ layer_norm_first=args.layer_norm_first,
549
+ has_relative_attention_bias=(self.relative_position_embedding and i == 0),
550
+ num_buckets=self.num_buckets,
551
+ max_distance=self.max_distance,
552
+ gru_rel_pos=args.gru_rel_pos,
553
+ )
554
+ for i in range(args.encoder_layers)
555
+ ]
556
+ )
557
+
558
+ self.layer_norm_first = args.layer_norm_first
559
+ self.layer_norm = LayerNorm(self.embedding_dim)
560
+ self.layerdrop = args.encoder_layerdrop
561
+
562
+ self.apply(init_bert_params)
563
+
564
+ def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
565
+ x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
566
+
567
+ if self.layer_norm_first and layer is None:
568
+ x = self.layer_norm(x)
569
+
570
+ return x, layer_results
571
+
572
+ def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
573
+
574
+ if padding_mask is not None:
575
+ x[padding_mask] = 0
576
+
577
+ x_conv = self.pos_conv(x.transpose(1, 2))
578
+ x_conv = x_conv.transpose(1, 2)
579
+ x += x_conv
580
+
581
+ if not self.layer_norm_first:
582
+ x = self.layer_norm(x)
583
+
584
+ x = F.dropout(x, p=self.dropout, training=self.training)
585
+
586
+ # B x T x C -> T x B x C
587
+ x = x.transpose(0, 1)
588
+
589
+ layer_results = []
590
+ z = None
591
+ if tgt_layer is not None:
592
+ layer_results.append((x, z))
593
+ r = None
594
+ pos_bias = None
595
+ for i, layer in enumerate(self.layers):
596
+ dropout_probability = np.random.random()
597
+ if not self.training or (dropout_probability > self.layerdrop):
598
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
599
+ self_attn_mask=streaming_mask, pos_bias=pos_bias)
600
+ if tgt_layer is not None:
601
+ layer_results.append((x, z))
602
+ if i == tgt_layer:
603
+ r = x
604
+ break
605
+
606
+ if r is not None:
607
+ x = r
608
+
609
+ # T x B x C -> B x T x C
610
+ x = x.transpose(0, 1)
611
+
612
+ return x, layer_results
613
+
614
+
615
+ class TransformerSentenceEncoderLayer(nn.Module):
616
+ """
617
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
618
+ models.
619
+ """
620
+
621
+ def __init__(
622
+ self,
623
+ embedding_dim: float = 768,
624
+ ffn_embedding_dim: float = 3072,
625
+ num_attention_heads: float = 8,
626
+ dropout: float = 0.1,
627
+ attention_dropout: float = 0.1,
628
+ activation_dropout: float = 0.1,
629
+ activation_fn: str = "relu",
630
+ layer_norm_first: bool = False,
631
+ has_relative_attention_bias: bool = False,
632
+ num_buckets: int = 0,
633
+ max_distance: int = 0,
634
+ rescale_init: bool = False,
635
+ gru_rel_pos: bool = False,
636
+ ) -> None:
637
+
638
+ super().__init__()
639
+ # Initialize parameters
640
+ self.embedding_dim = embedding_dim
641
+ self.dropout = dropout
642
+ self.activation_dropout = activation_dropout
643
+
644
+ # Initialize blocks
645
+ self.activation_name = activation_fn
646
+ self.activation_fn = get_activation_fn(activation_fn)
647
+ self.self_attn = MultiheadAttention(
648
+ self.embedding_dim,
649
+ num_attention_heads,
650
+ dropout=attention_dropout,
651
+ self_attention=True,
652
+ has_relative_attention_bias=has_relative_attention_bias,
653
+ num_buckets=num_buckets,
654
+ max_distance=max_distance,
655
+ rescale_init=rescale_init,
656
+ gru_rel_pos=gru_rel_pos,
657
+ )
658
+
659
+ self.dropout1 = nn.Dropout(dropout)
660
+ self.dropout2 = nn.Dropout(self.activation_dropout)
661
+ self.dropout3 = nn.Dropout(dropout)
662
+
663
+ self.layer_norm_first = layer_norm_first
664
+
665
+ # layer norm associated with the self attention layer
666
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
667
+
668
+ if self.activation_name == "glu":
669
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
670
+ else:
671
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
672
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
673
+
674
+ # layer norm associated with the position wise feed-forward NN
675
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
676
+
677
+ def forward(
678
+ self,
679
+ x: torch.Tensor,
680
+ self_attn_mask: torch.Tensor = None,
681
+ self_attn_padding_mask: torch.Tensor = None,
682
+ need_weights: bool = False,
683
+ pos_bias=None
684
+ ):
685
+ """
686
+ LayerNorm is applied either before or after the self-attention/ffn
687
+ modules similar to the original Transformer imlementation.
688
+ """
689
+ residual = x
690
+
691
+ if self.layer_norm_first:
692
+ x = self.self_attn_layer_norm(x)
693
+ x, attn, pos_bias = self.self_attn(
694
+ query=x,
695
+ key=x,
696
+ value=x,
697
+ key_padding_mask=self_attn_padding_mask,
698
+ need_weights=False,
699
+ attn_mask=self_attn_mask,
700
+ position_bias=pos_bias
701
+ )
702
+ x = self.dropout1(x)
703
+ x = residual + x
704
+
705
+ residual = x
706
+ x = self.final_layer_norm(x)
707
+ if self.activation_name == "glu":
708
+ x = self.fc1(x)
709
+ else:
710
+ x = self.activation_fn(self.fc1(x))
711
+ x = self.dropout2(x)
712
+ x = self.fc2(x)
713
+ x = self.dropout3(x)
714
+ x = residual + x
715
+ else:
716
+ x, attn, pos_bias = self.self_attn(
717
+ query=x,
718
+ key=x,
719
+ value=x,
720
+ key_padding_mask=self_attn_padding_mask,
721
+ need_weights=need_weights,
722
+ attn_mask=self_attn_mask,
723
+ position_bias=pos_bias
724
+ )
725
+
726
+ x = self.dropout1(x)
727
+ x = residual + x
728
+
729
+ x = self.self_attn_layer_norm(x)
730
+
731
+ residual = x
732
+ if self.activation_name == "glu":
733
+ x = self.fc1(x)
734
+ else:
735
+ x = self.activation_fn(self.fc1(x))
736
+ x = self.dropout2(x)
737
+ x = self.fc2(x)
738
+ x = self.dropout3(x)
739
+ x = residual + x
740
+ x = self.final_layer_norm(x)
741
+
742
+ return x, attn, pos_bias
743
+
UniSpeech/WavLM/WavLM_ASR.PNG ADDED

Git LFS Details

  • SHA256: 4ed0c965745753b1d78fc15ac47a020423fb280a0c819198f78b5737973c20c7
  • Pointer size: 131 Bytes
  • Size of remote file: 108 kB
UniSpeech/WavLM/WavLM_SUPERB_Leaderboard.png ADDED

Git LFS Details

  • SHA256: 4c90b744bf4488726fb905d794ccc74994406b27cabf9dab1e50279b81d274b7
  • Pointer size: 131 Bytes
  • Size of remote file: 144 kB
UniSpeech/WavLM/WavLM_SUPERB_Results.png ADDED

Git LFS Details

  • SHA256: 492980df15c17ecdbd0d6696ea79e9648432c8667c6e426cf35a50bf300592a6
  • Pointer size: 131 Bytes
  • Size of remote file: 474 kB
UniSpeech/WavLM/modules.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ from typing import Dict, Optional, Tuple
13
+ import torch
14
+ from torch import Tensor, nn
15
+ from torch.nn import Parameter
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class TransposeLast(nn.Module):
20
+ def __init__(self, deconstruct_idx=None):
21
+ super().__init__()
22
+ self.deconstruct_idx = deconstruct_idx
23
+
24
+ def forward(self, x):
25
+ if self.deconstruct_idx is not None:
26
+ x = x[self.deconstruct_idx]
27
+ return x.transpose(-2, -1)
28
+
29
+
30
+ class Fp32LayerNorm(nn.LayerNorm):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+
34
+ def forward(self, input):
35
+ output = F.layer_norm(
36
+ input.float(),
37
+ self.normalized_shape,
38
+ self.weight.float() if self.weight is not None else None,
39
+ self.bias.float() if self.bias is not None else None,
40
+ self.eps,
41
+ )
42
+ return output.type_as(input)
43
+
44
+
45
+ class Fp32GroupNorm(nn.GroupNorm):
46
+ def __init__(self, *args, **kwargs):
47
+ super().__init__(*args, **kwargs)
48
+
49
+ def forward(self, input):
50
+ output = F.group_norm(
51
+ input.float(),
52
+ self.num_groups,
53
+ self.weight.float() if self.weight is not None else None,
54
+ self.bias.float() if self.bias is not None else None,
55
+ self.eps,
56
+ )
57
+ return output.type_as(input)
58
+
59
+
60
+ class GradMultiply(torch.autograd.Function):
61
+ @staticmethod
62
+ def forward(ctx, x, scale):
63
+ ctx.scale = scale
64
+ res = x.new(x)
65
+ return res
66
+
67
+ @staticmethod
68
+ def backward(ctx, grad):
69
+ return grad * ctx.scale, None
70
+
71
+
72
+ class SamePad(nn.Module):
73
+ def __init__(self, kernel_size, causal=False):
74
+ super().__init__()
75
+ if causal:
76
+ self.remove = kernel_size - 1
77
+ else:
78
+ self.remove = 1 if kernel_size % 2 == 0 else 0
79
+
80
+ def forward(self, x):
81
+ if self.remove > 0:
82
+ x = x[:, :, : -self.remove]
83
+ return x
84
+
85
+
86
+ class Swish(nn.Module):
87
+ """Swish function
88
+ """
89
+
90
+ def __init__(self):
91
+ """Construct an MultiHeadedAttention object."""
92
+ super(Swish, self).__init__()
93
+ self.act = torch.nn.Sigmoid()
94
+
95
+ def forward(self, x):
96
+ return x * self.act(x)
97
+
98
+
99
+ class GLU_Linear(nn.Module):
100
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
101
+ super(GLU_Linear, self).__init__()
102
+
103
+ self.glu_type = glu_type
104
+ self.output_dim = output_dim
105
+
106
+ if glu_type == "sigmoid":
107
+ self.glu_act = torch.nn.Sigmoid()
108
+ elif glu_type == "swish":
109
+ self.glu_act = Swish()
110
+ elif glu_type == "relu":
111
+ self.glu_act = torch.nn.ReLU()
112
+ elif glu_type == "gelu":
113
+ self.glu_act = torch.nn.GELU()
114
+
115
+ if bias_in_glu:
116
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
117
+ else:
118
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
119
+
120
+ def forward(self, x):
121
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
122
+ x = self.linear(x)
123
+
124
+ if self.glu_type == "bilinear":
125
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
126
+ else:
127
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
128
+
129
+ return x
130
+
131
+
132
+ def gelu_accurate(x):
133
+ if not hasattr(gelu_accurate, "_a"):
134
+ gelu_accurate._a = math.sqrt(2 / math.pi)
135
+ return (
136
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
137
+ )
138
+
139
+
140
+ def gelu(x: torch.Tensor) -> torch.Tensor:
141
+ return torch.nn.functional.gelu(x.float()).type_as(x)
142
+
143
+
144
+ def get_activation_fn(activation: str):
145
+ """Returns the activation function corresponding to `activation`"""
146
+
147
+ if activation == "relu":
148
+ return F.relu
149
+ elif activation == "gelu":
150
+ return gelu
151
+ elif activation == "gelu_fast":
152
+ warnings.warn(
153
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
154
+ )
155
+ return gelu_accurate
156
+ elif activation == "gelu_accurate":
157
+ return gelu_accurate
158
+ elif activation == "tanh":
159
+ return torch.tanh
160
+ elif activation == "linear":
161
+ return lambda x: x
162
+ elif activation == "glu":
163
+ return lambda x: x
164
+ else:
165
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
166
+
167
+
168
+ def init_bert_params(module):
169
+ """
170
+ Initialize the weights specific to the BERT Model.
171
+ This overrides the default initializations depending on the specified arguments.
172
+ 1. If normal_init_linear_weights is set then weights of linear
173
+ layer will be initialized using the normal distribution and
174
+ bais will be set to the specified value.
175
+ 2. If normal_init_embed_weights is set then weights of embedding
176
+ layer will be initialized using the normal distribution.
177
+ 3. If normal_init_proj_weights is set then weights of
178
+ in_project_weight for MultiHeadAttention initialized using
179
+ the normal distribution (to be validated).
180
+ """
181
+
182
+ def normal_(data):
183
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
184
+ # so that the RNG is consistent with and without FSDP
185
+ data.copy_(
186
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
187
+ )
188
+
189
+ if isinstance(module, nn.Linear):
190
+ normal_(module.weight.data)
191
+ if module.bias is not None:
192
+ module.bias.data.zero_()
193
+ if isinstance(module, nn.Embedding):
194
+ normal_(module.weight.data)
195
+ if module.padding_idx is not None:
196
+ module.weight.data[module.padding_idx].zero_()
197
+ if isinstance(module, MultiheadAttention):
198
+ normal_(module.q_proj.weight.data)
199
+ normal_(module.k_proj.weight.data)
200
+ normal_(module.v_proj.weight.data)
201
+
202
+
203
+ def quant_noise(module, p, block_size):
204
+ """
205
+ Wraps modules and applies quantization noise to the weights for
206
+ subsequent quantization with Iterative Product Quantization as
207
+ described in "Training with Quantization Noise for Extreme Model Compression"
208
+
209
+ Args:
210
+ - module: nn.Module
211
+ - p: amount of Quantization Noise
212
+ - block_size: size of the blocks for subsequent quantization with iPQ
213
+
214
+ Remarks:
215
+ - Module weights must have the right sizes wrt the block size
216
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
217
+ - For more detail on how to quantize by blocks with convolutional weights,
218
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
219
+ - We implement the simplest form of noise here as stated in the paper
220
+ which consists in randomly dropping blocks
221
+ """
222
+
223
+ # if no quantization noise, don't register hook
224
+ if p <= 0:
225
+ return module
226
+
227
+ # supported modules
228
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
229
+
230
+ # test whether module.weight has the right sizes wrt block_size
231
+ is_conv = module.weight.ndim == 4
232
+
233
+ # 2D matrix
234
+ if not is_conv:
235
+ assert (
236
+ module.weight.size(1) % block_size == 0
237
+ ), "Input features must be a multiple of block sizes"
238
+
239
+ # 4D matrix
240
+ else:
241
+ # 1x1 convolutions
242
+ if module.kernel_size == (1, 1):
243
+ assert (
244
+ module.in_channels % block_size == 0
245
+ ), "Input channels must be a multiple of block sizes"
246
+ # regular convolutions
247
+ else:
248
+ k = module.kernel_size[0] * module.kernel_size[1]
249
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
250
+
251
+ def _forward_pre_hook(mod, input):
252
+ # no noise for evaluation
253
+ if mod.training:
254
+ if not is_conv:
255
+ # gather weight and sizes
256
+ weight = mod.weight
257
+ in_features = weight.size(1)
258
+ out_features = weight.size(0)
259
+
260
+ # split weight matrix into blocks and randomly drop selected blocks
261
+ mask = torch.zeros(
262
+ in_features // block_size * out_features, device=weight.device
263
+ )
264
+ mask.bernoulli_(p)
265
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
266
+
267
+ else:
268
+ # gather weight and sizes
269
+ weight = mod.weight
270
+ in_channels = mod.in_channels
271
+ out_channels = mod.out_channels
272
+
273
+ # split weight matrix into blocks and randomly drop selected blocks
274
+ if mod.kernel_size == (1, 1):
275
+ mask = torch.zeros(
276
+ int(in_channels // block_size * out_channels),
277
+ device=weight.device,
278
+ )
279
+ mask.bernoulli_(p)
280
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
281
+ else:
282
+ mask = torch.zeros(
283
+ weight.size(0), weight.size(1), device=weight.device
284
+ )
285
+ mask.bernoulli_(p)
286
+ mask = (
287
+ mask.unsqueeze(2)
288
+ .unsqueeze(3)
289
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
290
+ )
291
+
292
+ # scale weights and apply mask
293
+ mask = mask.to(
294
+ torch.bool
295
+ ) # x.bool() is not currently supported in TorchScript
296
+ s = 1 / (1 - p)
297
+ mod.weight.data = s * weight.masked_fill(mask, 0)
298
+
299
+ module.register_forward_pre_hook(_forward_pre_hook)
300
+ return module
301
+
302
+
303
+ class MultiheadAttention(nn.Module):
304
+ """Multi-headed attention.
305
+
306
+ See "Attention Is All You Need" for more details.
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ embed_dim,
312
+ num_heads,
313
+ kdim=None,
314
+ vdim=None,
315
+ dropout=0.0,
316
+ bias=True,
317
+ add_bias_kv=False,
318
+ add_zero_attn=False,
319
+ self_attention=False,
320
+ encoder_decoder_attention=False,
321
+ q_noise=0.0,
322
+ qn_block_size=8,
323
+ has_relative_attention_bias=False,
324
+ num_buckets=32,
325
+ max_distance=128,
326
+ gru_rel_pos=False,
327
+ rescale_init=False,
328
+ ):
329
+ super().__init__()
330
+ self.embed_dim = embed_dim
331
+ self.kdim = kdim if kdim is not None else embed_dim
332
+ self.vdim = vdim if vdim is not None else embed_dim
333
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
334
+
335
+ self.num_heads = num_heads
336
+ self.dropout_module = nn.Dropout(dropout)
337
+
338
+ self.has_relative_attention_bias = has_relative_attention_bias
339
+ self.num_buckets = num_buckets
340
+ self.max_distance = max_distance
341
+ if self.has_relative_attention_bias:
342
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
343
+
344
+ self.head_dim = embed_dim // num_heads
345
+ self.q_head_dim = self.head_dim
346
+ self.k_head_dim = self.head_dim
347
+ assert (
348
+ self.head_dim * num_heads == self.embed_dim
349
+ ), "embed_dim must be divisible by num_heads"
350
+ self.scaling = self.head_dim ** -0.5
351
+
352
+ self.self_attention = self_attention
353
+ self.encoder_decoder_attention = encoder_decoder_attention
354
+
355
+ assert not self.self_attention or self.qkv_same_dim, (
356
+ "Self-attention requires query, key and " "value to be of the same size"
357
+ )
358
+
359
+ k_bias = True
360
+ if rescale_init:
361
+ k_bias = False
362
+
363
+ k_embed_dim = embed_dim
364
+ q_embed_dim = embed_dim
365
+
366
+ self.k_proj = quant_noise(
367
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
368
+ )
369
+ self.v_proj = quant_noise(
370
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
371
+ )
372
+ self.q_proj = quant_noise(
373
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
374
+ )
375
+
376
+ self.out_proj = quant_noise(
377
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
378
+ )
379
+
380
+ if add_bias_kv:
381
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
382
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
383
+ else:
384
+ self.bias_k = self.bias_v = None
385
+
386
+ self.add_zero_attn = add_zero_attn
387
+
388
+ self.gru_rel_pos = gru_rel_pos
389
+ if self.gru_rel_pos:
390
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
391
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
392
+
393
+ self.reset_parameters()
394
+
395
+ def reset_parameters(self):
396
+ if self.qkv_same_dim:
397
+ # Empirically observed the convergence to be much better with
398
+ # the scaled initialization
399
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
400
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
401
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
402
+ else:
403
+ nn.init.xavier_uniform_(self.k_proj.weight)
404
+ nn.init.xavier_uniform_(self.v_proj.weight)
405
+ nn.init.xavier_uniform_(self.q_proj.weight)
406
+
407
+ nn.init.xavier_uniform_(self.out_proj.weight)
408
+ if self.out_proj.bias is not None:
409
+ nn.init.constant_(self.out_proj.bias, 0.0)
410
+ if self.bias_k is not None:
411
+ nn.init.xavier_normal_(self.bias_k)
412
+ if self.bias_v is not None:
413
+ nn.init.xavier_normal_(self.bias_v)
414
+ if self.has_relative_attention_bias:
415
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
416
+
417
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
418
+ num_buckets = self.num_buckets
419
+ max_distance = self.max_distance
420
+ relative_buckets = 0
421
+
422
+ if bidirectional:
423
+ num_buckets = num_buckets // 2
424
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
425
+ relative_positions = torch.abs(relative_positions)
426
+ else:
427
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
428
+
429
+ max_exact = num_buckets // 2
430
+ is_small = relative_positions < max_exact
431
+
432
+ relative_postion_if_large = max_exact + (
433
+ torch.log(relative_positions.float() / max_exact)
434
+ / math.log(max_distance / max_exact)
435
+ * (num_buckets - max_exact)
436
+ ).to(torch.long)
437
+ relative_postion_if_large = torch.min(
438
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
439
+ )
440
+
441
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
442
+ return relative_buckets
443
+
444
+ def compute_bias(self, query_length, key_length):
445
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
446
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
447
+ relative_position = memory_position - context_position
448
+ relative_position_bucket = self._relative_positions_bucket(
449
+ relative_position,
450
+ bidirectional=True
451
+ )
452
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
453
+ values = self.relative_attention_bias(relative_position_bucket)
454
+ values = values.permute([2, 0, 1])
455
+ return values
456
+
457
+ def forward(
458
+ self,
459
+ query,
460
+ key: Optional[Tensor],
461
+ value: Optional[Tensor],
462
+ key_padding_mask: Optional[Tensor] = None,
463
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
464
+ need_weights: bool = True,
465
+ static_kv: bool = False,
466
+ attn_mask: Optional[Tensor] = None,
467
+ before_softmax: bool = False,
468
+ need_head_weights: bool = False,
469
+ position_bias: Optional[Tensor] = None
470
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
471
+ """Input shape: Time x Batch x Channel
472
+
473
+ Args:
474
+ key_padding_mask (ByteTensor, optional): mask to exclude
475
+ keys that are pads, of shape `(batch, src_len)`, where
476
+ padding elements are indicated by 1s.
477
+ need_weights (bool, optional): return the attention weights,
478
+ averaged over heads (default: False).
479
+ attn_mask (ByteTensor, optional): typically used to
480
+ implement causal attention, where the mask prevents the
481
+ attention from looking forward in time (default: None).
482
+ before_softmax (bool, optional): return the raw attention
483
+ weights and values before the attention softmax.
484
+ need_head_weights (bool, optional): return the attention
485
+ weights for each head. Implies *need_weights*. Default:
486
+ return the average attention weights over all heads.
487
+ """
488
+ if need_head_weights:
489
+ need_weights = True
490
+
491
+ is_tpu = query.device.type == "xla"
492
+
493
+ tgt_len, bsz, embed_dim = query.size()
494
+ src_len = tgt_len
495
+ assert embed_dim == self.embed_dim
496
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
497
+ if key is not None:
498
+ src_len, key_bsz, _ = key.size()
499
+ if not torch.jit.is_scripting():
500
+ assert key_bsz == bsz
501
+ assert value is not None
502
+ assert src_len, bsz == value.shape[:2]
503
+
504
+ if self.has_relative_attention_bias and position_bias is None:
505
+ position_bias = self.compute_bias(tgt_len, src_len)
506
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
507
+
508
+ if (
509
+ not is_tpu # don't use PyTorch version on TPUs
510
+ and incremental_state is None
511
+ and not static_kv
512
+ # A workaround for quantization to work. Otherwise JIT compilation
513
+ # treats bias in linear module as method.
514
+ and not torch.jit.is_scripting()
515
+ and self.q_head_dim == self.head_dim
516
+ ):
517
+ assert key is not None and value is not None
518
+ assert attn_mask is None
519
+
520
+ attn_mask_rel_pos = None
521
+ if position_bias is not None:
522
+ attn_mask_rel_pos = position_bias
523
+ if self.gru_rel_pos:
524
+ query_layer = query.transpose(0, 1)
525
+ new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
526
+ query_layer = query_layer.view(*new_x_shape)
527
+ query_layer = query_layer.permute(0, 2, 1, 3)
528
+ _B, _H, _L, __ = query_layer.size()
529
+
530
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
531
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
532
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
533
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
534
+
535
+ attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
536
+ k_proj_bias = self.k_proj.bias
537
+ if k_proj_bias is None:
538
+ k_proj_bias = torch.zeros_like(self.q_proj.bias)
539
+
540
+ x, attn = F.multi_head_attention_forward(
541
+ query,
542
+ key,
543
+ value,
544
+ self.embed_dim,
545
+ self.num_heads,
546
+ torch.empty([0]),
547
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
548
+ self.bias_k,
549
+ self.bias_v,
550
+ self.add_zero_attn,
551
+ self.dropout_module.p,
552
+ self.out_proj.weight,
553
+ self.out_proj.bias,
554
+ self.training,
555
+ # self.training or self.dropout_module.apply_during_inference,
556
+ key_padding_mask,
557
+ need_weights,
558
+ attn_mask_rel_pos,
559
+ use_separate_proj_weight=True,
560
+ q_proj_weight=self.q_proj.weight,
561
+ k_proj_weight=self.k_proj.weight,
562
+ v_proj_weight=self.v_proj.weight,
563
+ )
564
+ return x, attn, position_bias
565
+
566
+ if incremental_state is not None:
567
+ saved_state = self._get_input_buffer(incremental_state)
568
+ if saved_state is not None and "prev_key" in saved_state:
569
+ # previous time steps are cached - no need to recompute
570
+ # key and value if they are static
571
+ if static_kv:
572
+ assert self.encoder_decoder_attention and not self.self_attention
573
+ key = value = None
574
+ else:
575
+ saved_state = None
576
+
577
+ if self.self_attention:
578
+ q = self.q_proj(query)
579
+ k = self.k_proj(query)
580
+ v = self.v_proj(query)
581
+ elif self.encoder_decoder_attention:
582
+ # encoder-decoder attention
583
+ q = self.q_proj(query)
584
+ if key is None:
585
+ assert value is None
586
+ k = v = None
587
+ else:
588
+ k = self.k_proj(key)
589
+ v = self.v_proj(key)
590
+
591
+ else:
592
+ assert key is not None and value is not None
593
+ q = self.q_proj(query)
594
+ k = self.k_proj(key)
595
+ v = self.v_proj(value)
596
+ q *= self.scaling
597
+
598
+ if self.bias_k is not None:
599
+ assert self.bias_v is not None
600
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
601
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
602
+ if attn_mask is not None:
603
+ attn_mask = torch.cat(
604
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
605
+ )
606
+ if key_padding_mask is not None:
607
+ key_padding_mask = torch.cat(
608
+ [
609
+ key_padding_mask,
610
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
611
+ ],
612
+ dim=1,
613
+ )
614
+
615
+ q = (
616
+ q.contiguous()
617
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
618
+ .transpose(0, 1)
619
+ )
620
+ if k is not None:
621
+ k = (
622
+ k.contiguous()
623
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
624
+ .transpose(0, 1)
625
+ )
626
+ if v is not None:
627
+ v = (
628
+ v.contiguous()
629
+ .view(-1, bsz * self.num_heads, self.head_dim)
630
+ .transpose(0, 1)
631
+ )
632
+
633
+ if saved_state is not None:
634
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
635
+ if "prev_key" in saved_state:
636
+ _prev_key = saved_state["prev_key"]
637
+ assert _prev_key is not None
638
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
639
+ if static_kv:
640
+ k = prev_key
641
+ else:
642
+ assert k is not None
643
+ k = torch.cat([prev_key, k], dim=1)
644
+ src_len = k.size(1)
645
+ if "prev_value" in saved_state:
646
+ _prev_value = saved_state["prev_value"]
647
+ assert _prev_value is not None
648
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
649
+ if static_kv:
650
+ v = prev_value
651
+ else:
652
+ assert v is not None
653
+ v = torch.cat([prev_value, v], dim=1)
654
+ prev_key_padding_mask: Optional[Tensor] = None
655
+ if "prev_key_padding_mask" in saved_state:
656
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
657
+ assert k is not None and v is not None
658
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
659
+ key_padding_mask=key_padding_mask,
660
+ prev_key_padding_mask=prev_key_padding_mask,
661
+ batch_size=bsz,
662
+ src_len=k.size(1),
663
+ static_kv=static_kv,
664
+ )
665
+
666
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
667
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
668
+ saved_state["prev_key_padding_mask"] = key_padding_mask
669
+ # In this branch incremental_state is never None
670
+ assert incremental_state is not None
671
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
672
+ assert k is not None
673
+ assert k.size(1) == src_len
674
+
675
+ # This is part of a workaround to get around fork/join parallelism
676
+ # not supporting Optional types.
677
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
678
+ key_padding_mask = None
679
+
680
+ if key_padding_mask is not None:
681
+ assert key_padding_mask.size(0) == bsz
682
+ assert key_padding_mask.size(1) == src_len
683
+
684
+ if self.add_zero_attn:
685
+ assert v is not None
686
+ src_len += 1
687
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
688
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
689
+ if attn_mask is not None:
690
+ attn_mask = torch.cat(
691
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
692
+ )
693
+ if key_padding_mask is not None:
694
+ key_padding_mask = torch.cat(
695
+ [
696
+ key_padding_mask,
697
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
698
+ key_padding_mask
699
+ ),
700
+ ],
701
+ dim=1,
702
+ )
703
+
704
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
705
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
706
+
707
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
708
+
709
+ if attn_mask is not None:
710
+ attn_mask = attn_mask.unsqueeze(0)
711
+ attn_weights += attn_mask
712
+
713
+ if key_padding_mask is not None:
714
+ # don't attend to padding symbols
715
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
716
+ if not is_tpu:
717
+ attn_weights = attn_weights.masked_fill(
718
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
719
+ float("-inf"),
720
+ )
721
+ else:
722
+ attn_weights = attn_weights.transpose(0, 2)
723
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
724
+ attn_weights = attn_weights.transpose(0, 2)
725
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
726
+
727
+ if before_softmax:
728
+ return attn_weights, v, position_bias
729
+
730
+ if position_bias is not None:
731
+ if self.gru_rel_pos == 1:
732
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
733
+ _B, _H, _L, __ = query_layer.size()
734
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
735
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
736
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
737
+ position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
738
+
739
+ position_bias = position_bias.view(attn_weights.size())
740
+
741
+ attn_weights = attn_weights + position_bias
742
+
743
+ attn_weights_float = F.softmax(
744
+ attn_weights, dim=-1
745
+ )
746
+ attn_weights = attn_weights_float.type_as(attn_weights)
747
+ attn_probs = self.dropout_module(attn_weights)
748
+
749
+ assert v is not None
750
+ attn = torch.bmm(attn_probs, v)
751
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
752
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
753
+ attn = self.out_proj(attn)
754
+ attn_weights: Optional[Tensor] = None
755
+ if need_weights:
756
+ attn_weights = attn_weights_float.view(
757
+ bsz, self.num_heads, tgt_len, src_len
758
+ ).transpose(1, 0)
759
+ if not need_head_weights:
760
+ # average attention weights over heads
761
+ attn_weights = attn_weights.mean(dim=0)
762
+
763
+ return attn, attn_weights, position_bias
764
+
765
+ @staticmethod
766
+ def _append_prev_key_padding_mask(
767
+ key_padding_mask: Optional[Tensor],
768
+ prev_key_padding_mask: Optional[Tensor],
769
+ batch_size: int,
770
+ src_len: int,
771
+ static_kv: bool,
772
+ ) -> Optional[Tensor]:
773
+ # saved key padding masks have shape (bsz, seq_len)
774
+ if prev_key_padding_mask is not None and static_kv:
775
+ new_key_padding_mask = prev_key_padding_mask
776
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
777
+ new_key_padding_mask = torch.cat(
778
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
779
+ )
780
+ # During incremental decoding, as the padding token enters and
781
+ # leaves the frame, there will be a time when prev or current
782
+ # is None
783
+ elif prev_key_padding_mask is not None:
784
+ if src_len > prev_key_padding_mask.size(1):
785
+ filler = torch.zeros(
786
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
787
+ device=prev_key_padding_mask.device,
788
+ )
789
+ new_key_padding_mask = torch.cat(
790
+ [prev_key_padding_mask.float(), filler.float()], dim=1
791
+ )
792
+ else:
793
+ new_key_padding_mask = prev_key_padding_mask.float()
794
+ elif key_padding_mask is not None:
795
+ if src_len > key_padding_mask.size(1):
796
+ filler = torch.zeros(
797
+ (batch_size, src_len - key_padding_mask.size(1)),
798
+ device=key_padding_mask.device,
799
+ )
800
+ new_key_padding_mask = torch.cat(
801
+ [filler.float(), key_padding_mask.float()], dim=1
802
+ )
803
+ else:
804
+ new_key_padding_mask = key_padding_mask.float()
805
+ else:
806
+ new_key_padding_mask = prev_key_padding_mask
807
+ return new_key_padding_mask
808
+
809
+ def _get_input_buffer(
810
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
811
+ ) -> Dict[str, Optional[Tensor]]:
812
+ result = self.get_incremental_state(incremental_state, "attn_state")
813
+ if result is not None:
814
+ return result
815
+ else:
816
+ empty_result: Dict[str, Optional[Tensor]] = {}
817
+ return empty_result
818
+
819
+ def _set_input_buffer(
820
+ self,
821
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
822
+ buffer: Dict[str, Optional[Tensor]],
823
+ ):
824
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
825
+
826
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
827
+ return attn_weights
UniSpeech/azure-pipelines.yml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trigger:
2
+ - master
3
+
4
+ pool:
5
+ vmImage: 'windows-latest'
6
+
7
+ steps:
8
+ - script: echo Hello, world!
9
+ displayName: 'Run a one-line script'
10
+
11
+ - script: |
12
+ echo Add other tasks to build, test, and deploy your project.
13
+ echo See https://aka.ms/yaml
14
+ displayName: 'Run a multi-line script'
15
+
16
+ - task: CredScan@2
17
+ inputs:
18
+ toolMajorVersion: 'V2'
19
+
20
+
21
+ - task: Semmle@0
22
+ env:
23
+ SYSTEM_ACCESSTOKEN: $(PATToken)
24
+ inputs:
25
+ sourceCodeDirectory: '$(Build.SourcesDirectory)'
26
+ language: 'python'
27
+ includeNodeModules: true
28
+ querySuite: 'Recommended'
29
+ timeout: '1800'
30
+ ram: '16384'
31
+ addProjectDirToScanningExclusionList: true
32
+
33
+ - task: ComponentGovernanceComponentDetection@0
34
+ inputs:
35
+ scanType: 'Register'
36
+ verbosity: 'Verbose'
37
+ alertWarningLevel: 'High'
38
+
39
+ - task: PublishSecurityAnalysisLogs@2
40
+ inputs:
41
+ ArtifactName: 'CodeAnalysisLogs'
42
+ ArtifactType: 'Container'
43
+ AllTools: true
44
+ ToolLogsNotFoundAction: 'Standard'
45
+
46
+
47
+
UniSpeech/downstreams/speaker_diarization/README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Pre-training Representations for Speaker Diarization
2
+
3
+ ### Downstream Model
4
+
5
+ [EEND-vector-clustering](https://arxiv.org/abs/2105.09040)
6
+
7
+ ### Pre-trained models
8
+
9
+ - It should be noted that the diarization system is trained on 8k audio data.
10
+
11
+ | Model | 2 spk DER | 3 spk DER | 4 spk DER | 5 spk DER | 6 spk DER | ALL spk DER |
12
+ | ------------------------------------------------------------ | --------- | --------- | --------- | --------- | --------- | ----------- |
13
+ | EEND-vector-clustering | 7.96 | 11.93 | 16.38 | 21.21 | 23.1 | 12.49 |
14
+ | [**UniSpeech-SAT large**](https://drive.google.com/file/d/16OwIyOk2uYm0aWtSPaS0S12xE8RxF7k_/view?usp=sharing) | 5.93 | 10.66 | 12.90 | 16.48 | 23.25 | 10.92 |
15
+
16
+ ### How to use?
17
+
18
+ #### Environment Setup
19
+
20
+ 1. `pip install --require-hashes -r requirements.txt`
21
+ 2. Install fairseq code
22
+ - For UniSpeech-SAT large, we should install the [Unispeech-SAT](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech-SAT) fairseq code.
23
+
24
+ #### Example
25
+
26
+ 1. First, you should download the pre-trained model in the above table to `checkpoint_path`.
27
+ 2. Then, run the following codes:
28
+ - The wav file is the multi-talker simulated speech from Librispeech corpus.
29
+ 3. The output will be written in `out.rttm` by default.
30
+
31
+ ```bash
32
+ python diarization.py --wav_path tmp/mix_0000496.wav --model_init $checkpoint_path
33
+ ```
34
+
UniSpeech/downstreams/speaker_diarization/config/infer_est_nspk1.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
2
+ # All rights reserved
3
+
4
+ # inference options
5
+ est_nspk: 1
6
+ sil_spk_th: 0.05
7
+ ahc_dis_th: 1.0
8
+ clink_dis: 1.0e+4
9
+ model:
10
+ n_speakers: 3
11
+ all_n_speakers: 0
12
+ feat_dim: 1024
13
+ n_units: 256
14
+ n_heads: 8
15
+ n_layers: 6
16
+ dropout_rate: 0.1
17
+ spk_emb_dim: 256
18
+ sr: 8000
19
+ frame_shift: 320
20
+ frame_size: 200
21
+ context_size: 0
22
+ subsampling: 1
23
+ feat_type: "config/unispeech_sat.th"
24
+ feature_selection: "hidden_states"
25
+ interpolate_mode: "linear"
26
+ dataset:
27
+ chunk_size: 750
28
+ frame_shift: 320
29
+ sampling_rate: 8000
30
+ subsampling: 1
31
+ num_speakers: 3
UniSpeech/downstreams/speaker_diarization/config/unispeech_sat.th ADDED
Binary file (20.7 kB). View file
 
UniSpeech/downstreams/speaker_diarization/diarization.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import h5py
3
+ import soundfile as sf
4
+ import fire
5
+ import math
6
+ import yamlargparse
7
+ import numpy as np
8
+ from torch.utils.data import DataLoader
9
+ import torch
10
+ from utils.utils import parse_config_or_kwargs
11
+ from utils.dataset import DiarizationDataset
12
+ from models.models import TransformerDiarization
13
+ from scipy.signal import medfilt
14
+ from sklearn.cluster import AgglomerativeClustering
15
+ from scipy.spatial import distance
16
+ from utils.kaldi_data import KaldiData
17
+
18
+
19
+ def get_cl_sil(args, acti, cls_num):
20
+ n_chunks = len(acti)
21
+ mean_acti = np.array([np.mean(acti[i], axis=0)
22
+ for i in range(n_chunks)]).flatten()
23
+ n = args.num_speakers
24
+ sil_spk_th = args.sil_spk_th
25
+
26
+ cl_lst = []
27
+ sil_lst = []
28
+ for chunk_idx in range(n_chunks):
29
+ if cls_num is not None:
30
+ if args.num_speakers > cls_num:
31
+ mean_acti_bi = np.array([mean_acti[n * chunk_idx + s_loc_idx]
32
+ for s_loc_idx in range(n)])
33
+ min_idx = np.argmin(mean_acti_bi)
34
+ mean_acti[n * chunk_idx + min_idx] = 0.0
35
+
36
+ for s_loc_idx in range(n):
37
+ a = n * chunk_idx + (s_loc_idx + 0) % n
38
+ b = n * chunk_idx + (s_loc_idx + 1) % n
39
+ if mean_acti[a] > sil_spk_th and mean_acti[b] > sil_spk_th:
40
+ cl_lst.append((a, b))
41
+ else:
42
+ if mean_acti[a] <= sil_spk_th:
43
+ sil_lst.append(a)
44
+
45
+ return cl_lst, sil_lst
46
+
47
+
48
+ def clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst):
49
+ org_svec_len = len(svec)
50
+ svec = np.delete(svec, sil_lst, 0)
51
+
52
+ # update cl_lst idx
53
+ _tbl = [i - sum(sil < i for sil in sil_lst) for i in range(org_svec_len)]
54
+ cl_lst = [(_tbl[_cl[0]], _tbl[_cl[1]]) for _cl in cl_lst]
55
+
56
+ distMat = distance.cdist(svec, svec, metric='euclidean')
57
+ for cl in cl_lst:
58
+ distMat[cl[0], cl[1]] = args.clink_dis
59
+ distMat[cl[1], cl[0]] = args.clink_dis
60
+
61
+ clusterer = AgglomerativeClustering(
62
+ n_clusters=cls_num,
63
+ affinity='precomputed',
64
+ linkage='average',
65
+ distance_threshold=ahc_dis_th)
66
+ clusterer.fit(distMat)
67
+
68
+ if cls_num is not None:
69
+ print("oracle n_clusters is known")
70
+ else:
71
+ print("oracle n_clusters is unknown")
72
+ print("estimated n_clusters by constraind AHC: {}"
73
+ .format(len(np.unique(clusterer.labels_))))
74
+ cls_num = len(np.unique(clusterer.labels_))
75
+
76
+ sil_lab = cls_num
77
+ insert_sil_lab = [sil_lab for i in range(len(sil_lst))]
78
+ insert_sil_lab_idx = [sil_lst[i] - i for i in range(len(sil_lst))]
79
+ print("insert_sil_lab : {}".format(insert_sil_lab))
80
+ print("insert_sil_lab_idx : {}".format(insert_sil_lab_idx))
81
+ clslab = np.insert(clusterer.labels_,
82
+ insert_sil_lab_idx,
83
+ insert_sil_lab).reshape(-1, args.num_speakers)
84
+ print("clslab : {}".format(clslab))
85
+
86
+ return clslab, cls_num
87
+
88
+
89
+ def merge_act_max(act, i, j):
90
+ for k in range(len(act)):
91
+ act[k, i] = max(act[k, i], act[k, j])
92
+ act[k, j] = 0.0
93
+ return act
94
+
95
+
96
+ def merge_acti_clslab(args, acti, clslab, cls_num):
97
+ sil_lab = cls_num
98
+ for i in range(len(clslab)):
99
+ _lab = clslab[i].reshape(-1, 1)
100
+ distM = distance.cdist(_lab, _lab, metric='euclidean').astype(np.int64)
101
+ for j in range(len(distM)):
102
+ distM[j][:j] = -1
103
+ idx_lst = np.where(np.count_nonzero(distM == 0, axis=1) > 1)
104
+ merge_done = []
105
+ for j in idx_lst[0]:
106
+ for k in (np.where(distM[j] == 0))[0]:
107
+ if j != k and clslab[i, j] != sil_lab and k not in merge_done:
108
+ print("merge : (i, j, k) == ({}, {}, {})".format(i, j, k))
109
+ acti[i] = merge_act_max(acti[i], j, k)
110
+ clslab[i, k] = sil_lab
111
+ merge_done.append(j)
112
+
113
+ return acti, clslab
114
+
115
+
116
+ def stitching(args, acti, clslab, cls_num):
117
+ n_chunks = len(acti)
118
+ s_loc = args.num_speakers
119
+ sil_lab = cls_num
120
+ s_tot = max(cls_num, s_loc-1)
121
+
122
+ # Extend the max value of s_loc_idx to s_tot+1
123
+ add_acti = []
124
+ for chunk_idx in range(n_chunks):
125
+ zeros = np.zeros((len(acti[chunk_idx]), s_tot+1))
126
+ if s_tot+1 > s_loc:
127
+ zeros[:, :-(s_tot+1-s_loc)] = acti[chunk_idx]
128
+ else:
129
+ zeros = acti[chunk_idx]
130
+ add_acti.append(zeros)
131
+ acti = np.array(add_acti)
132
+
133
+ out_chunks = []
134
+ for chunk_idx in range(n_chunks):
135
+ # Make sloci2lab_dct.
136
+ # key: s_loc_idx
137
+ # value: estimated label by clustering or sil_lab
138
+ cls_set = set()
139
+ for s_loc_idx in range(s_tot+1):
140
+ cls_set.add(s_loc_idx)
141
+
142
+ sloci2lab_dct = {}
143
+ for s_loc_idx in range(s_tot+1):
144
+ if s_loc_idx < s_loc:
145
+ sloci2lab_dct[s_loc_idx] = clslab[chunk_idx][s_loc_idx]
146
+ if clslab[chunk_idx][s_loc_idx] in cls_set:
147
+ cls_set.remove(clslab[chunk_idx][s_loc_idx])
148
+ else:
149
+ if clslab[chunk_idx][s_loc_idx] != sil_lab:
150
+ raise ValueError
151
+ else:
152
+ sloci2lab_dct[s_loc_idx] = list(cls_set)[s_loc_idx-s_loc]
153
+
154
+ # Sort by label value
155
+ sloci2lab_lst = sorted(sloci2lab_dct.items(), key=lambda x: x[1])
156
+
157
+ # Select sil_lab_idx
158
+ sil_lab_idx = None
159
+ for idx_lab in sloci2lab_lst:
160
+ if idx_lab[1] == sil_lab:
161
+ sil_lab_idx = idx_lab[0]
162
+ break
163
+ if sil_lab_idx is None:
164
+ raise ValueError
165
+
166
+ # Get swap_idx
167
+ # [idx of label(0), idx of label(1), ..., idx of label(s_tot)]
168
+ swap_idx = [sil_lab_idx for j in range(s_tot+1)]
169
+ for lab in range(s_tot+1):
170
+ for idx_lab in sloci2lab_lst:
171
+ if lab == idx_lab[1]:
172
+ swap_idx[lab] = idx_lab[0]
173
+
174
+ print("swap_idx {}".format(swap_idx))
175
+ swap_acti = acti[chunk_idx][:, swap_idx]
176
+ swap_acti = np.delete(swap_acti, sil_lab, 1)
177
+ out_chunks.append(swap_acti)
178
+
179
+ return out_chunks
180
+
181
+
182
+ def prediction(num_speakers, net, wav_list, chunk_len_list):
183
+ acti_lst = []
184
+ svec_lst = []
185
+ len_list = []
186
+
187
+ with torch.no_grad():
188
+ for wav, chunk_len in zip(wav_list, chunk_len_list):
189
+ wav = wav.to('cuda')
190
+ outputs = net.batch_estimate(torch.unsqueeze(wav, 0))
191
+ ys = outputs[0]
192
+
193
+ for i in range(num_speakers):
194
+ spkivecs = outputs[i+1]
195
+ svec_lst.append(spkivecs[0].cpu().detach().numpy())
196
+
197
+ acti = ys[0][-chunk_len:].cpu().detach().numpy()
198
+ acti_lst.append(acti)
199
+ len_list.append(chunk_len)
200
+
201
+ acti_arr = np.concatenate(acti_lst, axis=0) # totol_len x num_speakers
202
+ svec_arr = np.stack(svec_lst) # (chunk_num x num_speakers) x emb_dim
203
+ len_arr = np.array(len_list) # chunk_num
204
+
205
+ return acti_arr, svec_arr, len_arr
206
+
207
+ def cluster(args, conf, acti_arr, svec_arr, len_arr):
208
+
209
+ acti_list = []
210
+ n_chunks = len_arr.shape[0]
211
+ start = 0
212
+ for i in range(n_chunks):
213
+ chunk_len = len_arr[i]
214
+ acti_list.append(acti_arr[start: start+chunk_len])
215
+ start += chunk_len
216
+ acti = np.array(acti_list)
217
+ svec = svec_arr
218
+
219
+ # initialize clustering setting
220
+ cls_num = None
221
+ ahc_dis_th = args.ahc_dis_th
222
+ # Get cannot-link index list and silence index list
223
+ cl_lst, sil_lst = get_cl_sil(args, acti, cls_num)
224
+
225
+ n_samples = n_chunks * args.num_speakers - len(sil_lst)
226
+ min_n_samples = 2
227
+ if cls_num is not None:
228
+ min_n_samples = cls_num
229
+
230
+ if n_samples >= min_n_samples:
231
+ # clustering (if cls_num is None, update cls_num)
232
+ clslab, cls_num =\
233
+ clustering(args, svec, cls_num, ahc_dis_th, cl_lst, sil_lst)
234
+ # merge
235
+ acti, clslab = merge_acti_clslab(args, acti, clslab, cls_num)
236
+ # stitching
237
+ out_chunks = stitching(args, acti, clslab, cls_num)
238
+ else:
239
+ out_chunks = acti
240
+
241
+ outdata = np.vstack(out_chunks)
242
+ # Saving the resuts
243
+ return outdata
244
+
245
+ def make_rttm(args, conf, cluster_data):
246
+ args.frame_shift = conf['model']['frame_shift']
247
+ args.subsampling = conf['model']['subsampling']
248
+ args.sampling_rate = conf['dataset']['sampling_rate']
249
+
250
+ with open(args.out_rttm_file, 'w') as wf:
251
+ a = np.where(cluster_data > args.threshold, 1, 0)
252
+ if args.median > 1:
253
+ a = medfilt(a, (args.median, 1))
254
+ for spkid, frames in enumerate(a.T):
255
+ frames = np.pad(frames, (1, 1), 'constant')
256
+ changes, = np.where(np.diff(frames, axis=0) != 0)
257
+ fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
258
+ for s, e in zip(changes[::2], changes[1::2]):
259
+ print(fmt.format(
260
+ args.session,
261
+ s * args.frame_shift * args.subsampling / args.sampling_rate,
262
+ (e - s) * args.frame_shift * args.subsampling / args.sampling_rate,
263
+ args.session + "_" + str(spkid)), file=wf)
264
+
265
+ def main(args):
266
+ conf = parse_config_or_kwargs(args.config_path)
267
+ num_speakers = conf['dataset']['num_speakers']
268
+ args.num_speakers = num_speakers
269
+
270
+ # Prepare model
271
+ model_parameter_dict = torch.load(args.model_init)['model']
272
+ model_all_n_speakers = model_parameter_dict["embed.weight"].shape[0]
273
+ conf['model']['all_n_speakers'] = model_all_n_speakers
274
+ net = TransformerDiarization(**conf['model'])
275
+ net.load_state_dict(model_parameter_dict, strict=False)
276
+ net.eval()
277
+ net = net.to("cuda")
278
+
279
+ audio, sr = sf.read(args.wav_path, dtype="float32")
280
+ audio_len = audio.shape[0]
281
+ chunk_size, frame_shift, subsampling = conf['dataset']['chunk_size'], conf['model']['frame_shift'], conf['model']['subsampling']
282
+ scale_ratio = int(frame_shift * subsampling)
283
+ chunk_audio_size = chunk_size * scale_ratio
284
+ wav_list, chunk_len_list = [], []
285
+ for i in range(0, math.ceil(1.0 * audio_len / chunk_audio_size)):
286
+ start, end = i*chunk_audio_size, (i+1)*chunk_audio_size
287
+ if end > audio_len:
288
+ chunk_len_list.append(int((audio_len-start) / scale_ratio))
289
+ end = audio_len
290
+ start = max(0, audio_len - chunk_audio_size)
291
+ else:
292
+ chunk_len_list.append(chunk_size)
293
+ wav_list.append(audio[start:end])
294
+ wav_list = [torch.from_numpy(wav).float() for wav in wav_list]
295
+
296
+ acti_arr, svec_arr, len_arr = prediction(num_speakers, net, wav_list, chunk_len_list)
297
+ cluster_data = cluster(args, conf, acti_arr, svec_arr, len_arr)
298
+ make_rttm(args, conf, cluster_data)
299
+
300
+ if __name__ == '__main__':
301
+ parser = yamlargparse.ArgumentParser(description='decoding')
302
+ parser.add_argument('--wav_path',
303
+ help='the input wav path',
304
+ default="tmp/mix_0000496.wav")
305
+ parser.add_argument('--config_path',
306
+ help='config file path',
307
+ default="config/infer_est_nspk1.yaml")
308
+ parser.add_argument('--model_init',
309
+ help='model initialize path',
310
+ default="")
311
+ parser.add_argument('--sil_spk_th', default=0.05, type=float)
312
+ parser.add_argument('--ahc_dis_th', default=1.0, type=float)
313
+ parser.add_argument('--clink_dis', default=1.0e+4, type=float)
314
+ parser.add_argument('--session', default='Anonymous', help='the name of the output speaker')
315
+ parser.add_argument('--out_rttm_file', default='out.rttm', help='the output rttm file')
316
+ parser.add_argument('--threshold', default=0.4, type=float)
317
+ parser.add_argument('--median', default=25, type=int)
318
+
319
+
320
+ args = parser.parse_args()
321
+ main(args)
UniSpeech/downstreams/speaker_diarization/models/models.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
2
+ # All rights reserved
3
+
4
+ import sys
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.nn as nn
9
+ import torchaudio.transforms as trans
10
+ from collections import OrderedDict
11
+ from itertools import permutations
12
+ from models.transformer import TransformerEncoder
13
+ from .utils import UpstreamExpert
14
+
15
+
16
+ class GradMultiply(torch.autograd.Function):
17
+ @staticmethod
18
+ def forward(ctx, x, scale):
19
+ ctx.scale = scale
20
+ res = x.new(x)
21
+ return res
22
+
23
+ @staticmethod
24
+ def backward(ctx, grad):
25
+ return grad * ctx.scale, None
26
+
27
+
28
+
29
+ """
30
+ P: number of permutation
31
+ T: number of frames
32
+ C: number of speakers (classes)
33
+ B: mini-batch size
34
+ """
35
+
36
+
37
+ def batch_pit_loss_parallel(outputs, labels, ilens=None):
38
+ """ calculate the batch pit loss parallelly
39
+ Args:
40
+ outputs (torch.Tensor): B x T x C
41
+ labels (torch.Tensor): B x T x C
42
+ ilens (torch.Tensor): B
43
+ Returns:
44
+ perm (torch.Tensor): permutation for outputs (Batch, num_spk)
45
+ loss
46
+ """
47
+
48
+ if ilens is None:
49
+ mask, scale = 1.0, outputs.shape[1]
50
+ else:
51
+ scale = torch.unsqueeze(torch.LongTensor(ilens), 1).to(outputs.device)
52
+ mask = outputs.new_zeros(outputs.size()[:-1])
53
+ for i, chunk_len in enumerate(ilens):
54
+ mask[i, :chunk_len] += 1.0
55
+ mask /= scale
56
+
57
+ def loss_func(output, label):
58
+ # return torch.mean(F.binary_cross_entropy_with_logits(output, label, reduction='none'), dim=tuple(range(1, output.dim())))
59
+ return torch.sum(F.binary_cross_entropy_with_logits(output, label, reduction='none') * mask, dim=-1)
60
+
61
+ def pair_loss(outputs, labels, permutation):
62
+ return sum([loss_func(outputs[:,:,s], labels[:,:,t]) for s, t in enumerate(permutation)]) / len(permutation)
63
+
64
+ device = outputs.device
65
+ num_spk = outputs.shape[-1]
66
+ all_permutations = list(permutations(range(num_spk)))
67
+ losses = torch.stack([pair_loss(outputs, labels, p) for p in all_permutations], dim=1)
68
+ loss, perm = torch.min(losses, dim=1)
69
+ perm = torch.index_select(torch.tensor(all_permutations, device=device, dtype=torch.long), 0, perm)
70
+ return torch.mean(loss), perm
71
+
72
+
73
+ def fix_state_dict(state_dict):
74
+ new_state_dict = OrderedDict()
75
+ for k, v in state_dict.items():
76
+ if k.startswith('module.'):
77
+ # remove 'module.' of DataParallel
78
+ k = k[7:]
79
+ if k.startswith('net.'):
80
+ # remove 'net.' of PadertorchModel
81
+ k = k[4:]
82
+ new_state_dict[k] = v
83
+ return new_state_dict
84
+
85
+
86
+ class TransformerDiarization(nn.Module):
87
+ def __init__(self,
88
+ n_speakers,
89
+ all_n_speakers,
90
+ feat_dim,
91
+ n_units,
92
+ n_heads,
93
+ n_layers,
94
+ dropout_rate,
95
+ spk_emb_dim,
96
+ sr=8000,
97
+ frame_shift=256,
98
+ frame_size=1024,
99
+ context_size=0,
100
+ subsampling=1,
101
+ feat_type='fbank',
102
+ feature_selection='default',
103
+ interpolate_mode='linear',
104
+ update_extract=False,
105
+ feature_grad_mult=1.0
106
+ ):
107
+ super(TransformerDiarization, self).__init__()
108
+ self.context_size = context_size
109
+ self.subsampling = subsampling
110
+ self.feat_type = feat_type
111
+ self.feature_selection = feature_selection
112
+ self.sr = sr
113
+ self.frame_shift = frame_shift
114
+ self.interpolate_mode = interpolate_mode
115
+ self.update_extract = update_extract
116
+ self.feature_grad_mult = feature_grad_mult
117
+
118
+ if feat_type == 'fbank':
119
+ self.feature_extract = trans.MelSpectrogram(sample_rate=sr,
120
+ n_fft=frame_size,
121
+ win_length=frame_size,
122
+ hop_length=frame_shift,
123
+ f_min=0.0,
124
+ f_max=sr // 2,
125
+ pad=0,
126
+ n_mels=feat_dim)
127
+ else:
128
+ self.feature_extract = UpstreamExpert(feat_type)
129
+ # self.feature_extract = torch.hub.load('s3prl/s3prl', 'hubert_local', ckpt=feat_type)
130
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
131
+ self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
132
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
133
+ self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
134
+ self.feat_num = self.get_feat_num()
135
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
136
+ # for param in self.feature_extract.parameters():
137
+ # param.requires_grad = False
138
+ self.resample = trans.Resample(orig_freq=sr, new_freq=16000)
139
+
140
+ if feat_type != 'fbank' and feat_type != 'mfcc':
141
+ freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer', 'spk_proj', 'layer_norm_for_extract']
142
+ for name, param in self.feature_extract.named_parameters():
143
+ for freeze_val in freeze_list:
144
+ if freeze_val in name:
145
+ param.requires_grad = False
146
+ break
147
+ if not self.update_extract:
148
+ for param in self.feature_extract.parameters():
149
+ param.requires_grad = False
150
+
151
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
152
+
153
+ feat_dim = feat_dim * (self.context_size*2 + 1)
154
+ self.enc = TransformerEncoder(
155
+ feat_dim, n_layers, n_units, h=n_heads, dropout_rate=dropout_rate)
156
+ self.linear = nn.Linear(n_units, n_speakers)
157
+
158
+ for i in range(n_speakers):
159
+ setattr(self, '{}{:d}'.format("linear", i), nn.Linear(n_units, spk_emb_dim))
160
+
161
+ self.n_speakers = n_speakers
162
+ self.embed = nn.Embedding(all_n_speakers, spk_emb_dim)
163
+ self.alpha = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0])
164
+ self.beta = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0])
165
+
166
+ def get_feat_num(self):
167
+ self.feature_extract.eval()
168
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
169
+ with torch.no_grad():
170
+ features = self.feature_extract(wav)
171
+ select_feature = features[self.feature_selection]
172
+ if isinstance(select_feature, (list, tuple)):
173
+ return len(select_feature)
174
+ else:
175
+ return 1
176
+
177
+ def fix_except_embedding(self, requires_grad=False):
178
+ for name, param in self.named_parameters():
179
+ if 'embed' not in name:
180
+ param.requires_grad = requires_grad
181
+
182
+ def modfy_emb(self, weight):
183
+ self.embed = nn.Embedding.from_pretrained(weight)
184
+
185
+ def splice(self, data, context_size):
186
+ # data: B x feat_dim x time_len
187
+ data = torch.unsqueeze(data, -1)
188
+ kernel_size = context_size*2 + 1
189
+ splice_data = F.unfold(data, kernel_size=(kernel_size, 1), padding=(context_size, 0))
190
+ return splice_data
191
+
192
+ def get_feat(self, xs):
193
+ wav_len = xs.shape[-1]
194
+ chunk_size = int(wav_len / self.frame_shift)
195
+ chunk_size = int(chunk_size / self.subsampling)
196
+
197
+ self.feature_extract.eval()
198
+ if self.update_extract:
199
+ xs = self.resample(xs)
200
+ feature = self.feature_extract([sample for sample in xs])
201
+ else:
202
+ with torch.no_grad():
203
+ if self.feat_type == 'fbank':
204
+ feature = self.feature_extract(xs) + 1e-6 # B x feat_dim x time_len
205
+ feature = feature.log()
206
+ else:
207
+ xs = self.resample(xs)
208
+ feature = self.feature_extract([sample for sample in xs])
209
+
210
+ if self.feat_type != "fbank" and self.feat_type != "mfcc":
211
+ feature = feature[self.feature_selection]
212
+ if isinstance(feature, (list, tuple)):
213
+ feature = torch.stack(feature, dim=0)
214
+ else:
215
+ feature = feature.unsqueeze(0)
216
+
217
+ norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
218
+ feature = (norm_weights * feature).sum(dim=0)
219
+ feature = torch.transpose(feature, 1, 2) + 1e-6
220
+
221
+ feature = self.instance_norm(feature)
222
+ feature = self.splice(feature, self.context_size)
223
+ feature = feature[:, :, ::self.subsampling]
224
+ feature = F.interpolate(feature, chunk_size, mode=self.interpolate_mode)
225
+ feature = torch.transpose(feature, 1, 2)
226
+
227
+ if self.feature_grad_mult != 1.0:
228
+ feature = GradMultiply.apply(feature, self.feature_grad_mult)
229
+
230
+ return feature
231
+
232
+ def forward(self, inputs):
233
+ if isinstance(inputs, list):
234
+ xs = inputs[0]
235
+ else:
236
+ xs = inputs
237
+ feature = self.get_feat(xs)
238
+
239
+ pad_shape = feature.shape
240
+ emb = self.enc(feature)
241
+ ys = self.linear(emb)
242
+ ys = ys.reshape(pad_shape[0], pad_shape[1], -1)
243
+
244
+ spksvecs = []
245
+ for i in range(self.n_speakers):
246
+ spkivecs = getattr(self, '{}{:d}'.format("linear", i))(emb)
247
+ spkivecs = spkivecs.reshape(pad_shape[0], pad_shape[1], -1)
248
+ spksvecs.append(spkivecs)
249
+
250
+ return ys, spksvecs
251
+
252
+ def get_loss(self, inputs, ys, spksvecs, cal_spk_loss=True):
253
+ ts = inputs[1]
254
+ ss = inputs[2]
255
+ ns = inputs[3]
256
+ ilens = inputs[4]
257
+ ilens = [ilen.item() for ilen in ilens]
258
+
259
+ pit_loss, sigmas = batch_pit_loss_parallel(ys, ts, ilens)
260
+ if cal_spk_loss:
261
+ spk_loss = self.spk_loss_parallel(spksvecs, ys, ts, ss, sigmas, ns, ilens)
262
+ else:
263
+ spk_loss = torch.tensor(0.0).to(pit_loss.device)
264
+
265
+ alpha = torch.clamp(self.alpha, min=sys.float_info.epsilon)
266
+
267
+ return {'spk_loss':spk_loss,
268
+ 'pit_loss': pit_loss}
269
+
270
+
271
+ def batch_estimate(self, xs):
272
+ out = self(xs)
273
+ ys = out[0]
274
+ spksvecs = out[1]
275
+ spksvecs = list(zip(*spksvecs))
276
+ outputs = [
277
+ self.estimate(spksvec, y)
278
+ for (spksvec, y) in zip(spksvecs, ys)]
279
+ outputs = list(zip(*outputs))
280
+
281
+ return outputs
282
+
283
+ def batch_estimate_with_perm(self, xs, ts, ilens=None):
284
+ out = self(xs)
285
+ ys = out[0]
286
+ if ts[0].shape[1] > ys[0].shape[1]:
287
+ # e.g. the case of training 3-spk model with 4-spk data
288
+ add_dim = ts[0].shape[1] - ys[0].shape[1]
289
+ y_device = ys[0].device
290
+ zeros = [torch.zeros(ts[0].shape).to(y_device)
291
+ for i in range(len(ts))]
292
+ _ys = []
293
+ for zero, y in zip(zeros, ys):
294
+ _zero = zero
295
+ _zero[:, :-add_dim] = y
296
+ _ys.append(_zero)
297
+ _, sigmas = batch_pit_loss_parallel(_ys, ts, ilens)
298
+ else:
299
+ _, sigmas = batch_pit_loss_parallel(ys, ts, ilens)
300
+ spksvecs = out[1]
301
+ spksvecs = list(zip(*spksvecs))
302
+ outputs = [self.estimate(spksvec, y)
303
+ for (spksvec, y) in zip(spksvecs, ys)]
304
+ outputs = list(zip(*outputs))
305
+ zs = outputs[0]
306
+
307
+ if ts[0].shape[1] > ys[0].shape[1]:
308
+ # e.g. the case of training 3-spk model with 4-spk data
309
+ add_dim = ts[0].shape[1] - ys[0].shape[1]
310
+ z_device = zs[0].device
311
+ zeros = [torch.zeros(ts[0].shape).to(z_device)
312
+ for i in range(len(ts))]
313
+ _zs = []
314
+ for zero, z in zip(zeros, zs):
315
+ _zero = zero
316
+ _zero[:, :-add_dim] = z
317
+ _zs.append(_zero)
318
+ zs = _zs
319
+ outputs[0] = zs
320
+ outputs.append(sigmas)
321
+
322
+ # outputs: [zs, nmz_wavg_spk0vecs, nmz_wavg_spk1vecs, ..., sigmas]
323
+ return outputs
324
+
325
+ def estimate(self, spksvec, y):
326
+ outputs = []
327
+ z = torch.sigmoid(y.transpose(1, 0))
328
+
329
+ outputs.append(z.transpose(1, 0))
330
+ for spkid, spkvec in enumerate(spksvec):
331
+ norm_spkvec_inv = 1.0 / torch.norm(spkvec, dim=1)
332
+ # Normalize speaker vectors before weighted average
333
+ spkvec = torch.mul(
334
+ spkvec.transpose(1, 0), norm_spkvec_inv
335
+ ).transpose(1, 0)
336
+ wavg_spkvec = torch.mul(
337
+ spkvec.transpose(1, 0), z[spkid]
338
+ ).transpose(1, 0)
339
+ sum_wavg_spkvec = torch.sum(wavg_spkvec, dim=0)
340
+ nmz_wavg_spkvec = sum_wavg_spkvec / torch.norm(sum_wavg_spkvec)
341
+ outputs.append(nmz_wavg_spkvec)
342
+
343
+ # outputs: [z, nmz_wavg_spk0vec, nmz_wavg_spk1vec, ...]
344
+ return outputs
345
+
346
+ def spk_loss_parallel(self, spksvecs, ys, ts, ss, sigmas, ns, ilens):
347
+ '''
348
+ spksvecs (List[torch.Tensor, ...]): [B x T x emb_dim, ...]
349
+ ys (torch.Tensor): B x T x 3
350
+ ts (torch.Tensor): B x T x 3
351
+ ss (torch.Tensor): B x 3
352
+ sigmas (torch.Tensor): B x 3
353
+ ns (torch.Tensor): B x total_spk_num x 1
354
+ ilens (List): B
355
+ '''
356
+ chunk_spk_num = len(spksvecs) # 3
357
+
358
+ len_mask = ys.new_zeros((ys.size()[:-1])) # B x T
359
+ for i, len_val in enumerate(ilens):
360
+ len_mask[i,:len_val] += 1.0
361
+ ts = ts * len_mask.unsqueeze(-1)
362
+ len_mask = len_mask.repeat((chunk_spk_num, 1)) # B*3 x T
363
+
364
+ spk_vecs = torch.cat(spksvecs, dim=0) # B*3 x T x emb_dim
365
+ # Normalize speaker vectors before weighted average
366
+ spk_vecs = F.normalize(spk_vecs, dim=-1)
367
+
368
+ ys = torch.permute(torch.sigmoid(ys), dims=(2, 0, 1)) # 3 x B x T
369
+ ys = ys.reshape(-1, ys.shape[-1]).unsqueeze(-1) # B*3 x T x 1
370
+
371
+ weight_spk_vec = ys * spk_vecs # B*3 x T x emb_dim
372
+ weight_spk_vec *= len_mask.unsqueeze(-1)
373
+ sum_spk_vec = torch.sum(weight_spk_vec, dim=1) # B*3 x emb_dim
374
+ norm_spk_vec = F.normalize(sum_spk_vec, dim=1)
375
+
376
+ embeds = F.normalize(self.embed(ns[0]).squeeze(), dim=1) # total_spk_num x emb_dim
377
+ dist = torch.cdist(norm_spk_vec, embeds) # B*3 x total_spk_num
378
+ logits = -1.0 * torch.add(torch.clamp(self.alpha, min=sys.float_info.epsilon) * torch.pow(dist, 2), self.beta)
379
+ label = torch.gather(ss, 1, sigmas).transpose(0, 1).reshape(-1, 1).squeeze() # B*3
380
+ label[label==-1] = 0
381
+ valid_spk_mask = torch.gather(torch.sum(ts, dim=1), 1, sigmas).transpose(0, 1) # 3 x B
382
+ valid_spk_mask = (torch.flatten(valid_spk_mask) > 0).float() # B*3
383
+
384
+ valid_spk_loss_num = torch.sum(valid_spk_mask).item()
385
+ if valid_spk_loss_num > 0:
386
+ loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_loss_num
387
+ # uncomment the line below, the loss result is same as batch_spk_loss
388
+ # loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_mask.shape[0]
389
+ return torch.sum(loss)
390
+ else:
391
+ return torch.tensor(0.0).to(ys.device)
UniSpeech/downstreams/speaker_diarization/models/transformer.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Nippon Telegraph and Telephone corporation (NTT).
2
+ # All rights reserved
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.nn as nn
8
+ from torch.optim.lr_scheduler import _LRScheduler
9
+
10
+
11
+ class NoamScheduler(_LRScheduler):
12
+ """ learning rate scheduler used in the transformer
13
+ See https://arxiv.org/pdf/1706.03762.pdf
14
+ lrate = d_model**(-0.5) * \
15
+ min(step_num**(-0.5), step_num*warmup_steps**(-1.5))
16
+ Scaling factor is implemented as in
17
+ http://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer
18
+ """
19
+
20
+ def __init__(
21
+ self, optimizer, d_model, warmup_steps, tot_step, scale,
22
+ last_epoch=-1
23
+ ):
24
+ self.d_model = d_model
25
+ self.warmup_steps = warmup_steps
26
+ self.tot_step = tot_step
27
+ self.scale = scale
28
+ super(NoamScheduler, self).__init__(optimizer, last_epoch)
29
+
30
+ def get_lr(self):
31
+ self.last_epoch = max(1, self.last_epoch)
32
+ step_num = self.last_epoch
33
+ val = self.scale * self.d_model ** (-0.5) * \
34
+ min(step_num ** (-0.5), step_num * self.warmup_steps ** (-1.5))
35
+
36
+ return [base_lr / base_lr * val for base_lr in self.base_lrs]
37
+
38
+
39
+ class MultiHeadSelfAttention(nn.Module):
40
+ """ Multi head "self" attention layer
41
+ """
42
+
43
+ def __init__(self, n_units, h=8, dropout_rate=0.1):
44
+ super(MultiHeadSelfAttention, self).__init__()
45
+ self.linearQ = nn.Linear(n_units, n_units)
46
+ self.linearK = nn.Linear(n_units, n_units)
47
+ self.linearV = nn.Linear(n_units, n_units)
48
+ self.linearO = nn.Linear(n_units, n_units)
49
+ self.d_k = n_units // h
50
+ self.h = h
51
+ self.dropout = nn.Dropout(p=dropout_rate)
52
+ # attention for plot
53
+ self.att = None
54
+
55
+ def forward(self, x, batch_size):
56
+ # x: (BT, F)
57
+ q = self.linearQ(x).reshape(batch_size, -1, self.h, self.d_k)
58
+ k = self.linearK(x).reshape(batch_size, -1, self.h, self.d_k)
59
+ v = self.linearV(x).reshape(batch_size, -1, self.h, self.d_k)
60
+
61
+ scores = torch.matmul(
62
+ q.transpose(1, 2), k.permute(0, 2, 3, 1)) / np.sqrt(self.d_k)
63
+ # scores: (B, h, T, T) = (B, h, T, d_k) x (B, h, d_k, T)
64
+ self.att = F.softmax(scores, dim=3)
65
+ p_att = self.dropout(self.att)
66
+ x = torch.matmul(p_att, v.transpose(1, 2))
67
+ x = x.transpose(1, 2).reshape(-1, self.h * self.d_k)
68
+
69
+ return self.linearO(x)
70
+
71
+
72
+ class PositionwiseFeedForward(nn.Module):
73
+ """ Positionwise feed-forward layer
74
+ """
75
+
76
+ def __init__(self, n_units, d_units, dropout_rate):
77
+ super(PositionwiseFeedForward, self).__init__()
78
+ self.linear1 = nn.Linear(n_units, d_units)
79
+ self.linear2 = nn.Linear(d_units, n_units)
80
+ self.dropout = nn.Dropout(p=dropout_rate)
81
+
82
+ def forward(self, x):
83
+ return self.linear2(self.dropout(F.relu(self.linear1(x))))
84
+
85
+
86
+ class PositionalEncoding(nn.Module):
87
+ """ Positional encoding function
88
+ """
89
+
90
+ def __init__(self, n_units, dropout_rate, max_len):
91
+ super(PositionalEncoding, self).__init__()
92
+ self.dropout = nn.Dropout(p=dropout_rate)
93
+ positions = np.arange(0, max_len, dtype='f')[:, None]
94
+ dens = np.exp(
95
+ np.arange(0, n_units, 2, dtype='f') * -(np.log(10000.) / n_units))
96
+ self.enc = np.zeros((max_len, n_units), dtype='f')
97
+ self.enc[:, ::2] = np.sin(positions * dens)
98
+ self.enc[:, 1::2] = np.cos(positions * dens)
99
+ self.scale = np.sqrt(n_units)
100
+
101
+ def forward(self, x):
102
+ x = x * self.scale + self.xp.array(self.enc[:, :x.shape[1]])
103
+ return self.dropout(x)
104
+
105
+
106
+ class TransformerEncoder(nn.Module):
107
+ def __init__(self, idim, n_layers, n_units,
108
+ e_units=2048, h=8, dropout_rate=0.1):
109
+ super(TransformerEncoder, self).__init__()
110
+ self.linear_in = nn.Linear(idim, n_units)
111
+ # self.lnorm_in = nn.LayerNorm(n_units)
112
+ self.pos_enc = PositionalEncoding(n_units, dropout_rate, 5000)
113
+ self.n_layers = n_layers
114
+ self.dropout = nn.Dropout(p=dropout_rate)
115
+ for i in range(n_layers):
116
+ setattr(self, '{}{:d}'.format("lnorm1_", i),
117
+ nn.LayerNorm(n_units))
118
+ setattr(self, '{}{:d}'.format("self_att_", i),
119
+ MultiHeadSelfAttention(n_units, h, dropout_rate))
120
+ setattr(self, '{}{:d}'.format("lnorm2_", i),
121
+ nn.LayerNorm(n_units))
122
+ setattr(self, '{}{:d}'.format("ff_", i),
123
+ PositionwiseFeedForward(n_units, e_units, dropout_rate))
124
+ self.lnorm_out = nn.LayerNorm(n_units)
125
+
126
+ def forward(self, x):
127
+ # x: (B, T, F) ... batch, time, (mel)freq
128
+ BT_size = x.shape[0] * x.shape[1]
129
+ # e: (BT, F)
130
+ e = self.linear_in(x.reshape(BT_size, -1))
131
+ # Encoder stack
132
+ for i in range(self.n_layers):
133
+ # layer normalization
134
+ e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e)
135
+ # self-attention
136
+ s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0])
137
+ # residual
138
+ e = e + self.dropout(s)
139
+ # layer normalization
140
+ e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e)
141
+ # positionwise feed-forward
142
+ s = getattr(self, '{}{:d}'.format("ff_", i))(e)
143
+ # residual
144
+ e = e + self.dropout(s)
145
+ # final layer normalization
146
+ # output: (BT, F)
147
+ return self.lnorm_out(e)
UniSpeech/downstreams/speaker_diarization/models/utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import fairseq
3
+ from packaging import version
4
+ import torch.nn.functional as F
5
+ from fairseq import tasks
6
+ from fairseq.checkpoint_utils import load_checkpoint_to_cpu
7
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
8
+ from omegaconf import OmegaConf
9
+ from s3prl.upstream.interfaces import UpstreamBase
10
+ from torch.nn.utils.rnn import pad_sequence
11
+
12
+ def load_model(filepath):
13
+ state = torch.load(filepath, map_location=lambda storage, loc: storage)
14
+ # state = load_checkpoint_to_cpu(filepath)
15
+ state["cfg"] = OmegaConf.create(state["cfg"])
16
+
17
+ if "args" in state and state["args"] is not None:
18
+ cfg = convert_namespace_to_omegaconf(state["args"])
19
+ elif "cfg" in state and state["cfg"] is not None:
20
+ cfg = state["cfg"]
21
+ else:
22
+ raise RuntimeError(
23
+ f"Neither args nor cfg exist in state keys = {state.keys()}"
24
+ )
25
+
26
+ task = tasks.setup_task(cfg.task)
27
+ if "task_state" in state:
28
+ task.load_state_dict(state["task_state"])
29
+
30
+ model = task.build_model(cfg.model)
31
+
32
+ return model, cfg, task
33
+
34
+
35
+ ###################
36
+ # UPSTREAM EXPERT #
37
+ ###################
38
+ class UpstreamExpert(UpstreamBase):
39
+ def __init__(self, ckpt, **kwargs):
40
+ super().__init__(**kwargs)
41
+ assert version.parse(fairseq.__version__) > version.parse(
42
+ "0.10.2"
43
+ ), "Please install the fairseq master branch."
44
+
45
+ model, cfg, task = load_model(ckpt)
46
+ self.model = model
47
+ self.task = task
48
+
49
+ if len(self.hooks) == 0:
50
+ module_name = "self.model.encoder.layers"
51
+ for module_id in range(len(eval(module_name))):
52
+ self.add_hook(
53
+ f"{module_name}[{module_id}]",
54
+ lambda input, output: input[0].transpose(0, 1),
55
+ )
56
+ self.add_hook("self.model.encoder", lambda input, output: output[0])
57
+
58
+ def forward(self, wavs):
59
+ if self.task.cfg.normalize:
60
+ wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
61
+
62
+ device = wavs[0].device
63
+ wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
64
+ wav_padding_mask = ~torch.lt(
65
+ torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
66
+ wav_lengths.unsqueeze(1),
67
+ )
68
+ padded_wav = pad_sequence(wavs, batch_first=True)
69
+
70
+ features, feat_padding_mask = self.model.extract_features(
71
+ padded_wav,
72
+ padding_mask=wav_padding_mask,
73
+ mask=None,
74
+ )
75
+ return {
76
+ "default": features,
77
+ }
78
+
UniSpeech/downstreams/speaker_diarization/requirements.txt ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SoundFile==0.10.3.post1 \
2
+ --hash=sha256:2d17e0a6fc2af0d6c1d868bafa5ec80aae6e186a97fec8db07ad6af29842fbc7 \
3
+ --hash=sha256:4555438c2c4f02b39fea2ed40f6ddeda88a80cd1ee9dd129be4d5f5134698cc2 \
4
+ --hash=sha256:490cff42650733d1832728b937fe99fa1802896f5ef4d61bcf78cf7ebecb107b \
5
+ --hash=sha256:5e342ee293b896d31da67617fe65d0bdca217af193991b0cb6052353b1e0e506 \
6
+ --hash=sha256:b361d4ac1519a2e516cabafa6bf7e93492f999f35d7d25350cd87fdc3e5cb27e
7
+ fire==0.4.0 \
8
+ --hash=sha256:c5e2b8763699d1142393a46d0e3e790c5eb2f0706082df8f647878842c216a62
9
+ sentencepiece==0.1.96 \
10
+ --hash=sha256:1dac8c2ad02b5ebc1179c0a14cbc7d7c6f4fd73d4dd51820626402d0aefc974e \
11
+ --hash=sha256:26d20d713b3ba1b7a19205336afb1e93a4327c372b2f795e907b8dc2315ac92e \
12
+ --hash=sha256:335bf84d72112cc91f3c3b691d61802fc963503b7772fd8280d20368048b8f3e \
13
+ --hash=sha256:36e9ff61e7b67c5b7ee96733613622620b4802fc8cf188a4dbc1f355b03dde02 \
14
+ --hash=sha256:384148cead5cdab34a4d74fe1fb6a5a8abaafed25eaa4a7698b49dd9482e4c4e \
15
+ --hash=sha256:3c703e68ea192e45b65c5d5836f6980849d828a18da4189899d7150fad82dc9e \
16
+ --hash=sha256:3e61e0757e49c306fff78ea75d6b75773418fe22214b4a460959203be934e834 \
17
+ --hash=sha256:466e381f0a812da8fda97a9707498cef3210ea8385a3421bcbadcb5384063969 \
18
+ --hash=sha256:48c6d13b3bfff08060c138248e85df60f6fad11135ad7a8fc2ef6005aacca839 \
19
+ --hash=sha256:4997c7ccf2ae462320250314aa5709a88d8a09fa271d073458a07bebf33f8e7c \
20
+ --hash=sha256:5388882bb24d083f6cc8cffc5c435f3694a7772b018e06ea6fd84d1044009efb \
21
+ --hash=sha256:5513298d62fe63dd0862d08a6eb52a9aa3537006f597f2386184e3f95bb88889 \
22
+ --hash=sha256:78e18d9106c36dcca929e18fd2c412378deac661d47fa3ee25defc55eef8a215 \
23
+ --hash=sha256:8179785883b556cd517416cdbda6244745414b00ec83132cfe1d26000971f3ae \
24
+ --hash=sha256:81bb77ba3651114943b2f8f77829cf764137dff06e38f4bf7fa43efea12c7f84 \
25
+ --hash=sha256:89c038da7f827a6e2ca4c73aeb4e4b25b99d981ce47dd61b04d446c8200cba1e \
26
+ --hash=sha256:940a6999c7d3f55e9d7b194fd5e1f41a7dbed26d3519fb95333216292a39599e \
27
+ --hash=sha256:99ea2d9db19e63a2d17d5dc64f9ace83fb9308a735be05a1aaf98eb4b496fba7 \
28
+ --hash=sha256:9bdf097d5bd1d8ce42dfee51f6ff05f5578b96e48c6f6006aa4eff69edfa3639 \
29
+ --hash=sha256:a336575463d75d3aac1f7e32470b8998643ccd9a73786bd726f6b0470520b6b4 \
30
+ --hash=sha256:a697257a2cd7581732d7741a8d32a06927f0311c3d277dbc47fa1043350c9d17 \
31
+ --hash=sha256:a92e1932ee8fd500680ccbe1bf53eb33228f4c9d6524ed6f300bcc80ac359f27 \
32
+ --hash=sha256:aeb090ad462833df03af1debce4ae607a2766ef861f992003ad0c56d074ab805 \
33
+ --hash=sha256:b1c24c1d9405b2148184ff27c062493d5e3be5c144575f95b5a0d7c660a515af \
34
+ --hash=sha256:b77d27f59d515c43b61745b8173fbe7c7b3014b14b3702a75bf1793471e7def6 \
35
+ --hash=sha256:b8b1dd2712f8a7de5b4c8ec912e6c041d25750bf03e1ce325cdba43bae0944ae \
36
+ --hash=sha256:bedf0355117fb4e9b1fc9fc92b4d5ee743a7d468be9f6196e3b94447710ea589 \
37
+ --hash=sha256:cc969e6694fb27fba7cee2953f350804faf03913f25ae1ee713a7b8a1bc08018 \
38
+ --hash=sha256:d45e3f78e746aa161bc9f5a31c6a2839c512101113a4065f4d2e7a3ab8198d8c \
39
+ --hash=sha256:d501713a8396193883aa526f48dc609f5f031a5df1afbafa561cf9ab492ffc76 \
40
+ --hash=sha256:d954d25a8705f972e8bfc1dea5464d7e697dd6f4ade092f1a487387e6d6c829a \
41
+ --hash=sha256:dadccb2e49244b6e64b4527d13ec14d5e094a90b41cf9b963e457e64182f1941 \
42
+ --hash=sha256:e811984b0908c14c56de7d8226fdd494d87a7ccb75af8ac3a07423037aaafc35 \
43
+ --hash=sha256:e88354b61f59dfdeb41023f7be8ae31dc627c2dc2dacbc2de8b2d82a0997135c \
44
+ --hash=sha256:e8ec5bb6777e2060e1499750c50e1b69dca5a0f80f90f2c66656c5f3e5244593 \
45
+ --hash=sha256:e9e9fe8094ca57549d801e9a2017ac5c24108bbf485ea4f8994a72e8e96ee135 \
46
+ --hash=sha256:eba0471ab0bb2e07ed06d91ecf5185d402c83d194155a41d8e2aa547d187712e \
47
+ --hash=sha256:ef59ba19340dc1d002ce5713b911c0ef23c577b08f8ed57998ee3c8e62c5bf6e \
48
+ --hash=sha256:f8c90df663cd9759b2cf8dd29998b63140ac39e51ada2e739dc13bdac0b4f001 \
49
+ --hash=sha256:f8cb24d8d0b2f8b7463815a59183eb81ec1d7a06e3217bed456063f3303eddfb \
50
+ --hash=sha256:fd907a8f744e5337de7fc532dd800c4416b571ea47f8c3c66be10cd1bc67c925 \
51
+ --hash=sha256:ff7d752a7f82d87711ec1a95c2262cb74f98be5b457f0300d81a1aefe5be2a95
52
+ tqdm==4.62.0 \
53
+ --hash=sha256:3642d483b558eec80d3c831e23953582c34d7e4540db86d9e5ed9dad238dabc6 \
54
+ --hash=sha256:706dea48ee05ba16e936ee91cb3791cd2ea6da348a0e50b46863ff4363ff4340
55
+ PyYAML==5.4.1 \
56
+ --hash=sha256:08682f6b72c722394747bddaf0aa62277e02557c0fd1c42cb853016a38f8dedf \
57
+ --hash=sha256:0f5f5786c0e09baddcd8b4b45f20a7b5d61a7e7e99846e3c799b05c7c53fa696 \
58
+ --hash=sha256:129def1b7c1bf22faffd67b8f3724645203b79d8f4cc81f674654d9902cb4393 \
59
+ --hash=sha256:294db365efa064d00b8d1ef65d8ea2c3426ac366c0c4368d930bf1c5fb497f77 \
60
+ --hash=sha256:3b2b1824fe7112845700f815ff6a489360226a5609b96ec2190a45e62a9fc922 \
61
+ --hash=sha256:3bd0e463264cf257d1ffd2e40223b197271046d09dadf73a0fe82b9c1fc385a5 \
62
+ --hash=sha256:4465124ef1b18d9ace298060f4eccc64b0850899ac4ac53294547536533800c8 \
63
+ --hash=sha256:49d4cdd9065b9b6e206d0595fee27a96b5dd22618e7520c33204a4a3239d5b10 \
64
+ --hash=sha256:4e0583d24c881e14342eaf4ec5fbc97f934b999a6828693a99157fde912540cc \
65
+ --hash=sha256:5accb17103e43963b80e6f837831f38d314a0495500067cb25afab2e8d7a4018 \
66
+ --hash=sha256:607774cbba28732bfa802b54baa7484215f530991055bb562efbed5b2f20a45e \
67
+ --hash=sha256:6c78645d400265a062508ae399b60b8c167bf003db364ecb26dcab2bda048253 \
68
+ --hash=sha256:72a01f726a9c7851ca9bfad6fd09ca4e090a023c00945ea05ba1638c09dc3347 \
69
+ --hash=sha256:74c1485f7707cf707a7aef42ef6322b8f97921bd89be2ab6317fd782c2d53183 \
70
+ --hash=sha256:895f61ef02e8fed38159bb70f7e100e00f471eae2bc838cd0f4ebb21e28f8541 \
71
+ --hash=sha256:8c1be557ee92a20f184922c7b6424e8ab6691788e6d86137c5d93c1a6ec1b8fb \
72
+ --hash=sha256:bb4191dfc9306777bc594117aee052446b3fa88737cd13b7188d0e7aa8162185 \
73
+ --hash=sha256:bfb51918d4ff3d77c1c856a9699f8492c612cde32fd3bcd344af9be34999bfdc \
74
+ --hash=sha256:c20cfa2d49991c8b4147af39859b167664f2ad4561704ee74c1de03318e898db \
75
+ --hash=sha256:cb333c16912324fd5f769fff6bc5de372e9e7a202247b48870bc251ed40239aa \
76
+ --hash=sha256:d2d9808ea7b4af864f35ea216be506ecec180628aced0704e34aca0b040ffe46 \
77
+ --hash=sha256:d483ad4e639292c90170eb6f7783ad19490e7a8defb3e46f97dfe4bacae89122 \
78
+ --hash=sha256:dd5de0646207f053eb0d6c74ae45ba98c3395a571a2891858e87df7c9b9bd51b \
79
+ --hash=sha256:e1d4970ea66be07ae37a3c2e48b5ec63f7ba6804bdddfdbd3cfd954d25a82e63 \
80
+ --hash=sha256:e4fac90784481d221a8e4b1162afa7c47ed953be40d31ab4629ae917510051df \
81
+ --hash=sha256:fa5ae20527d8e831e8230cbffd9f8fe952815b2b7dae6ffec25318803a7528fc \
82
+ --hash=sha256:fd7f6999a8070df521b6384004ef42833b9bd62cfee11a09bda1079b4b704247 \
83
+ --hash=sha256:fdc842473cd33f45ff6bce46aea678a54e3d21f1b61a7750ce3c498eedfe25d6 \
84
+ --hash=sha256:fe69978f3f768926cfa37b867e3843918e012cf83f680806599ddce33c2c68b0
85
+ h5py==3.3.0 \
86
+ --hash=sha256:09e78cefdef0b7566ab66366c5c7d9984c7b23142245bd51b82b744ad1eebf65 \
87
+ --hash=sha256:13355234c004ff8bd819f7d3420188aa1936b17d7f8470d622974a373421b7a5 \
88
+ --hash=sha256:5e2f22e66a3fb1815405cfe5711670450c973b8552507c535a546a23a468af3d \
89
+ --hash=sha256:7ca7d23ebbdd59a4be9b4820de52fe67adc74e6a44d5084881305461765aac47 \
90
+ --hash=sha256:89d7e10409b62fed81c571e35798763cb8375442b98f8ebfc52ba41ac019e081 \
91
+ --hash=sha256:8e09b682e4059c8cd259ddcc34bee35d639b9170105efeeae6ad195e7c1cea7a \
92
+ --hash=sha256:baef1a2cdef287a83e7f95ce9e0f4d762a9852fe7117b471063442c78b973695 \
93
+ --hash=sha256:e0dac887d779929778b3cfd13309a939359cc9e74756fc09af7c527a82797186 \
94
+ --hash=sha256:e0ea3330bf136f8213e43db67448994046ce501585dddc7ea4e8ceef0ef1600c \
95
+ --hash=sha256:f3bba8ffddd1fd2bf06127c5ff7b73f022cc1c8b7164355ddc760dc3f8570136
96
+ yamlargparse==1.31.1 \
97
+ --hash=sha256:2c09fc8e20c147d074f765512b880757a6fea669d57a3dc672a5e1be6c68c667
98
+ sklearn==0.0 \
99
+ --hash=sha256:e23001573aa194b834122d2b9562459bf5ae494a2d59ca6b8aa22c85a44c0e31
100
+ matplotlib==3.4.2 \
101
+ --hash=sha256:0bea5ec5c28d49020e5d7923c2725b837e60bc8be99d3164af410eb4b4c827da \
102
+ --hash=sha256:1c1779f7ab7d8bdb7d4c605e6ffaa0614b3e80f1e3c8ccf7b9269a22dbc5986b \
103
+ --hash=sha256:21b31057bbc5e75b08e70a43cefc4c0b2c2f1b1a850f4a0f7af044eb4163086c \
104
+ --hash=sha256:32fa638cc10886885d1ca3d409d4473d6a22f7ceecd11322150961a70fab66dd \
105
+ --hash=sha256:3a5c18dbd2c7c366da26a4ad1462fe3e03a577b39e3b503bbcf482b9cdac093c \
106
+ --hash=sha256:5826f56055b9b1c80fef82e326097e34dc4af8c7249226b7dd63095a686177d1 \
107
+ --hash=sha256:6382bc6e2d7e481bcd977eb131c31dee96e0fb4f9177d15ec6fb976d3b9ace1a \
108
+ --hash=sha256:6475d0209024a77f869163ec3657c47fed35d9b6ed8bccba8aa0f0099fbbdaa8 \
109
+ --hash=sha256:6a6a44f27aabe720ec4fd485061e8a35784c2b9ffa6363ad546316dfc9cea04e \
110
+ --hash=sha256:7a58f3d8fe8fac3be522c79d921c9b86e090a59637cb88e3bc51298d7a2c862a \
111
+ --hash=sha256:7ad19f3fb6145b9eb41c08e7cbb9f8e10b91291396bee21e9ce761bb78df63ec \
112
+ --hash=sha256:85f191bb03cb1a7b04b5c2cca4792bef94df06ef473bc49e2818105671766fee \
113
+ --hash=sha256:956c8849b134b4a343598305a3ca1bdd3094f01f5efc8afccdebeffe6b315247 \
114
+ --hash=sha256:a9d8cb5329df13e0cdaa14b3b43f47b5e593ec637f13f14db75bb16e46178b05 \
115
+ --hash=sha256:b1d5a2cedf5de05567c441b3a8c2651fbde56df08b82640e7f06c8cd91e201f6 \
116
+ --hash=sha256:b26535b9de85326e6958cdef720ecd10bcf74a3f4371bf9a7e5b2e659c17e153 \
117
+ --hash=sha256:c541ee5a3287efe066bbe358320853cf4916bc14c00c38f8f3d8d75275a405a9 \
118
+ --hash=sha256:d8d994cefdff9aaba45166eb3de4f5211adb4accac85cbf97137e98f26ea0219 \
119
+ --hash=sha256:df815378a754a7edd4559f8c51fc7064f779a74013644a7f5ac7a0c31f875866
120
+ torchaudio==0.9.0 \
121
+ --hash=sha256:0a387e78eeaf6e0abd36df70e9d8a15d242b49c2507dbd9522568f5d4af5fb96 \
122
+ --hash=sha256:18763c05cb7d85a08b8ea960e40f6984e9513b02e76f4526d920493c701b0671 \
123
+ --hash=sha256:48e33bb96b7ff2dc10a778a695429dbd6dfc8c8baa0d7c9b63569cb002bb87cd \
124
+ --hash=sha256:62fd9393ddbe40aadaabef7595f5bff0057e39f7e519195a010731542815f5a4 \
125
+ --hash=sha256:76a5b8ea0e4ddafd5b8f24abdf1a6f7afe847d892570da13cf0fc9bceeac437f \
126
+ --hash=sha256:87520525da10b5f00d3e5e1180db6ee37b1fa305edb2260c7335e0859dbe634e \
127
+ --hash=sha256:9d3f5d6df7d91676e67a38a448253b74d77da723f8e24bd833ff7ed0f82fa4ef \
128
+ --hash=sha256:acf0d736a5c1ea6b94adf08b0a31670009b6e78dfe50a1b0bdabf2b0f7895dc0 \
129
+ --hash=sha256:ad221258fc5d1d446f2c1ce9a1bb54cc05ca2b208491d4eaa5af443f1c0f16a2 \
130
+ --hash=sha256:ba52ae64611773bec7fc664c29f9ea3e02c9e5c817693726b978ed1bdedd07f2 \
131
+ --hash=sha256:c6126556d529df73b676e023063388d551be3c0cb2d42a4ff5c4cfd44ef3e012 \
132
+ --hash=sha256:ef5f0b22646a94f95869001b40ab940468b1ae399d0ffd3bc73d5c43342a013a \
133
+ --hash=sha256:ef8dc4ab1ec807382a713e71e8493d1985930537c933273e3c0739f02183cedc \
134
+ --hash=sha256:efb16c593b2a5ada07b180580c7612617e84f4714ce86928ad54baefe71ef29d
135
+ s3prl==0.3.1 \
136
+ --hash=sha256:e497989b10d4e058b619cf3e7a547820fceb3fe18c14c566427eb7b8c770d62e
UniSpeech/downstreams/speaker_diarization/tmp/mix_0000496.wav ADDED
@@ -0,0 +1 @@
 
 
1
+ /mnt/lustre/sjtu/home/czy97/workspace/sd/EEND-vec-clustering/EEND-vector-clustering/egs/mini_librispeech/v1/data/simu/wav/dev_clean_2_ns3_beta2_500/100/mix_0000496.wav
UniSpeech/downstreams/speaker_diarization/utils/dataset.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*- #
2
+ """*********************************************************************************************"""
3
+ # FileName [ dataset.py ]
4
+ # Synopsis [ the speaker diarization dataset ]
5
+ # Source [ Refactored from https://github.com/hitachi-speech/EEND ]
6
+ # Author [ Jiatong Shi ]
7
+ # Copyright [ Copyright(c), Johns Hopkins University ]
8
+ """*********************************************************************************************"""
9
+
10
+
11
+ ###############
12
+ # IMPORTATION #
13
+ ###############
14
+ import io
15
+ import os
16
+ import subprocess
17
+ import sys
18
+
19
+ # -------------#
20
+ import numpy as np
21
+ import soundfile as sf
22
+ import torch
23
+ from torch.nn.utils.rnn import pad_sequence
24
+
25
+ # -------------#
26
+ from torch.utils.data.dataset import Dataset
27
+ # -------------#
28
+
29
+
30
+ def _count_frames(data_len, size, step):
31
+ # no padding at edges, last remaining samples are ignored
32
+ return int((data_len - size + step) / step)
33
+
34
+
35
+ def _gen_frame_indices(data_length, size=2000, step=2000):
36
+ i = -1
37
+ for i in range(_count_frames(data_length, size, step)):
38
+ yield i * step, i * step + size
39
+
40
+ if i * step + size < data_length:
41
+ if data_length - (i + 1) * step > 0:
42
+ if i == -1:
43
+ yield (i + 1) * step, data_length
44
+ else:
45
+ yield data_length - size, data_length
46
+
47
+
48
+ def _gen_chunk_indices(data_len, chunk_size):
49
+ step = chunk_size
50
+ start = 0
51
+ while start < data_len:
52
+ end = min(data_len, start + chunk_size)
53
+ yield start, end
54
+ start += step
55
+
56
+
57
+ #######################
58
+ # Diarization Dataset #
59
+ #######################
60
+ class DiarizationDataset(Dataset):
61
+ def __init__(
62
+ self,
63
+ mode,
64
+ data_dir,
65
+ chunk_size=2000,
66
+ frame_shift=256,
67
+ sampling_rate=16000,
68
+ subsampling=1,
69
+ use_last_samples=True,
70
+ num_speakers=3,
71
+ filter_spk=False
72
+ ):
73
+ super(DiarizationDataset, self).__init__()
74
+
75
+ self.mode = mode
76
+ self.data_dir = data_dir
77
+ self.chunk_size = chunk_size
78
+ self.frame_shift = frame_shift
79
+ self.subsampling = subsampling
80
+ self.n_speakers = num_speakers
81
+ self.chunk_indices = [] if mode != "test" else {}
82
+
83
+ self.data = KaldiData(self.data_dir)
84
+ self.all_speakers = sorted(self.data.spk2utt.keys())
85
+ self.all_n_speakers = len(self.all_speakers)
86
+
87
+ # make chunk indices: filepath, start_frame, end_frame
88
+ for rec in self.data.wavs:
89
+ data_len = int(self.data.reco2dur[rec] * sampling_rate / frame_shift)
90
+ data_len = int(data_len / self.subsampling)
91
+ if mode == "test":
92
+ self.chunk_indices[rec] = []
93
+ if mode != "test":
94
+ for st, ed in _gen_frame_indices(data_len, chunk_size, chunk_size):
95
+ self.chunk_indices.append(
96
+ (rec, st * self.subsampling, ed * self.subsampling)
97
+ )
98
+ else:
99
+ for st, ed in _gen_chunk_indices(data_len, chunk_size):
100
+ self.chunk_indices[rec].append(
101
+ (rec, st, ed)
102
+ )
103
+
104
+ if mode != "test":
105
+ if filter_spk:
106
+ self.filter_spk()
107
+ print(len(self.chunk_indices), " chunks")
108
+ else:
109
+ self.rec_list = list(self.chunk_indices.keys())
110
+ print(len(self.rec_list), " recordings")
111
+
112
+ def __len__(self):
113
+ return (
114
+ len(self.rec_list)
115
+ if type(self.chunk_indices) == dict
116
+ else len(self.chunk_indices)
117
+ )
118
+
119
+ def filter_spk(self):
120
+ # filter the spk in spk2utt but will not be used in training
121
+ # i.e. the chunks contains more spk than self.n_speakers
122
+ occur_spk_set = set()
123
+
124
+ new_chunk_indices = [] # filter the chunk that more than self.num_speakers
125
+ for idx in range(self.__len__()):
126
+ rec, st, ed = self.chunk_indices[idx]
127
+
128
+ filtered_segments = self.data.segments[rec]
129
+ # all the speakers in this recording not the chunk
130
+ speakers = np.unique(
131
+ [self.data.utt2spk[seg['utt']] for seg in filtered_segments]
132
+ ).tolist()
133
+ n_speakers = self.n_speakers
134
+ # we assume that in each chunk the speaker number is less or equal than self.n_speakers
135
+ # but the speaker number in the whole recording may exceed self.n_speakers
136
+ if self.n_speakers < len(speakers):
137
+ n_speakers = len(speakers)
138
+
139
+ # Y: (length,), T: (frame_num, n_speakers)
140
+ Y, T = self._get_labeled_speech(rec, st, ed, n_speakers)
141
+ # the spk index exist in this chunk data
142
+ exist_spk_idx = np.sum(T, axis=0) > 0.5 # bool index
143
+ chunk_spk_num = np.sum(exist_spk_idx)
144
+ if chunk_spk_num <= self.n_speakers:
145
+ spk_arr = np.array(speakers)
146
+ valid_spk_arr = spk_arr[exist_spk_idx[:spk_arr.shape[0]]]
147
+ for spk in valid_spk_arr:
148
+ occur_spk_set.add(spk)
149
+
150
+ new_chunk_indices.append((rec, st, ed))
151
+ self.chunk_indices = new_chunk_indices
152
+ self.all_speakers = sorted(list(occur_spk_set))
153
+ self.all_n_speakers = len(self.all_speakers)
154
+
155
+ def __getitem__(self, i):
156
+ if self.mode != "test":
157
+ rec, st, ed = self.chunk_indices[i]
158
+
159
+ filtered_segments = self.data.segments[rec]
160
+ # all the speakers in this recording not the chunk
161
+ speakers = np.unique(
162
+ [self.data.utt2spk[seg['utt']] for seg in filtered_segments]
163
+ ).tolist()
164
+ n_speakers = self.n_speakers
165
+ # we assume that in each chunk the speaker number is less or equal than self.n_speakers
166
+ # but the speaker number in the whole recording may exceed self.n_speakers
167
+ if self.n_speakers < len(speakers):
168
+ n_speakers = len(speakers)
169
+
170
+ # Y: (length,), T: (frame_num, n_speakers)
171
+ Y, T = self._get_labeled_speech(rec, st, ed, n_speakers)
172
+ # the spk index exist in this chunk data
173
+ exist_spk_idx = np.sum(T, axis=0) > 0.5 # bool index
174
+ chunk_spk_num = np.sum(exist_spk_idx)
175
+ if chunk_spk_num > self.n_speakers:
176
+ # the speaker number in a chunk exceed our pre-set value
177
+ return None, None, None
178
+
179
+ # the map from within recording speaker index to global speaker index
180
+ S_arr = -1 * np.ones(n_speakers).astype(np.int64)
181
+ for seg in filtered_segments:
182
+ speaker_index = speakers.index(self.data.utt2spk[seg['utt']])
183
+ try:
184
+ all_speaker_index = self.all_speakers.index(
185
+ self.data.utt2spk[seg['utt']])
186
+ except:
187
+ # we have pre-filter some spk in self.filter_spk
188
+ all_speaker_index = -1
189
+ S_arr[speaker_index] = all_speaker_index
190
+ # If T[:, n_speakers - 1] == 0.0, then S_arr[n_speakers - 1] == -1,
191
+ # so S_arr[n_speakers - 1] is not used for training,
192
+ # e.g., in the case of training 3-spk model with 2-spk data
193
+
194
+ # filter the speaker not exist in this chunk and ensure there are self.num_speakers outputs
195
+ T_exist = T[:,exist_spk_idx]
196
+ T = np.zeros((T_exist.shape[0], self.n_speakers), dtype=np.int32)
197
+ T[:,:T_exist.shape[1]] = T_exist
198
+ # subsampling for Y will be done in the model forward function
199
+ T = T[::self.subsampling]
200
+
201
+ S_arr_exist = S_arr[exist_spk_idx]
202
+ S_arr = -1 * np.ones(self.n_speakers).astype(np.int64)
203
+ S_arr[:S_arr_exist.shape[0]] = S_arr_exist
204
+
205
+ n = np.arange(self.all_n_speakers, dtype=np.int64).reshape(self.all_n_speakers, 1)
206
+ return Y, T, S_arr, n, T.shape[0]
207
+ else:
208
+ len_ratio = self.frame_shift * self.subsampling
209
+ chunks = self.chunk_indices[self.rec_list[i]]
210
+ Ys = []
211
+ chunk_len_list = []
212
+ for (rec, st, ed) in chunks:
213
+ chunk_len = ed - st
214
+ if chunk_len != self.chunk_size:
215
+ st = max(0, ed - self.chunk_size)
216
+ Y, _ = self.data.load_wav(rec, st * len_ratio, ed * len_ratio)
217
+ Ys.append(Y)
218
+ chunk_len_list.append(chunk_len)
219
+ return Ys, self.rec_list[i], chunk_len_list
220
+
221
+ def get_allnspk(self):
222
+ return self.all_n_speakers
223
+
224
+ def _get_labeled_speech(
225
+ self, rec, start, end, n_speakers=None, use_speaker_id=False
226
+ ):
227
+ """Extracts speech chunks and corresponding labels
228
+
229
+ Extracts speech chunks and corresponding diarization labels for
230
+ given recording id and start/end times
231
+
232
+ Args:
233
+ rec (str): recording id
234
+ start (int): start frame index
235
+ end (int): end frame index
236
+ n_speakers (int): number of speakers
237
+ if None, the value is given from data
238
+ Returns:
239
+ data: speech chunk
240
+ (n_samples)
241
+ T: label
242
+ (n_frmaes, n_speakers)-shaped np.int32 array.
243
+ """
244
+ data, rate = self.data.load_wav(
245
+ rec, start * self.frame_shift, end * self.frame_shift
246
+ )
247
+ frame_num = end - start
248
+ filtered_segments = self.data.segments[rec]
249
+ # filtered_segments = self.data.segments[self.data.segments['rec'] == rec]
250
+ speakers = np.unique(
251
+ [self.data.utt2spk[seg["utt"]] for seg in filtered_segments]
252
+ ).tolist()
253
+ if n_speakers is None:
254
+ n_speakers = len(speakers)
255
+ T = np.zeros((frame_num, n_speakers), dtype=np.int32)
256
+
257
+ if use_speaker_id:
258
+ all_speakers = sorted(self.data.spk2utt.keys())
259
+ S = np.zeros((frame_num, len(all_speakers)), dtype=np.int32)
260
+
261
+ for seg in filtered_segments:
262
+ speaker_index = speakers.index(self.data.utt2spk[seg["utt"]])
263
+ if use_speaker_id:
264
+ all_speaker_index = all_speakers.index(self.data.utt2spk[seg["utt"]])
265
+ start_frame = np.rint(seg["st"] * rate / self.frame_shift).astype(int)
266
+ end_frame = np.rint(seg["et"] * rate / self.frame_shift).astype(int)
267
+ rel_start = rel_end = None
268
+ if start <= start_frame and start_frame < end:
269
+ rel_start = start_frame - start
270
+ if start < end_frame and end_frame <= end:
271
+ rel_end = end_frame - start
272
+ if rel_start is not None or rel_end is not None:
273
+ T[rel_start:rel_end, speaker_index] = 1
274
+ if use_speaker_id:
275
+ S[rel_start:rel_end, all_speaker_index] = 1
276
+
277
+ if use_speaker_id:
278
+ return data, T, S
279
+ else:
280
+ return data, T
281
+
282
+ def collate_fn(self, batch):
283
+ valid_samples = [sample for sample in batch if sample[0] is not None]
284
+
285
+ wav_list, binary_label_list, spk_label_list= [], [], []
286
+ all_spk_idx_list, len_list = [], []
287
+ for sample in valid_samples:
288
+ wav_list.append(torch.from_numpy(sample[0]).float())
289
+ binary_label_list.append(torch.from_numpy(sample[1]).long())
290
+ spk_label_list.append(torch.from_numpy(sample[2]).long())
291
+ all_spk_idx_list.append(torch.from_numpy(sample[3]).long())
292
+ len_list.append(sample[4])
293
+ wav_batch = pad_sequence(wav_list, batch_first=True, padding_value=0.0)
294
+ binary_label_batch = pad_sequence(binary_label_list, batch_first=True, padding_value=1).long()
295
+ spk_label_batch = torch.stack(spk_label_list)
296
+ all_spk_idx_batch = torch.stack(all_spk_idx_list)
297
+ len_batch = torch.LongTensor(len_list)
298
+
299
+ return wav_batch, binary_label_batch.float(), spk_label_batch, all_spk_idx_batch, len_batch
300
+
301
+ def collate_fn_infer(self, batch):
302
+ assert len(batch) == 1 # each batch should contain one recording
303
+ Ys, rec, chunk_len_list = batch[0]
304
+ wav_list = [torch.from_numpy(Y).float() for Y in Ys]
305
+
306
+ return wav_list, rec, chunk_len_list
307
+
308
+
309
+ #######################
310
+ # Kaldi-style Dataset #
311
+ #######################
312
+ class KaldiData:
313
+ """This class holds data in kaldi-style directory."""
314
+
315
+ def __init__(self, data_dir):
316
+ """Load kaldi data directory."""
317
+ self.data_dir = data_dir
318
+ self.segments = self._load_segments_rechash(
319
+ os.path.join(self.data_dir, "segments")
320
+ )
321
+ self.utt2spk = self._load_utt2spk(os.path.join(self.data_dir, "utt2spk"))
322
+ self.wavs = self._load_wav_scp(os.path.join(self.data_dir, "wav.scp"))
323
+ self.reco2dur = self._load_reco2dur(os.path.join(self.data_dir, "reco2dur"))
324
+ self.spk2utt = self._load_spk2utt(os.path.join(self.data_dir, "spk2utt"))
325
+
326
+ def load_wav(self, recid, start=0, end=None):
327
+ """Load wavfile given recid, start time and end time."""
328
+ data, rate = self._load_wav(self.wavs[recid], start, end)
329
+ return data, rate
330
+
331
+ def _load_segments(self, segments_file):
332
+ """Load segments file as array."""
333
+ if not os.path.exists(segments_file):
334
+ return None
335
+ return np.loadtxt(
336
+ segments_file,
337
+ dtype=[("utt", "object"), ("rec", "object"), ("st", "f"), ("et", "f")],
338
+ ndmin=1,
339
+ )
340
+
341
+ def _load_segments_hash(self, segments_file):
342
+ """Load segments file as dict with uttid index."""
343
+ ret = {}
344
+ if not os.path.exists(segments_file):
345
+ return None
346
+ for line in open(segments_file):
347
+ utt, rec, st, et = line.strip().split()
348
+ ret[utt] = (rec, float(st), float(et))
349
+ return ret
350
+
351
+ def _load_segments_rechash(self, segments_file):
352
+ """Load segments file as dict with recid index."""
353
+ ret = {}
354
+ if not os.path.exists(segments_file):
355
+ return None
356
+ for line in open(segments_file):
357
+ utt, rec, st, et = line.strip().split()
358
+ if rec not in ret:
359
+ ret[rec] = []
360
+ ret[rec].append({"utt": utt, "st": float(st), "et": float(et)})
361
+ return ret
362
+
363
+ def _load_wav_scp(self, wav_scp_file):
364
+ """Return dictionary { rec: wav_rxfilename }."""
365
+ if os.path.exists(wav_scp_file):
366
+ lines = [line.strip().split(None, 1) for line in open(wav_scp_file)]
367
+ return {x[0]: x[1] for x in lines}
368
+ else:
369
+ wav_dir = os.path.join(self.data_dir, "wav")
370
+ return {
371
+ os.path.splitext(filename)[0]: os.path.join(wav_dir, filename)
372
+ for filename in sorted(os.listdir(wav_dir))
373
+ }
374
+
375
+ def _load_wav(self, wav_rxfilename, start=0, end=None):
376
+ """This function reads audio file and return data in numpy.float32 array.
377
+ "lru_cache" holds recently loaded audio so that can be called
378
+ many times on the same audio file.
379
+ OPTIMIZE: controls lru_cache size for random access,
380
+ considering memory size
381
+ """
382
+ if wav_rxfilename.endswith("|"):
383
+ # input piped command
384
+ p = subprocess.Popen(
385
+ wav_rxfilename[:-1],
386
+ shell=True,
387
+ stdout=subprocess.PIPE,
388
+ )
389
+ data, samplerate = sf.read(
390
+ io.BytesIO(p.stdout.read()),
391
+ dtype="float32",
392
+ )
393
+ # cannot seek
394
+ data = data[start:end]
395
+ elif wav_rxfilename == "-":
396
+ # stdin
397
+ data, samplerate = sf.read(sys.stdin, dtype="float32")
398
+ # cannot seek
399
+ data = data[start:end]
400
+ else:
401
+ # normal wav file
402
+ data, samplerate = sf.read(wav_rxfilename, start=start, stop=end)
403
+ return data, samplerate
404
+
405
+ def _load_utt2spk(self, utt2spk_file):
406
+ """Returns dictionary { uttid: spkid }."""
407
+ lines = [line.strip().split(None, 1) for line in open(utt2spk_file)]
408
+ return {x[0]: x[1] for x in lines}
409
+
410
+ def _load_spk2utt(self, spk2utt_file):
411
+ """Returns dictionary { spkid: list of uttids }."""
412
+ if not os.path.exists(spk2utt_file):
413
+ return None
414
+ lines = [line.strip().split() for line in open(spk2utt_file)]
415
+ return {x[0]: x[1:] for x in lines}
416
+
417
+ def _load_reco2dur(self, reco2dur_file):
418
+ """Returns dictionary { recid: duration }."""
419
+ if not os.path.exists(reco2dur_file):
420
+ return None
421
+ lines = [line.strip().split(None, 1) for line in open(reco2dur_file)]
422
+ return {x[0]: float(x[1]) for x in lines}
423
+
424
+ def _process_wav(self, wav_rxfilename, process):
425
+ """This function returns preprocessed wav_rxfilename.
426
+ Args:
427
+ wav_rxfilename:
428
+ input
429
+ process:
430
+ command which can be connected via pipe, use stdin and stdout
431
+ Returns:
432
+ wav_rxfilename: output piped command
433
+ """
434
+ if wav_rxfilename.endswith("|"):
435
+ # input piped command
436
+ return wav_rxfilename + process + "|"
437
+ # stdin "-" or normal file
438
+ return "cat {0} | {1} |".format(wav_rxfilename, process)
439
+
440
+ def _extract_segments(self, wavs, segments=None):
441
+ """This function returns generator of segmented audio.
442
+ Yields (utterance id, numpy.float32 array).
443
+ TODO?: sampling rate is not converted.
444
+ """
445
+ if segments is not None:
446
+ # segments should be sorted by rec-id
447
+ for seg in segments:
448
+ wav = wavs[seg["rec"]]
449
+ data, samplerate = self.load_wav(wav)
450
+ st_sample = np.rint(seg["st"] * samplerate).astype(int)
451
+ et_sample = np.rint(seg["et"] * samplerate).astype(int)
452
+ yield seg["utt"], data[st_sample:et_sample]
453
+ else:
454
+ # segments file not found,
455
+ # wav.scp is used as segmented audio list
456
+ for rec in wavs:
457
+ data, samplerate = self.load_wav(wavs[rec])
458
+ yield rec, data
459
+
460
+ if __name__ == "__main__":
461
+ args = {
462
+ 'mode': 'train',
463
+ 'data_dir': "/mnt/lustre/sjtu/home/czy97/workspace/sd/EEND-vec-clustering/EEND-vector-clustering/egs/mini_librispeech/v1/data/simu/data/train_clean_5_ns3_beta2_500",
464
+ 'chunk_size': 2001,
465
+ 'frame_shift': 256,
466
+ 'sampling_rate': 8000,
467
+ 'num_speakers':3
468
+ }
469
+
470
+ torch.manual_seed(6)
471
+ dataset = DiarizationDataset(**args)
472
+
473
+ from torch.utils.data import DataLoader
474
+ dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=dataset.collate_fn)
475
+ data_iter = iter(dataloader)
476
+ # wav_batch, binary_label_batch, spk_label_batch, all_spk_idx_batch, len_batch = next(data_iter)
477
+ data = next(data_iter)
478
+ for val in data:
479
+ print(val.shape)
480
+
481
+ # from torch.utils.data import DataLoader
482
+ # dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=dataset.collate_fn_infer)
483
+ # data_iter = iter(dataloader)
484
+ # wav_list, binary_label_list, rec = next(data_iter)
UniSpeech/downstreams/speaker_diarization/utils/kaldi_data.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita)
2
+ # Licensed under the MIT license.
3
+ #
4
+ # This library provides utilities for kaldi-style data directory.
5
+
6
+
7
+ from __future__ import print_function
8
+ import os
9
+ import sys
10
+ import numpy as np
11
+ import subprocess
12
+ import soundfile as sf
13
+ import io
14
+ from functools import lru_cache
15
+
16
+
17
+ def load_segments(segments_file):
18
+ """ load segments file as array """
19
+ if not os.path.exists(segments_file):
20
+ return None
21
+ return np.loadtxt(
22
+ segments_file,
23
+ dtype=[('utt', 'object'),
24
+ ('rec', 'object'),
25
+ ('st', 'f'),
26
+ ('et', 'f')],
27
+ ndmin=1)
28
+
29
+
30
+ def load_segments_hash(segments_file):
31
+ ret = {}
32
+ if not os.path.exists(segments_file):
33
+ return None
34
+ for line in open(segments_file):
35
+ utt, rec, st, et = line.strip().split()
36
+ ret[utt] = (rec, float(st), float(et))
37
+ return ret
38
+
39
+
40
+ def load_segments_rechash(segments_file):
41
+ ret = {}
42
+ if not os.path.exists(segments_file):
43
+ return None
44
+ for line in open(segments_file):
45
+ utt, rec, st, et = line.strip().split()
46
+ if rec not in ret:
47
+ ret[rec] = []
48
+ ret[rec].append({'utt':utt, 'st':float(st), 'et':float(et)})
49
+ return ret
50
+
51
+
52
+ def load_wav_scp(wav_scp_file):
53
+ """ return dictionary { rec: wav_rxfilename } """
54
+ lines = [line.strip().split(None, 1) for line in open(wav_scp_file)]
55
+ return {x[0]: x[1] for x in lines}
56
+
57
+
58
+ @lru_cache(maxsize=1)
59
+ def load_wav(wav_rxfilename, start=0, end=None):
60
+ """ This function reads audio file and return data in numpy.float32 array.
61
+ "lru_cache" holds recently loaded audio so that can be called
62
+ many times on the same audio file.
63
+ OPTIMIZE: controls lru_cache size for random access,
64
+ considering memory size
65
+ """
66
+ if wav_rxfilename.endswith('|'):
67
+ # input piped command
68
+ p = subprocess.Popen(wav_rxfilename[:-1], shell=True,
69
+ stdout=subprocess.PIPE)
70
+ data, samplerate = sf.read(io.BytesIO(p.stdout.read()),
71
+ dtype='float32')
72
+ # cannot seek
73
+ data = data[start:end]
74
+ elif wav_rxfilename == '-':
75
+ # stdin
76
+ data, samplerate = sf.read(sys.stdin, dtype='float32')
77
+ # cannot seek
78
+ data = data[start:end]
79
+ else:
80
+ # normal wav file
81
+ data, samplerate = sf.read(wav_rxfilename, start=start, stop=end)
82
+ return data, samplerate
83
+
84
+
85
+ def load_utt2spk(utt2spk_file):
86
+ """ returns dictionary { uttid: spkid } """
87
+ lines = [line.strip().split(None, 1) for line in open(utt2spk_file)]
88
+ return {x[0]: x[1] for x in lines}
89
+
90
+
91
+ def load_spk2utt(spk2utt_file):
92
+ """ returns dictionary { spkid: list of uttids } """
93
+ if not os.path.exists(spk2utt_file):
94
+ return None
95
+ lines = [line.strip().split() for line in open(spk2utt_file)]
96
+ return {x[0]: x[1:] for x in lines}
97
+
98
+
99
+ def load_reco2dur(reco2dur_file):
100
+ """ returns dictionary { recid: duration } """
101
+ if not os.path.exists(reco2dur_file):
102
+ return None
103
+ lines = [line.strip().split(None, 1) for line in open(reco2dur_file)]
104
+ return {x[0]: float(x[1]) for x in lines}
105
+
106
+
107
+ def process_wav(wav_rxfilename, process):
108
+ """ This function returns preprocessed wav_rxfilename
109
+ Args:
110
+ wav_rxfilename: input
111
+ process: command which can be connected via pipe,
112
+ use stdin and stdout
113
+ Returns:
114
+ wav_rxfilename: output piped command
115
+ """
116
+ if wav_rxfilename.endswith('|'):
117
+ # input piped command
118
+ return wav_rxfilename + process + "|"
119
+ else:
120
+ # stdin "-" or normal file
121
+ return "cat {} | {} |".format(wav_rxfilename, process)
122
+
123
+
124
+ def extract_segments(wavs, segments=None):
125
+ """ This function returns generator of segmented audio as
126
+ (utterance id, numpy.float32 array)
127
+ TODO?: sampling rate is not converted.
128
+ """
129
+ if segments is not None:
130
+ # segments should be sorted by rec-id
131
+ for seg in segments:
132
+ wav = wavs[seg['rec']]
133
+ data, samplerate = load_wav(wav)
134
+ st_sample = np.rint(seg['st'] * samplerate).astype(int)
135
+ et_sample = np.rint(seg['et'] * samplerate).astype(int)
136
+ yield seg['utt'], data[st_sample:et_sample]
137
+ else:
138
+ # segments file not found,
139
+ # wav.scp is used as segmented audio list
140
+ for rec in wavs:
141
+ data, samplerate = load_wav(wavs[rec])
142
+ yield rec, data
143
+
144
+
145
+ class KaldiData:
146
+ def __init__(self, data_dir):
147
+ self.data_dir = data_dir
148
+ self.segments = load_segments_rechash(
149
+ os.path.join(self.data_dir, 'segments'))
150
+ self.utt2spk = load_utt2spk(
151
+ os.path.join(self.data_dir, 'utt2spk'))
152
+ self.wavs = load_wav_scp(
153
+ os.path.join(self.data_dir, 'wav.scp'))
154
+ self.reco2dur = load_reco2dur(
155
+ os.path.join(self.data_dir, 'reco2dur'))
156
+ self.spk2utt = load_spk2utt(
157
+ os.path.join(self.data_dir, 'spk2utt'))
158
+
159
+ def load_wav(self, recid, start=0, end=None):
160
+ data, rate = load_wav(
161
+ self.wavs[recid], start, end)
162
+ return data, rate
UniSpeech/downstreams/speaker_diarization/utils/parse_options.sh ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
4
+ # Arnab Ghoshal, Karel Vesely
5
+
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13
+ # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
14
+ # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
15
+ # MERCHANTABLITY OR NON-INFRINGEMENT.
16
+ # See the Apache 2 License for the specific language governing permissions and
17
+ # limitations under the License.
18
+
19
+
20
+ # Parse command-line options.
21
+ # To be sourced by another script (as in ". parse_options.sh").
22
+ # Option format is: --option-name arg
23
+ # and shell variable "option_name" gets set to value "arg."
24
+ # The exception is --help, which takes no arguments, but prints the
25
+ # $help_message variable (if defined).
26
+
27
+
28
+ ###
29
+ ### The --config file options have lower priority to command line
30
+ ### options, so we need to import them first...
31
+ ###
32
+
33
+ # Now import all the configs specified by command-line, in left-to-right order
34
+ for ((argpos=1; argpos<$#; argpos++)); do
35
+ if [ "${!argpos}" == "--config" ]; then
36
+ argpos_plus1=$((argpos+1))
37
+ config=${!argpos_plus1}
38
+ [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
39
+ . $config # source the config file.
40
+ fi
41
+ done
42
+
43
+
44
+ ###
45
+ ### Now we process the command line options
46
+ ###
47
+ while true; do
48
+ [ -z "${1:-}" ] && break; # break if there are no arguments
49
+ case "$1" in
50
+ # If the enclosing script is called with --help option, print the help
51
+ # message and exit. Scripts should put help messages in $help_message
52
+ --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
53
+ else printf "$help_message\n" 1>&2 ; fi;
54
+ exit 0 ;;
55
+ --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
56
+ exit 1 ;;
57
+ # If the first command-line argument begins with "--" (e.g. --foo-bar),
58
+ # then work out the variable name as $name, which will equal "foo_bar".
59
+ --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
60
+ # Next we test whether the variable in question is undefned-- if so it's
61
+ # an invalid option and we die. Note: $0 evaluates to the name of the
62
+ # enclosing script.
63
+ # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
64
+ # is undefined. We then have to wrap this test inside "eval" because
65
+ # foo_bar is itself inside a variable ($name).
66
+ eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
67
+
68
+ oldval="`eval echo \\$$name`";
69
+ # Work out whether we seem to be expecting a Boolean argument.
70
+ if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
71
+ was_bool=true;
72
+ else
73
+ was_bool=false;
74
+ fi
75
+
76
+ # Set the variable to the right value-- the escaped quotes make it work if
77
+ # the option had spaces, like --cmd "queue.pl -sync y"
78
+ eval $name=\"$2\";
79
+
80
+ # Check that Boolean-valued arguments are really Boolean.
81
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
82
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
83
+ exit 1;
84
+ fi
85
+ shift 2;
86
+ ;;
87
+ *) break;
88
+ esac
89
+ done
90
+
91
+
92
+ # Check for an empty argument to the --cmd option, which can easily occur as a
93
+ # result of scripting errors.
94
+ [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
95
+
96
+
97
+ true; # so this script returns exit code 0.
UniSpeech/downstreams/speaker_diarization/utils/utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import struct
3
+ import logging
4
+ import torch
5
+ import math
6
+ import numpy as np
7
+ import random
8
+ import yaml
9
+ import torch.distributed as dist
10
+ import torch.nn.functional as F
11
+
12
+
13
+ # ------------------------------ Logger ------------------------------
14
+ # log to console or a file
15
+ def get_logger(
16
+ name,
17
+ format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
18
+ date_format="%Y-%m-%d %H:%M:%S",
19
+ file=False):
20
+ """
21
+ Get python logger instance
22
+ """
23
+ logger = logging.getLogger(name)
24
+ logger.setLevel(logging.INFO)
25
+ # file or console
26
+ handler = logging.StreamHandler() if not file else logging.FileHandler(
27
+ name)
28
+ handler.setLevel(logging.INFO)
29
+ formatter = logging.Formatter(fmt=format_str, datefmt=date_format)
30
+ handler.setFormatter(formatter)
31
+ logger.addHandler(handler)
32
+ return logger
33
+
34
+
35
+ # log to concole and file at the same time
36
+ def get_logger_2(
37
+ name,
38
+ format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
39
+ date_format="%Y-%m-%d %H:%M:%S"):
40
+ logger = logging.getLogger(name)
41
+ logger.setLevel(logging.INFO)
42
+
43
+ # Create handlers
44
+ c_handler = logging.StreamHandler()
45
+ f_handler = logging.FileHandler(name)
46
+ c_handler.setLevel(logging.INFO)
47
+ f_handler.setLevel(logging.INFO)
48
+
49
+ # Create formatters and add it to handlers
50
+ c_format = logging.Formatter(fmt=format_str, datefmt=date_format)
51
+ f_format = logging.Formatter(fmt=format_str, datefmt=date_format)
52
+ c_handler.setFormatter(c_format)
53
+ f_handler.setFormatter(f_format)
54
+
55
+ # Add handlers to the logger
56
+ logger.addHandler(c_handler)
57
+ logger.addHandler(f_handler)
58
+
59
+ return logger
60
+
61
+
62
+ # ------------------------------ Logger ------------------------------
63
+
64
+ # ------------------------------ Pytorch Distributed Training ------------------------------
65
+ def getoneNode():
66
+ nodelist = os.environ['SLURM_JOB_NODELIST']
67
+ nodelist = nodelist.strip().split(',')[0]
68
+ import re
69
+ text = re.split('[-\[\]]', nodelist)
70
+ if ('' in text):
71
+ text.remove('')
72
+ return text[0] + '-' + text[1] + '-' + text[2]
73
+
74
+
75
+ def dist_init(host_addr, rank, local_rank, world_size, port=23456):
76
+ host_addr_full = 'tcp://' + host_addr + ':' + str(port)
77
+ dist.init_process_group("nccl", init_method=host_addr_full,
78
+ rank=rank, world_size=world_size)
79
+ num_gpus = torch.cuda.device_count()
80
+ # torch.cuda.set_device(local_rank)
81
+ assert dist.is_initialized()
82
+
83
+
84
+ def cleanup():
85
+ dist.destroy_process_group()
86
+
87
+
88
+ def average_gradients(model, world_size):
89
+ size = float(world_size)
90
+ for param in model.parameters():
91
+ if (param.requires_grad and param.grad is not None):
92
+ dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
93
+ param.grad.data /= size
94
+
95
+
96
+ def data_reduce(data):
97
+ dist.all_reduce(data, op=dist.ReduceOp.SUM)
98
+ return data / torch.distributed.get_world_size()
99
+
100
+
101
+ # ------------------------------ Pytorch Distributed Training ------------------------------
102
+
103
+
104
+ # ------------------------------ Hyper-parameter Dynamic Change ------------------------------
105
+ def reduce_lr(optimizer, initial_lr, final_lr, current_iter, max_iter, coeff=1.0):
106
+ current_lr = coeff * math.exp((current_iter / max_iter) * math.log(final_lr / initial_lr)) * initial_lr
107
+ for param_group in optimizer.param_groups:
108
+ param_group['lr'] = current_lr
109
+
110
+
111
+ def get_reduce_lr(initial_lr, final_lr, current_iter, max_iter):
112
+ current_lr = math.exp((current_iter / max_iter) * math.log(final_lr / initial_lr)) * initial_lr
113
+ return current_lr
114
+
115
+
116
+ def set_lr(optimizer, lr):
117
+ for param_group in optimizer.param_groups:
118
+ param_group['lr'] = lr
119
+
120
+ # ------------------------------ Hyper-parameter Dynamic Change ------------------------------
121
+
122
+ # ---------------------- About Configuration --------------------
123
+ def parse_config_or_kwargs(config_file, **kwargs):
124
+ with open(config_file) as con_read:
125
+ yaml_config = yaml.load(con_read, Loader=yaml.FullLoader)
126
+ # passed kwargs will override yaml config
127
+ return dict(yaml_config, **kwargs)
128
+
129
+
130
+ def store_yaml(config_file, store_path, **kwargs):
131
+ with open(config_file, 'r') as f:
132
+ config_lines = f.readlines()
133
+
134
+ keys_list = list(kwargs.keys())
135
+ with open(store_path, 'w') as f:
136
+ for line in config_lines:
137
+ if ':' in line and line.split(':')[0] in keys_list:
138
+ key = line.split(':')[0]
139
+ line = '{}: {}\n'.format(key, kwargs[key])
140
+ f.write(line)
141
+
142
+
143
+ # ---------------------- About Configuration --------------------
144
+
145
+
146
+ def check_dir(dir):
147
+ if not os.path.exists(dir):
148
+ os.mkdir(dir)
149
+
150
+
151
+ def set_seed(seed=66):
152
+ np.random.seed(seed)
153
+ random.seed(seed)
154
+
155
+ torch.manual_seed(seed)
156
+ torch.cuda.manual_seed(seed)
157
+ torch.cuda.manual_seed_all(seed)
158
+
159
+ # torch.backends.cudnn.deterministic = True
160
+ # torch.backends.cudnn.benchmark = False
161
+
162
+
163
+ # when store the model wrongly with "module" involved,
164
+ # we remove it here
165
+ def correct_key(state_dict):
166
+ keys = list(state_dict.keys())
167
+ if 'module' not in keys[0]:
168
+ return state_dict
169
+ else:
170
+ new_state_dict = {}
171
+ for key in keys:
172
+ new_key = '.'.join(key.split('.')[1:])
173
+ new_state_dict[new_key] = state_dict[key]
174
+ return new_state_dict
175
+
176
+
177
+ def validate_path(dir_name):
178
+ """
179
+ :param dir_name: Create the directory if it doesn't exist
180
+ :return: None
181
+ """
182
+ dir_name = os.path.dirname(dir_name) # get the path
183
+ if not os.path.exists(dir_name) and (dir_name != ''):
184
+ os.makedirs(dir_name)
185
+
186
+
187
+ def get_lr(optimizer):
188
+ for param_group in optimizer.param_groups:
189
+ return param_group['lr']
UniSpeech/downstreams/speaker_verification/README.md ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Pre-training Representations for Speaker Verification
2
+
3
+ ### Pre-trained models
4
+
5
+ | Model | Fix pre-train | Vox1-O | Vox1-E | Vox1-H |
6
+ | ------------------------------------------------------------ | ------------- | --------- | --------- | -------- |
7
+ | [ECAPA-TDNN](https://drive.google.com/file/d/1kWmLyTGkBExTdxtwmrXoP4DhWz_7ZAv3/view?usp=sharing) | - | 1.080 | 1.200 | 2.127 |
8
+ | [HuBERT large](https://1drv.ms/u/s!AqeByhGUtINrgcsDGB2GnI61eMcCWA?e=olWmeG) | Yes | 0.888 | 0.912 | 1.853 |
9
+ | [Wav2Vec2.0 (XLSR)](https://1drv.ms/u/s!AqeByhGUtINrgcsCVGF9JpPlxNkVKg?e=KknWbW) | Yes | 0.915 | 0.945 | 1.895 |
10
+ | [UniSpeech-SAT large](https://1drv.ms/u/s!AqeByhGUtINrgcsEXX-_TS5VBxIF-Q?e=nRfkcT) | Yes | 0.771 | 0.781 | 1.669 |
11
+ | [WavLM Base](https://1drv.ms/u/s!AqeByhGUtINrgcsBsAf8EJT7x-HMhA?e=gXjLRD) | Yes | 0.84 | 0.928 | 1.758 |
12
+ | [**WavLM large**](https://1drv.ms/u/s!AqeByhGUtINrgcp_7CsbcBjYW2Tr-w?e=VeCMic) | Yes | 0.75 | 0.764 | 1.548 |
13
+ | [HuBERT large](https://drive.google.com/file/d/1nit9Z6RyM8Sdb3n8ccaglOQVNnqsjnui/view?usp=sharing) | No | 0.585 | 0.654 | 1.342 |
14
+ | [Wav2Vec2.0 (XLSR)](https://drive.google.com/file/d/1TgKro9pp197TCgIF__IlE_rMVQOk50Eb/view?usp=sharing) | No | 0.564 | 0.605 | 1.23 |
15
+ | [UniSpeech-SAT large](https://drive.google.com/file/d/10o6NHZsPXJn2k8n57e8Z_FkKh3V4TC3g/view?usp=sharing) | No | 0.564 | 0.561 | 1.23 |
16
+ | [**WavLM large**](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view?usp=sharing) | No | **0.431** | **0.538** | **1.154** |
17
+
18
+ ### How to use?
19
+
20
+ #### Environment Setup
21
+
22
+ 1. `pip install --require-hashes -r requirements.txt`
23
+ 2. Install fairseq code
24
+ - For HuBERT_Large and Wav2Vec2.0 (XLSR), we should install the official [fairseq](https://github.com/pytorch/fairseq).
25
+ - For UniSpeech-SAT large, we should install the [Unispeech-SAT](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech-SAT) fairseq code.
26
+ - For WavLM, we should install the latest s3prl: `pip install s3prl@git+https://github.com/s3prl/s3prl.git@7ab62aaf2606d83da6c71ee74e7d16e0979edbc3#egg=s3prl`
27
+
28
+ #### Example
29
+
30
+ Take `unispeech_sat ` and `ecapa_tdnn` for example:
31
+
32
+ 1. First, you should download the pre-trained model in the above table to `checkpoint_path`.
33
+ 2. Then, run the following codes:
34
+ - The wav files are sampled from [voxceleb1](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html).
35
+
36
+ ```bash
37
+ python verification.py --model_name unispeech_sat --wav1 vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav --wav2 vox1_data/Josh_Gad/HXUqYaOwrxA_0000015.wav --checkpoint $checkpoint_path
38
+ # output: The similarity score between two audios is 0.0317 (-1.0, 1.0).
39
+
40
+ python verification.py --model_name unispeech_sat --wav1 vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav --wav2 vox1_data/David_Faustino/xTOk1Jz-F_g_0000015.wav --checkpoint --checkpoint $checkpoint_path
41
+ # output: The similarity score between two audios is 0.5389 (-1.0, 1.0).
42
+
43
+ python verification.py --model_name ecapa_tdnn --wav1 vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav --wav2 vox1_data/Josh_Gad/HXUqYaOwrxA_0000015.wav --checkpoint $checkpoint_path
44
+ # output: The similarity score between two audios is 0.2053 (-1.0, 1.0).
45
+
46
+ python verification.py --model_name ecapa_tdnn --wav1 vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav --wav2 vox1_data/David_Faustino/xTOk1Jz-F_g_0000015.wav --checkpoint --checkpoint $checkpoint_path
47
+ # output: he similarity score between two audios is 0.5302 (-1.0, 1.0).
48
+ ```
49
+
UniSpeech/downstreams/speaker_verification/config/unispeech_sat.th ADDED
Binary file (20.7 kB). View file
 
UniSpeech/downstreams/speaker_verification/models/__init__.py ADDED
File without changes
UniSpeech/downstreams/speaker_verification/models/ecapa_tdnn.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchaudio.transforms as trans
7
+ from .utils import UpstreamExpert
8
+
9
+
10
+ ''' Res2Conv1d + BatchNorm1d + ReLU
11
+ '''
12
+
13
+
14
+ class Res2Conv1dReluBn(nn.Module):
15
+ '''
16
+ in_channels == out_channels == channels
17
+ '''
18
+
19
+ def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
20
+ super().__init__()
21
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
22
+ self.scale = scale
23
+ self.width = channels // scale
24
+ self.nums = scale if scale == 1 else scale - 1
25
+
26
+ self.convs = []
27
+ self.bns = []
28
+ for i in range(self.nums):
29
+ self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
30
+ self.bns.append(nn.BatchNorm1d(self.width))
31
+ self.convs = nn.ModuleList(self.convs)
32
+ self.bns = nn.ModuleList(self.bns)
33
+
34
+ def forward(self, x):
35
+ out = []
36
+ spx = torch.split(x, self.width, 1)
37
+ for i in range(self.nums):
38
+ if i == 0:
39
+ sp = spx[i]
40
+ else:
41
+ sp = sp + spx[i]
42
+ # Order: conv -> relu -> bn
43
+ sp = self.convs[i](sp)
44
+ sp = self.bns[i](F.relu(sp))
45
+ out.append(sp)
46
+ if self.scale != 1:
47
+ out.append(spx[self.nums])
48
+ out = torch.cat(out, dim=1)
49
+
50
+ return out
51
+
52
+
53
+ ''' Conv1d + BatchNorm1d + ReLU
54
+ '''
55
+
56
+
57
+ class Conv1dReluBn(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
59
+ super().__init__()
60
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
61
+ self.bn = nn.BatchNorm1d(out_channels)
62
+
63
+ def forward(self, x):
64
+ return self.bn(F.relu(self.conv(x)))
65
+
66
+
67
+ ''' The SE connection of 1D case.
68
+ '''
69
+
70
+
71
+ class SE_Connect(nn.Module):
72
+ def __init__(self, channels, se_bottleneck_dim=128):
73
+ super().__init__()
74
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
75
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
76
+
77
+ def forward(self, x):
78
+ out = x.mean(dim=2)
79
+ out = F.relu(self.linear1(out))
80
+ out = torch.sigmoid(self.linear2(out))
81
+ out = x * out.unsqueeze(2)
82
+
83
+ return out
84
+
85
+
86
+ ''' SE-Res2Block of the ECAPA-TDNN architecture.
87
+ '''
88
+
89
+
90
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
91
+ # return nn.Sequential(
92
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
93
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
94
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
95
+ # SE_Connect(channels)
96
+ # )
97
+
98
+
99
+ class SE_Res2Block(nn.Module):
100
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
101
+ super().__init__()
102
+ self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
103
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
104
+ self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
105
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
106
+
107
+ self.shortcut = None
108
+ if in_channels != out_channels:
109
+ self.shortcut = nn.Conv1d(
110
+ in_channels=in_channels,
111
+ out_channels=out_channels,
112
+ kernel_size=1,
113
+ )
114
+
115
+ def forward(self, x):
116
+ residual = x
117
+ if self.shortcut:
118
+ residual = self.shortcut(x)
119
+
120
+ x = self.Conv1dReluBn1(x)
121
+ x = self.Res2Conv1dReluBn(x)
122
+ x = self.Conv1dReluBn2(x)
123
+ x = self.SE_Connect(x)
124
+
125
+ return x + residual
126
+
127
+
128
+ ''' Attentive weighted mean and standard deviation pooling.
129
+ '''
130
+
131
+
132
+ class AttentiveStatsPool(nn.Module):
133
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
134
+ super().__init__()
135
+ self.global_context_att = global_context_att
136
+
137
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
138
+ if global_context_att:
139
+ self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
140
+ else:
141
+ self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
142
+ self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
143
+
144
+ def forward(self, x):
145
+
146
+ if self.global_context_att:
147
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
148
+ context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
149
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
150
+ else:
151
+ x_in = x
152
+
153
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
154
+ alpha = torch.tanh(self.linear1(x_in))
155
+ # alpha = F.relu(self.linear1(x_in))
156
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
157
+ mean = torch.sum(alpha * x, dim=2)
158
+ residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
159
+ std = torch.sqrt(residuals.clamp(min=1e-9))
160
+ return torch.cat([mean, std], dim=1)
161
+
162
+
163
+ class ECAPA_TDNN(nn.Module):
164
+ def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
165
+ feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
166
+ super().__init__()
167
+
168
+ self.feat_type = feat_type
169
+ self.feature_selection = feature_selection
170
+ self.update_extract = update_extract
171
+ self.sr = sr
172
+
173
+ if feat_type == "fbank" or feat_type == "mfcc":
174
+ self.update_extract = False
175
+
176
+ win_len = int(sr * 0.025)
177
+ hop_len = int(sr * 0.01)
178
+
179
+ if feat_type == 'fbank':
180
+ self.feature_extract = trans.MelSpectrogram(sample_rate=sr, n_fft=512, win_length=win_len,
181
+ hop_length=hop_len, f_min=0.0, f_max=sr // 2,
182
+ pad=0, n_mels=feat_dim)
183
+ elif feat_type == 'mfcc':
184
+ melkwargs = {
185
+ 'n_fft': 512,
186
+ 'win_length': win_len,
187
+ 'hop_length': hop_len,
188
+ 'f_min': 0.0,
189
+ 'f_max': sr // 2,
190
+ 'pad': 0
191
+ }
192
+ self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False,
193
+ melkwargs=melkwargs)
194
+ else:
195
+ if config_path is None:
196
+ self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
197
+ else:
198
+ self.feature_extract = UpstreamExpert(config_path)
199
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
200
+ self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
201
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
202
+ self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
203
+
204
+ self.feat_num = self.get_feat_num()
205
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
206
+
207
+ if feat_type != 'fbank' and feat_type != 'mfcc':
208
+ freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
209
+ for name, param in self.feature_extract.named_parameters():
210
+ for freeze_val in freeze_list:
211
+ if freeze_val in name:
212
+ param.requires_grad = False
213
+ break
214
+
215
+ if not self.update_extract:
216
+ for param in self.feature_extract.parameters():
217
+ param.requires_grad = False
218
+
219
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
220
+ # self.channels = [channels] * 4 + [channels * 3]
221
+ self.channels = [channels] * 4 + [1536]
222
+
223
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
224
+ self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
225
+ self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
226
+ self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
227
+
228
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
229
+ cat_channels = channels * 3
230
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
231
+ self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
232
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
233
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
234
+
235
+
236
+ def get_feat_num(self):
237
+ self.feature_extract.eval()
238
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
239
+ with torch.no_grad():
240
+ features = self.feature_extract(wav)
241
+ select_feature = features[self.feature_selection]
242
+ if isinstance(select_feature, (list, tuple)):
243
+ return len(select_feature)
244
+ else:
245
+ return 1
246
+
247
+ def get_feat(self, x):
248
+ if self.update_extract:
249
+ x = self.feature_extract([sample for sample in x])
250
+ else:
251
+ with torch.no_grad():
252
+ if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
253
+ x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
254
+ else:
255
+ x = self.feature_extract([sample for sample in x])
256
+
257
+ if self.feat_type == 'fbank':
258
+ x = x.log()
259
+
260
+ if self.feat_type != "fbank" and self.feat_type != "mfcc":
261
+ x = x[self.feature_selection]
262
+ if isinstance(x, (list, tuple)):
263
+ x = torch.stack(x, dim=0)
264
+ else:
265
+ x = x.unsqueeze(0)
266
+ norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
267
+ x = (norm_weights * x).sum(dim=0)
268
+ x = torch.transpose(x, 1, 2) + 1e-6
269
+
270
+ x = self.instance_norm(x)
271
+ return x
272
+
273
+ def forward(self, x):
274
+ x = self.get_feat(x)
275
+
276
+ out1 = self.layer1(x)
277
+ out2 = self.layer2(out1)
278
+ out3 = self.layer3(out2)
279
+ out4 = self.layer4(out3)
280
+
281
+ out = torch.cat([out2, out3, out4], dim=1)
282
+ out = F.relu(self.conv(out))
283
+ out = self.bn(self.pooling(out))
284
+ out = self.linear(out)
285
+
286
+ return out
287
+
288
+
289
+ def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
290
+ return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
291
+ feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
292
+
293
+ if __name__ == '__main__':
294
+ x = torch.zeros(2, 32000)
295
+ model = ECAPA_TDNN_SMALL(feat_dim=768, emb_dim=256, feat_type='hubert_base', feature_selection="hidden_states",
296
+ update_extract=False)
297
+
298
+ out = model(x)
299
+ # print(model)
300
+ print(out.shape)
301
+
UniSpeech/downstreams/speaker_verification/models/utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import fairseq
3
+ from packaging import version
4
+ import torch.nn.functional as F
5
+ from fairseq import tasks
6
+ from fairseq.checkpoint_utils import load_checkpoint_to_cpu
7
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
8
+ from omegaconf import OmegaConf
9
+ from s3prl.upstream.interfaces import UpstreamBase
10
+ from torch.nn.utils.rnn import pad_sequence
11
+
12
+ def load_model(filepath):
13
+ state = torch.load(filepath, map_location=lambda storage, loc: storage)
14
+ # state = load_checkpoint_to_cpu(filepath)
15
+ state["cfg"] = OmegaConf.create(state["cfg"])
16
+
17
+ if "args" in state and state["args"] is not None:
18
+ cfg = convert_namespace_to_omegaconf(state["args"])
19
+ elif "cfg" in state and state["cfg"] is not None:
20
+ cfg = state["cfg"]
21
+ else:
22
+ raise RuntimeError(
23
+ f"Neither args nor cfg exist in state keys = {state.keys()}"
24
+ )
25
+
26
+ task = tasks.setup_task(cfg.task)
27
+ if "task_state" in state:
28
+ task.load_state_dict(state["task_state"])
29
+
30
+ model = task.build_model(cfg.model)
31
+
32
+ return model, cfg, task
33
+
34
+
35
+ ###################
36
+ # UPSTREAM EXPERT #
37
+ ###################
38
+ class UpstreamExpert(UpstreamBase):
39
+ def __init__(self, ckpt, **kwargs):
40
+ super().__init__(**kwargs)
41
+ assert version.parse(fairseq.__version__) > version.parse(
42
+ "0.10.2"
43
+ ), "Please install the fairseq master branch."
44
+
45
+ model, cfg, task = load_model(ckpt)
46
+ self.model = model
47
+ self.task = task
48
+
49
+ if len(self.hooks) == 0:
50
+ module_name = "self.model.encoder.layers"
51
+ for module_id in range(len(eval(module_name))):
52
+ self.add_hook(
53
+ f"{module_name}[{module_id}]",
54
+ lambda input, output: input[0].transpose(0, 1),
55
+ )
56
+ self.add_hook("self.model.encoder", lambda input, output: output[0])
57
+
58
+ def forward(self, wavs):
59
+ if self.task.cfg.normalize:
60
+ wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
61
+
62
+ device = wavs[0].device
63
+ wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
64
+ wav_padding_mask = ~torch.lt(
65
+ torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
66
+ wav_lengths.unsqueeze(1),
67
+ )
68
+ padded_wav = pad_sequence(wavs, batch_first=True)
69
+
70
+ features, feat_padding_mask = self.model.extract_features(
71
+ padded_wav,
72
+ padding_mask=wav_padding_mask,
73
+ mask=None,
74
+ )
75
+ return {
76
+ "default": features,
77
+ }
78
+
UniSpeech/downstreams/speaker_verification/requirements.txt ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scipy==1.7.1 \
2
+ --hash=sha256:2a0eeaab01258e0870c4022a6cd329aef3b7c6c2b606bd7cf7bb2ba9820ae561 \
3
+ --hash=sha256:3304bd5bc32e00954ac4b3f4cc382ca8824719bf348aacbec6347337d6b125fe \
4
+ --hash=sha256:3f52470e0548cdb74fb8ddf06773ffdcca7c97550f903b1c51312ec19243a7f7 \
5
+ --hash=sha256:4729b41a4cdaf4cd011aeac816b532f990bdf97710cef59149d3e293115cf467 \
6
+ --hash=sha256:4ee952f39a4a4c7ba775a32b664b1f4b74818548b65f765987adc14bb78f5802 \
7
+ --hash=sha256:611f9cb459d0707dd8e4de0c96f86e93f61aac7475fcb225e9ec71fecdc5cebf \
8
+ --hash=sha256:6b47d5fa7ea651054362561a28b1ccc8da9368a39514c1bbf6c0977a1c376764 \
9
+ --hash=sha256:71cfc96297617eab911e22216e8a8597703202e95636d9406df9af5c2ac99a2b \
10
+ --hash=sha256:787749110a23502031fb1643c55a2236c99c6b989cca703ea2114d65e21728ef \
11
+ --hash=sha256:90c07ba5f34f33299a428b0d4fa24c30d2ceba44d63f8385b2b05be460819fcb \
12
+ --hash=sha256:a496b42dbcd04ea9924f5e92be63af3d8e0f43a274b769bfaca0a297327d54ee \
13
+ --hash=sha256:bc61e3e5ff92d2f32bb263621d54a9cff5e3f7c420af3d1fa122ce2529de2bd9 \
14
+ --hash=sha256:c9951e3746b68974125e5e3445008a4163dd6d20ae0bbdae22b38cb8951dc11b \
15
+ --hash=sha256:d1388fbac9dd591ea630da75c455f4cc637a7ca5ecb31a6b6cef430914749cde \
16
+ --hash=sha256:d13f31457f2216e5705304d9f28e2826edf75487410a57aa99263fa4ffd792c2 \
17
+ --hash=sha256:d648aa85dd5074b1ed83008ae987c3fbb53d68af619fce1dee231f4d8bd40e2f \
18
+ --hash=sha256:da9c6b336e540def0b7fd65603da8abeb306c5fc9a5f4238665cbbb5ff95cf58 \
19
+ --hash=sha256:e101bceeb9e65a90dadbc5ca31283403a2d4667b9c178db29109750568e8d112 \
20
+ --hash=sha256:efdd3825d54c58df2cc394366ca4b9166cf940a0ebddeb87b6c10053deb625ea
21
+ fire==0.4.0 \
22
+ --hash=sha256:c5e2b8763699d1142393a46d0e3e790c5eb2f0706082df8f647878842c216a62
23
+ sklearn==0.0 \
24
+ --hash=sha256:e23001573aa194b834122d2b9562459bf5ae494a2d59ca6b8aa22c85a44c0e31
25
+ s3prl==0.3.1 \
26
+ --hash=sha256:e497989b10d4e058b619cf3e7a547820fceb3fe18c14c566427eb7b8c770d62e
27
+ torchaudio==0.9.0 \
28
+ --hash=sha256:0a387e78eeaf6e0abd36df70e9d8a15d242b49c2507dbd9522568f5d4af5fb96 \
29
+ --hash=sha256:18763c05cb7d85a08b8ea960e40f6984e9513b02e76f4526d920493c701b0671 \
30
+ --hash=sha256:48e33bb96b7ff2dc10a778a695429dbd6dfc8c8baa0d7c9b63569cb002bb87cd \
31
+ --hash=sha256:62fd9393ddbe40aadaabef7595f5bff0057e39f7e519195a010731542815f5a4 \
32
+ --hash=sha256:76a5b8ea0e4ddafd5b8f24abdf1a6f7afe847d892570da13cf0fc9bceeac437f \
33
+ --hash=sha256:87520525da10b5f00d3e5e1180db6ee37b1fa305edb2260c7335e0859dbe634e \
34
+ --hash=sha256:9d3f5d6df7d91676e67a38a448253b74d77da723f8e24bd833ff7ed0f82fa4ef \
35
+ --hash=sha256:acf0d736a5c1ea6b94adf08b0a31670009b6e78dfe50a1b0bdabf2b0f7895dc0 \
36
+ --hash=sha256:ad221258fc5d1d446f2c1ce9a1bb54cc05ca2b208491d4eaa5af443f1c0f16a2 \
37
+ --hash=sha256:ba52ae64611773bec7fc664c29f9ea3e02c9e5c817693726b978ed1bdedd07f2 \
38
+ --hash=sha256:c6126556d529df73b676e023063388d551be3c0cb2d42a4ff5c4cfd44ef3e012 \
39
+ --hash=sha256:ef5f0b22646a94f95869001b40ab940468b1ae399d0ffd3bc73d5c43342a013a \
40
+ --hash=sha256:ef8dc4ab1ec807382a713e71e8493d1985930537c933273e3c0739f02183cedc \
41
+ --hash=sha256:efb16c593b2a5ada07b180580c7612617e84f4714ce86928ad54baefe71ef29d
42
+ sentencepiece==0.1.96 \
43
+ --hash=sha256:1dac8c2ad02b5ebc1179c0a14cbc7d7c6f4fd73d4dd51820626402d0aefc974e \
44
+ --hash=sha256:26d20d713b3ba1b7a19205336afb1e93a4327c372b2f795e907b8dc2315ac92e \
45
+ --hash=sha256:335bf84d72112cc91f3c3b691d61802fc963503b7772fd8280d20368048b8f3e \
46
+ --hash=sha256:36e9ff61e7b67c5b7ee96733613622620b4802fc8cf188a4dbc1f355b03dde02 \
47
+ --hash=sha256:384148cead5cdab34a4d74fe1fb6a5a8abaafed25eaa4a7698b49dd9482e4c4e \
48
+ --hash=sha256:3c703e68ea192e45b65c5d5836f6980849d828a18da4189899d7150fad82dc9e \
49
+ --hash=sha256:3e61e0757e49c306fff78ea75d6b75773418fe22214b4a460959203be934e834 \
50
+ --hash=sha256:466e381f0a812da8fda97a9707498cef3210ea8385a3421bcbadcb5384063969 \
51
+ --hash=sha256:48c6d13b3bfff08060c138248e85df60f6fad11135ad7a8fc2ef6005aacca839 \
52
+ --hash=sha256:4997c7ccf2ae462320250314aa5709a88d8a09fa271d073458a07bebf33f8e7c \
53
+ --hash=sha256:5388882bb24d083f6cc8cffc5c435f3694a7772b018e06ea6fd84d1044009efb \
54
+ --hash=sha256:5513298d62fe63dd0862d08a6eb52a9aa3537006f597f2386184e3f95bb88889 \
55
+ --hash=sha256:78e18d9106c36dcca929e18fd2c412378deac661d47fa3ee25defc55eef8a215 \
56
+ --hash=sha256:8179785883b556cd517416cdbda6244745414b00ec83132cfe1d26000971f3ae \
57
+ --hash=sha256:81bb77ba3651114943b2f8f77829cf764137dff06e38f4bf7fa43efea12c7f84 \
58
+ --hash=sha256:89c038da7f827a6e2ca4c73aeb4e4b25b99d981ce47dd61b04d446c8200cba1e \
59
+ --hash=sha256:940a6999c7d3f55e9d7b194fd5e1f41a7dbed26d3519fb95333216292a39599e \
60
+ --hash=sha256:99ea2d9db19e63a2d17d5dc64f9ace83fb9308a735be05a1aaf98eb4b496fba7 \
61
+ --hash=sha256:9bdf097d5bd1d8ce42dfee51f6ff05f5578b96e48c6f6006aa4eff69edfa3639 \
62
+ --hash=sha256:a336575463d75d3aac1f7e32470b8998643ccd9a73786bd726f6b0470520b6b4 \
63
+ --hash=sha256:a697257a2cd7581732d7741a8d32a06927f0311c3d277dbc47fa1043350c9d17 \
64
+ --hash=sha256:a92e1932ee8fd500680ccbe1bf53eb33228f4c9d6524ed6f300bcc80ac359f27 \
65
+ --hash=sha256:aeb090ad462833df03af1debce4ae607a2766ef861f992003ad0c56d074ab805 \
66
+ --hash=sha256:b1c24c1d9405b2148184ff27c062493d5e3be5c144575f95b5a0d7c660a515af \
67
+ --hash=sha256:b77d27f59d515c43b61745b8173fbe7c7b3014b14b3702a75bf1793471e7def6 \
68
+ --hash=sha256:b8b1dd2712f8a7de5b4c8ec912e6c041d25750bf03e1ce325cdba43bae0944ae \
69
+ --hash=sha256:bedf0355117fb4e9b1fc9fc92b4d5ee743a7d468be9f6196e3b94447710ea589 \
70
+ --hash=sha256:cc969e6694fb27fba7cee2953f350804faf03913f25ae1ee713a7b8a1bc08018 \
71
+ --hash=sha256:d45e3f78e746aa161bc9f5a31c6a2839c512101113a4065f4d2e7a3ab8198d8c \
72
+ --hash=sha256:d501713a8396193883aa526f48dc609f5f031a5df1afbafa561cf9ab492ffc76 \
73
+ --hash=sha256:d954d25a8705f972e8bfc1dea5464d7e697dd6f4ade092f1a487387e6d6c829a \
74
+ --hash=sha256:dadccb2e49244b6e64b4527d13ec14d5e094a90b41cf9b963e457e64182f1941 \
75
+ --hash=sha256:e811984b0908c14c56de7d8226fdd494d87a7ccb75af8ac3a07423037aaafc35 \
76
+ --hash=sha256:e88354b61f59dfdeb41023f7be8ae31dc627c2dc2dacbc2de8b2d82a0997135c \
77
+ --hash=sha256:e8ec5bb6777e2060e1499750c50e1b69dca5a0f80f90f2c66656c5f3e5244593 \
78
+ --hash=sha256:e9e9fe8094ca57549d801e9a2017ac5c24108bbf485ea4f8994a72e8e96ee135 \
79
+ --hash=sha256:eba0471ab0bb2e07ed06d91ecf5185d402c83d194155a41d8e2aa547d187712e \
80
+ --hash=sha256:ef59ba19340dc1d002ce5713b911c0ef23c577b08f8ed57998ee3c8e62c5bf6e \
81
+ --hash=sha256:f8c90df663cd9759b2cf8dd29998b63140ac39e51ada2e739dc13bdac0b4f001 \
82
+ --hash=sha256:f8cb24d8d0b2f8b7463815a59183eb81ec1d7a06e3217bed456063f3303eddfb \
83
+ --hash=sha256:fd907a8f744e5337de7fc532dd800c4416b571ea47f8c3c66be10cd1bc67c925 \
84
+ --hash=sha256:ff7d752a7f82d87711ec1a95c2262cb74f98be5b457f0300d81a1aefe5be2a95
85
+
UniSpeech/downstreams/speaker_verification/verification.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import soundfile as sf
2
+ import torch
3
+ import fire
4
+ import torch.nn.functional as F
5
+ from torchaudio.transforms import Resample
6
+ from models.ecapa_tdnn import ECAPA_TDNN_SMALL
7
+
8
+ MODEL_LIST = ['ecapa_tdnn', 'hubert_large', 'wav2vec2_xlsr', 'unispeech_sat', "wavlm_base_plus", "wavlm_large"]
9
+
10
+
11
+ def init_model(model_name, checkpoint=None):
12
+ if model_name == 'unispeech_sat':
13
+ config_path = 'config/unispeech_sat.th'
14
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='unispeech_sat', config_path=config_path)
15
+ elif model_name == 'wavlm_base_plus':
16
+ config_path = None
17
+ model = ECAPA_TDNN_SMALL(feat_dim=768, feat_type='wavlm_base_plus', config_path=config_path)
18
+ elif model_name == 'wavlm_large':
19
+ config_path = None
20
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=config_path)
21
+ elif model_name == 'hubert_large':
22
+ config_path = None
23
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='hubert_large_ll60k', config_path=config_path)
24
+ elif model_name == 'wav2vec2_xlsr':
25
+ config_path = None
26
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wav2vec2_xlsr', config_path=config_path)
27
+ else:
28
+ model = ECAPA_TDNN_SMALL(feat_dim=40, feat_type='fbank')
29
+
30
+ if checkpoint is not None:
31
+ state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)
32
+ model.load_state_dict(state_dict['model'], strict=False)
33
+ return model
34
+
35
+
36
+ def verification(model_name, wav1, wav2, use_gpu=True, checkpoint=None):
37
+
38
+ assert model_name in MODEL_LIST, 'The model_name should be in {}'.format(MODEL_LIST)
39
+ model = init_model(model_name, checkpoint)
40
+
41
+ wav1, sr1 = sf.read(wav1)
42
+ wav2, sr2 = sf.read(wav2)
43
+
44
+ wav1 = torch.from_numpy(wav1).unsqueeze(0).float()
45
+ wav2 = torch.from_numpy(wav2).unsqueeze(0).float()
46
+ resample1 = Resample(orig_freq=sr1, new_freq=16000)
47
+ resample2 = Resample(orig_freq=sr2, new_freq=16000)
48
+ wav1 = resample1(wav1)
49
+ wav2 = resample2(wav2)
50
+
51
+ if use_gpu:
52
+ model = model.cuda()
53
+ wav1 = wav1.cuda()
54
+ wav2 = wav2.cuda()
55
+
56
+ model.eval()
57
+ with torch.no_grad():
58
+ emb1 = model(wav1)
59
+ emb2 = model(wav2)
60
+
61
+ sim = F.cosine_similarity(emb1, emb2)
62
+ print("The similarity score between two audios is {:.4f} (-1.0, 1.0).".format(sim[0].item()))
63
+
64
+
65
+ if __name__ == "__main__":
66
+ fire.Fire(verification)
67
+
UniSpeech/downstreams/speaker_verification/vox1_data/David_Faustino/hn8GyCJIfLM_0000012.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c76f757c218238a1bb30fa3fb03ccfe3cdb2376e411b984536238296d30b0c0
3
+ size 157486
UniSpeech/downstreams/speaker_verification/vox1_data/David_Faustino/xTOk1Jz-F_g_0000015.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc70db51fba0095b2820a1afce3781b2327fa3e74677deeaa300eac9c740679c
3
+ size 285486
UniSpeech/downstreams/speaker_verification/vox1_data/Josh_Gad/HXUqYaOwrxA_0000015.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dca90fb57b11034d3adf378a17786336497eee71176fc4966ed4fdcd5c63c142
3
+ size 312366
UniSpeech/downstreams/speaker_verification/vox1_data/Josh_Gad/RFyw7V3SOnQ_0000001.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:981c26c1adc3a833872d44702ff8d0ee2da114096aa7ba11cb2ec683c3b0b95d
3
+ size 454446
UniSpeech/downstreams/speaker_verification/vox1_data/Lea_Thompson/HladKGyKTLM_0000006.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7263ce6b8c9b7e4bc137270798c94491d5f962854ef9187f84477e2118cee7eb
3
+ size 541486
UniSpeech/downstreams/speaker_verification/vox1_data/Lea_Thompson/mHTAr5dlAgc_0000004.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a09cc35b58825fac276e20a2cc71146b9c0143d5f64127235f56274a02300ea3
3
+ size 140846
UniSpeech/downstreams/speaker_verification/vox1_data/Zulay_Henao/WbB8m9-wlIQ_0000001.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08fbfaa3d34cbf1abfc56721ece687d306adc7b2bf2d725ca8b5f0dd10d5d84f
3
+ size 160046
UniSpeech/downstreams/speaker_verification/vox1_data/Zulay_Henao/gFfcgOVmiO0_0000002.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b47b677fba76b3125537c7562b0be223f3ae7cf881bdac6d7b1c59311c02c7d
3
+ size 253486
UniSpeech/src/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ ## Enforcement
56
+
57
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
+ reported by contacting the project team at <conduct@pytorch.org>. All
59
+ complaints will be reviewed and investigated and will result in a response that
60
+ is deemed necessary and appropriate to the circumstances. The project team is
61
+ obligated to maintain confidentiality with regard to the reporter of an incident.
62
+ Further details of specific enforcement policies may be posted separately.
63
+
64
+ Project maintainers who do not follow or enforce the Code of Conduct in good
65
+ faith may face temporary or permanent repercussions as determined by other
66
+ members of the project's leadership.
67
+
68
+ ## Attribution
69
+
70
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72
+
73
+ [homepage]: https://www.contributor-covenant.org
74
+
75
+ For answers to common questions about this code of conduct, see
76
+ https://www.contributor-covenant.org/faq
77
+
UniSpeech/src/CONTRIBUTING.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `master`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Facebook's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ ## License
26
+ By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
27
+ you agree that your contributions will be licensed under the LICENSE file in
28
+ the root directory of this source tree.
UniSpeech/src/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Facebook, Inc. and its affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
UniSpeech/src/config/config.yaml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ common:
3
+ no_progress_bar: false
4
+ log_interval: 100
5
+ log_format: null
6
+ tensorboard_logdir: null
7
+ seed: 1
8
+ cpu: false
9
+ tpu: false
10
+ bf16: false
11
+ fp16: false
12
+ memory_efficient_fp16: false
13
+ memory_efficient_bf16: false
14
+ fp16_no_flatten_grads: false
15
+ fp16_init_scale: 128
16
+ fp16_scale_window: null
17
+ fp16_scale_tolerance: 0.0
18
+ min_loss_scale: 1.0e-4
19
+ threshold_loss_scale: null
20
+ user_dir: null
21
+ empty_cache_freq: 0
22
+ all_gather_list_size: 16384
23
+ model_parallel_size: 1
24
+ quantization_config_path: null
25
+ profile: false
26
+ distributed_training:
27
+ distributed_rank: 0
28
+ distributed_backend: "nccl"
29
+ distributed_init_method: null
30
+ distributed_port: -1
31
+ device_id: 0
32
+ local_rank: 0
33
+ distributed_no_spawn: false
34
+ ddp_backend: "c10d"
35
+ bucket_cap_mb: 25
36
+ fix_batches_to_gpus: false
37
+ find_unused_parameters: false
38
+ fast_stat_sync: false
39
+ broadcast_buffers: false
40
+ distributed_wrapper: "DDP"
41
+ slowmo_momentum: null
42
+ slowmo_algorithm: "LocalSGD"
43
+ localsgd_frequency: 3
44
+ dataset:
45
+ num_workers: 1
46
+ skip_invalid_size_inputs_valid_test: false
47
+ max_tokens: null
48
+ batch_size: null
49
+ required_batch_size_multiple: 8
50
+ dataset_impl: null
51
+ data_buffer_size: 10
52
+ train_subset: "train"
53
+ valid_subset: "valid"
54
+ validate_interval: 1
55
+ fixed_validation_seed: null
56
+ disable_validation: false
57
+ curriculum: 0
58
+ gen_subset: "test"
59
+ num_shards: 1
60
+ shard_id: 0
61
+ max_tokens_valid: ${dataset.max_tokens}
62
+ batch_size_valid: ${dataset.batch_size}
63
+ optimization:
64
+ max_epoch: 0
65
+ max_update: 0
66
+ clip_norm: 25.0
67
+ sentence_avg: false
68
+ update_freq: [ 1 ]
69
+ lr: [ 0.25 ]
70
+ min_lr: -1.0
71
+ use_bmuf: false
72
+ checkpoint:
73
+ save_dir: "checkpoints"
74
+ restore_file: "checkpoint_last.pt"
75
+ reset_dataloader: false
76
+ reset_lr_scheduler: false
77
+ reset_meters: false
78
+ reset_optimizer: false
79
+ optimizer_overrides: "{}"
80
+ save_interval: 1
81
+ save_interval_updates: 0
82
+ keep_interval_updates: -1
83
+ keep_last_epochs: -1
84
+ keep_best_checkpoints: -1
85
+ no_save: false
86
+ no_epoch_checkpoints: false
87
+ no_last_checkpoints: false
88
+ no_save_optimizer_state: false
89
+ best_checkpoint_metric: "loss"
90
+ maximize_best_checkpoint_metric: false
91
+ patience: -1
92
+ checkpoint_suffix: ""
93
+ bmuf:
94
+ block_lr: 1
95
+ block_momentum: 0.875
96
+ global_sync_iter: 50
97
+ warmup_iterations: 500
98
+ use_nbm: false
99
+ average_sync: false
100
+ defaults:
101
+ - task: language_modeling
102
+ - model: null
103
+ - criterion: null
104
+ - optimizer: null
105
+ - lr_scheduler: null
106
+ - bpe: null
107
+ - tokenizer: null
108
+ - scoring: null
109
+ - generation: null
110
+ - common_eval: null
111
+ - eval_lm: null