haoxiangsnr commited on
Commit
9c38bf7
·
verified ·
1 Parent(s): 7926343

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Amphion/egs/tts/VITS/README.md +221 -0
  2. Amphion/egs/vocoder/diffusion/exp_config_base.json +71 -0
  3. Amphion/egs/vocoder/gan/bigvgan_large/run.sh +141 -0
  4. Amphion/evaluation/metrics/similarity/__init__.py +0 -0
  5. Amphion/models/base/base_dataset.py +464 -0
  6. Amphion/models/base/base_inference.py +220 -0
  7. Amphion/models/base/new_inference.py +253 -0
  8. Amphion/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
  9. Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/act.cpython-310.pyc +0 -0
  10. Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/filter.cpython-310.pyc +0 -0
  11. Amphion/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
  12. Amphion/models/codec/ns3_codec/facodec.py +1163 -0
  13. Amphion/models/codec/ns3_codec/gradient_reversal.py +35 -0
  14. Amphion/models/codec/ns3_codec/melspec.py +102 -0
  15. Amphion/models/codec/ns3_codec/quantize/__pycache__/__init__.cpython-310.pyc +0 -0
  16. Amphion/models/codec/ns3_codec/quantize/fvq.py +116 -0
  17. Amphion/models/svc/transformer/transformer_inference.py +45 -0
  18. Amphion/models/svc/vits/vits_trainer.py +704 -0
  19. Amphion/models/tta/autoencoder/autoencoder_dataset.py +112 -0
  20. Amphion/models/tta/ldm/__init__.py +0 -0
  21. Amphion/models/tta/ldm/audioldm_dataset.py +151 -0
  22. Amphion/models/tta/ldm/audioldm_trainer.py +251 -0
  23. Amphion/models/tta/ldm/inference_utils/vocoder.py +408 -0
  24. Amphion/models/tts/base/__init__.py +7 -0
  25. Amphion/models/tts/base/tts_trainer.py +721 -0
  26. Amphion/models/tts/fastspeech2/fs2_trainer.py +155 -0
  27. Amphion/models/tts/naturalspeech2/ns2.py +259 -0
  28. Amphion/models/tts/naturalspeech2/ns2_dataset.py +524 -0
  29. Amphion/models/tts/naturalspeech2/ns2_inference.py +128 -0
  30. Amphion/models/tts/naturalspeech2/ns2_trainer.py +798 -0
  31. Amphion/models/tts/valle/__init__.py +0 -0
  32. Amphion/models/vocoders/autoregressive/autoregressive_vocoder_inference.py +0 -0
  33. Amphion/models/vocoders/autoregressive/wavenet/conv.py +66 -0
  34. Amphion/models/vocoders/autoregressive/wavenet/wavenet.py +170 -0
  35. Amphion/models/vocoders/diffusion/diffusion_vocoder_inference.py +131 -0
  36. Amphion/models/vocoders/flow/flow_vocoder_dataset.py +0 -0
  37. Amphion/models/vocoders/flow/flow_vocoder_inference.py +0 -0
  38. Amphion/models/vocoders/gan/discriminator/msd.py +88 -0
  39. Amphion/models/vocoders/gan/gan_vocoder_inference.py +96 -0
  40. Amphion/models/vocoders/vocoder_dataset.py +264 -0
  41. Amphion/models/vocoders/vocoder_sampler.py +126 -0
  42. Amphion/modules/activation_functions/snake.py +122 -0
  43. Amphion/modules/diffusion/bidilconv/bidilated_conv.py +102 -0
  44. Amphion/modules/diffusion/karras/sample.py +185 -0
  45. Amphion/modules/diffusion/unet/attention.py +241 -0
  46. Amphion/modules/diffusion/unet/resblock.py +178 -0
  47. Amphion/modules/diffusion/unet/unet.py +310 -0
  48. Amphion/modules/encoder/__init__.py +1 -0
  49. Amphion/modules/general/utils.py +100 -0
  50. Amphion/modules/norms/__init__.py +1 -0
Amphion/egs/tts/VITS/README.md ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VITS Recipe
2
+
3
+ [![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Spaces-yellow)](https://huggingface.co/spaces/amphion/Text-to-Speech)
4
+ [![openxlab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Amphion/Text-to-Speech)
5
+
6
+ In this recipe, we will show how to train VITS using Amphion's infrastructure. [VITS](https://arxiv.org/abs/2106.06103) is an end-to-end TTS architecture that utilizes a conditional variational autoencoder with adversarial learning.
7
+
8
+ There are four stages in total:
9
+
10
+ 1. Data preparation
11
+ 2. Features extraction
12
+ 3. Training
13
+ 4. Inference
14
+
15
+ > **NOTE:** You need to run every command of this recipe in the `Amphion` root path:
16
+ > ```bash
17
+ > cd Amphion
18
+ > ```
19
+
20
+ ## 1. Data Preparation
21
+
22
+ ### Dataset Download
23
+ You can use the commonly used TTS dataset to train the TTS model, e.g., LJSpeech, VCTK, Hi-Fi TTS, LibriTTS, etc. We strongly recommend using LJSpeech to train the single-speaker TTS model for the first time. While training the multi-speaker TTS model for the first time, we recommend using Hi-Fi TTS. The process of downloading the dataset has been detailed [here](../../datasets/README.md).
24
+
25
+ ### Configuration
26
+
27
+ After downloading the dataset, you can set the dataset paths in `exp_config.json`. Note that you can change the `dataset` list to use your preferred datasets.
28
+
29
+ ```json
30
+ "dataset": [
31
+ "LJSpeech",
32
+ //"hifitts"
33
+ ],
34
+ "dataset_path": {
35
+ // TODO: Fill in your dataset path
36
+ "LJSpeech": "[LJSpeech dataset path]",
37
+ //"hifitts": "[Hi-Fi TTS dataset path]
38
+ },
39
+ ```
40
+
41
+ ## 2. Features Extraction
42
+
43
+ ### Configuration
44
+
45
+ In `exp_config.json`, specify the `log_dir` for saving the checkpoints and logs, and specify the `processed_dir` for saving processed data. For preprocessing the multi-speaker TTS dataset, set `extract_audio` and `use_spkid` to `true`:
46
+
47
+ ```json
48
+ // TODO: Fill in the output log path. The default value is "Amphion/ckpts/tts"
49
+ "log_dir": "ckpts/tts",
50
+ "preprocess": {
51
+ //"extract_audio": true,
52
+ "use_phone": true,
53
+ // linguistic features
54
+ "extract_phone": true,
55
+ "phone_extractor": "espeak", // "espeak, pypinyin, pypinyin_initials_finals, lexicon (only for language=en-us right now)"
56
+ // TODO: Fill in the output data path. The default value is "Amphion/data"
57
+ "processed_dir": "data",
58
+ "sample_rate": 22050, //target sampling rate
59
+ "valid_file": "valid.json", //validation set
60
+ //"use_spkid": true, //use speaker ID to train multi-speaker TTS model
61
+ },
62
+ ```
63
+
64
+ ### Run
65
+
66
+ Run the `run.sh` as the preprocess stage (set `--stage 1`):
67
+
68
+ ```bash
69
+ sh egs/tts/VITS/run.sh --stage 1
70
+ ```
71
+
72
+ > **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "1"`.
73
+
74
+ ## 3. Training
75
+
76
+ ### Configuration
77
+
78
+ We provide the default hyperparameters in the `exp_config.json`. They can work on a single NVIDIA-24g GPU. You can adjust them based on your GPU machines.
79
+ For training the multi-speaker TTS model, specify the `n_speakers` value to be greater (used for new speaker fine-tuning) than or equal to the number of speakers in your dataset(s) and set `multi_speaker_training` to `true`.
80
+
81
+ ```json
82
+ "model": {
83
+ //"n_speakers": 10 //Number of speakers in the dataset(s) used. The default value is 0 if not specified.
84
+ },
85
+ "train": {
86
+ "batch_size": 16,
87
+ //"multi_speaker_training": true,
88
+ }
89
+ ```
90
+
91
+ ### Train From Scratch
92
+
93
+ Run the `run.sh` as the training stage (set `--stage 2`). Specify an experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/tts/[YourExptName]`.
94
+
95
+ ```bash
96
+ sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName]
97
+ ```
98
+
99
+ ### Train From Existing Source
100
+
101
+ We support training from existing sources for various purposes. You can resume training the model from a checkpoint or fine-tune a model from another checkpoint.
102
+
103
+ By setting `--resume true`, the training will resume from the **latest checkpoint** from the current `[YourExptName]` by default. For example, if you want to resume training from the latest checkpoint in `Amphion/ckpts/tts/[YourExptName]/checkpoint`, run:
104
+
105
+ ```bash
106
+ sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName] \
107
+ --resume true
108
+ ```
109
+
110
+ You can also choose a **specific checkpoint** for retraining by `--resume_from_ckpt_path` argument. For example, if you want to resume training from the checkpoint `Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificCheckpoint]`, run:
111
+
112
+ ```bash
113
+ sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName] \
114
+ --resume true \
115
+ --resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificCheckpoint]"
116
+ ```
117
+
118
+ If you want to **fine-tune from another checkpoint**, just use `--resume_type` and set it to `"finetune"`. For example, If you want to fine-tune the model from the checkpoint `Amphion/ckpts/tts/[AnotherExperiment]/checkpoint/[SpecificCheckpoint]`, run:
119
+
120
+
121
+ ```bash
122
+ sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName] \
123
+ --resume true \
124
+ --resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificCheckpoint]" \
125
+ --resume_type "finetune"
126
+ ```
127
+
128
+ > **NOTE:** The `--resume_type` is set as `"resume"` in default. It's not necessary to specify it when resuming training.
129
+ >
130
+ > The difference between `"resume"` and `"finetune"` is that the `"finetune"` will **only** load the pretrained model weights from the checkpoint, while the `"resume"` will load all the training states (including optimizer, scheduler, etc.) from the checkpoint.
131
+
132
+ Here are some example scenarios to better understand how to use these arguments:
133
+ | Scenario | `--resume` | `--resume_from_ckpt_path` | `--resume_type` |
134
+ | ------ | -------- | ----------------------- | ------------- |
135
+ | You want to train from scratch | no | no | no |
136
+ | The machine breaks down during training and you want to resume training from the latest checkpoint | `true` | no | no |
137
+ | You find the latest model is overfitting and you want to re-train from the checkpoint before | `true` | `SpecificCheckpoint Path` | no |
138
+ | You want to fine-tune a model from another checkpoint | `true` | `SpecificCheckpoint Path` | `"finetune"` |
139
+
140
+
141
+ > **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`.
142
+
143
+
144
+ ## 4. Inference
145
+
146
+ ### Pre-trained Model Download
147
+
148
+ We released a pre-trained Amphion VITS model trained on LJSpeech. So you can download the pre-trained model [here](https://huggingface.co/amphion/vits-ljspeech) and generate speech according to the following inference instruction.
149
+
150
+
151
+ ### Configuration
152
+
153
+ For inference, you need to specify the following configurations when running `run.sh`:
154
+
155
+
156
+ | Parameters | Description | Example |
157
+ | --------------------- | -------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
158
+ | `--infer_expt_dir` | The experimental directory which contains `checkpoint` | `Amphion/ckpts/tts/[YourExptName]` |
159
+ | `--infer_output_dir` | The output directory to save inferred audios. | `Amphion/ckpts/tts/[YourExptName]/result` |
160
+ | `--infer_mode` | The inference mode, e.g., "`single`", "`batch`". | "`single`" to generate a clip of speech, "`batch`" to generate a batch of speech at a time. |
161
+ | `--infer_dataset` | The dataset used for inference. | For LJSpeech dataset, the inference dataset would be `LJSpeech`.<br> For Hi-Fi TTS dataset, the inference dataset would be `hifitts`. |
162
+ | `--infer_testing_set` | The subset of the inference dataset used for inference, e.g., train, test, golden_test | For LJSpeech dataset, the testing set would be  "`test`" split from LJSpeech at the feature extraction, or "`golden_test`" cherry-picked from the test set as template testing set.<br>For Hi-Fi TTS dataset, the testing set would be "`test`" split from Hi-Fi TTS during the feature extraction process. |
163
+ | `--infer_text` | The text to be synthesized. | "`This is a clip of generated speech with the given text from a TTS model.`" |
164
+ | `--infer_speaker_name` | The target speaker's voice is to be synthesized.<br> (***Note: only applicable to multi-speaker TTS model***) | For Hi-Fi TTS dataset, the list of available speakers includes: "`hifitts_11614`", "`hifitts_11697`", "`hifitts_12787`", "`hifitts_6097`", "`hifitts_6670`", "`hifitts_6671`", "`hifitts_8051`", "`hifitts_9017`", "`hifitts_9136`", "`hifitts_92`". <br> You may find the list of available speakers from `spk2id.json` file generated in ```log_dir/[YourExptName]``` that you have specified in `exp_config.json`. |
165
+
166
+ ### Run
167
+ #### Single text inference:
168
+ For the single-speaker TTS model, if you want to generate a single clip of speech from a given text, just run:
169
+
170
+ ```bash
171
+ sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \
172
+ --infer_expt_dir Amphion/ckpts/tts/[YourExptName] \
173
+ --infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \
174
+ --infer_mode "single" \
175
+ --infer_text "This is a clip of generated speech with the given text from a TTS model."
176
+ ```
177
+
178
+ For the multi-speaker TTS model, in addition to the above-mentioned arguments, you need to add ```infer_speaker_name``` argument, and run:
179
+ ```bash
180
+ sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \
181
+ --infer_expt_dir Amphion/ckpts/tts/[YourExptName] \
182
+ --infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \
183
+ --infer_mode "single" \
184
+ --infer_text "This is a clip of generated speech with the given text from a TTS model." \
185
+ --infer_speaker_name "hifitts_92"
186
+ ```
187
+
188
+ #### Batch inference:
189
+ For the single-speaker TTS model, if you want to generate speech of all testing sets split from LJSpeech, just run:
190
+
191
+ ```bash
192
+ sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \
193
+ --infer_expt_dir Amphion/ckpts/tts/[YourExptName] \
194
+ --infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \
195
+ --infer_mode "batch" \
196
+ --infer_dataset "LJSpeech" \
197
+ --infer_testing_set "test"
198
+ ```
199
+ For the multi-speaker TTS model, if you want to generate speech of all testing sets split from Hi-Fi TTS, the same procedure follows from above, with ```LJSpeech``` replaced by ```hifitts```.
200
+ ```bash
201
+ sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \
202
+ --infer_expt_dir Amphion/ckpts/tts/[YourExptName] \
203
+ --infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \
204
+ --infer_mode "batch" \
205
+ --infer_dataset "hifitts" \
206
+ --infer_testing_set "test"
207
+ ```
208
+
209
+
210
+ We released a pre-trained Amphion VITS model trained on LJSpeech. So, you can download the pre-trained model [here](https://huggingface.co/amphion/vits-ljspeech) and generate speech following the above inference instructions. Meanwhile, the pre-trained multi-speaker VITS model trained on Hi-Fi TTS will be released soon. Stay tuned.
211
+
212
+
213
+ ```bibtex
214
+ @inproceedings{kim2021conditional,
215
+ title={Conditional variational autoencoder with adversarial learning for end-to-end text-to-speech},
216
+ author={Kim, Jaehyeon and Kong, Jungil and Son, Juhee},
217
+ booktitle={International Conference on Machine Learning},
218
+ pages={5530--5540},
219
+ year={2021},
220
+ }
221
+ ```
Amphion/egs/vocoder/diffusion/exp_config_base.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/vocoder.json",
3
+ "model_type": "DiffusionVocoder",
4
+ // TODO: Choose your needed datasets
5
+ "dataset": [
6
+ "csd",
7
+ "kising",
8
+ "m4singer",
9
+ "nus48e",
10
+ "opencpop",
11
+ "opensinger",
12
+ "opera",
13
+ "pjs",
14
+ "popbutfy",
15
+ "popcs",
16
+ "ljspeech",
17
+ "vctk",
18
+ "libritts",
19
+ ],
20
+ "dataset_path": {
21
+ // TODO: Fill in your dataset path
22
+ "csd": "[dataset path]",
23
+ "kising": "[dataset path]",
24
+ "m4singer": "[dataset path]",
25
+ "nus48e": "[dataset path]",
26
+ "opencpop": "[dataset path]",
27
+ "opensinger": "[dataset path]",
28
+ "opera": "[dataset path]",
29
+ "pjs": "[dataset path]",
30
+ "popbutfy": "[dataset path]",
31
+ "popcs": "[dataset path]",
32
+ "ljspeech": "[dataset path]",
33
+ "vctk": "[dataset path]",
34
+ "libritts": "[dataset path]",
35
+ },
36
+ // TODO: Fill in the output log path
37
+ "log_dir": "ckpts/vocoder",
38
+ "preprocess": {
39
+ // Acoustic features
40
+ "extract_mel": true,
41
+ "extract_audio": true,
42
+ "extract_pitch": false,
43
+ "extract_uv": false,
44
+ "pitch_extractor": "parselmouth",
45
+
46
+ // Features used for model training
47
+ "use_mel": true,
48
+ "use_frame_pitch": false,
49
+ "use_uv": false,
50
+ "use_audio": true,
51
+
52
+ // TODO: Fill in the output data path
53
+ "processed_dir": "data/",
54
+ "n_mel": 100,
55
+ "sample_rate": 24000
56
+ },
57
+ "train": {
58
+ // TODO: Choose a suitable batch size, training epoch, and save stride
59
+ "batch_size": 32,
60
+ "max_epoch": 1000000,
61
+ "save_checkpoint_stride": [20],
62
+ "adamw": {
63
+ "lr": 2.0e-4,
64
+ "adam_b1": 0.8,
65
+ "adam_b2": 0.99
66
+ },
67
+ "exponential_lr": {
68
+ "lr_decay": 0.999
69
+ },
70
+ }
71
+ }
Amphion/egs/vocoder/gan/bigvgan_large/run.sh ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ ######## Build Experiment Environment ###########
7
+ exp_dir=$(cd `dirname $0`; pwd)
8
+ work_dir=$(dirname $(dirname $(dirname $(dirname $exp_dir))))
9
+
10
+ export WORK_DIR=$work_dir
11
+ export PYTHONPATH=$work_dir
12
+ export PYTHONIOENCODING=UTF-8
13
+
14
+ ######## Parse the Given Parameters from the Commond ###########
15
+ options=$(getopt -o c:n:s --long gpu:,config:,name:,stage:,checkpoint:,resume_type:,main_process_port:,infer_mode:,infer_datasets:,infer_feature_dir:,infer_audio_dir:,infer_expt_dir:,infer_output_dir: -- "$@")
16
+ eval set -- "$options"
17
+
18
+ while true; do
19
+ case $1 in
20
+ # Experimental Configuration File
21
+ -c | --config) shift; exp_config=$1 ; shift ;;
22
+ # Experimental Name
23
+ -n | --name) shift; exp_name=$1 ; shift ;;
24
+ # Running Stage
25
+ -s | --stage) shift; running_stage=$1 ; shift ;;
26
+ # Visible GPU machines. The default value is "0".
27
+ --gpu) shift; gpu=$1 ; shift ;;
28
+
29
+ # [Only for Training] The specific checkpoint path that you want to resume from.
30
+ --checkpoint) shift; checkpoint=$1 ; shift ;;
31
+ # [Only for Training] `resume` for loading all the things (including model weights, optimizer, scheduler, and random states). `finetune` for loading only the model weights.
32
+ --resume_type) shift; resume_type=$1 ; shift ;;
33
+ # [Only for Traiing] `main_process_port` for multi gpu training
34
+ --main_process_port) shift; main_process_port=$1 ; shift ;;
35
+
36
+ # [Only for Inference] The inference mode
37
+ --infer_mode) shift; infer_mode=$1 ; shift ;;
38
+ # [Only for Inference] The inferenced datasets
39
+ --infer_datasets) shift; infer_datasets=$1 ; shift ;;
40
+ # [Only for Inference] The feature dir for inference
41
+ --infer_feature_dir) shift; infer_feature_dir=$1 ; shift ;;
42
+ # [Only for Inference] The audio dir for inference
43
+ --infer_audio_dir) shift; infer_audio_dir=$1 ; shift ;;
44
+ # [Only for Inference] The experiment dir. The value is like "[Your path to save logs and checkpoints]/[YourExptName]"
45
+ --infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;;
46
+ # [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result"
47
+ --infer_output_dir) shift; infer_output_dir=$1 ; shift ;;
48
+
49
+ --) shift ; break ;;
50
+ *) echo "Invalid option: $1" exit 1 ;;
51
+ esac
52
+ done
53
+
54
+
55
+ ### Value check ###
56
+ if [ -z "$running_stage" ]; then
57
+ echo "[Error] Please specify the running stage"
58
+ exit 1
59
+ fi
60
+
61
+ if [ -z "$exp_config" ]; then
62
+ exp_config="${exp_dir}"/exp_config.json
63
+ fi
64
+ echo "Exprimental Configuration File: $exp_config"
65
+
66
+ if [ -z "$gpu" ]; then
67
+ gpu="0"
68
+ fi
69
+
70
+ if [ -z "$main_process_port" ]; then
71
+ main_process_port=29500
72
+ fi
73
+ echo "Main Process Port: $main_process_port"
74
+
75
+ ######## Features Extraction ###########
76
+ if [ $running_stage -eq 1 ]; then
77
+ CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/vocoder/preprocess.py \
78
+ --config $exp_config \
79
+ --num_workers 8
80
+ fi
81
+
82
+ ######## Training ###########
83
+ if [ $running_stage -eq 2 ]; then
84
+ if [ -z "$exp_name" ]; then
85
+ echo "[Error] Please specify the experiments name"
86
+ exit 1
87
+ fi
88
+ echo "Exprimental Name: $exp_name"
89
+
90
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch \
91
+ --main_process_port "$main_process_port" \
92
+ "${work_dir}"/bins/vocoder/train.py \
93
+ --config "$exp_config" \
94
+ --exp_name "$exp_name" \
95
+ --log_level info \
96
+ --checkpoint "$checkpoint" \
97
+ --resume_type "$resume_type"
98
+ fi
99
+
100
+ ######## Inference/Conversion ###########
101
+ if [ $running_stage -eq 3 ]; then
102
+ if [ -z "$infer_expt_dir" ]; then
103
+ echo "[Error] Please specify the experimental directionary. The value is like [Your path to save logs and checkpoints]/[YourExptName]"
104
+ exit 1
105
+ fi
106
+
107
+ if [ -z "$infer_output_dir" ]; then
108
+ infer_output_dir="$infer_expt_dir/result"
109
+ fi
110
+
111
+ if [ $infer_mode = "infer_from_dataset" ]; then
112
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
113
+ --config $exp_config \
114
+ --infer_mode $infer_mode \
115
+ --infer_datasets $infer_datasets \
116
+ --vocoder_dir $infer_expt_dir \
117
+ --output_dir $infer_output_dir \
118
+ --log_level debug
119
+ fi
120
+
121
+ if [ $infer_mode = "infer_from_feature" ]; then
122
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
123
+ --config $exp_config \
124
+ --infer_mode $infer_mode \
125
+ --feature_folder $infer_feature_dir \
126
+ --vocoder_dir $infer_expt_dir \
127
+ --output_dir $infer_output_dir \
128
+ --log_level debug
129
+ fi
130
+
131
+ if [ $infer_mode = "infer_from_audio" ]; then
132
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
133
+ --config $exp_config \
134
+ --infer_mode $infer_mode \
135
+ --audio_folder $infer_audio_dir \
136
+ --vocoder_dir $infer_expt_dir \
137
+ --output_dir $infer_output_dir \
138
+ --log_level debug
139
+ fi
140
+
141
+ fi
Amphion/evaluation/metrics/similarity/__init__.py ADDED
File without changes
Amphion/models/base/base_dataset.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import numpy as np
8
+ import torch.utils.data
9
+ from torch.nn.utils.rnn import pad_sequence
10
+ import librosa
11
+
12
+ from utils.data_utils import *
13
+ from processors.acoustic_extractor import cal_normalized_mel
14
+ from text import text_to_sequence
15
+ from text.text_token_collation import phoneIDCollation
16
+
17
+
18
+ class BaseOfflineDataset(torch.utils.data.Dataset):
19
+ def __init__(self, cfg, dataset, is_valid=False):
20
+ """
21
+ Args:
22
+ cfg: config
23
+ dataset: dataset name
24
+ is_valid: whether to use train or valid dataset
25
+ """
26
+
27
+ assert isinstance(dataset, str)
28
+
29
+ # self.data_root = processed_data_dir
30
+ self.cfg = cfg
31
+
32
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
33
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
34
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
35
+ self.metadata = self.get_metadata()
36
+
37
+ """
38
+ load spk2id and utt2spk from json file
39
+ spk2id: {spk1: 0, spk2: 1, ...}
40
+ utt2spk: {dataset_uid: spk1, ...}
41
+ """
42
+ if cfg.preprocess.use_spkid:
43
+ spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
44
+ with open(spk2id_path, "r") as f:
45
+ self.spk2id = json.load(f)
46
+
47
+ utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
48
+ self.utt2spk = dict()
49
+ with open(utt2spk_path, "r") as f:
50
+ for line in f.readlines():
51
+ utt, spk = line.strip().split("\t")
52
+ self.utt2spk[utt] = spk
53
+
54
+ if cfg.preprocess.use_uv:
55
+ self.utt2uv_path = {}
56
+ for utt_info in self.metadata:
57
+ dataset = utt_info["Dataset"]
58
+ uid = utt_info["Uid"]
59
+ utt = "{}_{}".format(dataset, uid)
60
+ self.utt2uv_path[utt] = os.path.join(
61
+ cfg.preprocess.processed_dir,
62
+ dataset,
63
+ cfg.preprocess.uv_dir,
64
+ uid + ".npy",
65
+ )
66
+
67
+ if cfg.preprocess.use_frame_pitch:
68
+ self.utt2frame_pitch_path = {}
69
+ for utt_info in self.metadata:
70
+ dataset = utt_info["Dataset"]
71
+ uid = utt_info["Uid"]
72
+ utt = "{}_{}".format(dataset, uid)
73
+
74
+ self.utt2frame_pitch_path[utt] = os.path.join(
75
+ cfg.preprocess.processed_dir,
76
+ dataset,
77
+ cfg.preprocess.pitch_dir,
78
+ uid + ".npy",
79
+ )
80
+
81
+ if cfg.preprocess.use_frame_energy:
82
+ self.utt2frame_energy_path = {}
83
+ for utt_info in self.metadata:
84
+ dataset = utt_info["Dataset"]
85
+ uid = utt_info["Uid"]
86
+ utt = "{}_{}".format(dataset, uid)
87
+
88
+ self.utt2frame_energy_path[utt] = os.path.join(
89
+ cfg.preprocess.processed_dir,
90
+ dataset,
91
+ cfg.preprocess.energy_dir,
92
+ uid + ".npy",
93
+ )
94
+
95
+ if cfg.preprocess.use_mel:
96
+ self.utt2mel_path = {}
97
+ for utt_info in self.metadata:
98
+ dataset = utt_info["Dataset"]
99
+ uid = utt_info["Uid"]
100
+ utt = "{}_{}".format(dataset, uid)
101
+
102
+ self.utt2mel_path[utt] = os.path.join(
103
+ cfg.preprocess.processed_dir,
104
+ dataset,
105
+ cfg.preprocess.mel_dir,
106
+ uid + ".npy",
107
+ )
108
+
109
+ if cfg.preprocess.use_linear:
110
+ self.utt2linear_path = {}
111
+ for utt_info in self.metadata:
112
+ dataset = utt_info["Dataset"]
113
+ uid = utt_info["Uid"]
114
+ utt = "{}_{}".format(dataset, uid)
115
+
116
+ self.utt2linear_path[utt] = os.path.join(
117
+ cfg.preprocess.processed_dir,
118
+ dataset,
119
+ cfg.preprocess.linear_dir,
120
+ uid + ".npy",
121
+ )
122
+
123
+ if cfg.preprocess.use_audio:
124
+ self.utt2audio_path = {}
125
+ for utt_info in self.metadata:
126
+ dataset = utt_info["Dataset"]
127
+ uid = utt_info["Uid"]
128
+ utt = "{}_{}".format(dataset, uid)
129
+
130
+ self.utt2audio_path[utt] = os.path.join(
131
+ cfg.preprocess.processed_dir,
132
+ dataset,
133
+ cfg.preprocess.audio_dir,
134
+ uid + ".npy",
135
+ )
136
+ elif cfg.preprocess.use_label:
137
+ self.utt2label_path = {}
138
+ for utt_info in self.metadata:
139
+ dataset = utt_info["Dataset"]
140
+ uid = utt_info["Uid"]
141
+ utt = "{}_{}".format(dataset, uid)
142
+
143
+ self.utt2label_path[utt] = os.path.join(
144
+ cfg.preprocess.processed_dir,
145
+ dataset,
146
+ cfg.preprocess.label_dir,
147
+ uid + ".npy",
148
+ )
149
+ elif cfg.preprocess.use_one_hot:
150
+ self.utt2one_hot_path = {}
151
+ for utt_info in self.metadata:
152
+ dataset = utt_info["Dataset"]
153
+ uid = utt_info["Uid"]
154
+ utt = "{}_{}".format(dataset, uid)
155
+
156
+ self.utt2one_hot_path[utt] = os.path.join(
157
+ cfg.preprocess.processed_dir,
158
+ dataset,
159
+ cfg.preprocess.one_hot_dir,
160
+ uid + ".npy",
161
+ )
162
+
163
+ if cfg.preprocess.use_text or cfg.preprocess.use_phone:
164
+ self.utt2seq = {}
165
+ for utt_info in self.metadata:
166
+ dataset = utt_info["Dataset"]
167
+ uid = utt_info["Uid"]
168
+ utt = "{}_{}".format(dataset, uid)
169
+
170
+ if cfg.preprocess.use_text:
171
+ text = utt_info["Text"]
172
+ sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
173
+ elif cfg.preprocess.use_phone:
174
+ # load phoneme squence from phone file
175
+ phone_path = os.path.join(
176
+ processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
177
+ )
178
+ with open(phone_path, "r") as fin:
179
+ phones = fin.readlines()
180
+ assert len(phones) == 1
181
+ phones = phones[0].strip()
182
+ phones_seq = phones.split(" ")
183
+
184
+ phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
185
+ sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
186
+
187
+ self.utt2seq[utt] = sequence
188
+
189
+ def get_metadata(self):
190
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
191
+ metadata = json.load(f)
192
+
193
+ return metadata
194
+
195
+ def get_dataset_name(self):
196
+ return self.metadata[0]["Dataset"]
197
+
198
+ def __getitem__(self, index):
199
+ utt_info = self.metadata[index]
200
+
201
+ dataset = utt_info["Dataset"]
202
+ uid = utt_info["Uid"]
203
+ utt = "{}_{}".format(dataset, uid)
204
+
205
+ single_feature = dict()
206
+
207
+ if self.cfg.preprocess.use_spkid:
208
+ single_feature["spk_id"] = np.array(
209
+ [self.spk2id[self.utt2spk[utt]]], dtype=np.int32
210
+ )
211
+
212
+ if self.cfg.preprocess.use_mel:
213
+ mel = np.load(self.utt2mel_path[utt])
214
+ assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
215
+ if self.cfg.preprocess.use_min_max_norm_mel:
216
+ # do mel norm
217
+ mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)
218
+
219
+ if "target_len" not in single_feature.keys():
220
+ single_feature["target_len"] = mel.shape[1]
221
+ single_feature["mel"] = mel.T # [T, n_mels]
222
+
223
+ if self.cfg.preprocess.use_linear:
224
+ linear = np.load(self.utt2linear_path[utt])
225
+ if "target_len" not in single_feature.keys():
226
+ single_feature["target_len"] = linear.shape[1]
227
+ single_feature["linear"] = linear.T # [T, n_linear]
228
+
229
+ if self.cfg.preprocess.use_frame_pitch:
230
+ frame_pitch_path = self.utt2frame_pitch_path[utt]
231
+ frame_pitch = np.load(frame_pitch_path)
232
+ if "target_len" not in single_feature.keys():
233
+ single_feature["target_len"] = len(frame_pitch)
234
+ aligned_frame_pitch = align_length(
235
+ frame_pitch, single_feature["target_len"]
236
+ )
237
+ single_feature["frame_pitch"] = aligned_frame_pitch
238
+
239
+ if self.cfg.preprocess.use_uv:
240
+ frame_uv_path = self.utt2uv_path[utt]
241
+ frame_uv = np.load(frame_uv_path)
242
+ aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
243
+ aligned_frame_uv = [
244
+ 0 if frame_uv else 1 for frame_uv in aligned_frame_uv
245
+ ]
246
+ aligned_frame_uv = np.array(aligned_frame_uv)
247
+ single_feature["frame_uv"] = aligned_frame_uv
248
+
249
+ if self.cfg.preprocess.use_frame_energy:
250
+ frame_energy_path = self.utt2frame_energy_path[utt]
251
+ frame_energy = np.load(frame_energy_path)
252
+ if "target_len" not in single_feature.keys():
253
+ single_feature["target_len"] = len(frame_energy)
254
+ aligned_frame_energy = align_length(
255
+ frame_energy, single_feature["target_len"]
256
+ )
257
+ single_feature["frame_energy"] = aligned_frame_energy
258
+
259
+ if self.cfg.preprocess.use_audio:
260
+ audio = np.load(self.utt2audio_path[utt])
261
+ single_feature["audio"] = audio
262
+ single_feature["audio_len"] = audio.shape[0]
263
+
264
+ if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
265
+ single_feature["phone_seq"] = np.array(self.utt2seq[utt])
266
+ single_feature["phone_len"] = len(self.utt2seq[utt])
267
+
268
+ return single_feature
269
+
270
+ def __len__(self):
271
+ return len(self.metadata)
272
+
273
+
274
+ class BaseOfflineCollator(object):
275
+ """Zero-pads model inputs and targets based on number of frames per step"""
276
+
277
+ def __init__(self, cfg):
278
+ self.cfg = cfg
279
+
280
+ def __call__(self, batch):
281
+ packed_batch_features = dict()
282
+
283
+ # mel: [b, T, n_mels]
284
+ # frame_pitch, frame_energy: [1, T]
285
+ # target_len: [b]
286
+ # spk_id: [b, 1]
287
+ # mask: [b, T, 1]
288
+
289
+ for key in batch[0].keys():
290
+ if key == "target_len":
291
+ packed_batch_features["target_len"] = torch.LongTensor(
292
+ [b["target_len"] for b in batch]
293
+ )
294
+ masks = [
295
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
296
+ ]
297
+ packed_batch_features["mask"] = pad_sequence(
298
+ masks, batch_first=True, padding_value=0
299
+ )
300
+ elif key == "phone_len":
301
+ packed_batch_features["phone_len"] = torch.LongTensor(
302
+ [b["phone_len"] for b in batch]
303
+ )
304
+ masks = [
305
+ torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
306
+ ]
307
+ packed_batch_features["phn_mask"] = pad_sequence(
308
+ masks, batch_first=True, padding_value=0
309
+ )
310
+ elif key == "audio_len":
311
+ packed_batch_features["audio_len"] = torch.LongTensor(
312
+ [b["audio_len"] for b in batch]
313
+ )
314
+ masks = [
315
+ torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
316
+ ]
317
+ else:
318
+ values = [torch.from_numpy(b[key]) for b in batch]
319
+ packed_batch_features[key] = pad_sequence(
320
+ values, batch_first=True, padding_value=0
321
+ )
322
+ return packed_batch_features
323
+
324
+
325
+ class BaseOnlineDataset(torch.utils.data.Dataset):
326
+ def __init__(self, cfg, dataset, is_valid=False):
327
+ """
328
+ Args:
329
+ cfg: config
330
+ dataset: dataset name
331
+ is_valid: whether to use train or valid dataset
332
+ """
333
+ assert isinstance(dataset, str)
334
+
335
+ self.cfg = cfg
336
+ self.sample_rate = cfg.preprocess.sample_rate
337
+ self.hop_size = self.cfg.preprocess.hop_size
338
+
339
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
340
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
341
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
342
+ self.metadata = self.get_metadata()
343
+
344
+ """
345
+ load spk2id and utt2spk from json file
346
+ spk2id: {spk1: 0, spk2: 1, ...}
347
+ utt2spk: {dataset_uid: spk1, ...}
348
+ """
349
+ if cfg.preprocess.use_spkid:
350
+ spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
351
+ with open(spk2id_path, "r") as f:
352
+ self.spk2id = json.load(f)
353
+
354
+ utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
355
+ self.utt2spk = dict()
356
+ with open(utt2spk_path, "r") as f:
357
+ for line in f.readlines():
358
+ utt, spk = line.strip().split("\t")
359
+ self.utt2spk[utt] = spk
360
+
361
+ def get_metadata(self):
362
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
363
+ metadata = json.load(f)
364
+
365
+ return metadata
366
+
367
+ def get_dataset_name(self):
368
+ return self.metadata[0]["Dataset"]
369
+
370
+ def __getitem__(self, index):
371
+ """
372
+ single_feature:
373
+ wav: (T)
374
+ wav_len: int
375
+ target_len: int
376
+ mask: (n_frames, 1)
377
+ spk_id: (1)
378
+ """
379
+ utt_item = self.metadata[index]
380
+
381
+ wav_path = utt_item["Path"]
382
+ wav, _ = librosa.load(wav_path, sr=self.sample_rate)
383
+ # wav: (T)
384
+ wav = torch.as_tensor(wav, dtype=torch.float32)
385
+ wav_len = len(wav)
386
+ # mask: (n_frames, 1)
387
+ frame_len = wav_len // self.hop_size
388
+ mask = torch.ones(frame_len, 1, dtype=torch.long)
389
+
390
+ single_feature = {
391
+ "wav": wav,
392
+ "wav_len": wav_len,
393
+ "target_len": frame_len,
394
+ "mask": mask,
395
+ }
396
+
397
+ if self.cfg.preprocess.use_spkid:
398
+ utt = "{}_{}".format(utt_item["Dataset"], utt_item["Uid"])
399
+ single_feature["spk_id"] = torch.tensor(
400
+ [self.spk2id[self.utt2spk[utt]]], dtype=torch.int32
401
+ )
402
+
403
+ return single_feature
404
+
405
+ def __len__(self):
406
+ return len(self.metadata)
407
+
408
+
409
+ class BaseOnlineCollator(object):
410
+ """Zero-pads model inputs and targets based on number of frames per step (For on-the-fly features extraction, whose iterative item contains only wavs)"""
411
+
412
+ def __init__(self, cfg):
413
+ self.cfg = cfg
414
+
415
+ def __call__(self, batch):
416
+ """
417
+ BaseOnlineDataset.__getitem__:
418
+ wav: (T,)
419
+ wav_len: int
420
+ target_len: int
421
+ mask: (n_frames, 1)
422
+ spk_id: (1)
423
+
424
+ Returns:
425
+ wav: (B, T), torch.float32
426
+ wav_len: (B), torch.long
427
+ target_len: (B), torch.long
428
+ mask: (B, n_frames, 1), torch.long
429
+ spk_id: (B, 1), torch.int32
430
+ """
431
+ packed_batch_features = dict()
432
+
433
+ for key in batch[0].keys():
434
+ if key in ["wav_len", "target_len"]:
435
+ packed_batch_features[key] = torch.LongTensor([b[key] for b in batch])
436
+ else:
437
+ packed_batch_features[key] = pad_sequence(
438
+ [b[key] for b in batch], batch_first=True, padding_value=0
439
+ )
440
+ return packed_batch_features
441
+
442
+
443
+ class BaseTestDataset(torch.utils.data.Dataset):
444
+ def __init__(self, cfg, args):
445
+ raise NotImplementedError
446
+
447
+ def get_metadata(self):
448
+ raise NotImplementedError
449
+
450
+ def __getitem__(self, index):
451
+ raise NotImplementedError
452
+
453
+ def __len__(self):
454
+ return len(self.metadata)
455
+
456
+
457
+ class BaseTestCollator(object):
458
+ """Zero-pads model inputs and targets based on number of frames per step"""
459
+
460
+ def __init__(self, cfg):
461
+ raise NotImplementedError
462
+
463
+ def __call__(self, batch):
464
+ raise NotImplementedError
Amphion/models/base/base_inference.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import os
8
+ import re
9
+ import time
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+ from tqdm import tqdm
15
+
16
+ from models.vocoders.vocoder_inference import synthesis
17
+ from torch.utils.data import DataLoader
18
+ from utils.util import set_all_random_seed
19
+ from utils.util import load_config
20
+
21
+
22
+ def parse_vocoder(vocoder_dir):
23
+ r"""Parse vocoder config"""
24
+ vocoder_dir = os.path.abspath(vocoder_dir)
25
+ ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
26
+ ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
27
+ ckpt_path = str(ckpt_list[0])
28
+ vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True)
29
+ vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder
30
+ return vocoder_cfg, ckpt_path
31
+
32
+
33
+ class BaseInference(object):
34
+ def __init__(self, cfg, args):
35
+ self.cfg = cfg
36
+ self.args = args
37
+ self.model_type = cfg.model_type
38
+ self.avg_rtf = list()
39
+ set_all_random_seed(10086)
40
+ os.makedirs(args.output_dir, exist_ok=True)
41
+
42
+ if torch.cuda.is_available():
43
+ self.device = torch.device("cuda")
44
+ else:
45
+ self.device = torch.device("cpu")
46
+ torch.set_num_threads(10) # inference on 1 core cpu.
47
+
48
+ # Load acoustic model
49
+ self.model = self.create_model().to(self.device)
50
+ state_dict = self.load_state_dict()
51
+ self.load_model(state_dict)
52
+ self.model.eval()
53
+
54
+ # Load vocoder model if necessary
55
+ if self.args.checkpoint_dir_vocoder is not None:
56
+ self.get_vocoder_info()
57
+
58
+ def create_model(self):
59
+ raise NotImplementedError
60
+
61
+ def load_state_dict(self):
62
+ self.checkpoint_file = self.args.checkpoint_file
63
+ if self.checkpoint_file is None:
64
+ assert self.args.checkpoint_dir is not None
65
+ checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint")
66
+ checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
67
+ self.checkpoint_file = os.path.join(
68
+ self.args.checkpoint_dir, checkpoint_filename
69
+ )
70
+
71
+ self.checkpoint_dir = os.path.split(self.checkpoint_file)[0]
72
+
73
+ print("Restore acoustic model from {}".format(self.checkpoint_file))
74
+ raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device)
75
+ self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0]
76
+
77
+ return raw_state_dict
78
+
79
+ def load_model(self, model):
80
+ raise NotImplementedError
81
+
82
+ def get_vocoder_info(self):
83
+ self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder
84
+ self.vocoder_cfg = os.path.join(
85
+ os.path.dirname(self.checkpoint_dir_vocoder), "args.json"
86
+ )
87
+ self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True)
88
+ self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1]
89
+ self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0]
90
+
91
+ def build_test_utt_data(self):
92
+ raise NotImplementedError
93
+
94
+ def build_testdata_loader(self, args, target_speaker=None):
95
+ datasets, collate = self.build_test_dataset()
96
+ self.test_dataset = datasets(self.cfg, args, target_speaker)
97
+ self.test_collate = collate(self.cfg)
98
+ self.test_batch_size = min(
99
+ self.cfg.train.batch_size, len(self.test_dataset.metadata)
100
+ )
101
+ test_loader = DataLoader(
102
+ self.test_dataset,
103
+ collate_fn=self.test_collate,
104
+ num_workers=self.args.num_workers,
105
+ batch_size=self.test_batch_size,
106
+ shuffle=False,
107
+ )
108
+ return test_loader
109
+
110
+ def inference_each_batch(self, batch_data):
111
+ raise NotImplementedError
112
+
113
+ def inference_for_batches(self, args, target_speaker=None):
114
+ ###### Construct test_batch ######
115
+ loader = self.build_testdata_loader(args, target_speaker)
116
+
117
+ n_batch = len(loader)
118
+ now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
119
+ print(
120
+ "Model eval time: {}, batch_size = {}, n_batch = {}".format(
121
+ now, self.test_batch_size, n_batch
122
+ )
123
+ )
124
+ self.model.eval()
125
+
126
+ ###### Inference for each batch ######
127
+ pred_res = []
128
+ with torch.no_grad():
129
+ for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)):
130
+ # Put the data to device
131
+ for k, v in batch_data.items():
132
+ batch_data[k] = batch_data[k].to(self.device)
133
+
134
+ y_pred, stats = self.inference_each_batch(batch_data)
135
+
136
+ pred_res += y_pred
137
+
138
+ return pred_res
139
+
140
+ def inference(self, feature):
141
+ raise NotImplementedError
142
+
143
+ def synthesis_by_vocoder(self, pred):
144
+ audios_pred = synthesis(
145
+ self.vocoder_cfg,
146
+ self.checkpoint_dir_vocoder,
147
+ len(pred),
148
+ pred,
149
+ )
150
+ return audios_pred
151
+
152
+ def __call__(self, utt):
153
+ feature = self.build_test_utt_data(utt)
154
+ start_time = time.time()
155
+ with torch.no_grad():
156
+ outputs = self.inference(feature)[0]
157
+ time_used = time.time() - start_time
158
+ rtf = time_used / (
159
+ outputs.shape[1]
160
+ * self.cfg.preprocess.hop_size
161
+ / self.cfg.preprocess.sample_rate
162
+ )
163
+ print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf))
164
+ self.avg_rtf.append(rtf)
165
+ audios = outputs.cpu().squeeze().numpy().reshape(-1, 1)
166
+ return audios
167
+
168
+
169
+ def base_parser():
170
+ parser = argparse.ArgumentParser()
171
+ parser.add_argument(
172
+ "--config", default="config.json", help="json files for configurations."
173
+ )
174
+ parser.add_argument("--use_ddp_inference", default=False)
175
+ parser.add_argument("--n_workers", default=1, type=int)
176
+ parser.add_argument("--local_rank", default=-1, type=int)
177
+ parser.add_argument(
178
+ "--batch_size", default=1, type=int, help="Batch size for inference"
179
+ )
180
+ parser.add_argument(
181
+ "--num_workers",
182
+ default=1,
183
+ type=int,
184
+ help="Worker number for inference dataloader",
185
+ )
186
+ parser.add_argument(
187
+ "--checkpoint_dir",
188
+ type=str,
189
+ default=None,
190
+ help="Checkpoint dir including model file and configuration",
191
+ )
192
+ parser.add_argument(
193
+ "--checkpoint_file", help="checkpoint file", type=str, default=None
194
+ )
195
+ parser.add_argument(
196
+ "--test_list", help="test utterance list for testing", type=str, default=None
197
+ )
198
+ parser.add_argument(
199
+ "--checkpoint_dir_vocoder",
200
+ help="Vocoder's checkpoint dir including model file and configuration",
201
+ type=str,
202
+ default=None,
203
+ )
204
+ parser.add_argument(
205
+ "--output_dir",
206
+ type=str,
207
+ default=None,
208
+ help="Output dir for saving generated results",
209
+ )
210
+ return parser
211
+
212
+
213
+ if __name__ == "__main__":
214
+ parser = base_parser()
215
+ args = parser.parse_args()
216
+ cfg = load_config(args.config)
217
+
218
+ # Build inference
219
+ inference = BaseInference(cfg, args)
220
+ inference()
Amphion/models/base/new_inference.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import random
8
+ import re
9
+ import time
10
+ from abc import abstractmethod
11
+ from pathlib import Path
12
+
13
+ import accelerate
14
+ import json5
15
+ import numpy as np
16
+ import torch
17
+ from accelerate.logging import get_logger
18
+ from torch.utils.data import DataLoader
19
+
20
+ from models.vocoders.vocoder_inference import synthesis
21
+ from utils.io import save_audio
22
+ from utils.util import load_config
23
+ from utils.audio_slicer import is_silence
24
+
25
+ EPS = 1.0e-12
26
+
27
+
28
+ class BaseInference(object):
29
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
30
+ super().__init__()
31
+
32
+ start = time.monotonic_ns()
33
+ self.args = args
34
+ self.cfg = cfg
35
+
36
+ assert infer_type in ["from_dataset", "from_file"]
37
+ self.infer_type = infer_type
38
+
39
+ # init with accelerate
40
+ self.accelerator = accelerate.Accelerator()
41
+ self.accelerator.wait_for_everyone()
42
+
43
+ # Use accelerate logger for distributed inference
44
+ with self.accelerator.main_process_first():
45
+ self.logger = get_logger("inference", log_level=args.log_level)
46
+
47
+ # Log some info
48
+ self.logger.info("=" * 56)
49
+ self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
50
+ self.logger.info("=" * 56)
51
+ self.logger.info("\n")
52
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
53
+
54
+ self.acoustics_dir = args.acoustics_dir
55
+ self.logger.debug(f"Acoustic dir: {args.acoustics_dir}")
56
+ self.vocoder_dir = args.vocoder_dir
57
+ self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
58
+ # should be in svc inferencer
59
+ # self.target_singer = args.target_singer
60
+ # self.logger.info(f"Target singers: {args.target_singer}")
61
+ # self.trans_key = args.trans_key
62
+ # self.logger.info(f"Trans key: {args.trans_key}")
63
+
64
+ os.makedirs(args.output_dir, exist_ok=True)
65
+
66
+ # set random seed
67
+ with self.accelerator.main_process_first():
68
+ start = time.monotonic_ns()
69
+ self._set_random_seed(self.cfg.train.random_seed)
70
+ end = time.monotonic_ns()
71
+ self.logger.debug(
72
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
73
+ )
74
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
75
+
76
+ # setup data_loader
77
+ with self.accelerator.main_process_first():
78
+ self.logger.info("Building dataset...")
79
+ start = time.monotonic_ns()
80
+ self.test_dataloader = self._build_dataloader()
81
+ end = time.monotonic_ns()
82
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
83
+
84
+ # setup model
85
+ with self.accelerator.main_process_first():
86
+ self.logger.info("Building model...")
87
+ start = time.monotonic_ns()
88
+ self.model = self._build_model()
89
+ end = time.monotonic_ns()
90
+ # self.logger.debug(self.model)
91
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
92
+
93
+ # init with accelerate
94
+ self.logger.info("Initializing accelerate...")
95
+ start = time.monotonic_ns()
96
+ self.accelerator = accelerate.Accelerator()
97
+ self.model = self.accelerator.prepare(self.model)
98
+ end = time.monotonic_ns()
99
+ self.accelerator.wait_for_everyone()
100
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
101
+
102
+ with self.accelerator.main_process_first():
103
+ self.logger.info("Loading checkpoint...")
104
+ start = time.monotonic_ns()
105
+ # TODO: Also, suppose only use latest one yet
106
+ self.__load_model(os.path.join(args.acoustics_dir, "checkpoint"))
107
+ end = time.monotonic_ns()
108
+ self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
109
+
110
+ self.model.eval()
111
+ self.accelerator.wait_for_everyone()
112
+
113
+ ### Abstract methods ###
114
+ @abstractmethod
115
+ def _build_test_dataset(self):
116
+ pass
117
+
118
+ @abstractmethod
119
+ def _build_model(self):
120
+ pass
121
+
122
+ @abstractmethod
123
+ @torch.inference_mode()
124
+ def _inference_each_batch(self, batch_data):
125
+ pass
126
+
127
+ ### Abstract methods end ###
128
+
129
+ @torch.inference_mode()
130
+ def inference(self):
131
+ for i, batch in enumerate(self.test_dataloader):
132
+ y_pred = self._inference_each_batch(batch).cpu()
133
+
134
+ # Judge whether the min-max normliazation is used
135
+ if self.cfg.preprocess.use_min_max_norm_mel:
136
+ mel_min, mel_max = self.test_dataset.target_mel_extrema
137
+ y_pred = (y_pred + 1.0) / 2.0 * (mel_max - mel_min + EPS) + mel_min
138
+
139
+ y_ls = y_pred.chunk(self.test_batch_size)
140
+ tgt_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
141
+ j = 0
142
+ for it, l in zip(y_ls, tgt_ls):
143
+ l = l.item()
144
+ it = it.squeeze(0)[:l]
145
+ uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
146
+ torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
147
+ j += 1
148
+
149
+ vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
150
+
151
+ res = synthesis(
152
+ cfg=vocoder_cfg,
153
+ vocoder_weight_file=vocoder_ckpt,
154
+ n_samples=None,
155
+ pred=[
156
+ torch.load(
157
+ os.path.join(self.args.output_dir, "{}.pt".format(i["Uid"]))
158
+ ).numpy(force=True)
159
+ for i in self.test_dataset.metadata
160
+ ],
161
+ )
162
+
163
+ output_audio_files = []
164
+ for it, wav in zip(self.test_dataset.metadata, res):
165
+ uid = it["Uid"]
166
+ file = os.path.join(self.args.output_dir, f"{uid}.wav")
167
+ output_audio_files.append(file)
168
+
169
+ wav = wav.numpy(force=True)
170
+ save_audio(
171
+ file,
172
+ wav,
173
+ self.cfg.preprocess.sample_rate,
174
+ add_silence=False,
175
+ turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
176
+ )
177
+ os.remove(os.path.join(self.args.output_dir, f"{uid}.pt"))
178
+
179
+ return sorted(output_audio_files)
180
+
181
+ # TODO: LEGACY CODE
182
+ def _build_dataloader(self):
183
+ datasets, collate = self._build_test_dataset()
184
+ self.test_dataset = datasets(self.args, self.cfg, self.infer_type)
185
+ self.test_collate = collate(self.cfg)
186
+ self.test_batch_size = min(
187
+ self.cfg.train.batch_size, len(self.test_dataset.metadata)
188
+ )
189
+ test_dataloader = DataLoader(
190
+ self.test_dataset,
191
+ collate_fn=self.test_collate,
192
+ num_workers=1,
193
+ batch_size=self.test_batch_size,
194
+ shuffle=False,
195
+ )
196
+ return test_dataloader
197
+
198
+ def __load_model(self, checkpoint_dir: str = None, checkpoint_path: str = None):
199
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
200
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
201
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
202
+ method after** ``accelerator.prepare()``.
203
+ """
204
+ if checkpoint_path is None:
205
+ ls = []
206
+ for i in Path(checkpoint_dir).iterdir():
207
+ if re.match(r"epoch-\d+_step-\d+_loss-[\d.]+", str(i.stem)):
208
+ ls.append(i)
209
+ ls.sort(
210
+ key=lambda x: int(x.stem.split("_")[-3].split("-")[-1]), reverse=True
211
+ )
212
+ checkpoint_path = ls[0]
213
+ else:
214
+ checkpoint_path = Path(checkpoint_path)
215
+ self.accelerator.load_state(str(checkpoint_path))
216
+ # set epoch and step
217
+ self.epoch = int(checkpoint_path.stem.split("_")[-3].split("-")[-1])
218
+ self.step = int(checkpoint_path.stem.split("_")[-2].split("-")[-1])
219
+ return str(checkpoint_path)
220
+
221
+ @staticmethod
222
+ def _set_random_seed(seed):
223
+ r"""Set random seed for all possible random modules."""
224
+ random.seed(seed)
225
+ np.random.seed(seed)
226
+ torch.random.manual_seed(seed)
227
+
228
+ @staticmethod
229
+ def _parse_vocoder(vocoder_dir):
230
+ r"""Parse vocoder config"""
231
+ vocoder_dir = os.path.abspath(vocoder_dir)
232
+ ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
233
+ ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
234
+ ckpt_path = str(ckpt_list[0])
235
+ vocoder_cfg = load_config(
236
+ os.path.join(vocoder_dir, "args.json"), lowercase=True
237
+ )
238
+ return vocoder_cfg, ckpt_path
239
+
240
+ @staticmethod
241
+ def __count_parameters(model):
242
+ return sum(p.numel() for p in model.parameters())
243
+
244
+ def __dump_cfg(self, path):
245
+ os.makedirs(os.path.dirname(path), exist_ok=True)
246
+ json5.dump(
247
+ self.cfg,
248
+ open(path, "w"),
249
+ indent=4,
250
+ sort_keys=True,
251
+ ensure_ascii=False,
252
+ quote_keys=True,
253
+ )
Amphion/models/codec/ns3_codec/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ from .filter import *
4
+ from .resample import *
5
+ from .act import *
Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/act.cpython-310.pyc ADDED
Binary file (1.11 kB). View file
 
Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/filter.cpython-310.pyc ADDED
Binary file (2.69 kB). View file
 
Amphion/models/codec/ns3_codec/alias_free_torch/act.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+
3
+ import torch.nn as nn
4
+ from .resample import UpSample1d, DownSample1d
5
+
6
+
7
+ class Activation1d(nn.Module):
8
+ def __init__(
9
+ self,
10
+ activation,
11
+ up_ratio: int = 2,
12
+ down_ratio: int = 2,
13
+ up_kernel_size: int = 12,
14
+ down_kernel_size: int = 12,
15
+ ):
16
+ super().__init__()
17
+ self.up_ratio = up_ratio
18
+ self.down_ratio = down_ratio
19
+ self.act = activation
20
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
21
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
22
+
23
+ # x: [B,C,T]
24
+ def forward(self, x):
25
+ x = self.upsample(x)
26
+ x = self.act(x)
27
+ x = self.downsample(x)
28
+
29
+ return x
Amphion/models/codec/ns3_codec/facodec.py ADDED
@@ -0,0 +1,1163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import torch
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+ from torch import nn, pow, sin
11
+ from torch.nn import Parameter
12
+ from torch.nn.utils import weight_norm
13
+
14
+ from .alias_free_torch import *
15
+ from .gradient_reversal import GradientReversal
16
+ from .melspec import MelSpectrogram
17
+ from .quantize import *
18
+ from .transformer import TransformerEncoder
19
+
20
+
21
+ def init_weights(m):
22
+ if isinstance(m, nn.Conv1d):
23
+ nn.init.trunc_normal_(m.weight, std=0.02)
24
+ nn.init.constant_(m.bias, 0)
25
+
26
+
27
+ def WNConv1d(*args, **kwargs):
28
+ return weight_norm(nn.Conv1d(*args, **kwargs))
29
+
30
+
31
+ def WNConvTranspose1d(*args, **kwargs):
32
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
33
+
34
+
35
+ class CNNLSTM(nn.Module):
36
+ def __init__(self, indim, outdim, head, global_pred=False):
37
+ super().__init__()
38
+ self.global_pred = global_pred
39
+ self.model = nn.Sequential(
40
+ ResidualUnit(indim, dilation=1),
41
+ ResidualUnit(indim, dilation=2),
42
+ ResidualUnit(indim, dilation=3),
43
+ Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
44
+ Rearrange("b c t -> b t c"),
45
+ )
46
+ self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
47
+
48
+ def forward(self, x):
49
+ # x: [B, C, T]
50
+ x = self.model(x)
51
+ if self.global_pred:
52
+ x = torch.mean(x, dim=1, keepdim=False)
53
+ outs = [head(x) for head in self.heads]
54
+ return outs
55
+
56
+
57
+ class SnakeBeta(nn.Module):
58
+ """
59
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
60
+ Shape:
61
+ - Input: (B, C, T)
62
+ - Output: (B, C, T), same shape as the input
63
+ Parameters:
64
+ - alpha - trainable parameter that controls frequency
65
+ - beta - trainable parameter that controls magnitude
66
+ References:
67
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
68
+ https://arxiv.org/abs/2006.08195
69
+ Examples:
70
+ >>> a1 = snakebeta(256)
71
+ >>> x = torch.randn(256)
72
+ >>> x = a1(x)
73
+ """
74
+
75
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
76
+ """
77
+ Initialization.
78
+ INPUT:
79
+ - in_features: shape of the input
80
+ - alpha - trainable parameter that controls frequency
81
+ - beta - trainable parameter that controls magnitude
82
+ alpha is initialized to 1 by default, higher values = higher-frequency.
83
+ beta is initialized to 1 by default, higher values = higher-magnitude.
84
+ alpha will be trained along with the rest of your model.
85
+ """
86
+ super(SnakeBeta, self).__init__()
87
+ self.in_features = in_features
88
+
89
+ # initialize alpha
90
+ self.alpha_logscale = alpha_logscale
91
+ if self.alpha_logscale: # log scale alphas initialized to zeros
92
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
93
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
94
+ else: # linear scale alphas initialized to ones
95
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
96
+ self.beta = Parameter(torch.ones(in_features) * alpha)
97
+
98
+ self.alpha.requires_grad = alpha_trainable
99
+ self.beta.requires_grad = alpha_trainable
100
+
101
+ self.no_div_by_zero = 0.000000001
102
+
103
+ def forward(self, x):
104
+ """
105
+ Forward pass of the function.
106
+ Applies the function to the input elementwise.
107
+ SnakeBeta := x + 1/b * sin^2 (xa)
108
+ """
109
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
110
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
111
+ if self.alpha_logscale:
112
+ alpha = torch.exp(alpha)
113
+ beta = torch.exp(beta)
114
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
115
+
116
+ return x
117
+
118
+
119
+ class ResidualUnit(nn.Module):
120
+ def __init__(self, dim: int = 16, dilation: int = 1):
121
+ super().__init__()
122
+ pad = ((7 - 1) * dilation) // 2
123
+ self.block = nn.Sequential(
124
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
125
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
126
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
127
+ WNConv1d(dim, dim, kernel_size=1),
128
+ )
129
+
130
+ def forward(self, x):
131
+ return x + self.block(x)
132
+
133
+
134
+ class EncoderBlock(nn.Module):
135
+ def __init__(self, dim: int = 16, stride: int = 1):
136
+ super().__init__()
137
+ self.block = nn.Sequential(
138
+ ResidualUnit(dim // 2, dilation=1),
139
+ ResidualUnit(dim // 2, dilation=3),
140
+ ResidualUnit(dim // 2, dilation=9),
141
+ Activation1d(activation=SnakeBeta(dim // 2, alpha_logscale=True)),
142
+ WNConv1d(
143
+ dim // 2,
144
+ dim,
145
+ kernel_size=2 * stride,
146
+ stride=stride,
147
+ padding=stride // 2 + stride % 2,
148
+ ),
149
+ )
150
+
151
+ def forward(self, x):
152
+ return self.block(x)
153
+
154
+
155
+ class FACodecEncoder(nn.Module):
156
+ def __init__(
157
+ self,
158
+ ngf=32,
159
+ up_ratios=(2, 4, 5, 5),
160
+ out_channels=1024,
161
+ ):
162
+ super().__init__()
163
+ self.hop_length = np.prod(up_ratios)
164
+ self.up_ratios = up_ratios
165
+
166
+ # Create first convolution
167
+ d_model = ngf
168
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
169
+
170
+ # Create EncoderBlocks that double channels as they downsample by `stride`
171
+ for stride in up_ratios:
172
+ d_model *= 2
173
+ self.block += [EncoderBlock(d_model, stride=stride)]
174
+
175
+ # Create last convolution
176
+ self.block += [
177
+ Activation1d(activation=SnakeBeta(d_model, alpha_logscale=True)),
178
+ WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
179
+ ]
180
+
181
+ # Wrap black into nn.Sequential
182
+ self.block = nn.Sequential(*self.block)
183
+ self.enc_dim = d_model
184
+
185
+ self.reset_parameters()
186
+
187
+ def forward(self, x):
188
+ out = self.block(x)
189
+ return out
190
+
191
+ def inference(self, x):
192
+ return self.block(x)
193
+
194
+ def remove_weight_norm(self):
195
+ """Remove weight normalization module from all of the layers."""
196
+
197
+ def _remove_weight_norm(m):
198
+ try:
199
+ torch.nn.utils.remove_weight_norm(m)
200
+ except ValueError: # this module didn't have weight norm
201
+ return
202
+
203
+ self.apply(_remove_weight_norm)
204
+
205
+ def apply_weight_norm(self):
206
+ """Apply weight normalization module from all of the layers."""
207
+
208
+ def _apply_weight_norm(m):
209
+ if isinstance(m, nn.Conv1d):
210
+ torch.nn.utils.weight_norm(m)
211
+
212
+ self.apply(_apply_weight_norm)
213
+
214
+ def reset_parameters(self):
215
+ self.apply(init_weights)
216
+
217
+
218
+ class DecoderBlock(nn.Module):
219
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
220
+ super().__init__()
221
+ self.block = nn.Sequential(
222
+ Activation1d(activation=SnakeBeta(input_dim, alpha_logscale=True)),
223
+ WNConvTranspose1d(
224
+ input_dim,
225
+ output_dim,
226
+ kernel_size=2 * stride,
227
+ stride=stride,
228
+ padding=stride // 2 + stride % 2,
229
+ output_padding=stride % 2,
230
+ ),
231
+ ResidualUnit(output_dim, dilation=1),
232
+ ResidualUnit(output_dim, dilation=3),
233
+ ResidualUnit(output_dim, dilation=9),
234
+ )
235
+
236
+ def forward(self, x):
237
+ return self.block(x)
238
+
239
+
240
+ class FACodecDecoder(nn.Module):
241
+ def __init__(
242
+ self,
243
+ in_channels=256,
244
+ upsample_initial_channel=1536,
245
+ ngf=32,
246
+ up_ratios=(5, 5, 4, 2),
247
+ vq_num_q_c=2,
248
+ vq_num_q_p=1,
249
+ vq_num_q_r=3,
250
+ vq_dim=1024,
251
+ vq_commit_weight=0.005,
252
+ vq_weight_init=False,
253
+ vq_full_commit_loss=False,
254
+ codebook_dim=8,
255
+ codebook_size_prosody=10, # true codebook size is equal to 2^codebook_size
256
+ codebook_size_content=10,
257
+ codebook_size_residual=10,
258
+ quantizer_dropout=0.0,
259
+ dropout_type="linear",
260
+ use_gr_content_f0=False,
261
+ use_gr_prosody_phone=False,
262
+ use_gr_residual_f0=False,
263
+ use_gr_residual_phone=False,
264
+ use_gr_x_timbre=False,
265
+ use_random_mask_residual=True,
266
+ prob_random_mask_residual=0.75,
267
+ ):
268
+ super().__init__()
269
+ self.hop_length = np.prod(up_ratios)
270
+ self.ngf = ngf
271
+ self.up_ratios = up_ratios
272
+
273
+ self.use_random_mask_residual = use_random_mask_residual
274
+ self.prob_random_mask_residual = prob_random_mask_residual
275
+
276
+ self.vq_num_q_p = vq_num_q_p
277
+ self.vq_num_q_c = vq_num_q_c
278
+ self.vq_num_q_r = vq_num_q_r
279
+
280
+ self.codebook_size_prosody = codebook_size_prosody
281
+ self.codebook_size_content = codebook_size_content
282
+ self.codebook_size_residual = codebook_size_residual
283
+
284
+ quantizer_class = ResidualVQ
285
+
286
+ self.quantizer = nn.ModuleList()
287
+
288
+ # prosody
289
+ quantizer = quantizer_class(
290
+ num_quantizers=vq_num_q_p,
291
+ dim=vq_dim,
292
+ codebook_size=codebook_size_prosody,
293
+ codebook_dim=codebook_dim,
294
+ threshold_ema_dead_code=2,
295
+ commitment=vq_commit_weight,
296
+ weight_init=vq_weight_init,
297
+ full_commit_loss=vq_full_commit_loss,
298
+ quantizer_dropout=quantizer_dropout,
299
+ dropout_type=dropout_type,
300
+ )
301
+ self.quantizer.append(quantizer)
302
+
303
+ # phone
304
+ quantizer = quantizer_class(
305
+ num_quantizers=vq_num_q_c,
306
+ dim=vq_dim,
307
+ codebook_size=codebook_size_content,
308
+ codebook_dim=codebook_dim,
309
+ threshold_ema_dead_code=2,
310
+ commitment=vq_commit_weight,
311
+ weight_init=vq_weight_init,
312
+ full_commit_loss=vq_full_commit_loss,
313
+ quantizer_dropout=quantizer_dropout,
314
+ dropout_type=dropout_type,
315
+ )
316
+ self.quantizer.append(quantizer)
317
+
318
+ # residual
319
+ if self.vq_num_q_r > 0:
320
+ quantizer = quantizer_class(
321
+ num_quantizers=vq_num_q_r,
322
+ dim=vq_dim,
323
+ codebook_size=codebook_size_residual,
324
+ codebook_dim=codebook_dim,
325
+ threshold_ema_dead_code=2,
326
+ commitment=vq_commit_weight,
327
+ weight_init=vq_weight_init,
328
+ full_commit_loss=vq_full_commit_loss,
329
+ quantizer_dropout=quantizer_dropout,
330
+ dropout_type=dropout_type,
331
+ )
332
+ self.quantizer.append(quantizer)
333
+
334
+ # Add first conv layer
335
+ channels = upsample_initial_channel
336
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
337
+
338
+ # Add upsampling + MRF blocks
339
+ for i, stride in enumerate(up_ratios):
340
+ input_dim = channels // 2**i
341
+ output_dim = channels // 2 ** (i + 1)
342
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
343
+
344
+ # Add final conv layer
345
+ layers += [
346
+ Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
347
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
348
+ nn.Tanh(),
349
+ ]
350
+
351
+ self.model = nn.Sequential(*layers)
352
+
353
+ self.timbre_encoder = TransformerEncoder(
354
+ enc_emb_tokens=None,
355
+ encoder_layer=4,
356
+ encoder_hidden=256,
357
+ encoder_head=4,
358
+ conv_filter_size=1024,
359
+ conv_kernel_size=5,
360
+ encoder_dropout=0.1,
361
+ use_cln=False,
362
+ )
363
+
364
+ self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
365
+ self.timbre_linear.bias.data[:in_channels] = 1
366
+ self.timbre_linear.bias.data[in_channels:] = 0
367
+ self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
368
+
369
+ self.f0_predictor = CNNLSTM(in_channels, 1, 2)
370
+ self.phone_predictor = CNNLSTM(in_channels, 5003, 1)
371
+
372
+ self.use_gr_content_f0 = use_gr_content_f0
373
+ self.use_gr_prosody_phone = use_gr_prosody_phone
374
+ self.use_gr_residual_f0 = use_gr_residual_f0
375
+ self.use_gr_residual_phone = use_gr_residual_phone
376
+ self.use_gr_x_timbre = use_gr_x_timbre
377
+
378
+ if self.vq_num_q_r > 0 and self.use_gr_residual_f0:
379
+ self.res_f0_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2))
380
+
381
+ if self.vq_num_q_r > 0 and self.use_gr_residual_phone > 0:
382
+ self.res_phone_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1))
383
+
384
+ if self.use_gr_content_f0:
385
+ self.content_f0_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2))
386
+
387
+ if self.use_gr_prosody_phone:
388
+ self.prosody_phone_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1))
389
+
390
+ if self.use_gr_x_timbre:
391
+ self.x_timbre_predictor = nn.Sequential(
392
+ GradientReversal(alpha=1),
393
+ CNNLSTM(in_channels, 245200, 1, global_pred=True),
394
+ )
395
+
396
+ self.reset_parameters()
397
+
398
+ def quantize(self, x, n_quantizers=None):
399
+ outs, qs, commit_loss, quantized_buf = 0, [], [], []
400
+
401
+ # prosody
402
+ f0_input = x # (B, d, T)
403
+ f0_quantizer = self.quantizer[0]
404
+ out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers)
405
+ outs += out
406
+ qs.append(q)
407
+ quantized_buf.append(quantized.sum(0))
408
+ commit_loss.append(commit)
409
+
410
+ # phone
411
+ phone_input = x
412
+ phone_quantizer = self.quantizer[1]
413
+ out, q, commit, quantized = phone_quantizer(phone_input, n_quantizers=n_quantizers)
414
+ outs += out
415
+ qs.append(q)
416
+ quantized_buf.append(quantized.sum(0))
417
+ commit_loss.append(commit)
418
+
419
+ # residual
420
+ if self.vq_num_q_r > 0:
421
+ residual_quantizer = self.quantizer[2]
422
+ residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach()
423
+ out, q, commit, quantized = residual_quantizer(residual_input, n_quantizers=n_quantizers)
424
+ outs += out
425
+ qs.append(q)
426
+ quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T]
427
+ commit_loss.append(commit)
428
+
429
+ qs = torch.cat(qs, dim=0)
430
+ commit_loss = torch.cat(commit_loss, dim=0)
431
+ return outs, qs, commit_loss, quantized_buf
432
+
433
+ def forward(
434
+ self,
435
+ x,
436
+ vq=True,
437
+ get_vq=False,
438
+ eval_vq=True,
439
+ speaker_embedding=None,
440
+ n_quantizers=None,
441
+ quantized=None,
442
+ ):
443
+ if get_vq:
444
+ return self.quantizer.get_emb()
445
+ if vq is True:
446
+ if eval_vq:
447
+ self.quantizer.eval()
448
+ x_timbre = x
449
+ outs, qs, commit_loss, quantized_buf = self.quantize(x, n_quantizers=n_quantizers)
450
+
451
+ x_timbre = x_timbre.transpose(1, 2)
452
+ x_timbre = self.timbre_encoder(x_timbre, None, None)
453
+ x_timbre = x_timbre.transpose(1, 2)
454
+ spk_embs = torch.mean(x_timbre, dim=2)
455
+ return outs, qs, commit_loss, quantized_buf, spk_embs
456
+
457
+ out = {}
458
+
459
+ layer_0 = quantized[0]
460
+ f0, uv = self.f0_predictor(layer_0)
461
+ f0 = rearrange(f0, "... 1 -> ...")
462
+ uv = rearrange(uv, "... 1 -> ...")
463
+
464
+ layer_1 = quantized[1]
465
+ (phone,) = self.phone_predictor(layer_1)
466
+
467
+ out = {"f0": f0, "uv": uv, "phone": phone}
468
+
469
+ if self.use_gr_prosody_phone:
470
+ (prosody_phone,) = self.prosody_phone_predictor(layer_0)
471
+ out["prosody_phone"] = prosody_phone
472
+
473
+ if self.use_gr_content_f0:
474
+ content_f0, content_uv = self.content_f0_predictor(layer_1)
475
+ content_f0 = rearrange(content_f0, "... 1 -> ...")
476
+ content_uv = rearrange(content_uv, "... 1 -> ...")
477
+ out["content_f0"] = content_f0
478
+ out["content_uv"] = content_uv
479
+
480
+ if self.vq_num_q_r > 0:
481
+ layer_2 = quantized[2]
482
+
483
+ if self.use_gr_residual_f0:
484
+ res_f0, res_uv = self.res_f0_predictor(layer_2)
485
+ res_f0 = rearrange(res_f0, "... 1 -> ...")
486
+ res_uv = rearrange(res_uv, "... 1 -> ...")
487
+ out["res_f0"] = res_f0
488
+ out["res_uv"] = res_uv
489
+
490
+ if self.use_gr_residual_phone:
491
+ (res_phone,) = self.res_phone_predictor(layer_2)
492
+ out["res_phone"] = res_phone
493
+
494
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
495
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
496
+ if self.vq_num_q_r > 0:
497
+ if self.use_random_mask_residual:
498
+ bsz = quantized[2].shape[0]
499
+ res_mask = np.random.choice(
500
+ [0, 1],
501
+ size=bsz,
502
+ p=[
503
+ self.prob_random_mask_residual,
504
+ 1 - self.prob_random_mask_residual,
505
+ ],
506
+ )
507
+ res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
508
+ res_mask = res_mask.to(device=quantized[2].device, dtype=quantized[2].dtype)
509
+ x = quantized[0].detach() + quantized[1].detach() + quantized[2] * res_mask
510
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] * res_mask
511
+ else:
512
+ x = quantized[0].detach() + quantized[1].detach() + quantized[2]
513
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2]
514
+ else:
515
+ x = quantized[0].detach() + quantized[1].detach()
516
+ # x = quantized_perturbe[0].detach() + quantized[1].detach()
517
+
518
+ if self.use_gr_x_timbre:
519
+ (x_timbre,) = self.x_timbre_predictor(x)
520
+ out["x_timbre"] = x_timbre
521
+
522
+ x = x.transpose(1, 2)
523
+ x = self.timbre_norm(x)
524
+ x = x.transpose(1, 2)
525
+ x = x * gamma + beta
526
+
527
+ x = self.model(x)
528
+ out["audio"] = x
529
+
530
+ return out
531
+
532
+ def vq2emb(self, vq, use_residual_code=True):
533
+ # vq: [num_quantizer, B, T]
534
+ self.quantizer = self.quantizer.eval()
535
+ out = 0
536
+ out += self.quantizer[0].vq2emb(vq[0 : self.vq_num_q_p])
537
+ out += self.quantizer[1].vq2emb(vq[self.vq_num_q_p : self.vq_num_q_p + self.vq_num_q_c])
538
+ if self.vq_num_q_r > 0 and use_residual_code:
539
+ out += self.quantizer[2].vq2emb(vq[self.vq_num_q_p + self.vq_num_q_c :])
540
+ return out
541
+
542
+ def inference(self, x, speaker_embedding):
543
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
544
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
545
+ x = x.transpose(1, 2)
546
+ x = self.timbre_norm(x)
547
+ x = x.transpose(1, 2)
548
+ x = x * gamma + beta
549
+ x = self.model(x)
550
+ return x
551
+
552
+ def remove_weight_norm(self):
553
+ """Remove weight normalization module from all of the layers."""
554
+
555
+ def _remove_weight_norm(m):
556
+ try:
557
+ torch.nn.utils.remove_weight_norm(m)
558
+ except ValueError: # this module didn't have weight norm
559
+ return
560
+
561
+ self.apply(_remove_weight_norm)
562
+
563
+ def apply_weight_norm(self):
564
+ """Apply weight normalization module from all of the layers."""
565
+
566
+ def _apply_weight_norm(m):
567
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
568
+ torch.nn.utils.weight_norm(m)
569
+
570
+ self.apply(_apply_weight_norm)
571
+
572
+ def reset_parameters(self):
573
+ self.apply(init_weights)
574
+
575
+
576
+ class FACodecRedecoder(nn.Module):
577
+ def __init__(
578
+ self,
579
+ in_channels=256,
580
+ upsample_initial_channel=1280,
581
+ up_ratios=(5, 5, 4, 2),
582
+ vq_num_q_c=2,
583
+ vq_num_q_p=1,
584
+ vq_num_q_r=3,
585
+ vq_dim=256,
586
+ codebook_size_prosody=10,
587
+ codebook_size_content=10,
588
+ codebook_size_residual=10,
589
+ ):
590
+ super().__init__()
591
+ self.hop_length = np.prod(up_ratios)
592
+ self.up_ratios = up_ratios
593
+
594
+ self.vq_num_q_p = vq_num_q_p
595
+ self.vq_num_q_c = vq_num_q_c
596
+ self.vq_num_q_r = vq_num_q_r
597
+
598
+ self.vq_dim = vq_dim
599
+
600
+ self.codebook_size_prosody = codebook_size_prosody
601
+ self.codebook_size_content = codebook_size_content
602
+ self.codebook_size_residual = codebook_size_residual
603
+
604
+ self.prosody_embs = nn.ModuleList()
605
+ for i in range(self.vq_num_q_p):
606
+ emb_tokens = nn.Embedding(
607
+ num_embeddings=2**self.codebook_size_prosody,
608
+ embedding_dim=self.vq_dim,
609
+ )
610
+ emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
611
+ self.prosody_embs.append(emb_tokens)
612
+ self.content_embs = nn.ModuleList()
613
+ for i in range(self.vq_num_q_c):
614
+ emb_tokens = nn.Embedding(
615
+ num_embeddings=2**self.codebook_size_content,
616
+ embedding_dim=self.vq_dim,
617
+ )
618
+ emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
619
+ self.content_embs.append(emb_tokens)
620
+ self.residual_embs = nn.ModuleList()
621
+ for i in range(self.vq_num_q_r):
622
+ emb_tokens = nn.Embedding(
623
+ num_embeddings=2**self.codebook_size_residual,
624
+ embedding_dim=self.vq_dim,
625
+ )
626
+ emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
627
+ self.residual_embs.append(emb_tokens)
628
+
629
+ # Add first conv layer
630
+ channels = upsample_initial_channel
631
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
632
+
633
+ # Add upsampling + MRF blocks
634
+ for i, stride in enumerate(up_ratios):
635
+ input_dim = channels // 2**i
636
+ output_dim = channels // 2 ** (i + 1)
637
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
638
+
639
+ # Add final conv layer
640
+ layers += [
641
+ Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
642
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
643
+ nn.Tanh(),
644
+ ]
645
+
646
+ self.model = nn.Sequential(*layers)
647
+
648
+ self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
649
+ self.timbre_linear.bias.data[:in_channels] = 1
650
+ self.timbre_linear.bias.data[in_channels:] = 0
651
+ self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
652
+
653
+ self.timbre_cond_prosody_enc = TransformerEncoder(
654
+ enc_emb_tokens=None,
655
+ encoder_layer=4,
656
+ encoder_hidden=256,
657
+ encoder_head=4,
658
+ conv_filter_size=1024,
659
+ conv_kernel_size=5,
660
+ encoder_dropout=0.1,
661
+ use_cln=True,
662
+ cfg=None,
663
+ )
664
+
665
+ def forward(
666
+ self,
667
+ vq,
668
+ speaker_embedding,
669
+ use_residual_code=False,
670
+ ):
671
+ x = 0
672
+
673
+ x_p = 0
674
+ for i in range(self.vq_num_q_p):
675
+ x_p = x_p + self.prosody_embs[i](vq[i]) # (B, T, d)
676
+ spk_cond = speaker_embedding.unsqueeze(1).expand(-1, x_p.shape[1], -1)
677
+ x_p = self.timbre_cond_prosody_enc(x_p, key_padding_mask=None, condition=spk_cond)
678
+ x = x + x_p
679
+
680
+ x_c = 0
681
+ for i in range(self.vq_num_q_c):
682
+ x_c = x_c + self.content_embs[i](vq[self.vq_num_q_p + i])
683
+
684
+ x = x + x_c
685
+
686
+ if use_residual_code:
687
+ x_r = 0
688
+ for i in range(self.vq_num_q_r):
689
+ x_r = x_r + self.residual_embs[i](vq[self.vq_num_q_p + self.vq_num_q_c + i])
690
+ x = x + x_r
691
+
692
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
693
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
694
+ x = x.transpose(1, 2)
695
+ x = self.timbre_norm(x)
696
+ x = x.transpose(1, 2)
697
+ x = x * gamma + beta
698
+ x = self.model(x)
699
+
700
+ return x
701
+
702
+ def vq2emb(self, vq, speaker_embedding, use_residual=True):
703
+ out = 0
704
+
705
+ x_t = 0
706
+ for i in range(self.vq_num_q_p):
707
+ x_t += self.prosody_embs[i](vq[i]) # (B, T, d)
708
+ spk_cond = speaker_embedding.unsqueeze(1).expand(-1, x_t.shape[1], -1)
709
+ x_t = self.timbre_cond_prosody_enc(x_t, key_padding_mask=None, condition=spk_cond)
710
+
711
+ # prosody
712
+ out += x_t
713
+
714
+ # content
715
+ for i in range(self.vq_num_q_c):
716
+ out += self.content_embs[i](vq[self.vq_num_q_p + i])
717
+
718
+ # residual
719
+ if use_residual:
720
+ for i in range(self.vq_num_q_r):
721
+ out += self.residual_embs[i](vq[self.vq_num_q_p + self.vq_num_q_c + i])
722
+
723
+ out = out.transpose(1, 2) # (B, T, d) -> (B, d, T)
724
+ return out
725
+
726
+ def inference(self, x, speaker_embedding):
727
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
728
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
729
+ x = x.transpose(1, 2)
730
+ x = self.timbre_norm(x)
731
+ x = x.transpose(1, 2)
732
+ x = x * gamma + beta
733
+ x = self.model(x)
734
+ return x
735
+
736
+
737
+ class FACodecEncoderV2(nn.Module):
738
+ def __init__(
739
+ self,
740
+ ngf=32,
741
+ up_ratios=(2, 4, 5, 5),
742
+ out_channels=1024,
743
+ ):
744
+ super().__init__()
745
+ self.hop_length = np.prod(up_ratios)
746
+ self.up_ratios = up_ratios
747
+
748
+ # Create first convolution
749
+ d_model = ngf
750
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
751
+
752
+ # Create EncoderBlocks that double channels as they downsample by `stride`
753
+ for stride in up_ratios:
754
+ d_model *= 2
755
+ self.block += [EncoderBlock(d_model, stride=stride)]
756
+
757
+ # Create last convolution
758
+ self.block += [
759
+ Activation1d(activation=SnakeBeta(d_model, alpha_logscale=True)),
760
+ WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
761
+ ]
762
+
763
+ # Wrap black into nn.Sequential
764
+ self.block = nn.Sequential(*self.block)
765
+ self.enc_dim = d_model
766
+
767
+ self.mel_transform = MelSpectrogram(
768
+ n_fft=1024,
769
+ num_mels=80,
770
+ sampling_rate=16000,
771
+ hop_size=200,
772
+ win_size=800,
773
+ fmin=0,
774
+ fmax=8000,
775
+ )
776
+
777
+ self.reset_parameters()
778
+
779
+ def forward(self, x):
780
+ out = self.block(x)
781
+ return out
782
+
783
+ def inference(self, x):
784
+ return self.block(x)
785
+
786
+ def get_prosody_feature(self, x):
787
+ return self.mel_transform(x.squeeze(1))[:, :20, :]
788
+
789
+ def remove_weight_norm(self):
790
+ """Remove weight normalization module from all of the layers."""
791
+
792
+ def _remove_weight_norm(m):
793
+ try:
794
+ torch.nn.utils.remove_weight_norm(m)
795
+ except ValueError: # this module didn't have weight norm
796
+ return
797
+
798
+ self.apply(_remove_weight_norm)
799
+
800
+ def apply_weight_norm(self):
801
+ """Apply weight normalization module from all of the layers."""
802
+
803
+ def _apply_weight_norm(m):
804
+ if isinstance(m, nn.Conv1d):
805
+ torch.nn.utils.weight_norm(m)
806
+
807
+ self.apply(_apply_weight_norm)
808
+
809
+ def reset_parameters(self):
810
+ self.apply(init_weights)
811
+
812
+
813
+ class FACodecDecoderV2(nn.Module):
814
+ def __init__(
815
+ self,
816
+ in_channels=256,
817
+ upsample_initial_channel=1536,
818
+ ngf=32,
819
+ up_ratios=(5, 5, 4, 2),
820
+ vq_num_q_c=2,
821
+ vq_num_q_p=1,
822
+ vq_num_q_r=3,
823
+ vq_dim=1024,
824
+ vq_commit_weight=0.005,
825
+ vq_weight_init=False,
826
+ vq_full_commit_loss=False,
827
+ codebook_dim=8,
828
+ codebook_size_prosody=10, # true codebook size is equal to 2^codebook_size
829
+ codebook_size_content=10,
830
+ codebook_size_residual=10,
831
+ quantizer_dropout=0.0,
832
+ dropout_type="linear",
833
+ use_gr_content_f0=False,
834
+ use_gr_prosody_phone=False,
835
+ use_gr_residual_f0=False,
836
+ use_gr_residual_phone=False,
837
+ use_gr_x_timbre=False,
838
+ use_random_mask_residual=True,
839
+ prob_random_mask_residual=0.75,
840
+ ):
841
+ super().__init__()
842
+ self.hop_length = np.prod(up_ratios)
843
+ self.ngf = ngf
844
+ self.up_ratios = up_ratios
845
+
846
+ self.use_random_mask_residual = use_random_mask_residual
847
+ self.prob_random_mask_residual = prob_random_mask_residual
848
+
849
+ self.vq_num_q_p = vq_num_q_p
850
+ self.vq_num_q_c = vq_num_q_c
851
+ self.vq_num_q_r = vq_num_q_r
852
+
853
+ self.codebook_size_prosody = codebook_size_prosody
854
+ self.codebook_size_content = codebook_size_content
855
+ self.codebook_size_residual = codebook_size_residual
856
+
857
+ quantizer_class = ResidualVQ
858
+
859
+ self.quantizer = nn.ModuleList()
860
+
861
+ # prosody
862
+ quantizer = quantizer_class(
863
+ num_quantizers=vq_num_q_p,
864
+ dim=vq_dim,
865
+ codebook_size=codebook_size_prosody,
866
+ codebook_dim=codebook_dim,
867
+ threshold_ema_dead_code=2,
868
+ commitment=vq_commit_weight,
869
+ weight_init=vq_weight_init,
870
+ full_commit_loss=vq_full_commit_loss,
871
+ quantizer_dropout=quantizer_dropout,
872
+ dropout_type=dropout_type,
873
+ )
874
+ self.quantizer.append(quantizer)
875
+
876
+ # phone
877
+ quantizer = quantizer_class(
878
+ num_quantizers=vq_num_q_c,
879
+ dim=vq_dim,
880
+ codebook_size=codebook_size_content,
881
+ codebook_dim=codebook_dim,
882
+ threshold_ema_dead_code=2,
883
+ commitment=vq_commit_weight,
884
+ weight_init=vq_weight_init,
885
+ full_commit_loss=vq_full_commit_loss,
886
+ quantizer_dropout=quantizer_dropout,
887
+ dropout_type=dropout_type,
888
+ )
889
+ self.quantizer.append(quantizer)
890
+
891
+ # residual
892
+ if self.vq_num_q_r > 0:
893
+ quantizer = quantizer_class(
894
+ num_quantizers=vq_num_q_r,
895
+ dim=vq_dim,
896
+ codebook_size=codebook_size_residual,
897
+ codebook_dim=codebook_dim,
898
+ threshold_ema_dead_code=2,
899
+ commitment=vq_commit_weight,
900
+ weight_init=vq_weight_init,
901
+ full_commit_loss=vq_full_commit_loss,
902
+ quantizer_dropout=quantizer_dropout,
903
+ dropout_type=dropout_type,
904
+ )
905
+ self.quantizer.append(quantizer)
906
+
907
+ # Add first conv layer
908
+ channels = upsample_initial_channel
909
+ layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
910
+
911
+ # Add upsampling + MRF blocks
912
+ for i, stride in enumerate(up_ratios):
913
+ input_dim = channels // 2**i
914
+ output_dim = channels // 2 ** (i + 1)
915
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
916
+
917
+ # Add final conv layer
918
+ layers += [
919
+ Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
920
+ WNConv1d(output_dim, 1, kernel_size=7, padding=3),
921
+ nn.Tanh(),
922
+ ]
923
+
924
+ self.model = nn.Sequential(*layers)
925
+
926
+ self.timbre_encoder = TransformerEncoder(
927
+ enc_emb_tokens=None,
928
+ encoder_layer=4,
929
+ encoder_hidden=256,
930
+ encoder_head=4,
931
+ conv_filter_size=1024,
932
+ conv_kernel_size=5,
933
+ encoder_dropout=0.1,
934
+ use_cln=False,
935
+ )
936
+
937
+ self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
938
+ self.timbre_linear.bias.data[:in_channels] = 1
939
+ self.timbre_linear.bias.data[in_channels:] = 0
940
+ self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
941
+
942
+ self.f0_predictor = CNNLSTM(in_channels, 1, 2)
943
+ self.phone_predictor = CNNLSTM(in_channels, 5003, 1)
944
+
945
+ self.use_gr_content_f0 = use_gr_content_f0
946
+ self.use_gr_prosody_phone = use_gr_prosody_phone
947
+ self.use_gr_residual_f0 = use_gr_residual_f0
948
+ self.use_gr_residual_phone = use_gr_residual_phone
949
+ self.use_gr_x_timbre = use_gr_x_timbre
950
+
951
+ if self.vq_num_q_r > 0 and self.use_gr_residual_f0:
952
+ self.res_f0_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2))
953
+
954
+ if self.vq_num_q_r > 0 and self.use_gr_residual_phone > 0:
955
+ self.res_phone_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1))
956
+
957
+ if self.use_gr_content_f0:
958
+ self.content_f0_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2))
959
+
960
+ if self.use_gr_prosody_phone:
961
+ self.prosody_phone_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1))
962
+
963
+ if self.use_gr_x_timbre:
964
+ self.x_timbre_predictor = nn.Sequential(
965
+ GradientReversal(alpha=1),
966
+ CNNLSTM(in_channels, 245200, 1, global_pred=True),
967
+ )
968
+
969
+ self.melspec_linear = nn.Linear(20, 256)
970
+ self.melspec_encoder = TransformerEncoder(
971
+ enc_emb_tokens=None,
972
+ encoder_layer=4,
973
+ encoder_hidden=256,
974
+ encoder_head=4,
975
+ conv_filter_size=1024,
976
+ conv_kernel_size=5,
977
+ encoder_dropout=0.1,
978
+ use_cln=False,
979
+ cfg=None,
980
+ )
981
+
982
+ self.reset_parameters()
983
+
984
+ def quantize(self, x, prosody_feature, n_quantizers=None):
985
+ outs, qs, commit_loss, quantized_buf = 0, [], [], []
986
+
987
+ # prosody
988
+ f0_input = prosody_feature.transpose(1, 2) # (B, T, 20)
989
+ f0_input = self.melspec_linear(f0_input)
990
+ f0_input = self.melspec_encoder(f0_input, None, None)
991
+ f0_input = f0_input.transpose(1, 2)
992
+ f0_quantizer = self.quantizer[0]
993
+ out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers)
994
+ outs += out
995
+ qs.append(q)
996
+ quantized_buf.append(quantized.sum(0))
997
+ commit_loss.append(commit)
998
+
999
+ # phone
1000
+ phone_input = x
1001
+ phone_quantizer = self.quantizer[1]
1002
+ out, q, commit, quantized = phone_quantizer(phone_input, n_quantizers=n_quantizers)
1003
+ outs += out
1004
+ qs.append(q)
1005
+ quantized_buf.append(quantized.sum(0))
1006
+ commit_loss.append(commit)
1007
+
1008
+ # residual
1009
+ if self.vq_num_q_r > 0:
1010
+ residual_quantizer = self.quantizer[2]
1011
+ residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach()
1012
+ out, q, commit, quantized = residual_quantizer(residual_input, n_quantizers=n_quantizers)
1013
+ outs += out
1014
+ qs.append(q)
1015
+ quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T]
1016
+ commit_loss.append(commit)
1017
+
1018
+ qs = torch.cat(qs, dim=0)
1019
+ commit_loss = torch.cat(commit_loss, dim=0)
1020
+ return outs, qs, commit_loss, quantized_buf
1021
+
1022
+ def forward(
1023
+ self,
1024
+ x,
1025
+ prosody_feature,
1026
+ vq=True,
1027
+ get_vq=False,
1028
+ eval_vq=True,
1029
+ speaker_embedding=None,
1030
+ n_quantizers=None,
1031
+ quantized=None,
1032
+ ):
1033
+ if get_vq:
1034
+ return self.quantizer.get_emb()
1035
+ if vq is True:
1036
+ if eval_vq:
1037
+ self.quantizer.eval()
1038
+ x_timbre = x
1039
+ outs, qs, commit_loss, quantized_buf = self.quantize(x, prosody_feature, n_quantizers=n_quantizers)
1040
+
1041
+ x_timbre = x_timbre.transpose(1, 2)
1042
+ x_timbre = self.timbre_encoder(x_timbre, None, None)
1043
+ x_timbre = x_timbre.transpose(1, 2)
1044
+ spk_embs = torch.mean(x_timbre, dim=2)
1045
+ return outs, qs, commit_loss, quantized_buf, spk_embs
1046
+
1047
+ out = {}
1048
+
1049
+ layer_0 = quantized[0]
1050
+ f0, uv = self.f0_predictor(layer_0)
1051
+ f0 = rearrange(f0, "... 1 -> ...")
1052
+ uv = rearrange(uv, "... 1 -> ...")
1053
+
1054
+ layer_1 = quantized[1]
1055
+ (phone,) = self.phone_predictor(layer_1)
1056
+
1057
+ out = {"f0": f0, "uv": uv, "phone": phone}
1058
+
1059
+ if self.use_gr_prosody_phone:
1060
+ (prosody_phone,) = self.prosody_phone_predictor(layer_0)
1061
+ out["prosody_phone"] = prosody_phone
1062
+
1063
+ if self.use_gr_content_f0:
1064
+ content_f0, content_uv = self.content_f0_predictor(layer_1)
1065
+ content_f0 = rearrange(content_f0, "... 1 -> ...")
1066
+ content_uv = rearrange(content_uv, "... 1 -> ...")
1067
+ out["content_f0"] = content_f0
1068
+ out["content_uv"] = content_uv
1069
+
1070
+ if self.vq_num_q_r > 0:
1071
+ layer_2 = quantized[2]
1072
+
1073
+ if self.use_gr_residual_f0:
1074
+ res_f0, res_uv = self.res_f0_predictor(layer_2)
1075
+ res_f0 = rearrange(res_f0, "... 1 -> ...")
1076
+ res_uv = rearrange(res_uv, "... 1 -> ...")
1077
+ out["res_f0"] = res_f0
1078
+ out["res_uv"] = res_uv
1079
+
1080
+ if self.use_gr_residual_phone:
1081
+ (res_phone,) = self.res_phone_predictor(layer_2)
1082
+ out["res_phone"] = res_phone
1083
+
1084
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
1085
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
1086
+ if self.vq_num_q_r > 0:
1087
+ if self.use_random_mask_residual:
1088
+ bsz = quantized[2].shape[0]
1089
+ res_mask = np.random.choice(
1090
+ [0, 1],
1091
+ size=bsz,
1092
+ p=[
1093
+ self.prob_random_mask_residual,
1094
+ 1 - self.prob_random_mask_residual,
1095
+ ],
1096
+ )
1097
+ res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
1098
+ res_mask = res_mask.to(device=quantized[2].device, dtype=quantized[2].dtype)
1099
+ x = quantized[0].detach() + quantized[1].detach() + quantized[2] * res_mask
1100
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] * res_mask
1101
+ else:
1102
+ x = quantized[0].detach() + quantized[1].detach() + quantized[2]
1103
+ # x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2]
1104
+ else:
1105
+ x = quantized[0].detach() + quantized[1].detach()
1106
+ # x = quantized_perturbe[0].detach() + quantized[1].detach()
1107
+
1108
+ if self.use_gr_x_timbre:
1109
+ (x_timbre,) = self.x_timbre_predictor(x)
1110
+ out["x_timbre"] = x_timbre
1111
+
1112
+ x = x.transpose(1, 2)
1113
+ x = self.timbre_norm(x)
1114
+ x = x.transpose(1, 2)
1115
+ x = x * gamma + beta
1116
+
1117
+ x = self.model(x)
1118
+ out["audio"] = x
1119
+
1120
+ return out
1121
+
1122
+ def vq2emb(self, vq, use_residual=True):
1123
+ # vq: [num_quantizer, B, T]
1124
+ self.quantizer = self.quantizer.eval()
1125
+ out = 0
1126
+ out += self.quantizer[0].vq2emb(vq[0 : self.vq_num_q_p])
1127
+ out += self.quantizer[1].vq2emb(vq[self.vq_num_q_p : self.vq_num_q_p + self.vq_num_q_c])
1128
+ if self.vq_num_q_r > 0 and use_residual:
1129
+ out += self.quantizer[2].vq2emb(vq[self.vq_num_q_p + self.vq_num_q_c :])
1130
+ return out
1131
+
1132
+ def inference(self, x, speaker_embedding):
1133
+ style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
1134
+ gamma, beta = style.chunk(2, 1) # (B, d, 1)
1135
+ x = x.transpose(1, 2)
1136
+ x = self.timbre_norm(x)
1137
+ x = x.transpose(1, 2)
1138
+ x = x * gamma + beta
1139
+ x = self.model(x)
1140
+ return x
1141
+
1142
+ def remove_weight_norm(self):
1143
+ """Remove weight normalization module from all of the layers."""
1144
+
1145
+ def _remove_weight_norm(m):
1146
+ try:
1147
+ torch.nn.utils.remove_weight_norm(m)
1148
+ except ValueError: # this module didn't have weight norm
1149
+ return
1150
+
1151
+ self.apply(_remove_weight_norm)
1152
+
1153
+ def apply_weight_norm(self):
1154
+ """Apply weight normalization module from all of the layers."""
1155
+
1156
+ def _apply_weight_norm(m):
1157
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
1158
+ torch.nn.utils.weight_norm(m)
1159
+
1160
+ self.apply(_apply_weight_norm)
1161
+
1162
+ def reset_parameters(self):
1163
+ self.apply(init_weights)
Amphion/models/codec/ns3_codec/gradient_reversal.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torch.autograd import Function
7
+ import torch
8
+ from torch import nn
9
+
10
+
11
+ class GradientReversal(Function):
12
+ @staticmethod
13
+ def forward(ctx, x, alpha):
14
+ ctx.save_for_backward(x, alpha)
15
+ return x
16
+
17
+ @staticmethod
18
+ def backward(ctx, grad_output):
19
+ grad_input = None
20
+ _, alpha = ctx.saved_tensors
21
+ if ctx.needs_input_grad[0]:
22
+ grad_input = -alpha * grad_output
23
+ return grad_input, None
24
+
25
+
26
+ revgrad = GradientReversal.apply
27
+
28
+
29
+ class GradientReversal(nn.Module):
30
+ def __init__(self, alpha):
31
+ super().__init__()
32
+ self.alpha = torch.tensor(alpha, requires_grad=False)
33
+
34
+ def forward(self, x):
35
+ return revgrad(x, self.alpha)
Amphion/models/codec/ns3_codec/melspec.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pyworld as pw
3
+ import numpy as np
4
+ import soundfile as sf
5
+ import os
6
+ from torchaudio.functional import pitch_shift
7
+ import librosa
8
+ from librosa.filters import mel as librosa_mel_fn
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
14
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
15
+
16
+
17
+ def dynamic_range_decompression(x, C=1):
18
+ return np.exp(x) / C
19
+
20
+
21
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
22
+ return torch.log(torch.clamp(x, min=clip_val) * C)
23
+
24
+
25
+ def dynamic_range_decompression_torch(x, C=1):
26
+ return torch.exp(x) / C
27
+
28
+
29
+ def spectral_normalize_torch(magnitudes):
30
+ output = dynamic_range_compression_torch(magnitudes)
31
+ return output
32
+
33
+
34
+ def spectral_de_normalize_torch(magnitudes):
35
+ output = dynamic_range_decompression_torch(magnitudes)
36
+ return output
37
+
38
+
39
+ class MelSpectrogram(nn.Module):
40
+ def __init__(
41
+ self,
42
+ n_fft,
43
+ num_mels,
44
+ sampling_rate,
45
+ hop_size,
46
+ win_size,
47
+ fmin,
48
+ fmax,
49
+ center=False,
50
+ ):
51
+ super(MelSpectrogram, self).__init__()
52
+ self.n_fft = n_fft
53
+ self.hop_size = hop_size
54
+ self.win_size = win_size
55
+ self.sampling_rate = sampling_rate
56
+ self.num_mels = num_mels
57
+ self.fmin = fmin
58
+ self.fmax = fmax
59
+ self.center = center
60
+
61
+ mel_basis = {}
62
+ hann_window = {}
63
+
64
+ mel = librosa_mel_fn(
65
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
66
+ )
67
+ mel_basis = torch.from_numpy(mel).float()
68
+ hann_window = torch.hann_window(win_size)
69
+
70
+ self.register_buffer("mel_basis", mel_basis)
71
+ self.register_buffer("hann_window", hann_window)
72
+
73
+ def forward(self, y):
74
+ y = torch.nn.functional.pad(
75
+ y.unsqueeze(1),
76
+ (
77
+ int((self.n_fft - self.hop_size) / 2),
78
+ int((self.n_fft - self.hop_size) / 2),
79
+ ),
80
+ mode="reflect",
81
+ )
82
+ y = y.squeeze(1)
83
+ spec = torch.stft(
84
+ y,
85
+ self.n_fft,
86
+ hop_length=self.hop_size,
87
+ win_length=self.win_size,
88
+ window=self.hann_window,
89
+ center=self.center,
90
+ pad_mode="reflect",
91
+ normalized=False,
92
+ onesided=True,
93
+ return_complex=True,
94
+ )
95
+ spec = torch.view_as_real(spec)
96
+
97
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
98
+
99
+ spec = torch.matmul(self.mel_basis, spec)
100
+ spec = spectral_normalize_torch(spec)
101
+
102
+ return spec
Amphion/models/codec/ns3_codec/quantize/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (246 Bytes). View file
 
Amphion/models/codec/ns3_codec/quantize/fvq.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Union
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from einops import rearrange
13
+ from torch.nn.utils import weight_norm
14
+
15
+
16
+ class FactorizedVectorQuantize(nn.Module):
17
+ def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs):
18
+ super().__init__()
19
+ self.codebook_size = codebook_size
20
+ self.codebook_dim = codebook_dim
21
+ self.commitment = commitment
22
+
23
+ if dim != self.codebook_dim:
24
+ self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim))
25
+ self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim))
26
+ else:
27
+ self.in_proj = nn.Identity()
28
+ self.out_proj = nn.Identity()
29
+ self._codebook = nn.Embedding(codebook_size, self.codebook_dim)
30
+
31
+ @property
32
+ def codebook(self):
33
+ return self._codebook
34
+
35
+ def forward(self, z):
36
+ """Quantized the input tensor using a fixed codebook and returns
37
+ the corresponding codebook vectors
38
+
39
+ Parameters
40
+ ----------
41
+ z : Tensor[B x D x T]
42
+
43
+ Returns
44
+ -------
45
+ Tensor[B x D x T]
46
+ Quantized continuous representation of input
47
+ Tensor[1]
48
+ Commitment loss to train encoder to predict vectors closer to codebook
49
+ entries
50
+ Tensor[1]
51
+ Codebook loss to update the codebook
52
+ Tensor[B x T]
53
+ Codebook indices (quantized discrete representation of input)
54
+ Tensor[B x D x T]
55
+ Projected latents (continuous representation of input before quantization)
56
+ """
57
+ # transpose since we use linear
58
+
59
+ z = rearrange(z, "b d t -> b t d")
60
+
61
+ # Factorized codes project input into low-dimensional space
62
+ z_e = self.in_proj(z) # z_e : (B x T x D)
63
+ z_e = rearrange(z_e, "b t d -> b d t")
64
+ z_q, indices = self.decode_latents(z_e)
65
+
66
+ if self.training:
67
+ commitment_loss = (
68
+ F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
69
+ * self.commitment
70
+ )
71
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
72
+ commit_loss = commitment_loss + codebook_loss
73
+ else:
74
+ commit_loss = torch.zeros(z.shape[0], device=z.device)
75
+
76
+ z_q = (
77
+ z_e + (z_q - z_e).detach()
78
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
79
+
80
+ z_q = rearrange(z_q, "b d t -> b t d")
81
+ z_q = self.out_proj(z_q)
82
+ z_q = rearrange(z_q, "b t d -> b d t")
83
+
84
+ return z_q, indices, commit_loss
85
+
86
+ def vq2emb(self, vq, proj=True):
87
+ emb = self.embed_code(vq)
88
+ if proj:
89
+ emb = self.out_proj(emb)
90
+ return emb.transpose(1, 2)
91
+
92
+ def get_emb(self):
93
+ return self.codebook.weight
94
+
95
+ def embed_code(self, embed_id):
96
+ return F.embedding(embed_id, self.codebook.weight)
97
+
98
+ def decode_code(self, embed_id):
99
+ return self.embed_code(embed_id).transpose(1, 2)
100
+
101
+ def decode_latents(self, latents):
102
+ encodings = rearrange(latents, "b d t -> (b t) d")
103
+ codebook = self.codebook.weight # codebook: (N x D)
104
+ # L2 normalize encodings and codebook
105
+ encodings = F.normalize(encodings)
106
+ codebook = F.normalize(codebook)
107
+
108
+ # Compute euclidean distance with codebook
109
+ dist = (
110
+ encodings.pow(2).sum(1, keepdim=True)
111
+ - 2 * encodings @ codebook.t()
112
+ + codebook.pow(2).sum(1, keepdim=True).t()
113
+ )
114
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
115
+ z_q = self.decode_code(indices)
116
+ return z_q, indices
Amphion/models/svc/transformer/transformer_inference.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import time
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+ import torch.nn as nn
12
+ from collections import OrderedDict
13
+
14
+ from models.svc.base import SVCInference
15
+ from modules.encoder.condition_encoder import ConditionEncoder
16
+ from models.svc.transformer.transformer import Transformer
17
+ from models.svc.transformer.conformer import Conformer
18
+
19
+
20
+ class TransformerInference(SVCInference):
21
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
22
+ SVCInference.__init__(self, args, cfg, infer_type)
23
+
24
+ def _build_model(self):
25
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
26
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
27
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
28
+ if self.cfg.model.transformer.type == "transformer":
29
+ self.acoustic_mapper = Transformer(self.cfg.model.transformer)
30
+ elif self.cfg.model.transformer.type == "conformer":
31
+ self.acoustic_mapper = Conformer(self.cfg.model.transformer)
32
+ else:
33
+ raise NotImplementedError
34
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
35
+ return model
36
+
37
+ def _inference_each_batch(self, batch_data):
38
+ device = self.accelerator.device
39
+ for k, v in batch_data.items():
40
+ batch_data[k] = v.to(device)
41
+
42
+ condition = self.condition_encoder(batch_data)
43
+ y_pred = self.acoustic_mapper(condition, batch_data["mask"])
44
+
45
+ return y_pred
Amphion/models/svc/vits/vits_trainer.py ADDED
@@ -0,0 +1,704 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from torch.optim.lr_scheduler import ExponentialLR
8
+ from tqdm import tqdm
9
+ from pathlib import Path
10
+ import shutil
11
+ import accelerate
12
+
13
+ # from models.svc.base import SVCTrainer
14
+ from models.svc.base.svc_dataset import SVCOfflineCollator, SVCOfflineDataset
15
+ from models.svc.vits.vits import *
16
+ from models.svc.base import SVCTrainer
17
+
18
+ from utils.mel import mel_spectrogram_torch
19
+ import json
20
+
21
+ from models.vocoders.gan.discriminator.mpd import (
22
+ MultiPeriodDiscriminator_vits as MultiPeriodDiscriminator,
23
+ )
24
+
25
+
26
+ class VitsSVCTrainer(SVCTrainer):
27
+ def __init__(self, args, cfg):
28
+ self.args = args
29
+ self.cfg = cfg
30
+ SVCTrainer.__init__(self, args, cfg)
31
+
32
+ def _accelerator_prepare(self):
33
+ (
34
+ self.train_dataloader,
35
+ self.valid_dataloader,
36
+ ) = self.accelerator.prepare(
37
+ self.train_dataloader,
38
+ self.valid_dataloader,
39
+ )
40
+ if isinstance(self.model, dict):
41
+ for key in self.model.keys():
42
+ self.model[key] = self.accelerator.prepare(self.model[key])
43
+ else:
44
+ self.model = self.accelerator.prepare(self.model)
45
+
46
+ if isinstance(self.optimizer, dict):
47
+ for key in self.optimizer.keys():
48
+ self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
49
+ else:
50
+ self.optimizer = self.accelerator.prepare(self.optimizer)
51
+
52
+ if isinstance(self.scheduler, dict):
53
+ for key in self.scheduler.keys():
54
+ self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
55
+ else:
56
+ self.scheduler = self.accelerator.prepare(self.scheduler)
57
+
58
+ def _load_model(
59
+ self,
60
+ checkpoint_dir: str = None,
61
+ checkpoint_path: str = None,
62
+ resume_type: str = "",
63
+ ):
64
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
65
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
66
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
67
+ method after** ``accelerator.prepare()``.
68
+ """
69
+ if checkpoint_path is None:
70
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
71
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
72
+ checkpoint_path = ls[0]
73
+ self.logger.info("Resume from {}...".format(checkpoint_path))
74
+
75
+ if resume_type in ["resume", ""]:
76
+ # Load all the things, including model weights, optimizer, scheduler, and random states.
77
+ self.accelerator.load_state(input_dir=checkpoint_path)
78
+
79
+ # set epoch and step
80
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
81
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
82
+
83
+ elif resume_type == "finetune":
84
+ # Load only the model weights
85
+ accelerate.load_checkpoint_and_dispatch(
86
+ self.accelerator.unwrap_model(self.model["generator"]),
87
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
88
+ )
89
+ accelerate.load_checkpoint_and_dispatch(
90
+ self.accelerator.unwrap_model(self.model["discriminator"]),
91
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
92
+ )
93
+ self.logger.info("Load model weights for finetune...")
94
+
95
+ else:
96
+ raise ValueError("Resume_type must be `resume` or `finetune`.")
97
+
98
+ return checkpoint_path
99
+
100
+ def _build_model(self):
101
+ net_g = SynthesizerTrn(
102
+ self.cfg.preprocess.n_fft // 2 + 1,
103
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
104
+ # directly use cfg
105
+ self.cfg,
106
+ )
107
+ net_d = MultiPeriodDiscriminator(self.cfg.model.vits.use_spectral_norm)
108
+ model = {"generator": net_g, "discriminator": net_d}
109
+
110
+ return model
111
+
112
+ def _build_dataset(self):
113
+ return SVCOfflineDataset, SVCOfflineCollator
114
+
115
+ def _build_optimizer(self):
116
+ optimizer_g = torch.optim.AdamW(
117
+ self.model["generator"].parameters(),
118
+ self.cfg.train.learning_rate,
119
+ betas=self.cfg.train.AdamW.betas,
120
+ eps=self.cfg.train.AdamW.eps,
121
+ )
122
+ optimizer_d = torch.optim.AdamW(
123
+ self.model["discriminator"].parameters(),
124
+ self.cfg.train.learning_rate,
125
+ betas=self.cfg.train.AdamW.betas,
126
+ eps=self.cfg.train.AdamW.eps,
127
+ )
128
+ optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d}
129
+
130
+ return optimizer
131
+
132
+ def _build_scheduler(self):
133
+ scheduler_g = ExponentialLR(
134
+ self.optimizer["optimizer_g"],
135
+ gamma=self.cfg.train.lr_decay,
136
+ last_epoch=self.epoch - 1,
137
+ )
138
+ scheduler_d = ExponentialLR(
139
+ self.optimizer["optimizer_d"],
140
+ gamma=self.cfg.train.lr_decay,
141
+ last_epoch=self.epoch - 1,
142
+ )
143
+
144
+ scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d}
145
+ return scheduler
146
+
147
+ def _build_criterion(self):
148
+ class GeneratorLoss(nn.Module):
149
+ def __init__(self, cfg):
150
+ super(GeneratorLoss, self).__init__()
151
+ self.cfg = cfg
152
+ self.l1_loss = nn.L1Loss()
153
+
154
+ def generator_loss(self, disc_outputs):
155
+ loss = 0
156
+ gen_losses = []
157
+ for dg in disc_outputs:
158
+ dg = dg.float()
159
+ l = torch.mean((1 - dg) ** 2)
160
+ gen_losses.append(l)
161
+ loss += l
162
+
163
+ return loss, gen_losses
164
+
165
+ def feature_loss(self, fmap_r, fmap_g):
166
+ loss = 0
167
+ for dr, dg in zip(fmap_r, fmap_g):
168
+ for rl, gl in zip(dr, dg):
169
+ rl = rl.float().detach()
170
+ gl = gl.float()
171
+ loss += torch.mean(torch.abs(rl - gl))
172
+
173
+ return loss * 2
174
+
175
+ def kl_loss(self, z_p, logs_q, m_p, logs_p, z_mask):
176
+ """
177
+ z_p, logs_q: [b, h, t_t]
178
+ m_p, logs_p: [b, h, t_t]
179
+ """
180
+ z_p = z_p.float()
181
+ logs_q = logs_q.float()
182
+ m_p = m_p.float()
183
+ logs_p = logs_p.float()
184
+ z_mask = z_mask.float()
185
+
186
+ kl = logs_p - logs_q - 0.5
187
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
188
+ kl = torch.sum(kl * z_mask)
189
+ l = kl / torch.sum(z_mask)
190
+ return l
191
+
192
+ def forward(
193
+ self,
194
+ outputs_g,
195
+ outputs_d,
196
+ y_mel,
197
+ y_hat_mel,
198
+ ):
199
+ loss_g = {}
200
+
201
+ # mel loss
202
+ loss_mel = self.l1_loss(y_mel, y_hat_mel) * self.cfg.train.c_mel
203
+ loss_g["loss_mel"] = loss_mel
204
+
205
+ # kl loss
206
+ loss_kl = (
207
+ self.kl_loss(
208
+ outputs_g["z_p"],
209
+ outputs_g["logs_q"],
210
+ outputs_g["m_p"],
211
+ outputs_g["logs_p"],
212
+ outputs_g["z_mask"],
213
+ )
214
+ * self.cfg.train.c_kl
215
+ )
216
+ loss_g["loss_kl"] = loss_kl
217
+
218
+ # feature loss
219
+ loss_fm = self.feature_loss(outputs_d["fmap_rs"], outputs_d["fmap_gs"])
220
+ loss_g["loss_fm"] = loss_fm
221
+
222
+ # gan loss
223
+ loss_gen, losses_gen = self.generator_loss(outputs_d["y_d_hat_g"])
224
+ loss_g["loss_gen"] = loss_gen
225
+ loss_g["loss_gen_all"] = loss_mel + loss_kl + loss_fm + loss_gen
226
+
227
+ return loss_g
228
+
229
+ class DiscriminatorLoss(nn.Module):
230
+ def __init__(self, cfg):
231
+ super(DiscriminatorLoss, self).__init__()
232
+ self.cfg = cfg
233
+ self.l1Loss = torch.nn.L1Loss(reduction="mean")
234
+
235
+ def __call__(self, disc_real_outputs, disc_generated_outputs):
236
+ loss_d = {}
237
+
238
+ loss = 0
239
+ r_losses = []
240
+ g_losses = []
241
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
242
+ dr = dr.float()
243
+ dg = dg.float()
244
+ r_loss = torch.mean((1 - dr) ** 2)
245
+ g_loss = torch.mean(dg**2)
246
+ loss += r_loss + g_loss
247
+ r_losses.append(r_loss.item())
248
+ g_losses.append(g_loss.item())
249
+
250
+ loss_d["loss_disc_all"] = loss
251
+
252
+ return loss_d
253
+
254
+ criterion = {
255
+ "generator": GeneratorLoss(self.cfg),
256
+ "discriminator": DiscriminatorLoss(self.cfg),
257
+ }
258
+ return criterion
259
+
260
+ # Keep legacy unchanged
261
+ def write_summary(
262
+ self,
263
+ losses,
264
+ stats,
265
+ images={},
266
+ audios={},
267
+ audio_sampling_rate=24000,
268
+ tag="train",
269
+ ):
270
+ for key, value in losses.items():
271
+ self.sw.add_scalar(tag + "/" + key, value, self.step)
272
+ self.sw.add_scalar(
273
+ "learning_rate",
274
+ self.optimizer["optimizer_g"].param_groups[0]["lr"],
275
+ self.step,
276
+ )
277
+
278
+ if len(images) != 0:
279
+ for key, value in images.items():
280
+ self.sw.add_image(key, value, self.global_step, batchformats="HWC")
281
+ if len(audios) != 0:
282
+ for key, value in audios.items():
283
+ self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
284
+
285
+ def write_valid_summary(
286
+ self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val"
287
+ ):
288
+ for key, value in losses.items():
289
+ self.sw.add_scalar(tag + "/" + key, value, self.step)
290
+
291
+ if len(images) != 0:
292
+ for key, value in images.items():
293
+ self.sw.add_image(key, value, self.global_step, batchformats="HWC")
294
+ if len(audios) != 0:
295
+ for key, value in audios.items():
296
+ self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
297
+
298
+ def _get_state_dict(self):
299
+ state_dict = {
300
+ "generator": self.model["generator"].state_dict(),
301
+ "discriminator": self.model["discriminator"].state_dict(),
302
+ "optimizer_g": self.optimizer["optimizer_g"].state_dict(),
303
+ "optimizer_d": self.optimizer["optimizer_d"].state_dict(),
304
+ "scheduler_g": self.scheduler["scheduler_g"].state_dict(),
305
+ "scheduler_d": self.scheduler["scheduler_d"].state_dict(),
306
+ "step": self.step,
307
+ "epoch": self.epoch,
308
+ "batch_size": self.cfg.train.batch_size,
309
+ }
310
+ return state_dict
311
+
312
+ def get_state_dict(self):
313
+ state_dict = {
314
+ "generator": self.model["generator"].state_dict(),
315
+ "discriminator": self.model["discriminator"].state_dict(),
316
+ "optimizer_g": self.optimizer["optimizer_g"].state_dict(),
317
+ "optimizer_d": self.optimizer["optimizer_d"].state_dict(),
318
+ "scheduler_g": self.scheduler["scheduler_g"].state_dict(),
319
+ "scheduler_d": self.scheduler["scheduler_d"].state_dict(),
320
+ "step": self.step,
321
+ "epoch": self.epoch,
322
+ "batch_size": self.cfg.train.batch_size,
323
+ }
324
+ return state_dict
325
+
326
+ def load_model(self, checkpoint):
327
+ self.step = checkpoint["step"]
328
+ self.epoch = checkpoint["epoch"]
329
+ self.model["generator"].load_state_dict(checkpoint["generator"])
330
+ self.model["discriminator"].load_state_dict(checkpoint["discriminator"])
331
+ self.optimizer["optimizer_g"].load_state_dict(checkpoint["optimizer_g"])
332
+ self.optimizer["optimizer_d"].load_state_dict(checkpoint["optimizer_d"])
333
+ self.scheduler["scheduler_g"].load_state_dict(checkpoint["scheduler_g"])
334
+ self.scheduler["scheduler_d"].load_state_dict(checkpoint["scheduler_d"])
335
+
336
+ @torch.inference_mode()
337
+ def _valid_step(self, batch):
338
+ r"""Testing forward step. Should return average loss of a sample over
339
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
340
+ See ``_test_epoch`` for usage.
341
+ """
342
+
343
+ valid_losses = {}
344
+ total_loss = 0
345
+ valid_stats = {}
346
+
347
+ # Discriminator
348
+ # Generator output
349
+ outputs_g = self.model["generator"](batch)
350
+
351
+ y_mel = slice_segments(
352
+ batch["mel"].transpose(1, 2),
353
+ outputs_g["ids_slice"],
354
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
355
+ )
356
+ y_hat_mel = mel_spectrogram_torch(
357
+ outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
358
+ )
359
+ y = slice_segments(
360
+ batch["audio"].unsqueeze(1),
361
+ outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
362
+ self.cfg.preprocess.segment_size,
363
+ )
364
+
365
+ # Discriminator output
366
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
367
+ ## Discriminator loss
368
+ loss_d = self.criterion["discriminator"](
369
+ outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
370
+ )
371
+ valid_losses.update(loss_d)
372
+
373
+ ## Generator
374
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
375
+ loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
376
+ valid_losses.update(loss_g)
377
+
378
+ for item in valid_losses:
379
+ valid_losses[item] = valid_losses[item].item()
380
+
381
+ total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
382
+
383
+ return (
384
+ total_loss.item(),
385
+ valid_losses,
386
+ valid_stats,
387
+ )
388
+
389
+ @torch.inference_mode()
390
+ def _valid_epoch(self):
391
+ r"""Testing epoch. Should return average loss of a batch (sample) over
392
+ one epoch. See ``train_loop`` for usage.
393
+ """
394
+ if isinstance(self.model, dict):
395
+ for key in self.model.keys():
396
+ self.model[key].eval()
397
+ else:
398
+ self.model.eval()
399
+
400
+ epoch_sum_loss = 0.0
401
+ epoch_losses = dict()
402
+ for batch in tqdm(
403
+ self.valid_dataloader,
404
+ desc=f"Validating Epoch {self.epoch}",
405
+ unit="batch",
406
+ colour="GREEN",
407
+ leave=False,
408
+ dynamic_ncols=True,
409
+ smoothing=0.04,
410
+ disable=not self.accelerator.is_main_process,
411
+ ):
412
+ total_loss, valid_losses, valid_stats = self._valid_step(batch)
413
+ epoch_sum_loss += total_loss
414
+ if isinstance(valid_losses, dict):
415
+ for key, value in valid_losses.items():
416
+ if key not in epoch_losses.keys():
417
+ epoch_losses[key] = value
418
+ else:
419
+ epoch_losses[key] += value
420
+
421
+ epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader)
422
+ for key in epoch_losses.keys():
423
+ epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader)
424
+
425
+ self.accelerator.wait_for_everyone()
426
+
427
+ return epoch_sum_loss, epoch_losses
428
+
429
+ ### THIS IS MAIN ENTRY ###
430
+ def train_loop(self):
431
+ r"""Training loop. The public entry of training process."""
432
+ # Wait everyone to prepare before we move on
433
+ self.accelerator.wait_for_everyone()
434
+ # dump config file
435
+ if self.accelerator.is_main_process:
436
+ self.__dump_cfg(self.config_save_path)
437
+
438
+ # self.optimizer.zero_grad()
439
+ # Wait to ensure good to go
440
+
441
+ self.accelerator.wait_for_everyone()
442
+ while self.epoch < self.max_epoch:
443
+ self.logger.info("\n")
444
+ self.logger.info("-" * 32)
445
+ self.logger.info("Epoch {}: ".format(self.epoch))
446
+
447
+ # Do training & validating epoch
448
+ train_total_loss, train_losses = self._train_epoch()
449
+ if isinstance(train_losses, dict):
450
+ for key, loss in train_losses.items():
451
+ self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
452
+ self.accelerator.log(
453
+ {"Epoch/Train {} Loss".format(key): loss},
454
+ step=self.epoch,
455
+ )
456
+
457
+ valid_total_loss, valid_losses = self._valid_epoch()
458
+ if isinstance(valid_losses, dict):
459
+ for key, loss in valid_losses.items():
460
+ self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss))
461
+ self.accelerator.log(
462
+ {"Epoch/Train {} Loss".format(key): loss},
463
+ step=self.epoch,
464
+ )
465
+
466
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_total_loss))
467
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_total_loss))
468
+ self.accelerator.log(
469
+ {
470
+ "Epoch/Train Loss": train_total_loss,
471
+ "Epoch/Valid Loss": valid_total_loss,
472
+ },
473
+ step=self.epoch,
474
+ )
475
+
476
+ self.accelerator.wait_for_everyone()
477
+
478
+ # Check if hit save_checkpoint_stride and run_eval
479
+ run_eval = False
480
+ if self.accelerator.is_main_process:
481
+ save_checkpoint = False
482
+ hit_dix = []
483
+ for i, num in enumerate(self.save_checkpoint_stride):
484
+ if self.epoch % num == 0:
485
+ save_checkpoint = True
486
+ hit_dix.append(i)
487
+ run_eval |= self.run_eval[i]
488
+
489
+ self.accelerator.wait_for_everyone()
490
+ if self.accelerator.is_main_process and save_checkpoint:
491
+ path = os.path.join(
492
+ self.checkpoint_dir,
493
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
494
+ self.epoch, self.step, train_total_loss
495
+ ),
496
+ )
497
+ self.tmp_checkpoint_save_path = path
498
+ self.accelerator.save_state(path)
499
+
500
+ json.dump(
501
+ self.checkpoints_path,
502
+ open(os.path.join(path, "ckpts.json"), "w"),
503
+ ensure_ascii=False,
504
+ indent=4,
505
+ )
506
+ self._save_auxiliary_states()
507
+
508
+ # Remove old checkpoints
509
+ to_remove = []
510
+ for idx in hit_dix:
511
+ self.checkpoints_path[idx].append(path)
512
+ while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
513
+ to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
514
+
515
+ # Search conflicts
516
+ total = set()
517
+ for i in self.checkpoints_path:
518
+ total |= set(i)
519
+ do_remove = set()
520
+ for idx, path in to_remove[::-1]:
521
+ if path in total:
522
+ self.checkpoints_path[idx].insert(0, path)
523
+ else:
524
+ do_remove.add(path)
525
+
526
+ # Remove old checkpoints
527
+ for path in do_remove:
528
+ shutil.rmtree(path, ignore_errors=True)
529
+ self.logger.debug(f"Remove old checkpoint: {path}")
530
+
531
+ self.accelerator.wait_for_everyone()
532
+ if run_eval:
533
+ # TODO: run evaluation
534
+ pass
535
+
536
+ # Update info for each epoch
537
+ self.epoch += 1
538
+
539
+ # Finish training and save final checkpoint
540
+ self.accelerator.wait_for_everyone()
541
+ if self.accelerator.is_main_process:
542
+ path = os.path.join(
543
+ self.checkpoint_dir,
544
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
545
+ self.epoch, self.step, valid_total_loss
546
+ ),
547
+ )
548
+ self.tmp_checkpoint_save_path = path
549
+ self.accelerator.save_state(
550
+ os.path.join(
551
+ self.checkpoint_dir,
552
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
553
+ self.epoch, self.step, valid_total_loss
554
+ ),
555
+ )
556
+ )
557
+
558
+ json.dump(
559
+ self.checkpoints_path,
560
+ open(os.path.join(path, "ckpts.json"), "w"),
561
+ ensure_ascii=False,
562
+ indent=4,
563
+ )
564
+ self._save_auxiliary_states()
565
+
566
+ self.accelerator.end_training()
567
+
568
+ def _train_step(self, batch):
569
+ r"""Forward step for training and inference. This function is called
570
+ in ``_train_step`` & ``_test_step`` function.
571
+ """
572
+
573
+ train_losses = {}
574
+ total_loss = 0
575
+ training_stats = {}
576
+
577
+ ## Train Discriminator
578
+ # Generator output
579
+ outputs_g = self.model["generator"](batch)
580
+
581
+ y_mel = slice_segments(
582
+ batch["mel"].transpose(1, 2),
583
+ outputs_g["ids_slice"],
584
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
585
+ )
586
+ y_hat_mel = mel_spectrogram_torch(
587
+ outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
588
+ )
589
+
590
+ y = slice_segments(
591
+ # [1, 168418] -> [1, 1, 168418]
592
+ batch["audio"].unsqueeze(1),
593
+ outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
594
+ self.cfg.preprocess.segment_size,
595
+ )
596
+
597
+ # Discriminator output
598
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
599
+ # Discriminator loss
600
+ loss_d = self.criterion["discriminator"](
601
+ outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
602
+ )
603
+ train_losses.update(loss_d)
604
+
605
+ # BP and Grad Updated
606
+ self.optimizer["optimizer_d"].zero_grad()
607
+ self.accelerator.backward(loss_d["loss_disc_all"])
608
+ self.optimizer["optimizer_d"].step()
609
+
610
+ ## Train Generator
611
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
612
+ loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
613
+ train_losses.update(loss_g)
614
+
615
+ # BP and Grad Updated
616
+ self.optimizer["optimizer_g"].zero_grad()
617
+ self.accelerator.backward(loss_g["loss_gen_all"])
618
+ self.optimizer["optimizer_g"].step()
619
+
620
+ for item in train_losses:
621
+ train_losses[item] = train_losses[item].item()
622
+
623
+ total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
624
+
625
+ return (
626
+ total_loss.item(),
627
+ train_losses,
628
+ training_stats,
629
+ )
630
+
631
+ def _train_epoch(self):
632
+ r"""Training epoch. Should return average loss of a batch (sample) over
633
+ one epoch. See ``train_loop`` for usage.
634
+ """
635
+ epoch_sum_loss: float = 0.0
636
+ epoch_losses: dict = {}
637
+ epoch_step: int = 0
638
+ for batch in tqdm(
639
+ self.train_dataloader,
640
+ desc=f"Training Epoch {self.epoch}",
641
+ unit="batch",
642
+ colour="GREEN",
643
+ leave=False,
644
+ dynamic_ncols=True,
645
+ smoothing=0.04,
646
+ disable=not self.accelerator.is_main_process,
647
+ ):
648
+ # Do training step and BP
649
+ with self.accelerator.accumulate(self.model):
650
+ total_loss, train_losses, training_stats = self._train_step(batch)
651
+ self.batch_count += 1
652
+
653
+ # Update info for each step
654
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
655
+ epoch_sum_loss += total_loss
656
+ for key, value in train_losses.items():
657
+ if key not in epoch_losses.keys():
658
+ epoch_losses[key] = value
659
+ else:
660
+ epoch_losses[key] += value
661
+
662
+ self.accelerator.log(
663
+ {
664
+ "Step/Generator Loss": train_losses["loss_gen_all"],
665
+ "Step/Discriminator Loss": train_losses["loss_disc_all"],
666
+ "Step/Generator Learning Rate": self.optimizer[
667
+ "optimizer_d"
668
+ ].param_groups[0]["lr"],
669
+ "Step/Discriminator Learning Rate": self.optimizer[
670
+ "optimizer_g"
671
+ ].param_groups[0]["lr"],
672
+ },
673
+ step=self.step,
674
+ )
675
+ self.step += 1
676
+ epoch_step += 1
677
+
678
+ self.accelerator.wait_for_everyone()
679
+
680
+ epoch_sum_loss = (
681
+ epoch_sum_loss
682
+ / len(self.train_dataloader)
683
+ * self.cfg.train.gradient_accumulation_step
684
+ )
685
+
686
+ for key in epoch_losses.keys():
687
+ epoch_losses[key] = (
688
+ epoch_losses[key]
689
+ / len(self.train_dataloader)
690
+ * self.cfg.train.gradient_accumulation_step
691
+ )
692
+
693
+ return epoch_sum_loss, epoch_losses
694
+
695
+ def __dump_cfg(self, path):
696
+ os.makedirs(os.path.dirname(path), exist_ok=True)
697
+ json5.dump(
698
+ self.cfg,
699
+ open(path, "w"),
700
+ indent=4,
701
+ sort_keys=True,
702
+ ensure_ascii=False,
703
+ quote_keys=True,
704
+ )
Amphion/models/tta/autoencoder/autoencoder_dataset.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import torch
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ from utils.data_utils import *
10
+ from models.base.base_dataset import (
11
+ BaseOfflineCollator,
12
+ BaseOfflineDataset,
13
+ BaseTestDataset,
14
+ BaseTestCollator,
15
+ )
16
+ import librosa
17
+
18
+
19
+ class AutoencoderKLDataset(BaseOfflineDataset):
20
+ def __init__(self, cfg, dataset, is_valid=False):
21
+ BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid)
22
+
23
+ cfg = self.cfg
24
+
25
+ # utt2melspec
26
+ if cfg.preprocess.use_melspec:
27
+ self.utt2melspec_path = {}
28
+ for utt_info in self.metadata:
29
+ dataset = utt_info["Dataset"]
30
+ uid = utt_info["Uid"]
31
+ utt = "{}_{}".format(dataset, uid)
32
+
33
+ self.utt2melspec_path[utt] = os.path.join(
34
+ cfg.preprocess.processed_dir,
35
+ dataset,
36
+ cfg.preprocess.melspec_dir,
37
+ uid + ".npy",
38
+ )
39
+
40
+ # utt2wav
41
+ if cfg.preprocess.use_wav:
42
+ self.utt2wav_path = {}
43
+ for utt_info in self.metadata:
44
+ dataset = utt_info["Dataset"]
45
+ uid = utt_info["Uid"]
46
+ utt = "{}_{}".format(dataset, uid)
47
+
48
+ self.utt2wav_path[utt] = os.path.join(
49
+ cfg.preprocess.processed_dir,
50
+ dataset,
51
+ cfg.preprocess.wav_dir,
52
+ uid + ".wav",
53
+ )
54
+
55
+ def __getitem__(self, index):
56
+ # melspec: (n_mels, T)
57
+ # wav: (T,)
58
+
59
+ single_feature = BaseOfflineDataset.__getitem__(self, index)
60
+
61
+ utt_info = self.metadata[index]
62
+ dataset = utt_info["Dataset"]
63
+ uid = utt_info["Uid"]
64
+ utt = "{}_{}".format(dataset, uid)
65
+
66
+ if self.cfg.preprocess.use_melspec:
67
+ single_feature["melspec"] = np.load(self.utt2melspec_path[utt])
68
+
69
+ if self.cfg.preprocess.use_wav:
70
+ wav, sr = librosa.load(
71
+ self.utt2wav_path[utt], sr=16000
72
+ ) # hard coding for 16KHz...
73
+ single_feature["wav"] = wav
74
+
75
+ return single_feature
76
+
77
+ def __len__(self):
78
+ return len(self.metadata)
79
+
80
+ def __len__(self):
81
+ return len(self.metadata)
82
+
83
+
84
+ class AutoencoderKLCollator(BaseOfflineCollator):
85
+ def __init__(self, cfg):
86
+ BaseOfflineCollator.__init__(self, cfg)
87
+
88
+ def __call__(self, batch):
89
+ # mel: (B, n_mels, T)
90
+ # wav (option): (B, T)
91
+
92
+ packed_batch_features = dict()
93
+
94
+ for key in batch[0].keys():
95
+ if key == "melspec":
96
+ packed_batch_features["melspec"] = torch.from_numpy(
97
+ np.array([b["melspec"][:, :624] for b in batch])
98
+ )
99
+
100
+ if key == "wav":
101
+ values = [torch.from_numpy(b[key]) for b in batch]
102
+ packed_batch_features[key] = pad_sequence(
103
+ values, batch_first=True, padding_value=0
104
+ )
105
+
106
+ return packed_batch_features
107
+
108
+
109
+ class AutoencoderKLTestDataset(BaseTestDataset): ...
110
+
111
+
112
+ class AutoencoderKLTestCollator(BaseTestCollator): ...
Amphion/models/tta/ldm/__init__.py ADDED
File without changes
Amphion/models/tta/ldm/audioldm_dataset.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import torch
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ from utils.data_utils import *
10
+
11
+
12
+ from models.base.base_dataset import (
13
+ BaseOfflineCollator,
14
+ BaseOfflineDataset,
15
+ BaseTestDataset,
16
+ BaseTestCollator,
17
+ )
18
+ import librosa
19
+
20
+ from transformers import AutoTokenizer
21
+
22
+
23
+ class AudioLDMDataset(BaseOfflineDataset):
24
+ def __init__(self, cfg, dataset, is_valid=False):
25
+ BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid)
26
+
27
+ self.cfg = cfg
28
+
29
+ # utt2melspec
30
+ if cfg.preprocess.use_melspec:
31
+ self.utt2melspec_path = {}
32
+ for utt_info in self.metadata:
33
+ dataset = utt_info["Dataset"]
34
+ uid = utt_info["Uid"]
35
+ utt = "{}_{}".format(dataset, uid)
36
+
37
+ self.utt2melspec_path[utt] = os.path.join(
38
+ cfg.preprocess.processed_dir,
39
+ dataset,
40
+ cfg.preprocess.melspec_dir,
41
+ uid + ".npy",
42
+ )
43
+
44
+ # utt2wav
45
+ if cfg.preprocess.use_wav:
46
+ self.utt2wav_path = {}
47
+ for utt_info in self.metadata:
48
+ dataset = utt_info["Dataset"]
49
+ uid = utt_info["Uid"]
50
+ utt = "{}_{}".format(dataset, uid)
51
+
52
+ self.utt2wav_path[utt] = os.path.join(
53
+ cfg.preprocess.processed_dir,
54
+ dataset,
55
+ cfg.preprocess.wav_dir,
56
+ uid + ".wav",
57
+ )
58
+
59
+ # utt2caption
60
+ if cfg.preprocess.use_caption:
61
+ self.utt2caption = {}
62
+ for utt_info in self.metadata:
63
+ dataset = utt_info["Dataset"]
64
+ uid = utt_info["Uid"]
65
+ utt = "{}_{}".format(dataset, uid)
66
+
67
+ self.utt2caption[utt] = utt_info["Caption"]
68
+
69
+ def __getitem__(self, index):
70
+ # melspec: (n_mels, T)
71
+ # wav: (T,)
72
+
73
+ single_feature = BaseOfflineDataset.__getitem__(self, index)
74
+
75
+ utt_info = self.metadata[index]
76
+ dataset = utt_info["Dataset"]
77
+ uid = utt_info["Uid"]
78
+ utt = "{}_{}".format(dataset, uid)
79
+
80
+ if self.cfg.preprocess.use_melspec:
81
+ single_feature["melspec"] = np.load(self.utt2melspec_path[utt])
82
+
83
+ if self.cfg.preprocess.use_wav:
84
+ wav, sr = librosa.load(
85
+ self.utt2wav_path[utt], sr=16000
86
+ ) # hard coding for 16KHz...
87
+ single_feature["wav"] = wav
88
+
89
+ if self.cfg.preprocess.use_caption:
90
+ cond_mask = np.random.choice(
91
+ [1, 0],
92
+ p=[
93
+ self.cfg.preprocess.cond_mask_prob,
94
+ 1 - self.cfg.preprocess.cond_mask_prob,
95
+ ],
96
+ ) # (0.1, 0.9)
97
+ if cond_mask:
98
+ single_feature["caption"] = ""
99
+ else:
100
+ single_feature["caption"] = self.utt2caption[utt]
101
+
102
+ return single_feature
103
+
104
+ def __len__(self):
105
+ return len(self.metadata)
106
+
107
+
108
+ class AudioLDMCollator(BaseOfflineCollator):
109
+ def __init__(self, cfg):
110
+ BaseOfflineCollator.__init__(self, cfg)
111
+
112
+ self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
113
+
114
+ def __call__(self, batch):
115
+ # mel: (B, n_mels, T)
116
+ # wav (option): (B, T)
117
+ # text_input_ids: (B, L)
118
+ # text_attention_mask: (B, L)
119
+
120
+ packed_batch_features = dict()
121
+
122
+ for key in batch[0].keys():
123
+ if key == "melspec":
124
+ packed_batch_features["melspec"] = torch.from_numpy(
125
+ np.array([b["melspec"][:, :624] for b in batch])
126
+ )
127
+
128
+ if key == "wav":
129
+ values = [torch.from_numpy(b[key]) for b in batch]
130
+ packed_batch_features[key] = pad_sequence(
131
+ values, batch_first=True, padding_value=0
132
+ )
133
+
134
+ if key == "caption":
135
+ captions = [b[key] for b in batch]
136
+ text_input = self.tokenizer(
137
+ captions, return_tensors="pt", truncation=True, padding="longest"
138
+ )
139
+ text_input_ids = text_input["input_ids"]
140
+ text_attention_mask = text_input["attention_mask"]
141
+
142
+ packed_batch_features["text_input_ids"] = text_input_ids
143
+ packed_batch_features["text_attention_mask"] = text_attention_mask
144
+
145
+ return packed_batch_features
146
+
147
+
148
+ class AudioLDMTestDataset(BaseTestDataset): ...
149
+
150
+
151
+ class AudioLDMTestCollator(BaseTestCollator): ...
Amphion/models/tta/ldm/audioldm_trainer.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from models.base.base_trainer import BaseTrainer
7
+ from diffusers import DDPMScheduler
8
+ from models.tta.ldm.audioldm_dataset import AudioLDMDataset, AudioLDMCollator
9
+ from models.tta.autoencoder.autoencoder import AutoencoderKL
10
+ from models.tta.ldm.audioldm import AudioLDM, UNetModel
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import MSELoss, L1Loss
14
+ import torch.nn.functional as F
15
+ from torch.utils.data import ConcatDataset, DataLoader
16
+
17
+ from transformers import T5EncoderModel
18
+ from diffusers import DDPMScheduler
19
+
20
+
21
+ class AudioLDMTrainer(BaseTrainer):
22
+ def __init__(self, args, cfg):
23
+ BaseTrainer.__init__(self, args, cfg)
24
+ self.cfg = cfg
25
+
26
+ self.build_autoencoderkl()
27
+ self.build_textencoder()
28
+ self.nosie_scheduler = self.build_noise_scheduler()
29
+
30
+ self.save_config_file()
31
+
32
+ def build_autoencoderkl(self):
33
+ self.autoencoderkl = AutoencoderKL(self.cfg.model.autoencoderkl)
34
+ self.autoencoder_path = self.cfg.model.autoencoder_path
35
+ checkpoint = torch.load(self.autoencoder_path, map_location="cpu")
36
+ self.autoencoderkl.load_state_dict(checkpoint["model"])
37
+ self.autoencoderkl.cuda(self.args.local_rank)
38
+ self.autoencoderkl.requires_grad_(requires_grad=False)
39
+ self.autoencoderkl.eval()
40
+
41
+ def build_textencoder(self):
42
+ self.text_encoder = T5EncoderModel.from_pretrained("t5-base")
43
+ self.text_encoder.cuda(self.args.local_rank)
44
+ self.text_encoder.requires_grad_(requires_grad=False)
45
+ self.text_encoder.eval()
46
+
47
+ def build_noise_scheduler(self):
48
+ nosie_scheduler = DDPMScheduler(
49
+ num_train_timesteps=self.cfg.model.noise_scheduler.num_train_timesteps,
50
+ beta_start=self.cfg.model.noise_scheduler.beta_start,
51
+ beta_end=self.cfg.model.noise_scheduler.beta_end,
52
+ beta_schedule=self.cfg.model.noise_scheduler.beta_schedule,
53
+ clip_sample=self.cfg.model.noise_scheduler.clip_sample,
54
+ # steps_offset=self.cfg.model.noise_scheduler.steps_offset,
55
+ # set_alpha_to_one=self.cfg.model.noise_scheduler.set_alpha_to_one,
56
+ # skip_prk_steps=self.cfg.model.noise_scheduler.skip_prk_steps,
57
+ prediction_type=self.cfg.model.noise_scheduler.prediction_type,
58
+ )
59
+ return nosie_scheduler
60
+
61
+ def build_dataset(self):
62
+ return AudioLDMDataset, AudioLDMCollator
63
+
64
+ def build_data_loader(self):
65
+ Dataset, Collator = self.build_dataset()
66
+ # build dataset instance for each dataset and combine them by ConcatDataset
67
+ datasets_list = []
68
+ for dataset in self.cfg.dataset:
69
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
70
+ datasets_list.append(subdataset)
71
+ train_dataset = ConcatDataset(datasets_list)
72
+
73
+ train_collate = Collator(self.cfg)
74
+
75
+ # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
76
+ train_loader = DataLoader(
77
+ train_dataset,
78
+ collate_fn=train_collate,
79
+ num_workers=self.args.num_workers,
80
+ batch_size=self.cfg.train.batch_size,
81
+ pin_memory=False,
82
+ )
83
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
84
+ datasets_list = []
85
+ for dataset in self.cfg.dataset:
86
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
87
+ datasets_list.append(subdataset)
88
+ valid_dataset = ConcatDataset(datasets_list)
89
+ valid_collate = Collator(self.cfg)
90
+
91
+ valid_loader = DataLoader(
92
+ valid_dataset,
93
+ collate_fn=valid_collate,
94
+ num_workers=1,
95
+ batch_size=self.cfg.train.batch_size,
96
+ )
97
+ else:
98
+ raise NotImplementedError("DDP is not supported yet.")
99
+ # valid_loader = None
100
+ data_loader = {"train": train_loader, "valid": valid_loader}
101
+ return data_loader
102
+
103
+ def build_optimizer(self):
104
+ optimizer = torch.optim.AdamW(self.model.parameters(), **self.cfg.train.adam)
105
+ return optimizer
106
+
107
+ # TODO: check it...
108
+ def build_scheduler(self):
109
+ return None
110
+ # return ReduceLROnPlateau(self.optimizer["opt_ae"], **self.cfg.train.lronPlateau)
111
+
112
+ def write_summary(self, losses, stats):
113
+ for key, value in losses.items():
114
+ self.sw.add_scalar(key, value, self.step)
115
+
116
+ def write_valid_summary(self, losses, stats):
117
+ for key, value in losses.items():
118
+ self.sw.add_scalar(key, value, self.step)
119
+
120
+ def build_criterion(self):
121
+ criterion = nn.MSELoss(reduction="mean")
122
+ return criterion
123
+
124
+ def get_state_dict(self):
125
+ if self.scheduler != None:
126
+ state_dict = {
127
+ "model": self.model.state_dict(),
128
+ "optimizer": self.optimizer.state_dict(),
129
+ "scheduler": self.scheduler.state_dict(),
130
+ "step": self.step,
131
+ "epoch": self.epoch,
132
+ "batch_size": self.cfg.train.batch_size,
133
+ }
134
+ else:
135
+ state_dict = {
136
+ "model": self.model.state_dict(),
137
+ "optimizer": self.optimizer.state_dict(),
138
+ "step": self.step,
139
+ "epoch": self.epoch,
140
+ "batch_size": self.cfg.train.batch_size,
141
+ }
142
+ return state_dict
143
+
144
+ def load_model(self, checkpoint):
145
+ self.step = checkpoint["step"]
146
+ self.epoch = checkpoint["epoch"]
147
+
148
+ self.model.load_state_dict(checkpoint["model"])
149
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
150
+ if self.scheduler != None:
151
+ self.scheduler.load_state_dict(checkpoint["scheduler"])
152
+
153
+ def build_model(self):
154
+ self.model = AudioLDM(self.cfg.model.audioldm)
155
+ return self.model
156
+
157
+ @torch.no_grad()
158
+ def mel_to_latent(self, melspec):
159
+ posterior = self.autoencoderkl.encode(melspec)
160
+ latent = posterior.sample() # (B, 4, 5, 78)
161
+ return latent
162
+
163
+ @torch.no_grad()
164
+ def get_text_embedding(self, text_input_ids, text_attention_mask):
165
+ text_embedding = self.text_encoder(
166
+ input_ids=text_input_ids, attention_mask=text_attention_mask
167
+ ).last_hidden_state
168
+ return text_embedding # (B, T, 768)
169
+
170
+ def train_step(self, data):
171
+ train_losses = {}
172
+ total_loss = 0
173
+ train_stats = {}
174
+
175
+ melspec = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
176
+ latents = self.mel_to_latent(melspec)
177
+
178
+ text_embedding = self.get_text_embedding(
179
+ data["text_input_ids"], data["text_attention_mask"]
180
+ )
181
+
182
+ noise = torch.randn_like(latents).float()
183
+
184
+ bsz = latents.shape[0]
185
+ timesteps = torch.randint(
186
+ 0,
187
+ self.cfg.model.noise_scheduler.num_train_timesteps,
188
+ (bsz,),
189
+ device=latents.device,
190
+ )
191
+ timesteps = timesteps.long()
192
+
193
+ with torch.no_grad():
194
+ noisy_latents = self.nosie_scheduler.add_noise(latents, noise, timesteps)
195
+
196
+ model_pred = self.model(
197
+ noisy_latents, timesteps=timesteps, context=text_embedding
198
+ )
199
+
200
+ loss = self.criterion(model_pred, noise)
201
+
202
+ train_losses["loss"] = loss
203
+ total_loss += loss
204
+
205
+ self.optimizer.zero_grad()
206
+ total_loss.backward()
207
+ self.optimizer.step()
208
+
209
+ for item in train_losses:
210
+ train_losses[item] = train_losses[item].item()
211
+
212
+ return train_losses, train_stats, total_loss.item()
213
+
214
+ # TODO: eval step
215
+ @torch.no_grad()
216
+ def eval_step(self, data, index):
217
+ valid_loss = {}
218
+ total_valid_loss = 0
219
+ valid_stats = {}
220
+
221
+ melspec = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
222
+ latents = self.mel_to_latent(melspec)
223
+
224
+ text_embedding = self.get_text_embedding(
225
+ data["text_input_ids"], data["text_attention_mask"]
226
+ )
227
+
228
+ noise = torch.randn_like(latents).float()
229
+
230
+ bsz = latents.shape[0]
231
+ timesteps = torch.randint(
232
+ 0,
233
+ self.cfg.model.noise_scheduler.num_train_timesteps,
234
+ (bsz,),
235
+ device=latents.device,
236
+ )
237
+ timesteps = timesteps.long()
238
+
239
+ noisy_latents = self.nosie_scheduler.add_noise(latents, noise, timesteps)
240
+
241
+ model_pred = self.model(noisy_latents, timesteps, text_embedding)
242
+
243
+ loss = self.criterion(model_pred, noise)
244
+ valid_loss["loss"] = loss
245
+
246
+ total_valid_loss += loss
247
+
248
+ for item in valid_loss:
249
+ valid_loss[item] = valid_loss[item].item()
250
+
251
+ return valid_loss, valid_stats, total_valid_loss.item()
Amphion/models/tta/ldm/inference_utils/vocoder.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.nn as nn
9
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
11
+ from models.tta.ldm.inference_utils.utils import get_padding, init_weights
12
+
13
+ LRELU_SLOPE = 0.1
14
+
15
+
16
+ class ResBlock1(torch.nn.Module):
17
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
18
+ super(ResBlock1, self).__init__()
19
+ self.h = h
20
+ self.convs1 = nn.ModuleList(
21
+ [
22
+ weight_norm(
23
+ Conv1d(
24
+ channels,
25
+ channels,
26
+ kernel_size,
27
+ 1,
28
+ dilation=dilation[0],
29
+ padding=get_padding(kernel_size, dilation[0]),
30
+ )
31
+ ),
32
+ weight_norm(
33
+ Conv1d(
34
+ channels,
35
+ channels,
36
+ kernel_size,
37
+ 1,
38
+ dilation=dilation[1],
39
+ padding=get_padding(kernel_size, dilation[1]),
40
+ )
41
+ ),
42
+ weight_norm(
43
+ Conv1d(
44
+ channels,
45
+ channels,
46
+ kernel_size,
47
+ 1,
48
+ dilation=dilation[2],
49
+ padding=get_padding(kernel_size, dilation[2]),
50
+ )
51
+ ),
52
+ ]
53
+ )
54
+ self.convs1.apply(init_weights)
55
+
56
+ self.convs2 = nn.ModuleList(
57
+ [
58
+ weight_norm(
59
+ Conv1d(
60
+ channels,
61
+ channels,
62
+ kernel_size,
63
+ 1,
64
+ dilation=1,
65
+ padding=get_padding(kernel_size, 1),
66
+ )
67
+ ),
68
+ weight_norm(
69
+ Conv1d(
70
+ channels,
71
+ channels,
72
+ kernel_size,
73
+ 1,
74
+ dilation=1,
75
+ padding=get_padding(kernel_size, 1),
76
+ )
77
+ ),
78
+ weight_norm(
79
+ Conv1d(
80
+ channels,
81
+ channels,
82
+ kernel_size,
83
+ 1,
84
+ dilation=1,
85
+ padding=get_padding(kernel_size, 1),
86
+ )
87
+ ),
88
+ ]
89
+ )
90
+ self.convs2.apply(init_weights)
91
+
92
+ def forward(self, x):
93
+ for c1, c2 in zip(self.convs1, self.convs2):
94
+ xt = F.leaky_relu(x, LRELU_SLOPE)
95
+ xt = c1(xt)
96
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
97
+ xt = c2(xt)
98
+ x = xt + x
99
+ return x
100
+
101
+ def remove_weight_norm(self):
102
+ for l in self.convs1:
103
+ remove_weight_norm(l)
104
+ for l in self.convs2:
105
+ remove_weight_norm(l)
106
+
107
+
108
+ class ResBlock2(torch.nn.Module):
109
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
110
+ super(ResBlock2, self).__init__()
111
+ self.h = h
112
+ self.convs = nn.ModuleList(
113
+ [
114
+ weight_norm(
115
+ Conv1d(
116
+ channels,
117
+ channels,
118
+ kernel_size,
119
+ 1,
120
+ dilation=dilation[0],
121
+ padding=get_padding(kernel_size, dilation[0]),
122
+ )
123
+ ),
124
+ weight_norm(
125
+ Conv1d(
126
+ channels,
127
+ channels,
128
+ kernel_size,
129
+ 1,
130
+ dilation=dilation[1],
131
+ padding=get_padding(kernel_size, dilation[1]),
132
+ )
133
+ ),
134
+ ]
135
+ )
136
+ self.convs.apply(init_weights)
137
+
138
+ def forward(self, x):
139
+ for c in self.convs:
140
+ xt = F.leaky_relu(x, LRELU_SLOPE)
141
+ xt = c(xt)
142
+ x = xt + x
143
+ return x
144
+
145
+ def remove_weight_norm(self):
146
+ for l in self.convs:
147
+ remove_weight_norm(l)
148
+
149
+
150
+ class Generator(torch.nn.Module):
151
+ def __init__(self, h):
152
+ super(Generator, self).__init__()
153
+ self.h = h
154
+ self.num_kernels = len(h.resblock_kernel_sizes)
155
+ self.num_upsamples = len(h.upsample_rates)
156
+ self.conv_pre = weight_norm(
157
+ Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
158
+ )
159
+ resblock = ResBlock1 if h.resblock == "1" else ResBlock2
160
+
161
+ self.ups = nn.ModuleList()
162
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
163
+ self.ups.append(
164
+ weight_norm(
165
+ ConvTranspose1d(
166
+ h.upsample_initial_channel // (2**i),
167
+ h.upsample_initial_channel // (2 ** (i + 1)),
168
+ k,
169
+ u,
170
+ padding=(k - u) // 2,
171
+ )
172
+ )
173
+ )
174
+
175
+ self.resblocks = nn.ModuleList()
176
+ for i in range(len(self.ups)):
177
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
178
+ for j, (k, d) in enumerate(
179
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
180
+ ):
181
+ self.resblocks.append(resblock(h, ch, k, d))
182
+
183
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
184
+ self.ups.apply(init_weights)
185
+ self.conv_post.apply(init_weights)
186
+
187
+ def forward(self, x):
188
+ x = self.conv_pre(x)
189
+ for i in range(self.num_upsamples):
190
+ x = F.leaky_relu(x, LRELU_SLOPE)
191
+ x = self.ups[i](x)
192
+ xs = None
193
+ for j in range(self.num_kernels):
194
+ if xs is None:
195
+ xs = self.resblocks[i * self.num_kernels + j](x)
196
+ else:
197
+ xs += self.resblocks[i * self.num_kernels + j](x)
198
+ x = xs / self.num_kernels
199
+ x = F.leaky_relu(x)
200
+ x = self.conv_post(x)
201
+ x = torch.tanh(x)
202
+
203
+ return x
204
+
205
+ def remove_weight_norm(self):
206
+ print("Removing weight norm...")
207
+ for l in self.ups:
208
+ remove_weight_norm(l)
209
+ for l in self.resblocks:
210
+ l.remove_weight_norm()
211
+ remove_weight_norm(self.conv_pre)
212
+ remove_weight_norm(self.conv_post)
213
+
214
+
215
+ class DiscriminatorP(torch.nn.Module):
216
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
217
+ super(DiscriminatorP, self).__init__()
218
+ self.period = period
219
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
220
+ self.convs = nn.ModuleList(
221
+ [
222
+ norm_f(
223
+ Conv2d(
224
+ 1,
225
+ 32,
226
+ (kernel_size, 1),
227
+ (stride, 1),
228
+ padding=(get_padding(5, 1), 0),
229
+ )
230
+ ),
231
+ norm_f(
232
+ Conv2d(
233
+ 32,
234
+ 128,
235
+ (kernel_size, 1),
236
+ (stride, 1),
237
+ padding=(get_padding(5, 1), 0),
238
+ )
239
+ ),
240
+ norm_f(
241
+ Conv2d(
242
+ 128,
243
+ 512,
244
+ (kernel_size, 1),
245
+ (stride, 1),
246
+ padding=(get_padding(5, 1), 0),
247
+ )
248
+ ),
249
+ norm_f(
250
+ Conv2d(
251
+ 512,
252
+ 1024,
253
+ (kernel_size, 1),
254
+ (stride, 1),
255
+ padding=(get_padding(5, 1), 0),
256
+ )
257
+ ),
258
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
259
+ ]
260
+ )
261
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
262
+
263
+ def forward(self, x):
264
+ fmap = []
265
+
266
+ # 1d to 2d
267
+ b, c, t = x.shape
268
+ if t % self.period != 0: # pad first
269
+ n_pad = self.period - (t % self.period)
270
+ x = F.pad(x, (0, n_pad), "reflect")
271
+ t = t + n_pad
272
+ x = x.view(b, c, t // self.period, self.period)
273
+
274
+ for l in self.convs:
275
+ x = l(x)
276
+ x = F.leaky_relu(x, LRELU_SLOPE)
277
+ fmap.append(x)
278
+ x = self.conv_post(x)
279
+ fmap.append(x)
280
+ x = torch.flatten(x, 1, -1)
281
+
282
+ return x, fmap
283
+
284
+
285
+ class MultiPeriodDiscriminator(torch.nn.Module):
286
+ def __init__(self):
287
+ super(MultiPeriodDiscriminator, self).__init__()
288
+ self.discriminators = nn.ModuleList(
289
+ [
290
+ DiscriminatorP(2),
291
+ DiscriminatorP(3),
292
+ DiscriminatorP(5),
293
+ DiscriminatorP(7),
294
+ DiscriminatorP(11),
295
+ ]
296
+ )
297
+
298
+ def forward(self, y, y_hat):
299
+ y_d_rs = []
300
+ y_d_gs = []
301
+ fmap_rs = []
302
+ fmap_gs = []
303
+ for i, d in enumerate(self.discriminators):
304
+ y_d_r, fmap_r = d(y)
305
+ y_d_g, fmap_g = d(y_hat)
306
+ y_d_rs.append(y_d_r)
307
+ fmap_rs.append(fmap_r)
308
+ y_d_gs.append(y_d_g)
309
+ fmap_gs.append(fmap_g)
310
+
311
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
312
+
313
+
314
+ class DiscriminatorS(torch.nn.Module):
315
+ def __init__(self, use_spectral_norm=False):
316
+ super(DiscriminatorS, self).__init__()
317
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
318
+ self.convs = nn.ModuleList(
319
+ [
320
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
321
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
322
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
323
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
324
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
325
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
326
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
327
+ ]
328
+ )
329
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
330
+
331
+ def forward(self, x):
332
+ fmap = []
333
+ for l in self.convs:
334
+ x = l(x)
335
+ x = F.leaky_relu(x, LRELU_SLOPE)
336
+ fmap.append(x)
337
+ x = self.conv_post(x)
338
+ fmap.append(x)
339
+ x = torch.flatten(x, 1, -1)
340
+
341
+ return x, fmap
342
+
343
+
344
+ class MultiScaleDiscriminator(torch.nn.Module):
345
+ def __init__(self):
346
+ super(MultiScaleDiscriminator, self).__init__()
347
+ self.discriminators = nn.ModuleList(
348
+ [
349
+ DiscriminatorS(use_spectral_norm=True),
350
+ DiscriminatorS(),
351
+ DiscriminatorS(),
352
+ ]
353
+ )
354
+ self.meanpools = nn.ModuleList(
355
+ [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
356
+ )
357
+
358
+ def forward(self, y, y_hat):
359
+ y_d_rs = []
360
+ y_d_gs = []
361
+ fmap_rs = []
362
+ fmap_gs = []
363
+ for i, d in enumerate(self.discriminators):
364
+ if i != 0:
365
+ y = self.meanpools[i - 1](y)
366
+ y_hat = self.meanpools[i - 1](y_hat)
367
+ y_d_r, fmap_r = d(y)
368
+ y_d_g, fmap_g = d(y_hat)
369
+ y_d_rs.append(y_d_r)
370
+ fmap_rs.append(fmap_r)
371
+ y_d_gs.append(y_d_g)
372
+ fmap_gs.append(fmap_g)
373
+
374
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
375
+
376
+
377
+ def feature_loss(fmap_r, fmap_g):
378
+ loss = 0
379
+ for dr, dg in zip(fmap_r, fmap_g):
380
+ for rl, gl in zip(dr, dg):
381
+ loss += torch.mean(torch.abs(rl - gl))
382
+
383
+ return loss * 2
384
+
385
+
386
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
387
+ loss = 0
388
+ r_losses = []
389
+ g_losses = []
390
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
391
+ r_loss = torch.mean((1 - dr) ** 2)
392
+ g_loss = torch.mean(dg**2)
393
+ loss += r_loss + g_loss
394
+ r_losses.append(r_loss.item())
395
+ g_losses.append(g_loss.item())
396
+
397
+ return loss, r_losses, g_losses
398
+
399
+
400
+ def generator_loss(disc_outputs):
401
+ loss = 0
402
+ gen_losses = []
403
+ for dg in disc_outputs:
404
+ l = torch.mean((1 - dg) ** 2)
405
+ gen_losses.append(l)
406
+ loss += l
407
+
408
+ return loss, gen_losses
Amphion/models/tts/base/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # from .tts_inferece import TTSInference
7
+ from .tts_trainer import TTSTrainer
Amphion/models/tts/base/tts_trainer.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ import shutil
9
+ import torch
10
+ import time
11
+ from pathlib import Path
12
+ import torch
13
+ from tqdm import tqdm
14
+ import re
15
+ import logging
16
+ import json5
17
+ import accelerate
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import ProjectConfiguration
20
+ from torch.utils.data import ConcatDataset, DataLoader
21
+ from accelerate import DistributedDataParallelKwargs
22
+ from schedulers.scheduler import Eden
23
+ from models.base.base_sampler import build_samplers
24
+ from models.base.new_trainer import BaseTrainer
25
+
26
+
27
+ class TTSTrainer(BaseTrainer):
28
+ r"""The base trainer for all TTS models. It inherits from BaseTrainer and implements
29
+ ``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this
30
+ class, and implement ``_build_model``, ``_forward_step``.
31
+ """
32
+
33
+ def __init__(self, args=None, cfg=None):
34
+ self.args = args
35
+ self.cfg = cfg
36
+
37
+ cfg.exp_name = args.exp_name
38
+
39
+ # init with accelerate
40
+ self._init_accelerator()
41
+ self.accelerator.wait_for_everyone()
42
+
43
+ with self.accelerator.main_process_first():
44
+ self.logger = get_logger(args.exp_name, log_level="INFO")
45
+
46
+ # Log some info
47
+ self.logger.info("=" * 56)
48
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
49
+ self.logger.info("=" * 56)
50
+ self.logger.info("\n")
51
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
52
+ self.logger.info(f"Experiment name: {args.exp_name}")
53
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
54
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
55
+ if self.accelerator.is_main_process:
56
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
57
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
58
+
59
+ # init counts
60
+ self.batch_count: int = 0
61
+ self.step: int = 0
62
+ self.epoch: int = 0
63
+ self.max_epoch = (
64
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
65
+ )
66
+ self.logger.info(
67
+ "Max epoch: {}".format(
68
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
69
+ )
70
+ )
71
+
72
+ # Check values
73
+ if self.accelerator.is_main_process:
74
+ self.__check_basic_configs()
75
+ # Set runtime configs
76
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
77
+ self.checkpoints_path = [
78
+ [] for _ in range(len(self.save_checkpoint_stride))
79
+ ]
80
+ self.keep_last = [
81
+ i if i > 0 else float("inf") for i in self.cfg.train.keep_last
82
+ ]
83
+ self.run_eval = self.cfg.train.run_eval
84
+
85
+ # set random seed
86
+ with self.accelerator.main_process_first():
87
+ start = time.monotonic_ns()
88
+ self._set_random_seed(self.cfg.train.random_seed)
89
+ end = time.monotonic_ns()
90
+ self.logger.debug(
91
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
92
+ )
93
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
94
+
95
+ # setup data_loader
96
+ with self.accelerator.main_process_first():
97
+ self.logger.info("Building dataset...")
98
+ start = time.monotonic_ns()
99
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
100
+ end = time.monotonic_ns()
101
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
102
+
103
+ # save phone table to exp dir. Should be done before building model due to loading phone table in model
104
+ if cfg.preprocess.use_phone and cfg.preprocess.phone_extractor != "lexicon":
105
+ self._save_phone_symbols_file_to_exp_path()
106
+
107
+ # setup model
108
+ with self.accelerator.main_process_first():
109
+ self.logger.info("Building model...")
110
+ start = time.monotonic_ns()
111
+ self.model = self._build_model()
112
+ end = time.monotonic_ns()
113
+ self.logger.debug(self.model)
114
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
115
+ self.logger.info(
116
+ f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
117
+ )
118
+
119
+ # optimizer & scheduler
120
+ with self.accelerator.main_process_first():
121
+ self.logger.info("Building optimizer and scheduler...")
122
+ start = time.monotonic_ns()
123
+ self.optimizer = self._build_optimizer()
124
+ self.scheduler = self._build_scheduler()
125
+ end = time.monotonic_ns()
126
+ self.logger.info(
127
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
128
+ )
129
+
130
+ # create criterion
131
+ with self.accelerator.main_process_first():
132
+ self.logger.info("Building criterion...")
133
+ start = time.monotonic_ns()
134
+ self.criterion = self._build_criterion()
135
+ end = time.monotonic_ns()
136
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
137
+
138
+ # Resume or Finetune
139
+ with self.accelerator.main_process_first():
140
+ self._check_resume()
141
+
142
+ # accelerate prepare
143
+ self.logger.info("Initializing accelerate...")
144
+ start = time.monotonic_ns()
145
+ self._accelerator_prepare()
146
+ end = time.monotonic_ns()
147
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
148
+
149
+ # save config file path
150
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
151
+ self.device = self.accelerator.device
152
+
153
+ if cfg.preprocess.use_spkid and cfg.train.multi_speaker_training:
154
+ self.speakers = self._build_speaker_lut()
155
+ self.utt2spk_dict = self._build_utt2spk_dict()
156
+
157
+ # Only for TTS tasks
158
+ self.task_type = "TTS"
159
+ self.logger.info("Task type: {}".format(self.task_type))
160
+
161
+ def _check_resume(self):
162
+ # if args.resume:
163
+ if self.args.resume or (
164
+ self.cfg.model_type == "VALLE" and self.args.train_stage == 2
165
+ ):
166
+ checkpoint_dir = self.checkpoint_dir
167
+ if self.cfg.model_type == "VALLE" and self.args.train_stage == 2:
168
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
169
+ if (
170
+ self.args.checkpoint_path is None or len(ls) == 0
171
+ ): # Train stage 2 from scratch using the checkpoint of stage 1
172
+ assert (
173
+ self.args.ar_model_ckpt_dir is not None
174
+ ), "Error: ar_model_ckpt_dir should be set to train nar model."
175
+ self.args.resume_type = "finetune"
176
+ checkpoint_dir = self.args.ar_model_ckpt_dir
177
+ self.logger.info(
178
+ f"Training NAR model at stage 2 using the checkpoint of AR model at stage 1."
179
+ )
180
+
181
+ self.logger.info(f"Resuming from checkpoint: {checkpoint_dir}")
182
+ start = time.monotonic_ns()
183
+ self.ckpt_path = self._load_model(
184
+ checkpoint_dir, self.args.checkpoint_path, self.args.resume_type
185
+ )
186
+ self.logger.info(f"Checkpoint path: {self.ckpt_path}")
187
+ end = time.monotonic_ns()
188
+ self.logger.info(
189
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
190
+ )
191
+ self.checkpoints_path = json.load(
192
+ open(os.path.join(self.ckpt_path, "ckpts.json"), "r")
193
+ )
194
+
195
+ def _init_accelerator(self):
196
+ self.exp_dir = os.path.join(
197
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
198
+ )
199
+ project_config = ProjectConfiguration(
200
+ project_dir=self.exp_dir,
201
+ logging_dir=os.path.join(self.exp_dir, "log"),
202
+ )
203
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
204
+ self.accelerator = accelerate.Accelerator(
205
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
206
+ log_with=self.cfg.train.tracker,
207
+ project_config=project_config,
208
+ kwargs_handlers=[kwargs],
209
+ )
210
+ if self.accelerator.is_main_process:
211
+ os.makedirs(project_config.project_dir, exist_ok=True)
212
+ os.makedirs(project_config.logging_dir, exist_ok=True)
213
+ with self.accelerator.main_process_first():
214
+ self.accelerator.init_trackers(self.args.exp_name)
215
+
216
+ def _accelerator_prepare(self):
217
+ (
218
+ self.train_dataloader,
219
+ self.valid_dataloader,
220
+ ) = self.accelerator.prepare(
221
+ self.train_dataloader,
222
+ self.valid_dataloader,
223
+ )
224
+
225
+ if isinstance(self.model, dict):
226
+ for key in self.model.keys():
227
+ self.model[key] = self.accelerator.prepare(self.model[key])
228
+ else:
229
+ self.model = self.accelerator.prepare(self.model)
230
+
231
+ if isinstance(self.optimizer, dict):
232
+ for key in self.optimizer.keys():
233
+ self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
234
+ else:
235
+ self.optimizer = self.accelerator.prepare(self.optimizer)
236
+
237
+ if isinstance(self.scheduler, dict):
238
+ for key in self.scheduler.keys():
239
+ self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
240
+ else:
241
+ self.scheduler = self.accelerator.prepare(self.scheduler)
242
+
243
+ ### Following are methods only for TTS tasks ###
244
+ def _build_dataset(self):
245
+ pass
246
+
247
+ def _build_criterion(self):
248
+ pass
249
+
250
+ def _build_model(self):
251
+ pass
252
+
253
+ def _build_dataloader(self):
254
+ """Build dataloader which merges a series of datasets."""
255
+ # Build dataset instance for each dataset and combine them by ConcatDataset
256
+ Dataset, Collator = self._build_dataset()
257
+
258
+ # Build train set
259
+ datasets_list = []
260
+ for dataset in self.cfg.dataset:
261
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
262
+ datasets_list.append(subdataset)
263
+ train_dataset = ConcatDataset(datasets_list)
264
+ train_collate = Collator(self.cfg)
265
+ _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
266
+ train_loader = DataLoader(
267
+ train_dataset,
268
+ collate_fn=train_collate,
269
+ batch_sampler=batch_sampler,
270
+ num_workers=self.cfg.train.dataloader.num_worker,
271
+ pin_memory=self.cfg.train.dataloader.pin_memory,
272
+ )
273
+
274
+ # Build test set
275
+ datasets_list = []
276
+ for dataset in self.cfg.dataset:
277
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
278
+ datasets_list.append(subdataset)
279
+ valid_dataset = ConcatDataset(datasets_list)
280
+ valid_collate = Collator(self.cfg)
281
+ _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
282
+ valid_loader = DataLoader(
283
+ valid_dataset,
284
+ collate_fn=valid_collate,
285
+ batch_sampler=batch_sampler,
286
+ num_workers=self.cfg.train.dataloader.num_worker,
287
+ pin_memory=self.cfg.train.dataloader.pin_memory,
288
+ )
289
+ return train_loader, valid_loader
290
+
291
+ def _build_optimizer(self):
292
+ pass
293
+
294
+ def _build_scheduler(self):
295
+ pass
296
+
297
+ def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
298
+ """Load model from checkpoint. If a folder is given, it will
299
+ load the latest checkpoint in checkpoint_dir. If a path is given
300
+ it will load the checkpoint specified by checkpoint_path.
301
+ **Only use this method after** ``accelerator.prepare()``.
302
+ """
303
+ if checkpoint_path is None or checkpoint_path == "":
304
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
305
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
306
+ checkpoint_path = ls[0]
307
+ self.logger.info("Load model from {}".format(checkpoint_path))
308
+ print("Load model from {}".format(checkpoint_path))
309
+ if resume_type == "resume":
310
+ self.accelerator.load_state(checkpoint_path)
311
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
312
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
313
+ elif resume_type == "finetune":
314
+ if isinstance(self.model, dict):
315
+ for idx, sub_model in enumerate(self.model.keys()):
316
+ if idx == 0:
317
+ ckpt_name = "pytorch_model.bin"
318
+ else:
319
+ ckpt_name = "pytorch_model_{}.bin".format(idx)
320
+
321
+ self.model[sub_model].load_state_dict(
322
+ torch.load(os.path.join(checkpoint_path, ckpt_name))
323
+ )
324
+ self.model[sub_model].cuda(self.accelerator.device)
325
+ else:
326
+ self.model.load_state_dict(
327
+ torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
328
+ )
329
+ self.model.cuda(self.accelerator.device)
330
+ self.logger.info("Load model weights for finetune SUCCESS!")
331
+
332
+ else:
333
+ raise ValueError("Unsupported resume type: {}".format(resume_type))
334
+
335
+ return checkpoint_path
336
+
337
+ ### THIS IS MAIN ENTRY ###
338
+ def train_loop(self):
339
+ r"""Training loop. The public entry of training process."""
340
+ # Wait everyone to prepare before we move on
341
+ self.accelerator.wait_for_everyone()
342
+ # dump config file
343
+ if self.accelerator.is_main_process:
344
+ self.__dump_cfg(self.config_save_path)
345
+
346
+ # self.optimizer.zero_grad()
347
+ # Wait to ensure good to go
348
+
349
+ self.accelerator.wait_for_everyone()
350
+ while self.epoch < self.max_epoch:
351
+ self.logger.info("\n")
352
+ self.logger.info("-" * 32)
353
+ self.logger.info("Epoch {}: ".format(self.epoch))
354
+
355
+ # Do training & validating epoch
356
+ train_total_loss, train_losses = self._train_epoch()
357
+ if isinstance(train_losses, dict):
358
+ for key, loss in train_losses.items():
359
+ self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
360
+ self.accelerator.log(
361
+ {"Epoch/Train {} Loss".format(key): loss},
362
+ step=self.epoch,
363
+ )
364
+
365
+ valid_total_loss, valid_losses = self._valid_epoch()
366
+ if isinstance(valid_losses, dict):
367
+ for key, loss in valid_losses.items():
368
+ self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss))
369
+ self.accelerator.log(
370
+ {"Epoch/Train {} Loss".format(key): loss},
371
+ step=self.epoch,
372
+ )
373
+
374
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_total_loss))
375
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_total_loss))
376
+ self.accelerator.log(
377
+ {
378
+ "Epoch/Train Loss": train_total_loss,
379
+ "Epoch/Valid Loss": valid_total_loss,
380
+ },
381
+ step=self.epoch,
382
+ )
383
+
384
+ self.accelerator.wait_for_everyone()
385
+
386
+ # Check if hit save_checkpoint_stride and run_eval
387
+ run_eval = False
388
+ if self.accelerator.is_main_process:
389
+ save_checkpoint = False
390
+ hit_dix = []
391
+ for i, num in enumerate(self.save_checkpoint_stride):
392
+ if self.epoch % num == 0:
393
+ save_checkpoint = True
394
+ hit_dix.append(i)
395
+ run_eval |= self.run_eval[i]
396
+
397
+ self.accelerator.wait_for_everyone()
398
+ if self.accelerator.is_main_process and save_checkpoint:
399
+ path = os.path.join(
400
+ self.checkpoint_dir,
401
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
402
+ self.epoch, self.step, train_total_loss
403
+ ),
404
+ )
405
+ self.accelerator.save_state(path)
406
+
407
+ json.dump(
408
+ self.checkpoints_path,
409
+ open(os.path.join(path, "ckpts.json"), "w"),
410
+ ensure_ascii=False,
411
+ indent=4,
412
+ )
413
+
414
+ # Remove old checkpoints
415
+ to_remove = []
416
+ for idx in hit_dix:
417
+ self.checkpoints_path[idx].append(path)
418
+ while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
419
+ to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
420
+
421
+ # Search conflicts
422
+ total = set()
423
+ for i in self.checkpoints_path:
424
+ total |= set(i)
425
+ do_remove = set()
426
+ for idx, path in to_remove[::-1]:
427
+ if path in total:
428
+ self.checkpoints_path[idx].insert(0, path)
429
+ else:
430
+ do_remove.add(path)
431
+
432
+ # Remove old checkpoints
433
+ for path in do_remove:
434
+ shutil.rmtree(path, ignore_errors=True)
435
+ self.logger.debug(f"Remove old checkpoint: {path}")
436
+
437
+ self.accelerator.wait_for_everyone()
438
+ if run_eval:
439
+ # TODO: run evaluation
440
+ pass
441
+
442
+ # Update info for each epoch
443
+ self.epoch += 1
444
+
445
+ # Finish training and save final checkpoint
446
+ self.accelerator.wait_for_everyone()
447
+ if self.accelerator.is_main_process:
448
+ path = os.path.join(
449
+ self.checkpoint_dir,
450
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
451
+ self.epoch, self.step, valid_total_loss
452
+ ),
453
+ )
454
+ self.accelerator.save_state(
455
+ os.path.join(
456
+ self.checkpoint_dir,
457
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
458
+ self.epoch, self.step, valid_total_loss
459
+ ),
460
+ )
461
+ )
462
+
463
+ json.dump(
464
+ self.checkpoints_path,
465
+ open(os.path.join(path, "ckpts.json"), "w"),
466
+ ensure_ascii=False,
467
+ indent=4,
468
+ )
469
+
470
+ self.accelerator.end_training()
471
+
472
+ ### Following are methods that can be used directly in child classes ###
473
+ def _train_epoch(self):
474
+ r"""Training epoch. Should return average loss of a batch (sample) over
475
+ one epoch. See ``train_loop`` for usage.
476
+ """
477
+ if isinstance(self.model, dict):
478
+ for key in self.model.keys():
479
+ self.model[key].train()
480
+ else:
481
+ self.model.train()
482
+
483
+ epoch_sum_loss: float = 0.0
484
+ epoch_losses: dict = {}
485
+ epoch_step: int = 0
486
+ for batch in tqdm(
487
+ self.train_dataloader,
488
+ desc=f"Training Epoch {self.epoch}",
489
+ unit="batch",
490
+ colour="GREEN",
491
+ leave=False,
492
+ dynamic_ncols=True,
493
+ smoothing=0.04,
494
+ disable=not self.accelerator.is_main_process,
495
+ ):
496
+ # Do training step and BP
497
+ with self.accelerator.accumulate(self.model):
498
+ total_loss, train_losses, _ = self._train_step(batch)
499
+ self.batch_count += 1
500
+
501
+ # Update info for each step
502
+ # TODO: step means BP counts or batch counts?
503
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
504
+ if isinstance(self.scheduler, dict):
505
+ for key in self.scheduler.keys():
506
+ self.scheduler[key].step()
507
+ else:
508
+ if isinstance(self.scheduler, Eden):
509
+ self.scheduler.step_batch(self.step)
510
+ else:
511
+ self.scheduler.step()
512
+
513
+ epoch_sum_loss += total_loss
514
+
515
+ if isinstance(train_losses, dict):
516
+ for key, value in train_losses.items():
517
+ epoch_losses[key] += value
518
+
519
+ if isinstance(train_losses, dict):
520
+ for key, loss in train_losses.items():
521
+ self.accelerator.log(
522
+ {"Epoch/Train {} Loss".format(key): loss},
523
+ step=self.step,
524
+ )
525
+
526
+ self.step += 1
527
+ epoch_step += 1
528
+
529
+ self.accelerator.wait_for_everyone()
530
+
531
+ epoch_sum_loss = (
532
+ epoch_sum_loss
533
+ / len(self.train_dataloader)
534
+ * self.cfg.train.gradient_accumulation_step
535
+ )
536
+
537
+ for key in epoch_losses.keys():
538
+ epoch_losses[key] = (
539
+ epoch_losses[key]
540
+ / len(self.train_dataloader)
541
+ * self.cfg.train.gradient_accumulation_step
542
+ )
543
+
544
+ return epoch_sum_loss, epoch_losses
545
+
546
+ @torch.inference_mode()
547
+ def _valid_epoch(self):
548
+ r"""Testing epoch. Should return average loss of a batch (sample) over
549
+ one epoch. See ``train_loop`` for usage.
550
+ """
551
+ if isinstance(self.model, dict):
552
+ for key in self.model.keys():
553
+ self.model[key].eval()
554
+ else:
555
+ self.model.eval()
556
+
557
+ epoch_sum_loss = 0.0
558
+ epoch_losses = dict()
559
+ for batch in tqdm(
560
+ self.valid_dataloader,
561
+ desc=f"Validating Epoch {self.epoch}",
562
+ unit="batch",
563
+ colour="GREEN",
564
+ leave=False,
565
+ dynamic_ncols=True,
566
+ smoothing=0.04,
567
+ disable=not self.accelerator.is_main_process,
568
+ ):
569
+ total_loss, valid_losses, valid_stats = self._valid_step(batch)
570
+ epoch_sum_loss += total_loss
571
+ if isinstance(valid_losses, dict):
572
+ for key, value in valid_losses.items():
573
+ if key not in epoch_losses.keys():
574
+ epoch_losses[key] = value
575
+ else:
576
+ epoch_losses[key] += value
577
+
578
+ epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader)
579
+ for key in epoch_losses.keys():
580
+ epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader)
581
+
582
+ self.accelerator.wait_for_everyone()
583
+
584
+ return epoch_sum_loss, epoch_losses
585
+
586
+ def _train_step(self):
587
+ pass
588
+
589
+ def _valid_step(self, batch):
590
+ pass
591
+
592
+ def _inference(self):
593
+ pass
594
+
595
+ def _is_valid_pattern(self, directory_name):
596
+ directory_name = str(directory_name)
597
+ pattern = r"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}"
598
+ return re.match(pattern, directory_name) is not None
599
+
600
+ def _check_basic_configs(self):
601
+ if self.cfg.train.gradient_accumulation_step <= 0:
602
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
603
+ self.logger.error(
604
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
605
+ )
606
+ self.accelerator.end_training()
607
+ raise ValueError(
608
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
609
+ )
610
+
611
+ def __dump_cfg(self, path):
612
+ os.makedirs(os.path.dirname(path), exist_ok=True)
613
+ json5.dump(
614
+ self.cfg,
615
+ open(path, "w"),
616
+ indent=4,
617
+ sort_keys=True,
618
+ ensure_ascii=False,
619
+ quote_keys=True,
620
+ )
621
+
622
+ def __check_basic_configs(self):
623
+ if self.cfg.train.gradient_accumulation_step <= 0:
624
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
625
+ self.logger.error(
626
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
627
+ )
628
+ self.accelerator.end_training()
629
+ raise ValueError(
630
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
631
+ )
632
+ # TODO: check other values
633
+
634
+ @staticmethod
635
+ def __count_parameters(model):
636
+ model_param = 0.0
637
+ if isinstance(model, dict):
638
+ for key, value in model.items():
639
+ model_param += sum(p.numel() for p in model[key].parameters())
640
+ else:
641
+ model_param = sum(p.numel() for p in model.parameters())
642
+ return model_param
643
+
644
+ def _build_speaker_lut(self):
645
+ # combine speakers
646
+ if not os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
647
+ speakers = {}
648
+ else:
649
+ with open(
650
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "r"
651
+ ) as speaker_file:
652
+ speakers = json.load(speaker_file)
653
+ for dataset in self.cfg.dataset:
654
+ speaker_lut_path = os.path.join(
655
+ self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
656
+ )
657
+ with open(speaker_lut_path, "r") as speaker_lut_path:
658
+ singer_lut = json.load(speaker_lut_path)
659
+ for singer in singer_lut.keys():
660
+ if singer not in speakers:
661
+ speakers[singer] = len(speakers)
662
+ with open(
663
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w"
664
+ ) as speaker_file:
665
+ json.dump(speakers, speaker_file, indent=4, ensure_ascii=False)
666
+ print(
667
+ "speakers have been dumped to {}".format(
668
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
669
+ )
670
+ )
671
+ return speakers
672
+
673
+ def _build_utt2spk_dict(self):
674
+ # combine speakers
675
+ utt2spk = {}
676
+ if not os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk)):
677
+ utt2spk = {}
678
+ else:
679
+ with open(
680
+ os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk), "r"
681
+ ) as utt2spk_file:
682
+ for line in utt2spk_file.readlines():
683
+ utt, spk = line.strip().split("\t")
684
+ utt2spk[utt] = spk
685
+ for dataset in self.cfg.dataset:
686
+ utt2spk_dict_path = os.path.join(
687
+ self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.utt2spk
688
+ )
689
+ with open(utt2spk_dict_path, "r") as utt2spk_dict:
690
+ for line in utt2spk_dict.readlines():
691
+ utt, spk = line.strip().split("\t")
692
+ if utt not in utt2spk.keys():
693
+ utt2spk[utt] = spk
694
+ with open(
695
+ os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk), "w"
696
+ ) as utt2spk_file:
697
+ for utt, spk in utt2spk.items():
698
+ utt2spk_file.write(utt + "\t" + spk + "\n")
699
+ print(
700
+ "utterance and speaker mapper have been dumped to {}".format(
701
+ os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk)
702
+ )
703
+ )
704
+ return utt2spk
705
+
706
+ def _save_phone_symbols_file_to_exp_path(self):
707
+ phone_symbols_file = os.path.join(
708
+ self.cfg.preprocess.processed_dir,
709
+ self.cfg.dataset[0],
710
+ self.cfg.preprocess.symbols_dict,
711
+ )
712
+ phone_symbols_file_to_exp_path = os.path.join(
713
+ self.exp_dir, self.cfg.preprocess.symbols_dict
714
+ )
715
+ shutil.copy(phone_symbols_file, phone_symbols_file_to_exp_path)
716
+ os.chmod(phone_symbols_file_to_exp_path, 0o666)
717
+ print(
718
+ "phone symbols been dumped to {}".format(
719
+ os.path.join(self.exp_dir, self.cfg.preprocess.symbols_dict)
720
+ )
721
+ )
Amphion/models/tts/fastspeech2/fs2_trainer.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from tqdm import tqdm
9
+ from models.tts.base import TTSTrainer
10
+ from models.tts.fastspeech2.fs2 import FastSpeech2, FastSpeech2Loss
11
+ from models.tts.fastspeech2.fs2_dataset import FS2Dataset, FS2Collator
12
+ from optimizer.optimizers import NoamLR
13
+
14
+
15
+ class FastSpeech2Trainer(TTSTrainer):
16
+ def __init__(self, args, cfg):
17
+ TTSTrainer.__init__(self, args, cfg)
18
+ self.cfg = cfg
19
+
20
+ def _build_dataset(self):
21
+ return FS2Dataset, FS2Collator
22
+
23
+ def __build_scheduler(self):
24
+ return NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
25
+
26
+ def _write_summary(self, losses, stats):
27
+ for key, value in losses.items():
28
+ self.sw.add_scalar("train/" + key, value, self.step)
29
+ lr = self.optimizer.state_dict()["param_groups"][0]["lr"]
30
+ self.sw.add_scalar("learning_rate", lr, self.step)
31
+
32
+ def _write_valid_summary(self, losses, stats):
33
+ for key, value in losses.items():
34
+ self.sw.add_scalar("val/" + key, value, self.step)
35
+
36
+ def _build_criterion(self):
37
+ return FastSpeech2Loss(self.cfg)
38
+
39
+ def get_state_dict(self):
40
+ state_dict = {
41
+ "model": self.model.state_dict(),
42
+ "optimizer": self.optimizer.state_dict(),
43
+ "scheduler": self.scheduler.state_dict(),
44
+ "step": self.step,
45
+ "epoch": self.epoch,
46
+ "batch_size": self.cfg.train.batch_size,
47
+ }
48
+ return state_dict
49
+
50
+ def _build_optimizer(self):
51
+ optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
52
+ return optimizer
53
+
54
+ def _build_scheduler(self):
55
+ scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
56
+ return scheduler
57
+
58
+ def _build_model(self):
59
+ self.model = FastSpeech2(self.cfg)
60
+ return self.model
61
+
62
+ def _train_epoch(self):
63
+ r"""Training epoch. Should return average loss of a batch (sample) over
64
+ one epoch. See ``train_loop`` for usage.
65
+ """
66
+ self.model.train()
67
+ epoch_sum_loss: float = 0.0
68
+ epoch_step: int = 0
69
+ epoch_losses: dict = {}
70
+ for batch in tqdm(
71
+ self.train_dataloader,
72
+ desc=f"Training Epoch {self.epoch}",
73
+ unit="batch",
74
+ colour="GREEN",
75
+ leave=False,
76
+ dynamic_ncols=True,
77
+ smoothing=0.04,
78
+ disable=not self.accelerator.is_main_process,
79
+ ):
80
+ # Do training step and BP
81
+ with self.accelerator.accumulate(self.model):
82
+ loss, train_losses = self._train_step(batch)
83
+ self.accelerator.backward(loss)
84
+ grad_clip_thresh = self.cfg.train.grad_clip_thresh
85
+ nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip_thresh)
86
+ self.optimizer.step()
87
+ self.scheduler.step()
88
+ self.optimizer.zero_grad()
89
+ self.batch_count += 1
90
+
91
+ # Update info for each step
92
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
93
+ epoch_sum_loss += loss
94
+ for key, value in train_losses.items():
95
+ if key not in epoch_losses.keys():
96
+ epoch_losses[key] = value
97
+ else:
98
+ epoch_losses[key] += value
99
+
100
+ self.accelerator.log(
101
+ {
102
+ "Step/Train Loss": loss,
103
+ "Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
104
+ },
105
+ step=self.step,
106
+ )
107
+ self.step += 1
108
+ epoch_step += 1
109
+
110
+ self.accelerator.wait_for_everyone()
111
+
112
+ epoch_sum_loss = (
113
+ epoch_sum_loss
114
+ / len(self.train_dataloader)
115
+ * self.cfg.train.gradient_accumulation_step
116
+ )
117
+
118
+ for key in epoch_losses.keys():
119
+ epoch_losses[key] = (
120
+ epoch_losses[key]
121
+ / len(self.train_dataloader)
122
+ * self.cfg.train.gradient_accumulation_step
123
+ )
124
+ return epoch_sum_loss, epoch_losses
125
+
126
+ def _train_step(self, data):
127
+ train_losses = {}
128
+ total_loss = 0
129
+ train_stats = {}
130
+
131
+ preds = self.model(data)
132
+
133
+ train_losses = self.criterion(data, preds)
134
+
135
+ total_loss = train_losses["loss"]
136
+ for key, value in train_losses.items():
137
+ train_losses[key] = value.item()
138
+
139
+ return total_loss, train_losses
140
+
141
+ @torch.no_grad()
142
+ def _valid_step(self, data):
143
+ valid_loss = {}
144
+ total_valid_loss = 0
145
+ valid_stats = {}
146
+
147
+ preds = self.model(data)
148
+
149
+ valid_losses = self.criterion(data, preds)
150
+
151
+ total_valid_loss = valid_losses["loss"]
152
+ for key, value in valid_losses.items():
153
+ valid_losses[key] = value.item()
154
+
155
+ return total_valid_loss, valid_losses, valid_stats
Amphion/models/tts/naturalspeech2/ns2.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ import torch.nn.functional as F
10
+ from models.tts.naturalspeech2.diffusion import Diffusion
11
+ from models.tts.naturalspeech2.diffusion_flow import DiffusionFlow
12
+ from models.tts.naturalspeech2.wavenet import WaveNet
13
+ from models.tts.naturalspeech2.prior_encoder import PriorEncoder
14
+ from modules.naturalpseech2.transformers import TransformerEncoder
15
+ from encodec import EncodecModel
16
+ from einops import rearrange, repeat
17
+
18
+ import os
19
+ import json
20
+
21
+
22
+ class NaturalSpeech2(nn.Module):
23
+ def __init__(self, cfg):
24
+ super().__init__()
25
+ self.cfg = cfg
26
+
27
+ self.latent_dim = cfg.latent_dim
28
+ self.query_emb_num = cfg.query_emb.query_token_num
29
+
30
+ self.prior_encoder = PriorEncoder(cfg.prior_encoder)
31
+ if cfg.diffusion.diffusion_type == "diffusion":
32
+ self.diffusion = Diffusion(cfg.diffusion)
33
+ elif cfg.diffusion.diffusion_type == "flow":
34
+ self.diffusion = DiffusionFlow(cfg.diffusion)
35
+
36
+ self.prompt_encoder = TransformerEncoder(cfg=cfg.prompt_encoder)
37
+ if self.latent_dim != cfg.prompt_encoder.encoder_hidden:
38
+ self.prompt_lin = nn.Linear(
39
+ self.latent_dim, cfg.prompt_encoder.encoder_hidden
40
+ )
41
+ self.prompt_lin.weight.data.normal_(0.0, 0.02)
42
+ else:
43
+ self.prompt_lin = None
44
+
45
+ self.query_emb = nn.Embedding(self.query_emb_num, cfg.query_emb.hidden_size)
46
+ self.query_attn = nn.MultiheadAttention(
47
+ cfg.query_emb.hidden_size, cfg.query_emb.head_num, batch_first=True
48
+ )
49
+
50
+ codec_model = EncodecModel.encodec_model_24khz()
51
+ codec_model.set_target_bandwidth(12.0)
52
+ codec_model.requires_grad_(False)
53
+ self.quantizer = codec_model.quantizer
54
+
55
+ @torch.no_grad()
56
+ def code_to_latent(self, code):
57
+ latent = self.quantizer.decode(code.transpose(0, 1))
58
+ return latent
59
+
60
+ def latent_to_code(self, latent, nq=16):
61
+ residual = latent
62
+ all_indices = []
63
+ all_dist = []
64
+ for i in range(nq):
65
+ layer = self.quantizer.vq.layers[i]
66
+ x = rearrange(residual, "b d n -> b n d")
67
+ x = layer.project_in(x)
68
+ shape = x.shape
69
+ x = layer._codebook.preprocess(x)
70
+ embed = layer._codebook.embed.t()
71
+ dist = -(
72
+ x.pow(2).sum(1, keepdim=True)
73
+ - 2 * x @ embed
74
+ + embed.pow(2).sum(0, keepdim=True)
75
+ )
76
+ indices = dist.max(dim=-1).indices
77
+ indices = layer._codebook.postprocess_emb(indices, shape)
78
+ dist = dist.reshape(*shape[:-1], dist.shape[-1])
79
+ quantized = layer.decode(indices)
80
+ residual = residual - quantized
81
+ all_indices.append(indices)
82
+ all_dist.append(dist)
83
+
84
+ out_indices = torch.stack(all_indices)
85
+ out_dist = torch.stack(all_dist)
86
+
87
+ return out_indices, out_dist # (nq, B, T); (nq, B, T, 1024)
88
+
89
+ @torch.no_grad()
90
+ def latent_to_latent(self, latent, nq=16):
91
+ codes, _ = self.latent_to_code(latent, nq)
92
+ latent = self.quantizer.vq.decode(codes)
93
+ return latent
94
+
95
+ def forward(
96
+ self,
97
+ code=None,
98
+ pitch=None,
99
+ duration=None,
100
+ phone_id=None,
101
+ phone_id_frame=None,
102
+ frame_nums=None,
103
+ ref_code=None,
104
+ ref_frame_nums=None,
105
+ phone_mask=None,
106
+ mask=None,
107
+ ref_mask=None,
108
+ ):
109
+ ref_latent = self.code_to_latent(ref_code)
110
+ latent = self.code_to_latent(code)
111
+
112
+ if self.latent_dim is not None:
113
+ ref_latent = self.prompt_lin(ref_latent.transpose(1, 2))
114
+
115
+ ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None)
116
+ spk_emb = ref_latent.transpose(1, 2) # (B, d, T')
117
+
118
+ spk_query_emb = self.query_emb(
119
+ torch.arange(self.query_emb_num).to(latent.device)
120
+ ).repeat(
121
+ latent.shape[0], 1, 1
122
+ ) # (B, query_emb_num, d)
123
+ spk_query_emb, _ = self.query_attn(
124
+ spk_query_emb,
125
+ spk_emb.transpose(1, 2),
126
+ spk_emb.transpose(1, 2),
127
+ key_padding_mask=~(ref_mask.bool()),
128
+ ) # (B, query_emb_num, d)
129
+
130
+ prior_out = self.prior_encoder(
131
+ phone_id=phone_id,
132
+ duration=duration,
133
+ pitch=pitch,
134
+ phone_mask=phone_mask,
135
+ mask=mask,
136
+ ref_emb=spk_emb,
137
+ ref_mask=ref_mask,
138
+ is_inference=False,
139
+ )
140
+ prior_condition = prior_out["prior_out"] # (B, T, d)
141
+
142
+ diff_out = self.diffusion(latent, mask, prior_condition, spk_query_emb)
143
+
144
+ return diff_out, prior_out
145
+
146
+ @torch.no_grad()
147
+ def inference(
148
+ self, ref_code=None, phone_id=None, ref_mask=None, inference_steps=1000
149
+ ):
150
+ ref_latent = self.code_to_latent(ref_code)
151
+
152
+ if self.latent_dim is not None:
153
+ ref_latent = self.prompt_lin(ref_latent.transpose(1, 2))
154
+
155
+ ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None)
156
+ spk_emb = ref_latent.transpose(1, 2) # (B, d, T')
157
+
158
+ spk_query_emb = self.query_emb(
159
+ torch.arange(self.query_emb_num).to(ref_latent.device)
160
+ ).repeat(
161
+ ref_latent.shape[0], 1, 1
162
+ ) # (B, query_emb_num, d)
163
+ spk_query_emb, _ = self.query_attn(
164
+ spk_query_emb,
165
+ spk_emb.transpose(1, 2),
166
+ spk_emb.transpose(1, 2),
167
+ key_padding_mask=~(ref_mask.bool()),
168
+ ) # (B, query_emb_num, d)
169
+
170
+ prior_out = self.prior_encoder(
171
+ phone_id=phone_id,
172
+ duration=None,
173
+ pitch=None,
174
+ phone_mask=None,
175
+ mask=None,
176
+ ref_emb=spk_emb,
177
+ ref_mask=ref_mask,
178
+ is_inference=True,
179
+ )
180
+ prior_condition = prior_out["prior_out"] # (B, T, d)
181
+
182
+ z = torch.randn(
183
+ prior_condition.shape[0], self.latent_dim, prior_condition.shape[1]
184
+ ).to(ref_latent.device) / (1.20)
185
+ x0 = self.diffusion.reverse_diffusion(
186
+ z, None, prior_condition, inference_steps, spk_query_emb
187
+ )
188
+
189
+ return x0, prior_out
190
+
191
+ @torch.no_grad()
192
+ def reverse_diffusion_from_t(
193
+ self,
194
+ code=None,
195
+ pitch=None,
196
+ duration=None,
197
+ phone_id=None,
198
+ ref_code=None,
199
+ phone_mask=None,
200
+ mask=None,
201
+ ref_mask=None,
202
+ n_timesteps=None,
203
+ t=None,
204
+ ):
205
+ # o Only for debug
206
+
207
+ ref_latent = self.code_to_latent(ref_code)
208
+ latent = self.code_to_latent(code)
209
+
210
+ if self.latent_dim is not None:
211
+ ref_latent = self.prompt_lin(ref_latent.transpose(1, 2))
212
+
213
+ ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None)
214
+ spk_emb = ref_latent.transpose(1, 2) # (B, d, T')
215
+
216
+ spk_query_emb = self.query_emb(
217
+ torch.arange(self.query_emb_num).to(latent.device)
218
+ ).repeat(
219
+ latent.shape[0], 1, 1
220
+ ) # (B, query_emb_num, d)
221
+ spk_query_emb, _ = self.query_attn(
222
+ spk_query_emb,
223
+ spk_emb.transpose(1, 2),
224
+ spk_emb.transpose(1, 2),
225
+ key_padding_mask=~(ref_mask.bool()),
226
+ ) # (B, query_emb_num, d)
227
+
228
+ prior_out = self.prior_encoder(
229
+ phone_id=phone_id,
230
+ duration=duration,
231
+ pitch=pitch,
232
+ phone_mask=phone_mask,
233
+ mask=mask,
234
+ ref_emb=spk_emb,
235
+ ref_mask=ref_mask,
236
+ is_inference=False,
237
+ )
238
+ prior_condition = prior_out["prior_out"] # (B, T, d)
239
+
240
+ diffusion_step = (
241
+ torch.ones(
242
+ latent.shape[0],
243
+ dtype=latent.dtype,
244
+ device=latent.device,
245
+ requires_grad=False,
246
+ )
247
+ * t
248
+ )
249
+ diffusion_step = torch.clamp(diffusion_step, 1e-5, 1.0 - 1e-5)
250
+ xt, _ = self.diffusion.forward_diffusion(
251
+ x0=latent, diffusion_step=diffusion_step
252
+ )
253
+ # print(torch.abs(xt-latent).max(), torch.abs(xt-latent).mean(), torch.abs(xt-latent).std())
254
+
255
+ x0 = self.diffusion.reverse_diffusion_from_t(
256
+ xt, mask, prior_condition, n_timesteps, spk_query_emb, t_start=t
257
+ )
258
+
259
+ return x0, prior_out, xt
Amphion/models/tts/naturalspeech2/ns2_dataset.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import torch
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ from utils.data_utils import *
10
+ from processors.acoustic_extractor import cal_normalized_mel
11
+ from processors.acoustic_extractor import load_normalized
12
+ from models.base.base_dataset import (
13
+ BaseOfflineCollator,
14
+ BaseOfflineDataset,
15
+ BaseTestDataset,
16
+ BaseTestCollator,
17
+ )
18
+ from text import text_to_sequence
19
+ from text.cmudict import valid_symbols
20
+ from tqdm import tqdm
21
+ import pickle
22
+
23
+
24
+ class NS2Dataset(torch.utils.data.Dataset):
25
+ def __init__(self, cfg, dataset, is_valid=False):
26
+ assert isinstance(dataset, str)
27
+
28
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
29
+
30
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
31
+ # train.json
32
+
33
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
34
+
35
+ self.metadata = self.get_metadata()
36
+
37
+ self.cfg = cfg
38
+
39
+ assert cfg.preprocess.use_mel == False
40
+ if cfg.preprocess.use_mel:
41
+ self.utt2melspec_path = {}
42
+ for utt_info in self.metadata:
43
+ dataset = utt_info["Dataset"]
44
+ uid = utt_info["Uid"]
45
+ utt = "{}_{}".format(dataset, uid)
46
+
47
+ self.utt2melspec_path[utt] = os.path.join(
48
+ cfg.preprocess.processed_dir,
49
+ dataset,
50
+ cfg.preprocess.melspec_dir, # mel
51
+ utt_info["speaker"],
52
+ uid + ".npy",
53
+ )
54
+
55
+ assert cfg.preprocess.use_code == True
56
+ if cfg.preprocess.use_code:
57
+ self.utt2code_path = {}
58
+ for utt_info in self.metadata:
59
+ dataset = utt_info["Dataset"]
60
+ uid = utt_info["Uid"]
61
+ utt = "{}_{}".format(dataset, uid)
62
+
63
+ self.utt2code_path[utt] = os.path.join(
64
+ cfg.preprocess.processed_dir,
65
+ dataset,
66
+ cfg.preprocess.code_dir, # code
67
+ utt_info["speaker"],
68
+ uid + ".npy",
69
+ )
70
+
71
+ assert cfg.preprocess.use_spkid == True
72
+ if cfg.preprocess.use_spkid:
73
+ self.utt2spkid = {}
74
+ for utt_info in self.metadata:
75
+ dataset = utt_info["Dataset"]
76
+ uid = utt_info["Uid"]
77
+ utt = "{}_{}".format(dataset, uid)
78
+
79
+ self.utt2spkid[utt] = utt_info["speaker"]
80
+
81
+ assert cfg.preprocess.use_pitch == True
82
+ if cfg.preprocess.use_pitch:
83
+ self.utt2pitch_path = {}
84
+ for utt_info in self.metadata:
85
+ dataset = utt_info["Dataset"]
86
+ uid = utt_info["Uid"]
87
+ utt = "{}_{}".format(dataset, uid)
88
+
89
+ self.utt2pitch_path[utt] = os.path.join(
90
+ cfg.preprocess.processed_dir,
91
+ dataset,
92
+ cfg.preprocess.pitch_dir, # pitch
93
+ utt_info["speaker"],
94
+ uid + ".npy",
95
+ )
96
+
97
+ assert cfg.preprocess.use_duration == True
98
+ if cfg.preprocess.use_duration:
99
+ self.utt2duration_path = {}
100
+ for utt_info in self.metadata:
101
+ dataset = utt_info["Dataset"]
102
+ uid = utt_info["Uid"]
103
+ utt = "{}_{}".format(dataset, uid)
104
+
105
+ self.utt2duration_path[utt] = os.path.join(
106
+ cfg.preprocess.processed_dir,
107
+ dataset,
108
+ cfg.preprocess.duration_dir, # duration
109
+ utt_info["speaker"],
110
+ uid + ".npy",
111
+ )
112
+
113
+ assert cfg.preprocess.use_phone == True
114
+ if cfg.preprocess.use_phone:
115
+ self.utt2phone = {}
116
+ for utt_info in self.metadata:
117
+ dataset = utt_info["Dataset"]
118
+ uid = utt_info["Uid"]
119
+ utt = "{}_{}".format(dataset, uid)
120
+
121
+ self.utt2phone[utt] = utt_info["phones"]
122
+
123
+ assert cfg.preprocess.use_len == True
124
+ if cfg.preprocess.use_len:
125
+ self.utt2len = {}
126
+ for utt_info in self.metadata:
127
+ dataset = utt_info["Dataset"]
128
+ uid = utt_info["Uid"]
129
+ utt = "{}_{}".format(dataset, uid)
130
+
131
+ self.utt2len[utt] = utt_info["num_frames"]
132
+
133
+ # for cross reference
134
+ if cfg.preprocess.use_cross_reference:
135
+ self.spkid2utt = {}
136
+ for utt_info in self.metadata:
137
+ dataset = utt_info["Dataset"]
138
+ uid = utt_info["Uid"]
139
+ utt = "{}_{}".format(dataset, uid)
140
+ spkid = utt_info["speaker"]
141
+ if spkid not in self.spkid2utt:
142
+ self.spkid2utt[spkid] = []
143
+ self.spkid2utt[spkid].append(utt)
144
+
145
+ # get phone to id / id to phone map
146
+ self.phone2id, self.id2phone = self.get_phone_map()
147
+
148
+ self.all_num_frames = []
149
+ for i in range(len(self.metadata)):
150
+ self.all_num_frames.append(self.metadata[i]["num_frames"])
151
+ self.num_frame_sorted = np.array(sorted(self.all_num_frames))
152
+ self.num_frame_indices = np.array(
153
+ sorted(
154
+ range(len(self.all_num_frames)), key=lambda k: self.all_num_frames[k]
155
+ )
156
+ )
157
+
158
+ def __len__(self):
159
+ return len(self.metadata)
160
+
161
+ def get_dataset_name(self):
162
+ return self.metadata[0]["Dataset"]
163
+
164
+ def get_metadata(self):
165
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
166
+ metadata = json.load(f)
167
+
168
+ print("metadata len: ", len(metadata))
169
+
170
+ return metadata
171
+
172
+ def get_phone_map(self):
173
+ symbols = valid_symbols + ["sp", "spn", "sil"] + ["<s>", "</s>"]
174
+ phone2id = {s: i for i, s in enumerate(symbols)}
175
+ id2phone = {i: s for s, i in phone2id.items()}
176
+ return phone2id, id2phone
177
+
178
+ def __getitem__(self, index):
179
+ utt_info = self.metadata[index]
180
+
181
+ dataset = utt_info["Dataset"]
182
+ uid = utt_info["Uid"]
183
+ utt = "{}_{}".format(dataset, uid)
184
+
185
+ single_feature = dict()
186
+
187
+ if self.cfg.preprocess.read_metadata:
188
+ metadata_uid_path = os.path.join(
189
+ self.cfg.preprocess.processed_dir,
190
+ self.cfg.preprocess.metadata_dir,
191
+ dataset,
192
+ # utt_info["speaker"],
193
+ uid + ".pkl",
194
+ )
195
+ with open(metadata_uid_path, "rb") as f:
196
+ metadata_uid = pickle.load(f)
197
+ # code
198
+ code = metadata_uid["code"]
199
+ # frame_nums
200
+ frame_nums = code.shape[1]
201
+ # pitch
202
+ pitch = metadata_uid["pitch"]
203
+ # duration
204
+ duration = metadata_uid["duration"]
205
+ # phone_id
206
+ phone_id = np.array(
207
+ [
208
+ *map(
209
+ self.phone2id.get,
210
+ self.utt2phone[utt].replace("{", "").replace("}", "").split(),
211
+ )
212
+ ]
213
+ )
214
+
215
+ else:
216
+ # code
217
+ code = np.load(self.utt2code_path[utt])
218
+ # frame_nums
219
+ frame_nums = code.shape[1]
220
+ # pitch
221
+ pitch = np.load(self.utt2pitch_path[utt])
222
+ # duration
223
+ duration = np.load(self.utt2duration_path[utt])
224
+ # phone_id
225
+ phone_id = np.array(
226
+ [
227
+ *map(
228
+ self.phone2id.get,
229
+ self.utt2phone[utt].replace("{", "").replace("}", "").split(),
230
+ )
231
+ ]
232
+ )
233
+
234
+ # align length
235
+ code, pitch, duration, phone_id, frame_nums = self.align_length(
236
+ code, pitch, duration, phone_id, frame_nums
237
+ )
238
+
239
+ # spkid
240
+ spkid = self.utt2spkid[utt]
241
+
242
+ # get target and reference
243
+ out = self.get_target_and_reference(code, pitch, duration, phone_id, frame_nums)
244
+ code, ref_code = out["code"], out["ref_code"]
245
+ pitch, ref_pitch = out["pitch"], out["ref_pitch"]
246
+ duration, ref_duration = out["duration"], out["ref_duration"]
247
+ phone_id, ref_phone_id = out["phone_id"], out["ref_phone_id"]
248
+ frame_nums, ref_frame_nums = out["frame_nums"], out["ref_frame_nums"]
249
+
250
+ # phone_id_frame
251
+ assert len(phone_id) == len(duration)
252
+ phone_id_frame = []
253
+ for i in range(len(phone_id)):
254
+ phone_id_frame.extend([phone_id[i] for _ in range(duration[i])])
255
+ phone_id_frame = np.array(phone_id_frame)
256
+
257
+ # ref_phone_id_frame
258
+ assert len(ref_phone_id) == len(ref_duration)
259
+ ref_phone_id_frame = []
260
+ for i in range(len(ref_phone_id)):
261
+ ref_phone_id_frame.extend([ref_phone_id[i] for _ in range(ref_duration[i])])
262
+ ref_phone_id_frame = np.array(ref_phone_id_frame)
263
+
264
+ single_feature.update(
265
+ {
266
+ "code": code,
267
+ "frame_nums": frame_nums,
268
+ "pitch": pitch,
269
+ "duration": duration,
270
+ "phone_id": phone_id,
271
+ "phone_id_frame": phone_id_frame,
272
+ "ref_code": ref_code,
273
+ "ref_frame_nums": ref_frame_nums,
274
+ "ref_pitch": ref_pitch,
275
+ "ref_duration": ref_duration,
276
+ "ref_phone_id": ref_phone_id,
277
+ "ref_phone_id_frame": ref_phone_id_frame,
278
+ "spkid": spkid,
279
+ }
280
+ )
281
+
282
+ return single_feature
283
+
284
+ def get_num_frames(self, index):
285
+ utt_info = self.metadata[index]
286
+ return utt_info["num_frames"]
287
+
288
+ def align_length(self, code, pitch, duration, phone_id, frame_nums):
289
+ # aligh lenght of code, pitch, duration, phone_id, and frame nums
290
+ code_len = code.shape[1]
291
+ pitch_len = len(pitch)
292
+ dur_sum = sum(duration)
293
+ min_len = min(code_len, dur_sum)
294
+ code = code[:, :min_len]
295
+ if pitch_len >= min_len:
296
+ pitch = pitch[:min_len]
297
+ else:
298
+ pitch = np.pad(pitch, (0, min_len - pitch_len), mode="edge")
299
+ frame_nums = min_len
300
+ if dur_sum > min_len:
301
+ assert (duration[-1] - (dur_sum - min_len)) >= 0
302
+ duration[-1] = duration[-1] - (dur_sum - min_len)
303
+ assert duration[-1] >= 0
304
+
305
+ return code, pitch, duration, phone_id, frame_nums
306
+
307
+ def get_target_and_reference(self, code, pitch, duration, phone_id, frame_nums):
308
+ phone_nums = len(phone_id)
309
+ clip_phone_nums = np.random.randint(
310
+ int(phone_nums * 0.1), int(phone_nums * 0.5) + 1
311
+ )
312
+ clip_phone_nums = max(clip_phone_nums, 1)
313
+ assert clip_phone_nums < phone_nums and clip_phone_nums >= 1
314
+ if self.cfg.preprocess.clip_mode == "mid":
315
+ start_idx = np.random.randint(0, phone_nums - clip_phone_nums)
316
+ elif self.cfg.preprocess.clip_mode == "start":
317
+ if duration[0] == 0 and clip_phone_nums == 1:
318
+ start_idx = 1
319
+ else:
320
+ start_idx = 0
321
+ else:
322
+ assert self.cfg.preprocess.clip_mode in ["mid", "start"]
323
+ end_idx = start_idx + clip_phone_nums
324
+ start_frames = sum(duration[:start_idx])
325
+ end_frames = sum(duration[:end_idx])
326
+
327
+ new_code = np.concatenate(
328
+ (code[:, :start_frames], code[:, end_frames:]), axis=1
329
+ )
330
+ ref_code = code[:, start_frames:end_frames]
331
+
332
+ new_pitch = np.append(pitch[:start_frames], pitch[end_frames:])
333
+ ref_pitch = pitch[start_frames:end_frames]
334
+
335
+ new_duration = np.append(duration[:start_idx], duration[end_idx:])
336
+ ref_duration = duration[start_idx:end_idx]
337
+
338
+ new_phone_id = np.append(phone_id[:start_idx], phone_id[end_idx:])
339
+ ref_phone_id = phone_id[start_idx:end_idx]
340
+
341
+ new_frame_nums = frame_nums - (end_frames - start_frames)
342
+ ref_frame_nums = end_frames - start_frames
343
+
344
+ return {
345
+ "code": new_code,
346
+ "ref_code": ref_code,
347
+ "pitch": new_pitch,
348
+ "ref_pitch": ref_pitch,
349
+ "duration": new_duration,
350
+ "ref_duration": ref_duration,
351
+ "phone_id": new_phone_id,
352
+ "ref_phone_id": ref_phone_id,
353
+ "frame_nums": new_frame_nums,
354
+ "ref_frame_nums": ref_frame_nums,
355
+ }
356
+
357
+
358
+ class NS2Collator(BaseOfflineCollator):
359
+ def __init__(self, cfg):
360
+ BaseOfflineCollator.__init__(self, cfg)
361
+
362
+ def __call__(self, batch):
363
+ packed_batch_features = dict()
364
+
365
+ # code: (B, 16, T)
366
+ # frame_nums: (B,) not used
367
+ # pitch: (B, T)
368
+ # duration: (B, N)
369
+ # phone_id: (B, N)
370
+ # phone_id_frame: (B, T)
371
+ # ref_code: (B, 16, T')
372
+ # ref_frame_nums: (B,) not used
373
+ # ref_pitch: (B, T) not used
374
+ # ref_duration: (B, N') not used
375
+ # ref_phone_id: (B, N') not used
376
+ # ref_phone_frame: (B, T') not used
377
+ # spkid: (B,) not used
378
+ # phone_mask: (B, N)
379
+ # mask: (B, T)
380
+ # ref_mask: (B, T')
381
+
382
+ for key in batch[0].keys():
383
+ if key == "phone_id":
384
+ phone_ids = [torch.LongTensor(b["phone_id"]) for b in batch]
385
+ phone_masks = [torch.ones(len(b["phone_id"])) for b in batch]
386
+ packed_batch_features["phone_id"] = pad_sequence(
387
+ phone_ids,
388
+ batch_first=True,
389
+ padding_value=0,
390
+ )
391
+ packed_batch_features["phone_mask"] = pad_sequence(
392
+ phone_masks,
393
+ batch_first=True,
394
+ padding_value=0,
395
+ )
396
+ elif key == "phone_id_frame":
397
+ phone_id_frames = [torch.LongTensor(b["phone_id_frame"]) for b in batch]
398
+ masks = [torch.ones(len(b["phone_id_frame"])) for b in batch]
399
+ packed_batch_features["phone_id_frame"] = pad_sequence(
400
+ phone_id_frames,
401
+ batch_first=True,
402
+ padding_value=0,
403
+ )
404
+ packed_batch_features["mask"] = pad_sequence(
405
+ masks,
406
+ batch_first=True,
407
+ padding_value=0,
408
+ )
409
+ elif key == "ref_code":
410
+ ref_codes = [
411
+ torch.from_numpy(b["ref_code"]).transpose(0, 1) for b in batch
412
+ ]
413
+ ref_masks = [torch.ones(max(b["ref_code"].shape[1], 1)) for b in batch]
414
+ packed_batch_features["ref_code"] = pad_sequence(
415
+ ref_codes,
416
+ batch_first=True,
417
+ padding_value=0,
418
+ ).transpose(1, 2)
419
+ packed_batch_features["ref_mask"] = pad_sequence(
420
+ ref_masks,
421
+ batch_first=True,
422
+ padding_value=0,
423
+ )
424
+ elif key == "code":
425
+ codes = [torch.from_numpy(b["code"]).transpose(0, 1) for b in batch]
426
+ masks = [torch.ones(max(b["code"].shape[1], 1)) for b in batch]
427
+ packed_batch_features["code"] = pad_sequence(
428
+ codes,
429
+ batch_first=True,
430
+ padding_value=0,
431
+ ).transpose(1, 2)
432
+ packed_batch_features["mask"] = pad_sequence(
433
+ masks,
434
+ batch_first=True,
435
+ padding_value=0,
436
+ )
437
+ elif key == "pitch":
438
+ values = [torch.from_numpy(b[key]) for b in batch]
439
+ packed_batch_features[key] = pad_sequence(
440
+ values, batch_first=True, padding_value=50.0
441
+ )
442
+ elif key == "duration":
443
+ values = [torch.from_numpy(b[key]) for b in batch]
444
+ packed_batch_features[key] = pad_sequence(
445
+ values, batch_first=True, padding_value=0
446
+ )
447
+ elif key == "frame_nums":
448
+ packed_batch_features["frame_nums"] = torch.LongTensor(
449
+ [b["frame_nums"] for b in batch]
450
+ )
451
+ elif key == "ref_frame_nums":
452
+ packed_batch_features["ref_frame_nums"] = torch.LongTensor(
453
+ [b["ref_frame_nums"] for b in batch]
454
+ )
455
+ else:
456
+ pass
457
+
458
+ return packed_batch_features
459
+
460
+
461
+ def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
462
+ if len(batch) == 0:
463
+ return 0
464
+ if len(batch) == max_sentences:
465
+ return 1
466
+ if num_tokens > max_tokens:
467
+ return 1
468
+ return 0
469
+
470
+
471
+ def batch_by_size(
472
+ indices,
473
+ num_tokens_fn,
474
+ max_tokens=None,
475
+ max_sentences=None,
476
+ required_batch_size_multiple=1,
477
+ ):
478
+ """
479
+ Yield mini-batches of indices bucketed by size. Batches may contain
480
+ sequences of different lengths.
481
+
482
+ Args:
483
+ indices (List[int]): ordered list of dataset indices
484
+ num_tokens_fn (callable): function that returns the number of tokens at
485
+ a given index
486
+ max_tokens (int, optional): max number of tokens in each batch
487
+ (default: None).
488
+ max_sentences (int, optional): max number of sentences in each
489
+ batch (default: None).
490
+ required_batch_size_multiple (int, optional): require batch size to
491
+ be a multiple of N (default: 1).
492
+ """
493
+ bsz_mult = required_batch_size_multiple
494
+
495
+ sample_len = 0
496
+ sample_lens = []
497
+ batch = []
498
+ batches = []
499
+ for i in range(len(indices)):
500
+ idx = indices[i]
501
+ num_tokens = num_tokens_fn(idx)
502
+ sample_lens.append(num_tokens)
503
+ sample_len = max(sample_len, num_tokens)
504
+
505
+ assert (
506
+ sample_len <= max_tokens
507
+ ), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format(
508
+ idx, sample_len, max_tokens
509
+ )
510
+ num_tokens = (len(batch) + 1) * sample_len
511
+
512
+ if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
513
+ mod_len = max(
514
+ bsz_mult * (len(batch) // bsz_mult),
515
+ len(batch) % bsz_mult,
516
+ )
517
+ batches.append(batch[:mod_len])
518
+ batch = batch[mod_len:]
519
+ sample_lens = sample_lens[mod_len:]
520
+ sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
521
+ batch.append(idx)
522
+ if len(batch) > 0:
523
+ batches.append(batch)
524
+ return batches
Amphion/models/tts/naturalspeech2/ns2_inference.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import os
8
+ import torch
9
+ import soundfile as sf
10
+ import numpy as np
11
+
12
+ from models.tts.naturalspeech2.ns2 import NaturalSpeech2
13
+ from encodec import EncodecModel
14
+ from encodec.utils import convert_audio
15
+ from utils.util import load_config
16
+
17
+ from text import text_to_sequence
18
+ from text.cmudict import valid_symbols
19
+ from text.g2p import preprocess_english, read_lexicon
20
+
21
+ import torchaudio
22
+
23
+
24
+ class NS2Inference:
25
+ def __init__(self, args, cfg):
26
+ self.cfg = cfg
27
+ self.args = args
28
+
29
+ self.model = self.build_model()
30
+ self.codec = self.build_codec()
31
+
32
+ self.symbols = valid_symbols + ["sp", "spn", "sil"] + ["<s>", "</s>"]
33
+ self.phone2id = {s: i for i, s in enumerate(self.symbols)}
34
+ self.id2phone = {i: s for s, i in self.phone2id.items()}
35
+
36
+ def build_model(self):
37
+ model = NaturalSpeech2(self.cfg.model)
38
+ model.load_state_dict(
39
+ torch.load(
40
+ os.path.join(self.args.checkpoint_path, "pytorch_model.bin"),
41
+ map_location="cpu",
42
+ )
43
+ )
44
+ model = model.to(self.args.device)
45
+ return model
46
+
47
+ def build_codec(self):
48
+ encodec_model = EncodecModel.encodec_model_24khz()
49
+ encodec_model = encodec_model.to(device=self.args.device)
50
+ encodec_model.set_target_bandwidth(12.0)
51
+ return encodec_model
52
+
53
+ def get_ref_code(self):
54
+ ref_wav_path = self.args.ref_audio
55
+ ref_wav, sr = torchaudio.load(ref_wav_path)
56
+ ref_wav = convert_audio(
57
+ ref_wav, sr, self.codec.sample_rate, self.codec.channels
58
+ )
59
+ ref_wav = ref_wav.unsqueeze(0).to(device=self.args.device)
60
+
61
+ with torch.no_grad():
62
+ encoded_frames = self.codec.encode(ref_wav)
63
+ ref_code = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
64
+ # print(ref_code.shape)
65
+
66
+ ref_mask = torch.ones(ref_code.shape[0], ref_code.shape[-1]).to(ref_code.device)
67
+ # print(ref_mask.shape)
68
+
69
+ return ref_code, ref_mask
70
+
71
+ def inference(self):
72
+ ref_code, ref_mask = self.get_ref_code()
73
+
74
+ lexicon = read_lexicon(self.cfg.preprocess.lexicon_path)
75
+ phone_seq = preprocess_english(self.args.text, lexicon)
76
+ print(phone_seq)
77
+
78
+ phone_id = np.array(
79
+ [
80
+ *map(
81
+ self.phone2id.get,
82
+ phone_seq.replace("{", "").replace("}", "").split(),
83
+ )
84
+ ]
85
+ )
86
+ phone_id = torch.from_numpy(phone_id).unsqueeze(0).to(device=self.args.device)
87
+ print(phone_id)
88
+
89
+ x0, prior_out = self.model.inference(
90
+ ref_code, phone_id, ref_mask, self.args.inference_step
91
+ )
92
+ print(prior_out["dur_pred"])
93
+ print(prior_out["dur_pred_round"])
94
+ print(torch.sum(prior_out["dur_pred_round"]))
95
+
96
+ latent_ref = self.codec.quantizer.vq.decode(ref_code.transpose(0, 1))
97
+
98
+ rec_wav = self.codec.decoder(x0)
99
+ # ref_wav = self.codec.decoder(latent_ref)
100
+
101
+ os.makedirs(self.args.output_dir, exist_ok=True)
102
+
103
+ sf.write(
104
+ "{}/{}.wav".format(
105
+ self.args.output_dir, self.args.text.replace(" ", "_", 100)
106
+ ),
107
+ rec_wav[0, 0].detach().cpu().numpy(),
108
+ samplerate=24000,
109
+ )
110
+
111
+ def add_arguments(parser: argparse.ArgumentParser):
112
+ parser.add_argument(
113
+ "--ref_audio",
114
+ type=str,
115
+ default="",
116
+ help="Reference audio path",
117
+ )
118
+ parser.add_argument(
119
+ "--device",
120
+ type=str,
121
+ default="cuda",
122
+ )
123
+ parser.add_argument(
124
+ "--inference_step",
125
+ type=int,
126
+ default=200,
127
+ help="Total inference steps for the diffusion model",
128
+ )
Amphion/models/tts/naturalspeech2/ns2_trainer.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import shutil
8
+ import json
9
+ import time
10
+ import torch
11
+ import numpy as np
12
+ from utils.util import Logger, ValueWindow
13
+ from torch.utils.data import ConcatDataset, DataLoader
14
+ from models.tts.base.tts_trainer import TTSTrainer
15
+ from models.base.base_trainer import BaseTrainer
16
+ from models.base.base_sampler import VariableSampler
17
+ from models.tts.naturalspeech2.ns2_dataset import NS2Dataset, NS2Collator, batch_by_size
18
+ from models.tts.naturalspeech2.ns2_loss import (
19
+ log_pitch_loss,
20
+ log_dur_loss,
21
+ diff_loss,
22
+ diff_ce_loss,
23
+ )
24
+ from torch.utils.data.sampler import BatchSampler, SequentialSampler
25
+ from models.tts.naturalspeech2.ns2 import NaturalSpeech2
26
+ from torch.optim import Adam, AdamW
27
+ from torch.nn import MSELoss, L1Loss
28
+ import torch.nn.functional as F
29
+ from diffusers import get_scheduler
30
+
31
+ import accelerate
32
+ from accelerate.logging import get_logger
33
+ from accelerate.utils import ProjectConfiguration
34
+
35
+
36
+ class NS2Trainer(TTSTrainer):
37
+ def __init__(self, args, cfg):
38
+ self.args = args
39
+ self.cfg = cfg
40
+
41
+ cfg.exp_name = args.exp_name
42
+
43
+ self._init_accelerator()
44
+ self.accelerator.wait_for_everyone()
45
+
46
+ # Init logger
47
+ with self.accelerator.main_process_first():
48
+ if self.accelerator.is_main_process:
49
+ os.makedirs(os.path.join(self.exp_dir, "checkpoint"), exist_ok=True)
50
+ self.log_file = os.path.join(
51
+ os.path.join(self.exp_dir, "checkpoint"), "train.log"
52
+ )
53
+ self.logger = Logger(self.log_file, level=self.args.log_level).logger
54
+
55
+ self.time_window = ValueWindow(50)
56
+
57
+ if self.accelerator.is_main_process:
58
+ # Log some info
59
+ self.logger.info("=" * 56)
60
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
61
+ self.logger.info("=" * 56)
62
+ self.logger.info("\n")
63
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
64
+ self.logger.info(f"Experiment name: {args.exp_name}")
65
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
66
+
67
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
68
+ if self.accelerator.is_main_process:
69
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
70
+
71
+ if self.accelerator.is_main_process:
72
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
73
+
74
+ # init counts
75
+ self.batch_count: int = 0
76
+ self.step: int = 0
77
+ self.epoch: int = 0
78
+ self.max_epoch = (
79
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
80
+ )
81
+ if self.accelerator.is_main_process:
82
+ self.logger.info(
83
+ "Max epoch: {}".format(
84
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
85
+ )
86
+ )
87
+
88
+ # Check values
89
+ if self.accelerator.is_main_process:
90
+ self._check_basic_configs()
91
+ # Set runtime configs
92
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
93
+ self.checkpoints_path = [
94
+ [] for _ in range(len(self.save_checkpoint_stride))
95
+ ]
96
+ self.keep_last = [
97
+ i if i > 0 else float("inf") for i in self.cfg.train.keep_last
98
+ ]
99
+ self.run_eval = self.cfg.train.run_eval
100
+
101
+ # set random seed
102
+ with self.accelerator.main_process_first():
103
+ start = time.monotonic_ns()
104
+ self._set_random_seed(self.cfg.train.random_seed)
105
+ end = time.monotonic_ns()
106
+ if self.accelerator.is_main_process:
107
+ self.logger.debug(
108
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
109
+ )
110
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
111
+
112
+ # setup data_loader
113
+ with self.accelerator.main_process_first():
114
+ if self.accelerator.is_main_process:
115
+ self.logger.info("Building dataset...")
116
+ start = time.monotonic_ns()
117
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
118
+ end = time.monotonic_ns()
119
+ if self.accelerator.is_main_process:
120
+ self.logger.info(
121
+ f"Building dataset done in {(end - start) / 1e6:.2f}ms"
122
+ )
123
+
124
+ # setup model
125
+ with self.accelerator.main_process_first():
126
+ if self.accelerator.is_main_process:
127
+ self.logger.info("Building model...")
128
+ start = time.monotonic_ns()
129
+ self.model = self._build_model()
130
+ end = time.monotonic_ns()
131
+ if self.accelerator.is_main_process:
132
+ self.logger.debug(self.model)
133
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
134
+ self.logger.info(
135
+ f"Model parameters: {self._count_parameters(self.model)/1e6:.2f}M"
136
+ )
137
+
138
+ # optimizer & scheduler
139
+ with self.accelerator.main_process_first():
140
+ if self.accelerator.is_main_process:
141
+ self.logger.info("Building optimizer and scheduler...")
142
+ start = time.monotonic_ns()
143
+ self.optimizer = self._build_optimizer()
144
+ self.scheduler = self._build_scheduler()
145
+ end = time.monotonic_ns()
146
+ if self.accelerator.is_main_process:
147
+ self.logger.info(
148
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
149
+ )
150
+
151
+ # accelerate prepare
152
+ if not self.cfg.train.use_dynamic_batchsize:
153
+ if self.accelerator.is_main_process:
154
+ self.logger.info("Initializing accelerate...")
155
+ start = time.monotonic_ns()
156
+ (
157
+ self.train_dataloader,
158
+ self.valid_dataloader,
159
+ ) = self.accelerator.prepare(
160
+ self.train_dataloader,
161
+ self.valid_dataloader,
162
+ )
163
+
164
+ if isinstance(self.model, dict):
165
+ for key in self.model.keys():
166
+ self.model[key] = self.accelerator.prepare(self.model[key])
167
+ else:
168
+ self.model = self.accelerator.prepare(self.model)
169
+
170
+ if isinstance(self.optimizer, dict):
171
+ for key in self.optimizer.keys():
172
+ self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
173
+ else:
174
+ self.optimizer = self.accelerator.prepare(self.optimizer)
175
+
176
+ if isinstance(self.scheduler, dict):
177
+ for key in self.scheduler.keys():
178
+ self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
179
+ else:
180
+ self.scheduler = self.accelerator.prepare(self.scheduler)
181
+
182
+ end = time.monotonic_ns()
183
+ if self.accelerator.is_main_process:
184
+ self.logger.info(
185
+ f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms"
186
+ )
187
+
188
+ # create criterion
189
+ with self.accelerator.main_process_first():
190
+ if self.accelerator.is_main_process:
191
+ self.logger.info("Building criterion...")
192
+ start = time.monotonic_ns()
193
+ self.criterion = self._build_criterion()
194
+ end = time.monotonic_ns()
195
+ if self.accelerator.is_main_process:
196
+ self.logger.info(
197
+ f"Building criterion done in {(end - start) / 1e6:.2f}ms"
198
+ )
199
+
200
+ # TODO: Resume from ckpt need test/debug
201
+ with self.accelerator.main_process_first():
202
+ if args.resume:
203
+ if self.accelerator.is_main_process:
204
+ self.logger.info("Resuming from checkpoint...")
205
+ start = time.monotonic_ns()
206
+ ckpt_path = self._load_model(
207
+ self.checkpoint_dir,
208
+ args.checkpoint_path,
209
+ resume_type=args.resume_type,
210
+ )
211
+ end = time.monotonic_ns()
212
+ if self.accelerator.is_main_process:
213
+ self.logger.info(
214
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
215
+ )
216
+ self.checkpoints_path = json.load(
217
+ open(os.path.join(ckpt_path, "ckpts.json"), "r")
218
+ )
219
+
220
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
221
+ if self.accelerator.is_main_process:
222
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
223
+ if self.accelerator.is_main_process:
224
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
225
+
226
+ # save config file path
227
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
228
+
229
+ # Only for TTS tasks
230
+ self.task_type = "TTS"
231
+ if self.accelerator.is_main_process:
232
+ self.logger.info("Task type: {}".format(self.task_type))
233
+
234
+ def _init_accelerator(self):
235
+ self.exp_dir = os.path.join(
236
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
237
+ )
238
+ project_config = ProjectConfiguration(
239
+ project_dir=self.exp_dir,
240
+ logging_dir=os.path.join(self.exp_dir, "log"),
241
+ )
242
+ # ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
243
+ self.accelerator = accelerate.Accelerator(
244
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
245
+ log_with=self.cfg.train.tracker,
246
+ project_config=project_config,
247
+ # kwargs_handlers=[ddp_kwargs]
248
+ )
249
+ if self.accelerator.is_main_process:
250
+ os.makedirs(project_config.project_dir, exist_ok=True)
251
+ os.makedirs(project_config.logging_dir, exist_ok=True)
252
+ with self.accelerator.main_process_first():
253
+ self.accelerator.init_trackers(self.args.exp_name)
254
+
255
+ def _build_model(self):
256
+ model = NaturalSpeech2(cfg=self.cfg.model)
257
+ return model
258
+
259
+ def _build_dataset(self):
260
+ return NS2Dataset, NS2Collator
261
+
262
+ def _build_dataloader(self):
263
+ if self.cfg.train.use_dynamic_batchsize:
264
+ print("Use Dynamic Batchsize......")
265
+ Dataset, Collator = self._build_dataset()
266
+ train_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=False)
267
+ train_collate = Collator(self.cfg)
268
+ batch_sampler = batch_by_size(
269
+ train_dataset.num_frame_indices,
270
+ train_dataset.get_num_frames,
271
+ max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
272
+ max_sentences=self.cfg.train.max_sentences
273
+ * self.accelerator.num_processes,
274
+ required_batch_size_multiple=self.accelerator.num_processes,
275
+ )
276
+ np.random.seed(980205)
277
+ np.random.shuffle(batch_sampler)
278
+ print(batch_sampler[:1])
279
+ batches = [
280
+ x[
281
+ self.accelerator.local_process_index :: self.accelerator.num_processes
282
+ ]
283
+ for x in batch_sampler
284
+ if len(x) % self.accelerator.num_processes == 0
285
+ ]
286
+
287
+ train_loader = DataLoader(
288
+ train_dataset,
289
+ collate_fn=train_collate,
290
+ num_workers=self.cfg.train.dataloader.num_worker,
291
+ batch_sampler=VariableSampler(
292
+ batches, drop_last=False, use_random_sampler=True
293
+ ),
294
+ pin_memory=self.cfg.train.dataloader.pin_memory,
295
+ )
296
+ self.accelerator.wait_for_everyone()
297
+
298
+ valid_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=True)
299
+ valid_collate = Collator(self.cfg)
300
+ batch_sampler = batch_by_size(
301
+ valid_dataset.num_frame_indices,
302
+ valid_dataset.get_num_frames,
303
+ max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
304
+ max_sentences=self.cfg.train.max_sentences
305
+ * self.accelerator.num_processes,
306
+ required_batch_size_multiple=self.accelerator.num_processes,
307
+ )
308
+ batches = [
309
+ x[
310
+ self.accelerator.local_process_index :: self.accelerator.num_processes
311
+ ]
312
+ for x in batch_sampler
313
+ if len(x) % self.accelerator.num_processes == 0
314
+ ]
315
+ valid_loader = DataLoader(
316
+ valid_dataset,
317
+ collate_fn=valid_collate,
318
+ num_workers=self.cfg.train.dataloader.num_worker,
319
+ batch_sampler=VariableSampler(batches, drop_last=False),
320
+ pin_memory=self.cfg.train.dataloader.pin_memory,
321
+ )
322
+ self.accelerator.wait_for_everyone()
323
+
324
+ else:
325
+ print("Use Normal Batchsize......")
326
+ Dataset, Collator = self._build_dataset()
327
+ train_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=False)
328
+ train_collate = Collator(self.cfg)
329
+
330
+ train_loader = DataLoader(
331
+ train_dataset,
332
+ shuffle=True,
333
+ collate_fn=train_collate,
334
+ batch_size=self.cfg.train.batch_size,
335
+ num_workers=self.cfg.train.dataloader.num_worker,
336
+ pin_memory=self.cfg.train.dataloader.pin_memory,
337
+ )
338
+
339
+ valid_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=True)
340
+ valid_collate = Collator(self.cfg)
341
+
342
+ valid_loader = DataLoader(
343
+ valid_dataset,
344
+ shuffle=True,
345
+ collate_fn=valid_collate,
346
+ batch_size=self.cfg.train.batch_size,
347
+ num_workers=self.cfg.train.dataloader.num_worker,
348
+ pin_memory=self.cfg.train.dataloader.pin_memory,
349
+ )
350
+ self.accelerator.wait_for_everyone()
351
+
352
+ return train_loader, valid_loader
353
+
354
+ def _build_optimizer(self):
355
+ optimizer = torch.optim.AdamW(
356
+ filter(lambda p: p.requires_grad, self.model.parameters()),
357
+ **self.cfg.train.adam,
358
+ )
359
+ return optimizer
360
+
361
+ def _build_scheduler(self):
362
+ lr_scheduler = get_scheduler(
363
+ self.cfg.train.lr_scheduler,
364
+ optimizer=self.optimizer,
365
+ num_warmup_steps=self.cfg.train.lr_warmup_steps,
366
+ num_training_steps=self.cfg.train.num_train_steps,
367
+ )
368
+ return lr_scheduler
369
+
370
+ def _build_criterion(self):
371
+ criterion = torch.nn.L1Loss(reduction="mean")
372
+ return criterion
373
+
374
+ def write_summary(self, losses, stats):
375
+ for key, value in losses.items():
376
+ self.sw.add_scalar(key, value, self.step)
377
+
378
+ def write_valid_summary(self, losses, stats):
379
+ for key, value in losses.items():
380
+ self.sw.add_scalar(key, value, self.step)
381
+
382
+ def get_state_dict(self):
383
+ state_dict = {
384
+ "model": self.model.state_dict(),
385
+ "optimizer": self.optimizer.state_dict(),
386
+ "scheduler": self.scheduler.state_dict(),
387
+ "step": self.step,
388
+ "epoch": self.epoch,
389
+ "batch_size": self.cfg.train.batch_size,
390
+ }
391
+ return state_dict
392
+
393
+ def load_model(self, checkpoint):
394
+ self.step = checkpoint["step"]
395
+ self.epoch = checkpoint["epoch"]
396
+
397
+ self.model.load_state_dict(checkpoint["model"])
398
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
399
+ self.scheduler.load_state_dict(checkpoint["scheduler"])
400
+
401
+ def _train_step(self, batch):
402
+ train_losses = {}
403
+ total_loss = 0
404
+ train_stats = {}
405
+
406
+ code = batch["code"] # (B, 16, T)
407
+ pitch = batch["pitch"] # (B, T)
408
+ duration = batch["duration"] # (B, N)
409
+ phone_id = batch["phone_id"] # (B, N)
410
+ ref_code = batch["ref_code"] # (B, 16, T')
411
+ phone_mask = batch["phone_mask"] # (B, N)
412
+ mask = batch["mask"] # (B, T)
413
+ ref_mask = batch["ref_mask"] # (B, T')
414
+
415
+ diff_out, prior_out = self.model(
416
+ code=code,
417
+ pitch=pitch,
418
+ duration=duration,
419
+ phone_id=phone_id,
420
+ ref_code=ref_code,
421
+ phone_mask=phone_mask,
422
+ mask=mask,
423
+ ref_mask=ref_mask,
424
+ )
425
+
426
+ # pitch loss
427
+ pitch_loss = log_pitch_loss(prior_out["pitch_pred_log"], pitch, mask=mask)
428
+ total_loss += pitch_loss
429
+ train_losses["pitch_loss"] = pitch_loss
430
+
431
+ # duration loss
432
+ dur_loss = log_dur_loss(prior_out["dur_pred_log"], duration, mask=phone_mask)
433
+ total_loss += dur_loss
434
+ train_losses["dur_loss"] = dur_loss
435
+
436
+ x0 = self.model.module.code_to_latent(code)
437
+ if self.cfg.model.diffusion.diffusion_type == "diffusion":
438
+ # diff loss x0
439
+ diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask)
440
+ total_loss += diff_loss_x0
441
+ train_losses["diff_loss_x0"] = diff_loss_x0
442
+
443
+ # diff loss noise
444
+ diff_loss_noise = diff_loss(
445
+ diff_out["noise_pred"], diff_out["noise"], mask=mask
446
+ )
447
+ total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda
448
+ train_losses["diff_loss_noise"] = diff_loss_noise
449
+
450
+ elif self.cfg.model.diffusion.diffusion_type == "flow":
451
+ # diff flow matching loss
452
+ flow_gt = diff_out["noise"] - x0
453
+ diff_loss_flow = diff_loss(diff_out["flow_pred"], flow_gt, mask=mask)
454
+ total_loss += diff_loss_flow
455
+ train_losses["diff_loss_flow"] = diff_loss_flow
456
+
457
+ # diff loss ce
458
+
459
+ # (nq, B, T); (nq, B, T, 1024)
460
+ if self.cfg.train.diff_ce_loss_lambda > 0:
461
+ pred_indices, pred_dist = self.model.module.latent_to_code(
462
+ diff_out["x0_pred"], nq=code.shape[1]
463
+ )
464
+ gt_indices, _ = self.model.module.latent_to_code(x0, nq=code.shape[1])
465
+ diff_loss_ce = diff_ce_loss(pred_dist, gt_indices, mask=mask)
466
+ total_loss += diff_loss_ce * self.cfg.train.diff_ce_loss_lambda
467
+ train_losses["diff_loss_ce"] = diff_loss_ce
468
+
469
+ self.optimizer.zero_grad()
470
+ # total_loss.backward()
471
+ self.accelerator.backward(total_loss)
472
+ if self.accelerator.sync_gradients:
473
+ self.accelerator.clip_grad_norm_(
474
+ filter(lambda p: p.requires_grad, self.model.parameters()), 0.5
475
+ )
476
+ self.optimizer.step()
477
+ self.scheduler.step()
478
+
479
+ for item in train_losses:
480
+ train_losses[item] = train_losses[item].item()
481
+
482
+ if self.cfg.train.diff_ce_loss_lambda > 0:
483
+ pred_indices_list = pred_indices.long().detach().cpu().numpy()
484
+ gt_indices_list = gt_indices.long().detach().cpu().numpy()
485
+ mask_list = batch["mask"].detach().cpu().numpy()
486
+
487
+ for i in range(pred_indices_list.shape[0]):
488
+ pred_acc = np.sum(
489
+ (pred_indices_list[i] == gt_indices_list[i]) * mask_list
490
+ ) / np.sum(mask_list)
491
+ train_losses["pred_acc_{}".format(str(i))] = pred_acc
492
+
493
+ train_losses["batch_size"] = code.shape[0]
494
+ train_losses["max_frame_nums"] = np.max(
495
+ batch["frame_nums"].detach().cpu().numpy()
496
+ )
497
+
498
+ return (total_loss.item(), train_losses, train_stats)
499
+
500
+ @torch.inference_mode()
501
+ def _valid_step(self, batch):
502
+ valid_losses = {}
503
+ total_loss = 0
504
+ valid_stats = {}
505
+
506
+ code = batch["code"] # (B, 16, T)
507
+ pitch = batch["pitch"] # (B, T)
508
+ duration = batch["duration"] # (B, N)
509
+ phone_id = batch["phone_id"] # (B, N)
510
+ ref_code = batch["ref_code"] # (B, 16, T')
511
+ phone_mask = batch["phone_mask"] # (B, N)
512
+ mask = batch["mask"] # (B, T)
513
+ ref_mask = batch["ref_mask"] # (B, T')
514
+
515
+ diff_out, prior_out = self.model(
516
+ code=code,
517
+ pitch=pitch,
518
+ duration=duration,
519
+ phone_id=phone_id,
520
+ ref_code=ref_code,
521
+ phone_mask=phone_mask,
522
+ mask=mask,
523
+ ref_mask=ref_mask,
524
+ )
525
+
526
+ # pitch loss
527
+ pitch_loss = log_pitch_loss(prior_out["pitch_pred_log"], pitch, mask=mask)
528
+ total_loss += pitch_loss
529
+ valid_losses["pitch_loss"] = pitch_loss
530
+
531
+ # duration loss
532
+ dur_loss = log_dur_loss(prior_out["dur_pred_log"], duration, mask=phone_mask)
533
+ total_loss += dur_loss
534
+ valid_losses["dur_loss"] = dur_loss
535
+
536
+ x0 = self.model.module.code_to_latent(code)
537
+ if self.cfg.model.diffusion.diffusion_type == "diffusion":
538
+ # diff loss x0
539
+ diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask)
540
+ total_loss += diff_loss_x0
541
+ valid_losses["diff_loss_x0"] = diff_loss_x0
542
+
543
+ # diff loss noise
544
+ diff_loss_noise = diff_loss(
545
+ diff_out["noise_pred"], diff_out["noise"], mask=mask
546
+ )
547
+ total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda
548
+ valid_losses["diff_loss_noise"] = diff_loss_noise
549
+
550
+ elif self.cfg.model.diffusion.diffusion_type == "flow":
551
+ # diff flow matching loss
552
+ flow_gt = diff_out["noise"] - x0
553
+ diff_loss_flow = diff_loss(diff_out["flow_pred"], flow_gt, mask=mask)
554
+ total_loss += diff_loss_flow
555
+ valid_losses["diff_loss_flow"] = diff_loss_flow
556
+
557
+ # diff loss ce
558
+
559
+ # (nq, B, T); (nq, B, T, 1024)
560
+ if self.cfg.train.diff_ce_loss_lambda > 0:
561
+ pred_indices, pred_dist = self.model.module.latent_to_code(
562
+ diff_out["x0_pred"], nq=code.shape[1]
563
+ )
564
+ gt_indices, _ = self.model.module.latent_to_code(x0, nq=code.shape[1])
565
+ diff_loss_ce = diff_ce_loss(pred_dist, gt_indices, mask=mask)
566
+ total_loss += diff_loss_ce * self.cfg.train.diff_ce_loss_lambda
567
+ valid_losses["diff_loss_ce"] = diff_loss_ce
568
+
569
+ for item in valid_losses:
570
+ valid_losses[item] = valid_losses[item].item()
571
+
572
+ if self.cfg.train.diff_ce_loss_lambda > 0:
573
+ pred_indices_list = pred_indices.long().detach().cpu().numpy()
574
+ gt_indices_list = gt_indices.long().detach().cpu().numpy()
575
+ mask_list = batch["mask"].detach().cpu().numpy()
576
+
577
+ for i in range(pred_indices_list.shape[0]):
578
+ pred_acc = np.sum(
579
+ (pred_indices_list[i] == gt_indices_list[i]) * mask_list
580
+ ) / np.sum(mask_list)
581
+ valid_losses["pred_acc_{}".format(str(i))] = pred_acc
582
+
583
+ return (total_loss.item(), valid_losses, valid_stats)
584
+
585
+ @torch.inference_mode()
586
+ def _valid_epoch(self):
587
+ r"""Testing epoch. Should return average loss of a batch (sample) over
588
+ one epoch. See ``train_loop`` for usage.
589
+ """
590
+ if isinstance(self.model, dict):
591
+ for key in self.model.keys():
592
+ self.model[key].eval()
593
+ else:
594
+ self.model.eval()
595
+
596
+ epoch_sum_loss = 0.0
597
+ epoch_losses = dict()
598
+
599
+ for batch in self.valid_dataloader:
600
+ # Put the data to cuda device
601
+ device = self.accelerator.device
602
+ for k, v in batch.items():
603
+ if isinstance(v, torch.Tensor):
604
+ batch[k] = v.to(device)
605
+
606
+ total_loss, valid_losses, valid_stats = self._valid_step(batch)
607
+ epoch_sum_loss = total_loss
608
+ for key, value in valid_losses.items():
609
+ epoch_losses[key] = value
610
+
611
+ self.accelerator.wait_for_everyone()
612
+
613
+ return epoch_sum_loss, epoch_losses
614
+
615
+ def _train_epoch(self):
616
+ r"""Training epoch. Should return average loss of a batch (sample) over
617
+ one epoch. See ``train_loop`` for usage.
618
+ """
619
+ if isinstance(self.model, dict):
620
+ for key in self.model.keys():
621
+ self.model[key].train()
622
+ else:
623
+ self.model.train()
624
+
625
+ epoch_sum_loss: float = 0.0
626
+ epoch_losses: dict = {}
627
+ epoch_step: int = 0
628
+
629
+ for batch in self.train_dataloader:
630
+ # Put the data to cuda device
631
+ device = self.accelerator.device
632
+ for k, v in batch.items():
633
+ if isinstance(v, torch.Tensor):
634
+ batch[k] = v.to(device)
635
+
636
+ # Do training step and BP
637
+ with self.accelerator.accumulate(self.model):
638
+ total_loss, train_losses, training_stats = self._train_step(batch)
639
+ self.batch_count += 1
640
+
641
+ # Update info for each step
642
+ # TODO: step means BP counts or batch counts?
643
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
644
+ epoch_sum_loss = total_loss
645
+ for key, value in train_losses.items():
646
+ epoch_losses[key] = value
647
+
648
+ if isinstance(train_losses, dict):
649
+ for key, loss in train_losses.items():
650
+ self.accelerator.log(
651
+ {"Epoch/Train {} Loss".format(key): loss},
652
+ step=self.step,
653
+ )
654
+
655
+ if (
656
+ self.accelerator.is_main_process
657
+ and self.batch_count
658
+ % (1 * self.cfg.train.gradient_accumulation_step)
659
+ == 0
660
+ ):
661
+ self.echo_log(train_losses, mode="Training")
662
+
663
+ self.step += 1
664
+ epoch_step += 1
665
+
666
+ self.accelerator.wait_for_everyone()
667
+
668
+ return epoch_sum_loss, epoch_losses
669
+
670
+ def train_loop(self):
671
+ r"""Training loop. The public entry of training process."""
672
+ # Wait everyone to prepare before we move on
673
+ self.accelerator.wait_for_everyone()
674
+ # dump config file
675
+ if self.accelerator.is_main_process:
676
+ self._dump_cfg(self.config_save_path)
677
+
678
+ # self.optimizer.zero_grad()
679
+
680
+ # Wait to ensure good to go
681
+ self.accelerator.wait_for_everyone()
682
+ while self.epoch < self.max_epoch:
683
+ if self.accelerator.is_main_process:
684
+ self.logger.info("\n")
685
+ self.logger.info("-" * 32)
686
+ self.logger.info("Epoch {}: ".format(self.epoch))
687
+
688
+ # Do training & validating epoch
689
+ train_total_loss, train_losses = self._train_epoch()
690
+ if isinstance(train_losses, dict):
691
+ for key, loss in train_losses.items():
692
+ if self.accelerator.is_main_process:
693
+ self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
694
+ self.accelerator.log(
695
+ {"Epoch/Train {} Loss".format(key): loss},
696
+ step=self.epoch,
697
+ )
698
+
699
+ valid_total_loss, valid_losses = self._valid_epoch()
700
+ if isinstance(valid_losses, dict):
701
+ for key, loss in valid_losses.items():
702
+ if self.accelerator.is_main_process:
703
+ self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss))
704
+ self.accelerator.log(
705
+ {"Epoch/Train {} Loss".format(key): loss},
706
+ step=self.epoch,
707
+ )
708
+
709
+ if self.accelerator.is_main_process:
710
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_total_loss))
711
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_total_loss))
712
+ self.accelerator.log(
713
+ {
714
+ "Epoch/Train Loss": train_total_loss,
715
+ "Epoch/Valid Loss": valid_total_loss,
716
+ },
717
+ step=self.epoch,
718
+ )
719
+
720
+ self.accelerator.wait_for_everyone()
721
+ if isinstance(self.scheduler, dict):
722
+ for key in self.scheduler.keys():
723
+ self.scheduler[key].step()
724
+ else:
725
+ self.scheduler.step()
726
+
727
+ # Check if hit save_checkpoint_stride and run_eval
728
+ run_eval = False
729
+ if self.accelerator.is_main_process:
730
+ save_checkpoint = False
731
+ hit_dix = []
732
+ for i, num in enumerate(self.save_checkpoint_stride):
733
+ if self.epoch % num == 0:
734
+ save_checkpoint = True
735
+ hit_dix.append(i)
736
+ run_eval |= self.run_eval[i]
737
+
738
+ self.accelerator.wait_for_everyone()
739
+ if self.accelerator.is_main_process and save_checkpoint:
740
+ path = os.path.join(
741
+ self.checkpoint_dir,
742
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
743
+ self.epoch, self.step, train_total_loss
744
+ ),
745
+ )
746
+ print("save state......")
747
+ self.accelerator.save_state(path)
748
+ print("finish saving state......")
749
+ json.dump(
750
+ self.checkpoints_path,
751
+ open(os.path.join(path, "ckpts.json"), "w"),
752
+ ensure_ascii=False,
753
+ indent=4,
754
+ )
755
+ # Remove old checkpoints
756
+ to_remove = []
757
+ for idx in hit_dix:
758
+ self.checkpoints_path[idx].append(path)
759
+ while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
760
+ to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
761
+
762
+ # Search conflicts
763
+ total = set()
764
+ for i in self.checkpoints_path:
765
+ total |= set(i)
766
+ do_remove = set()
767
+ for idx, path in to_remove[::-1]:
768
+ if path in total:
769
+ self.checkpoints_path[idx].insert(0, path)
770
+ else:
771
+ do_remove.add(path)
772
+
773
+ # Remove old checkpoints
774
+ for path in do_remove:
775
+ shutil.rmtree(path, ignore_errors=True)
776
+ if self.accelerator.is_main_process:
777
+ self.logger.debug(f"Remove old checkpoint: {path}")
778
+
779
+ self.accelerator.wait_for_everyone()
780
+ if run_eval:
781
+ # TODO: run evaluation
782
+ pass
783
+
784
+ # Update info for each epoch
785
+ self.epoch += 1
786
+
787
+ # Finish training and save final checkpoint
788
+ self.accelerator.wait_for_everyone()
789
+ if self.accelerator.is_main_process:
790
+ self.accelerator.save_state(
791
+ os.path.join(
792
+ self.checkpoint_dir,
793
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
794
+ self.epoch, self.step, valid_total_loss
795
+ ),
796
+ )
797
+ )
798
+ self.accelerator.end_training()
Amphion/models/tts/valle/__init__.py ADDED
File without changes
Amphion/models/vocoders/autoregressive/autoregressive_vocoder_inference.py ADDED
File without changes
Amphion/models/vocoders/autoregressive/wavenet/conv.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+
10
+ class Conv1d(nn.Conv1d):
11
+ """Extended nn.Conv1d for incremental dilated convolutions"""
12
+
13
+ def __init__(self, *args, **kwargs):
14
+ super().__init__(*args, **kwargs)
15
+ self.clear_buffer()
16
+ self._linearized_weight = None
17
+ self.register_backward_hook(self._clear_linearized_weight)
18
+
19
+ def incremental_forward(self, input):
20
+ # input (B, T, C)
21
+ # run forward pre hooks
22
+ for hook in self._forward_pre_hooks.values():
23
+ hook(self, input)
24
+
25
+ # reshape weight
26
+ weight = self._get_linearized_weight()
27
+ kw = self.kernel_size[0]
28
+ dilation = self.dilation[0]
29
+
30
+ bsz = input.size(0)
31
+ if kw > 1:
32
+ input = input.data
33
+ if self.input_buffer is None:
34
+ self.input_buffer = input.new(
35
+ bsz, kw + (kw - 1) * (dilation - 1), input.size(2)
36
+ )
37
+ self.input_buffer.zero_()
38
+ else:
39
+ # shift buffer
40
+ self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone()
41
+ # append next input
42
+ self.input_buffer[:, -1, :] = input[:, -1, :]
43
+ input = self.input_buffer
44
+ if dilation > 1:
45
+ input = input[:, 0::dilation, :].contiguous()
46
+ output = F.linear(input.view(bsz, -1), weight, self.bias)
47
+ return output.view(bsz, 1, -1)
48
+
49
+ def clear_buffer(self):
50
+ self.input_buffer = None
51
+
52
+ def _get_linearized_weight(self):
53
+ if self._linearized_weight is None:
54
+ kw = self.kernel_size[0]
55
+ # nn.Conv1d
56
+ if self.weight.size() == (self.out_channels, self.in_channels, kw):
57
+ weight = self.weight.transpose(1, 2).contiguous()
58
+ else:
59
+ # fairseq.modules.conv_tbc.ConvTBC
60
+ weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
61
+ assert weight.size() == (self.out_channels, kw, self.in_channels)
62
+ self._linearized_weight = weight.view(self.out_channels, -1)
63
+ return self._linearized_weight
64
+
65
+ def _clear_linearized_weight(self, *args):
66
+ self._linearized_weight = None
Amphion/models/vocoders/autoregressive/wavenet/wavenet.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from .modules import Conv1d1x1, ResidualConv1dGLU
12
+ from .upsample import ConvInUpsampleNetwork
13
+
14
+
15
+ def receptive_field_size(
16
+ total_layers, num_cycles, kernel_size, dilation=lambda x: 2**x
17
+ ):
18
+ """Compute receptive field size
19
+
20
+ Args:
21
+ total_layers (int): total layers
22
+ num_cycles (int): cycles
23
+ kernel_size (int): kernel size
24
+ dilation (lambda): lambda to compute dilation factor. ``lambda x : 1``
25
+ to disable dilated convolution.
26
+
27
+ Returns:
28
+ int: receptive field size in sample
29
+
30
+ """
31
+ assert total_layers % num_cycles == 0
32
+
33
+ layers_per_cycle = total_layers // num_cycles
34
+ dilations = [dilation(i % layers_per_cycle) for i in range(total_layers)]
35
+ return (kernel_size - 1) * sum(dilations) + 1
36
+
37
+
38
+ class WaveNet(nn.Module):
39
+ """The WaveNet model that supports local and global conditioning.
40
+
41
+ Args:
42
+ out_channels (int): Output channels. If input_type is mu-law quantized
43
+ one-hot vecror. this must equal to the quantize channels. Other wise
44
+ num_mixtures x 3 (pi, mu, log_scale).
45
+ layers (int): Number of total layers
46
+ stacks (int): Number of dilation cycles
47
+ residual_channels (int): Residual input / output channels
48
+ gate_channels (int): Gated activation channels.
49
+ skip_out_channels (int): Skip connection channels.
50
+ kernel_size (int): Kernel size of convolution layers.
51
+ dropout (float): Dropout probability.
52
+ input_dim (int): Number of mel-spec dimension.
53
+ upsample_scales (list): List of upsample scale.
54
+ ``np.prod(upsample_scales)`` must equal to hop size. Used only if
55
+ upsample_conditional_features is enabled.
56
+ freq_axis_kernel_size (int): Freq-axis kernel_size for transposed
57
+ convolution layers for upsampling. If you only care about time-axis
58
+ upsampling, set this to 1.
59
+ scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise
60
+ quantized one-hot vector is expected..
61
+ """
62
+
63
+ def __init__(self, cfg):
64
+ super(WaveNet, self).__init__()
65
+ self.cfg = cfg
66
+ self.scalar_input = self.cfg.VOCODER.SCALAR_INPUT
67
+ self.out_channels = self.cfg.VOCODER.OUT_CHANNELS
68
+ self.cin_channels = self.cfg.VOCODER.INPUT_DIM
69
+ self.residual_channels = self.cfg.VOCODER.RESIDUAL_CHANNELS
70
+ self.layers = self.cfg.VOCODER.LAYERS
71
+ self.stacks = self.cfg.VOCODER.STACKS
72
+ self.gate_channels = self.cfg.VOCODER.GATE_CHANNELS
73
+ self.kernel_size = self.cfg.VOCODER.KERNEL_SIZE
74
+ self.skip_out_channels = self.cfg.VOCODER.SKIP_OUT_CHANNELS
75
+ self.dropout = self.cfg.VOCODER.DROPOUT
76
+ self.upsample_scales = self.cfg.VOCODER.UPSAMPLE_SCALES
77
+ self.mel_frame_pad = self.cfg.VOCODER.MEL_FRAME_PAD
78
+
79
+ assert self.layers % self.stacks == 0
80
+
81
+ layers_per_stack = self.layers // self.stacks
82
+ if self.scalar_input:
83
+ self.first_conv = Conv1d1x1(1, self.residual_channels)
84
+ else:
85
+ self.first_conv = Conv1d1x1(self.out_channels, self.residual_channels)
86
+
87
+ self.conv_layers = nn.ModuleList()
88
+ for layer in range(self.layers):
89
+ dilation = 2 ** (layer % layers_per_stack)
90
+ conv = ResidualConv1dGLU(
91
+ self.residual_channels,
92
+ self.gate_channels,
93
+ kernel_size=self.kernel_size,
94
+ skip_out_channels=self.skip_out_channels,
95
+ bias=True,
96
+ dilation=dilation,
97
+ dropout=self.dropout,
98
+ cin_channels=self.cin_channels,
99
+ )
100
+ self.conv_layers.append(conv)
101
+
102
+ self.last_conv_layers = nn.ModuleList(
103
+ [
104
+ nn.ReLU(inplace=True),
105
+ Conv1d1x1(self.skip_out_channels, self.skip_out_channels),
106
+ nn.ReLU(inplace=True),
107
+ Conv1d1x1(self.skip_out_channels, self.out_channels),
108
+ ]
109
+ )
110
+
111
+ self.upsample_net = ConvInUpsampleNetwork(
112
+ upsample_scales=self.upsample_scales,
113
+ cin_pad=self.mel_frame_pad,
114
+ cin_channels=self.cin_channels,
115
+ )
116
+
117
+ self.receptive_field = receptive_field_size(
118
+ self.layers, self.stacks, self.kernel_size
119
+ )
120
+
121
+ def forward(self, x, mel, softmax=False):
122
+ """Forward step
123
+
124
+ Args:
125
+ x (Tensor): One-hot encoded audio signal, shape (B x C x T)
126
+ mel (Tensor): Local conditioning features,
127
+ shape (B x cin_channels x T)
128
+ softmax (bool): Whether applies softmax or not.
129
+
130
+ Returns:
131
+ Tensor: output, shape B x out_channels x T
132
+ """
133
+ B, _, T = x.size()
134
+
135
+ mel = self.upsample_net(mel)
136
+ assert mel.shape[-1] == x.shape[-1]
137
+
138
+ x = self.first_conv(x)
139
+ skips = 0
140
+ for f in self.conv_layers:
141
+ x, h = f(x, mel)
142
+ skips += h
143
+ skips *= math.sqrt(1.0 / len(self.conv_layers))
144
+
145
+ x = skips
146
+ for f in self.last_conv_layers:
147
+ x = f(x)
148
+
149
+ x = F.softmax(x, dim=1) if softmax else x
150
+
151
+ return x
152
+
153
+ def clear_buffer(self):
154
+ self.first_conv.clear_buffer()
155
+ for f in self.conv_layers:
156
+ f.clear_buffer()
157
+ for f in self.last_conv_layers:
158
+ try:
159
+ f.clear_buffer()
160
+ except AttributeError:
161
+ pass
162
+
163
+ def make_generation_fast_(self):
164
+ def remove_weight_norm(m):
165
+ try:
166
+ nn.utils.remove_weight_norm(m)
167
+ except ValueError: # this module didn't have weight norm
168
+ return
169
+
170
+ self.apply(remove_weight_norm)
Amphion/models/vocoders/diffusion/diffusion_vocoder_inference.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import numpy as np
8
+
9
+ from tqdm import tqdm
10
+ from utils.util import pad_mels_to_tensors, pad_f0_to_tensors
11
+
12
+
13
+ def vocoder_inference(cfg, model, mels, f0s=None, device=None, fast_inference=False):
14
+ """Inference the vocoder
15
+ Args:
16
+ mels: A tensor of mel-specs with the shape (batch_size, num_mels, frames)
17
+ Returns:
18
+ audios: A tensor of audios with the shape (batch_size, seq_len)
19
+ """
20
+ model.eval()
21
+
22
+ with torch.no_grad():
23
+ training_noise_schedule = np.array(cfg.model.diffwave.noise_schedule)
24
+ inference_noise_schedule = (
25
+ np.array(cfg.model.diffwave.inference_noise_schedule)
26
+ if fast_inference
27
+ else np.array(cfg.model.diffwave.noise_schedule)
28
+ )
29
+
30
+ talpha = 1 - training_noise_schedule
31
+ talpha_cum = np.cumprod(talpha)
32
+
33
+ beta = inference_noise_schedule
34
+ alpha = 1 - beta
35
+ alpha_cum = np.cumprod(alpha)
36
+
37
+ T = []
38
+ for s in range(len(inference_noise_schedule)):
39
+ for t in range(len(training_noise_schedule) - 1):
40
+ if talpha_cum[t + 1] <= alpha_cum[s] <= talpha_cum[t]:
41
+ twiddle = (talpha_cum[t] ** 0.5 - alpha_cum[s] ** 0.5) / (
42
+ talpha_cum[t] ** 0.5 - talpha_cum[t + 1] ** 0.5
43
+ )
44
+ T.append(t + twiddle)
45
+ break
46
+ T = np.array(T, dtype=np.float32)
47
+
48
+ mels = mels.to(device)
49
+ audio = torch.randn(
50
+ mels.shape[0],
51
+ cfg.preprocess.hop_size * mels.shape[-1],
52
+ device=device,
53
+ )
54
+
55
+ for n in tqdm(range(len(alpha) - 1, -1, -1)):
56
+ c1 = 1 / alpha[n] ** 0.5
57
+ c2 = beta[n] / (1 - alpha_cum[n]) ** 0.5
58
+ audio = c1 * (
59
+ audio
60
+ - c2
61
+ * model(audio, torch.tensor([T[n]], device=audio.device), mels).squeeze(
62
+ 1
63
+ )
64
+ )
65
+ if n > 0:
66
+ noise = torch.randn_like(audio)
67
+ sigma = (
68
+ (1.0 - alpha_cum[n - 1]) / (1.0 - alpha_cum[n]) * beta[n]
69
+ ) ** 0.5
70
+ audio += sigma * noise
71
+ audio = torch.clamp(audio, -1.0, 1.0)
72
+
73
+ return audio.detach().cpu()
74
+
75
+
76
+ def synthesis_audios(cfg, model, mels, f0s=None, batch_size=None, fast_inference=False):
77
+ """Inference the vocoder
78
+ Args:
79
+ mels: A list of mel-specs
80
+ Returns:
81
+ audios: A list of audios
82
+ """
83
+ # Get the device
84
+ device = next(model.parameters()).device
85
+
86
+ audios = []
87
+
88
+ # Pad the given list into tensors
89
+ mel_batches, mel_frames = pad_mels_to_tensors(mels, batch_size)
90
+ if f0s != None:
91
+ f0_batches = pad_f0_to_tensors(f0s, batch_size)
92
+
93
+ if f0s == None:
94
+ for mel_batch, mel_frame in zip(mel_batches, mel_frames):
95
+ for i in range(mel_batch.shape[0]):
96
+ mel = mel_batch[i]
97
+ frame = mel_frame[i]
98
+ audio = vocoder_inference(
99
+ cfg,
100
+ model,
101
+ mel.unsqueeze(0),
102
+ device=device,
103
+ fast_inference=fast_inference,
104
+ ).squeeze(0)
105
+
106
+ # calculate the audio length
107
+ audio_length = frame * cfg.preprocess.hop_size
108
+ audio = audio[:audio_length]
109
+
110
+ audios.append(audio)
111
+ else:
112
+ for mel_batch, f0_batch, mel_frame in zip(mel_batches, f0_batches, mel_frames):
113
+ for i in range(mel_batch.shape[0]):
114
+ mel = mel_batch[i]
115
+ f0 = f0_batch[i]
116
+ frame = mel_frame[i]
117
+ audio = vocoder_inference(
118
+ cfg,
119
+ model,
120
+ mel.unsqueeze(0),
121
+ f0s=f0.unsqueeze(0),
122
+ device=device,
123
+ fast_inference=fast_inference,
124
+ ).squeeze(0)
125
+
126
+ # calculate the audio length
127
+ audio_length = frame * cfg.preprocess.hop_size
128
+ audio = audio[:audio_length]
129
+
130
+ audios.append(audio)
131
+ return audios
Amphion/models/vocoders/flow/flow_vocoder_dataset.py ADDED
File without changes
Amphion/models/vocoders/flow/flow_vocoder_inference.py ADDED
File without changes
Amphion/models/vocoders/gan/discriminator/msd.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.nn as nn
9
+ from torch.nn import Conv1d, AvgPool1d
10
+ from torch.nn.utils import weight_norm, spectral_norm
11
+ from torch import nn
12
+ from modules.vocoder_blocks import *
13
+
14
+
15
+ LRELU_SLOPE = 0.1
16
+
17
+
18
+ class DiscriminatorS(nn.Module):
19
+ def __init__(self, use_spectral_norm=False):
20
+ super(DiscriminatorS, self).__init__()
21
+
22
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
23
+
24
+ self.convs = nn.ModuleList(
25
+ [
26
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
27
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
28
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
29
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
30
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
31
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
32
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
33
+ ]
34
+ )
35
+
36
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
37
+
38
+ def forward(self, x):
39
+ fmap = []
40
+
41
+ for l in self.convs:
42
+ x = l(x)
43
+ x = F.leaky_relu(x, LRELU_SLOPE)
44
+ fmap.append(x)
45
+
46
+ x = self.conv_post(x)
47
+ fmap.append(x)
48
+ x = torch.flatten(x, 1, -1)
49
+
50
+ return x, fmap
51
+
52
+
53
+ class MultiScaleDiscriminator(nn.Module):
54
+ def __init__(self, cfg):
55
+ super(MultiScaleDiscriminator, self).__init__()
56
+
57
+ self.cfg = cfg
58
+
59
+ self.discriminators = nn.ModuleList(
60
+ [
61
+ DiscriminatorS(use_spectral_norm=True),
62
+ DiscriminatorS(),
63
+ DiscriminatorS(),
64
+ ]
65
+ )
66
+
67
+ self.meanpools = nn.ModuleList(
68
+ [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
69
+ )
70
+
71
+ def forward(self, y, y_hat):
72
+ y_d_rs = []
73
+ y_d_gs = []
74
+ fmap_rs = []
75
+ fmap_gs = []
76
+
77
+ for i, d in enumerate(self.discriminators):
78
+ if i != 0:
79
+ y = self.meanpools[i - 1](y)
80
+ y_hat = self.meanpools[i - 1](y_hat)
81
+ y_d_r, fmap_r = d(y)
82
+ y_d_g, fmap_g = d(y_hat)
83
+ y_d_rs.append(y_d_r)
84
+ fmap_rs.append(fmap_r)
85
+ y_d_gs.append(y_d_g)
86
+ fmap_gs.append(fmap_g)
87
+
88
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
Amphion/models/vocoders/gan/gan_vocoder_inference.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from utils.util import pad_mels_to_tensors, pad_f0_to_tensors
9
+
10
+
11
+ def vocoder_inference(cfg, model, mels, f0s=None, device=None, fast_inference=False):
12
+ """Inference the vocoder
13
+ Args:
14
+ mels: A tensor of mel-specs with the shape (batch_size, num_mels, frames)
15
+ Returns:
16
+ audios: A tensor of audios with the shape (batch_size, seq_len)
17
+ """
18
+ model.eval()
19
+
20
+ with torch.no_grad():
21
+ mels = mels.to(device)
22
+ if f0s != None:
23
+ f0s = f0s.to(device)
24
+
25
+ if f0s == None and not cfg.preprocess.extract_amplitude_phase:
26
+ output = model.forward(mels)
27
+ elif cfg.preprocess.extract_amplitude_phase:
28
+ (
29
+ _,
30
+ _,
31
+ _,
32
+ _,
33
+ output,
34
+ ) = model.forward(mels)
35
+ else:
36
+ output = model.forward(mels, f0s)
37
+
38
+ return output.squeeze(1).detach().cpu()
39
+
40
+
41
+ def synthesis_audios(cfg, model, mels, f0s=None, batch_size=None, fast_inference=False):
42
+ """Inference the vocoder
43
+ Args:
44
+ mels: A list of mel-specs
45
+ Returns:
46
+ audios: A list of audios
47
+ """
48
+ # Get the device
49
+ device = next(model.parameters()).device
50
+
51
+ audios = []
52
+
53
+ # Pad the given list into tensors
54
+ mel_batches, mel_frames = pad_mels_to_tensors(mels, batch_size)
55
+ if f0s != None:
56
+ f0_batches = pad_f0_to_tensors(f0s, batch_size)
57
+
58
+ if f0s == None:
59
+ for mel_batch, mel_frame in zip(mel_batches, mel_frames):
60
+ for i in range(mel_batch.shape[0]):
61
+ mel = mel_batch[i]
62
+ frame = mel_frame[i]
63
+ audio = vocoder_inference(
64
+ cfg,
65
+ model,
66
+ mel.unsqueeze(0),
67
+ device=device,
68
+ fast_inference=fast_inference,
69
+ ).squeeze(0)
70
+
71
+ # calculate the audio length
72
+ audio_length = frame * model.cfg.preprocess.hop_size
73
+ audio = audio[:audio_length]
74
+
75
+ audios.append(audio)
76
+ else:
77
+ for mel_batch, f0_batch, mel_frame in zip(mel_batches, f0_batches, mel_frames):
78
+ for i in range(mel_batch.shape[0]):
79
+ mel = mel_batch[i]
80
+ f0 = f0_batch[i]
81
+ frame = mel_frame[i]
82
+ audio = vocoder_inference(
83
+ cfg,
84
+ model,
85
+ mel.unsqueeze(0),
86
+ f0s=f0.unsqueeze(0),
87
+ device=device,
88
+ fast_inference=fast_inference,
89
+ ).squeeze(0)
90
+
91
+ # calculate the audio length
92
+ audio_length = frame * model.cfg.preprocess.hop_size
93
+ audio = audio[:audio_length]
94
+
95
+ audios.append(audio)
96
+ return audios
Amphion/models/vocoders/vocoder_dataset.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Iterable
7
+ import torch
8
+ import numpy as np
9
+ import torch.utils.data
10
+ from torch.nn.utils.rnn import pad_sequence
11
+ from utils.data_utils import *
12
+ from torch.utils.data import ConcatDataset, Dataset
13
+
14
+
15
+ class VocoderDataset(torch.utils.data.Dataset):
16
+ def __init__(self, cfg, dataset, is_valid=False):
17
+ """
18
+ Args:
19
+ cfg: config
20
+ dataset: dataset name
21
+ is_valid: whether to use train or valid dataset
22
+ """
23
+ assert isinstance(dataset, str)
24
+
25
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
26
+
27
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
28
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
29
+ self.metadata = self.get_metadata()
30
+
31
+ self.data_root = processed_data_dir
32
+ self.cfg = cfg
33
+
34
+ if cfg.preprocess.use_audio:
35
+ self.utt2audio_path = {}
36
+ for utt_info in self.metadata:
37
+ dataset = utt_info["Dataset"]
38
+ uid = utt_info["Uid"]
39
+ utt = "{}_{}".format(dataset, uid)
40
+
41
+ self.utt2audio_path[utt] = os.path.join(
42
+ cfg.preprocess.processed_dir,
43
+ dataset,
44
+ cfg.preprocess.audio_dir,
45
+ uid + ".npy",
46
+ )
47
+ elif cfg.preprocess.use_label:
48
+ self.utt2label_path = {}
49
+ for utt_info in self.metadata:
50
+ dataset = utt_info["Dataset"]
51
+ uid = utt_info["Uid"]
52
+ utt = "{}_{}".format(dataset, uid)
53
+
54
+ self.utt2label_path[utt] = os.path.join(
55
+ cfg.preprocess.processed_dir,
56
+ dataset,
57
+ cfg.preprocess.label_dir,
58
+ uid + ".npy",
59
+ )
60
+ elif cfg.preprocess.use_one_hot:
61
+ self.utt2one_hot_path = {}
62
+ for utt_info in self.metadata:
63
+ dataset = utt_info["Dataset"]
64
+ uid = utt_info["Uid"]
65
+ utt = "{}_{}".format(dataset, uid)
66
+
67
+ self.utt2one_hot_path[utt] = os.path.join(
68
+ cfg.preprocess.processed_dir,
69
+ dataset,
70
+ cfg.preprocess.one_hot_dir,
71
+ uid + ".npy",
72
+ )
73
+
74
+ if cfg.preprocess.use_mel:
75
+ self.utt2mel_path = {}
76
+ for utt_info in self.metadata:
77
+ dataset = utt_info["Dataset"]
78
+ uid = utt_info["Uid"]
79
+ utt = "{}_{}".format(dataset, uid)
80
+
81
+ self.utt2mel_path[utt] = os.path.join(
82
+ cfg.preprocess.processed_dir,
83
+ dataset,
84
+ cfg.preprocess.mel_dir,
85
+ uid + ".npy",
86
+ )
87
+
88
+ if cfg.preprocess.use_frame_pitch:
89
+ self.utt2frame_pitch_path = {}
90
+ for utt_info in self.metadata:
91
+ dataset = utt_info["Dataset"]
92
+ uid = utt_info["Uid"]
93
+ utt = "{}_{}".format(dataset, uid)
94
+
95
+ self.utt2frame_pitch_path[utt] = os.path.join(
96
+ cfg.preprocess.processed_dir,
97
+ dataset,
98
+ cfg.preprocess.pitch_dir,
99
+ uid + ".npy",
100
+ )
101
+
102
+ if cfg.preprocess.use_uv:
103
+ self.utt2uv_path = {}
104
+ for utt_info in self.metadata:
105
+ dataset = utt_info["Dataset"]
106
+ uid = utt_info["Uid"]
107
+ utt = "{}_{}".format(dataset, uid)
108
+ self.utt2uv_path[utt] = os.path.join(
109
+ cfg.preprocess.processed_dir,
110
+ dataset,
111
+ cfg.preprocess.uv_dir,
112
+ uid + ".npy",
113
+ )
114
+
115
+ if cfg.preprocess.use_amplitude_phase:
116
+ self.utt2logamp_path = {}
117
+ self.utt2pha_path = {}
118
+ self.utt2rea_path = {}
119
+ self.utt2imag_path = {}
120
+ for utt_info in self.metadata:
121
+ dataset = utt_info["Dataset"]
122
+ uid = utt_info["Uid"]
123
+ utt = "{}_{}".format(dataset, uid)
124
+ self.utt2logamp_path[utt] = os.path.join(
125
+ cfg.preprocess.processed_dir,
126
+ dataset,
127
+ cfg.preprocess.log_amplitude_dir,
128
+ uid + ".npy",
129
+ )
130
+ self.utt2pha_path[utt] = os.path.join(
131
+ cfg.preprocess.processed_dir,
132
+ dataset,
133
+ cfg.preprocess.phase_dir,
134
+ uid + ".npy",
135
+ )
136
+ self.utt2rea_path[utt] = os.path.join(
137
+ cfg.preprocess.processed_dir,
138
+ dataset,
139
+ cfg.preprocess.real_dir,
140
+ uid + ".npy",
141
+ )
142
+ self.utt2imag_path[utt] = os.path.join(
143
+ cfg.preprocess.processed_dir,
144
+ dataset,
145
+ cfg.preprocess.imaginary_dir,
146
+ uid + ".npy",
147
+ )
148
+
149
+ def __getitem__(self, index):
150
+ utt_info = self.metadata[index]
151
+
152
+ dataset = utt_info["Dataset"]
153
+ uid = utt_info["Uid"]
154
+ utt = "{}_{}".format(dataset, uid)
155
+
156
+ single_feature = dict()
157
+
158
+ if self.cfg.preprocess.use_mel:
159
+ mel = np.load(self.utt2mel_path[utt])
160
+ assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
161
+
162
+ if "target_len" not in single_feature.keys():
163
+ single_feature["target_len"] = mel.shape[1]
164
+
165
+ single_feature["mel"] = mel
166
+
167
+ if self.cfg.preprocess.use_frame_pitch:
168
+ frame_pitch = np.load(self.utt2frame_pitch_path[utt])
169
+
170
+ if "target_len" not in single_feature.keys():
171
+ single_feature["target_len"] = len(frame_pitch)
172
+
173
+ aligned_frame_pitch = align_length(
174
+ frame_pitch, single_feature["target_len"]
175
+ )
176
+
177
+ single_feature["frame_pitch"] = aligned_frame_pitch
178
+
179
+ if self.cfg.preprocess.use_audio:
180
+ audio = np.load(self.utt2audio_path[utt])
181
+
182
+ single_feature["audio"] = audio
183
+
184
+ return single_feature
185
+
186
+ def get_metadata(self):
187
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
188
+ metadata = json.load(f)
189
+
190
+ return metadata
191
+
192
+ def get_dataset_name(self):
193
+ return self.metadata[0]["Dataset"]
194
+
195
+ def __len__(self):
196
+ return len(self.metadata)
197
+
198
+
199
+ class VocoderConcatDataset(ConcatDataset):
200
+ def __init__(self, datasets: Iterable[Dataset], full_audio_inference=False):
201
+ """Concatenate a series of datasets with their random inference audio merged."""
202
+ super().__init__(datasets)
203
+
204
+ self.cfg = self.datasets[0].cfg
205
+
206
+ self.metadata = []
207
+
208
+ # Merge metadata
209
+ for dataset in self.datasets:
210
+ self.metadata += dataset.metadata
211
+
212
+ # Merge random inference features
213
+ if full_audio_inference:
214
+ self.eval_audios = []
215
+ self.eval_dataset_names = []
216
+ if self.cfg.preprocess.use_mel:
217
+ self.eval_mels = []
218
+ if self.cfg.preprocess.use_frame_pitch:
219
+ self.eval_pitchs = []
220
+ for dataset in self.datasets:
221
+ self.eval_audios.append(dataset.eval_audio)
222
+ self.eval_dataset_names.append(dataset.get_dataset_name())
223
+ if self.cfg.preprocess.use_mel:
224
+ self.eval_mels.append(dataset.eval_mel)
225
+ if self.cfg.preprocess.use_frame_pitch:
226
+ self.eval_pitchs.append(dataset.eval_pitch)
227
+
228
+
229
+ class VocoderCollator(object):
230
+ """Zero-pads model inputs and targets based on number of frames per step"""
231
+
232
+ def __init__(self, cfg):
233
+ self.cfg = cfg
234
+
235
+ def __call__(self, batch):
236
+ packed_batch_features = dict()
237
+
238
+ # mel: [b, n_mels, frame]
239
+ # frame_pitch: [b, frame]
240
+ # audios: [b, frame * hop_size]
241
+
242
+ for key in batch[0].keys():
243
+ if key == "target_len":
244
+ packed_batch_features["target_len"] = torch.LongTensor(
245
+ [b["target_len"] for b in batch]
246
+ )
247
+ masks = [
248
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
249
+ ]
250
+ packed_batch_features["mask"] = pad_sequence(
251
+ masks, batch_first=True, padding_value=0
252
+ )
253
+ elif key == "mel":
254
+ values = [torch.from_numpy(b[key]).T for b in batch]
255
+ packed_batch_features[key] = pad_sequence(
256
+ values, batch_first=True, padding_value=0
257
+ )
258
+ else:
259
+ values = [torch.from_numpy(b[key]) for b in batch]
260
+ packed_batch_features[key] = pad_sequence(
261
+ values, batch_first=True, padding_value=0
262
+ )
263
+
264
+ return packed_batch_features
Amphion/models/vocoders/vocoder_sampler.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import random
8
+
9
+ from torch.utils.data import ConcatDataset, Dataset
10
+ from torch.utils.data.sampler import (
11
+ BatchSampler,
12
+ RandomSampler,
13
+ Sampler,
14
+ SequentialSampler,
15
+ )
16
+
17
+
18
+ class ScheduledSampler(Sampler):
19
+ """A sampler that samples data from a given concat-dataset.
20
+
21
+ Args:
22
+ concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
23
+ batch_size (int): batch size
24
+ holistic_shuffle (bool): whether to shuffle the whole dataset or not
25
+ logger (logging.Logger): logger to print warning message
26
+
27
+ Usage:
28
+ For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
29
+ >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
30
+ [3, 4, 5, 0, 1, 2, 6, 7, 8]
31
+ """
32
+
33
+ def __init__(
34
+ self, concat_dataset, batch_size, holistic_shuffle, logger=None, type="train"
35
+ ):
36
+ if not isinstance(concat_dataset, ConcatDataset):
37
+ raise ValueError(
38
+ "concat_dataset must be an instance of ConcatDataset, but got {}".format(
39
+ type(concat_dataset)
40
+ )
41
+ )
42
+ if not isinstance(batch_size, int):
43
+ raise ValueError(
44
+ "batch_size must be an integer, but got {}".format(type(batch_size))
45
+ )
46
+ if not isinstance(holistic_shuffle, bool):
47
+ raise ValueError(
48
+ "holistic_shuffle must be a boolean, but got {}".format(
49
+ type(holistic_shuffle)
50
+ )
51
+ )
52
+
53
+ self.concat_dataset = concat_dataset
54
+ self.batch_size = batch_size
55
+ self.holistic_shuffle = holistic_shuffle
56
+
57
+ affected_dataset_name = []
58
+ affected_dataset_len = []
59
+ for dataset in concat_dataset.datasets:
60
+ dataset_len = len(dataset)
61
+ dataset_name = dataset.get_dataset_name()
62
+ if dataset_len < batch_size:
63
+ affected_dataset_name.append(dataset_name)
64
+ affected_dataset_len.append(dataset_len)
65
+
66
+ self.type = type
67
+ for dataset_name, dataset_len in zip(
68
+ affected_dataset_name, affected_dataset_len
69
+ ):
70
+ if not type == "valid":
71
+ logger.warning(
72
+ "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
73
+ type, dataset_name, dataset_len, batch_size
74
+ )
75
+ )
76
+
77
+ def __len__(self):
78
+ # the number of batches with drop last
79
+ num_of_batches = sum(
80
+ [
81
+ math.floor(len(dataset) / self.batch_size)
82
+ for dataset in self.concat_dataset.datasets
83
+ ]
84
+ )
85
+ return num_of_batches * self.batch_size
86
+
87
+ def __iter__(self):
88
+ iters = []
89
+ for dataset in self.concat_dataset.datasets:
90
+ iters.append(
91
+ SequentialSampler(dataset).__iter__()
92
+ if self.holistic_shuffle
93
+ else RandomSampler(dataset).__iter__()
94
+ )
95
+ init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
96
+ output_batches = []
97
+ for dataset_idx in range(len(self.concat_dataset.datasets)):
98
+ cur_batch = []
99
+ for idx in iters[dataset_idx]:
100
+ cur_batch.append(idx + init_indices[dataset_idx])
101
+ if len(cur_batch) == self.batch_size:
102
+ output_batches.append(cur_batch)
103
+ cur_batch = []
104
+ if self.type == "valid" and len(cur_batch) > 0:
105
+ output_batches.append(cur_batch)
106
+ cur_batch = []
107
+ # force drop last in training
108
+ random.shuffle(output_batches)
109
+ output_indices = [item for sublist in output_batches for item in sublist]
110
+ return iter(output_indices)
111
+
112
+
113
+ def build_samplers(concat_dataset: Dataset, cfg, logger, type):
114
+ sampler = ScheduledSampler(
115
+ concat_dataset,
116
+ cfg.train.batch_size,
117
+ cfg.train.sampler.holistic_shuffle,
118
+ logger,
119
+ type,
120
+ )
121
+ batch_sampler = BatchSampler(
122
+ sampler,
123
+ cfg.train.batch_size,
124
+ cfg.train.sampler.drop_last if not type == "valid" else False,
125
+ )
126
+ return sampler, batch_sampler
Amphion/modules/activation_functions/snake.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from torch import nn, pow, sin
8
+ from torch.nn import Parameter
9
+
10
+
11
+ class Snake(nn.Module):
12
+ r"""Implementation of a sine-based periodic activation function.
13
+ Alpha is initialized to 1 by default, higher values means higher frequency.
14
+ It will be trained along with the rest of your model.
15
+
16
+ Args:
17
+ in_features: shape of the input
18
+ alpha: trainable parameter
19
+
20
+ Shape:
21
+ - Input: (B, C, T)
22
+ - Output: (B, C, T), same shape as the input
23
+
24
+ References:
25
+ This activation function is from this paper by Liu Ziyin, Tilman Hartwig,
26
+ Masahito Ueda: https://arxiv.org/abs/2006.08195
27
+
28
+ Examples:
29
+ >>> a1 = Snake(256)
30
+ >>> x = torch.randn(256)
31
+ >>> x = a1(x)
32
+ """
33
+
34
+ def __init__(
35
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
36
+ ):
37
+ super(Snake, self).__init__()
38
+ self.in_features = in_features
39
+
40
+ # initialize alpha
41
+ self.alpha_logscale = alpha_logscale
42
+ if self.alpha_logscale: # log scale alphas initialized to zeros
43
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
44
+ else: # linear scale alphas initialized to ones
45
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
46
+
47
+ self.alpha.requires_grad = alpha_trainable
48
+
49
+ self.no_div_by_zero = 0.000000001
50
+
51
+ def forward(self, x):
52
+ r"""Forward pass of the function. Applies the function to the input elementwise.
53
+ Snake ∶= x + 1/a * sin^2 (ax)
54
+ """
55
+
56
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
57
+ if self.alpha_logscale:
58
+ alpha = torch.exp(alpha)
59
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
60
+
61
+ return x
62
+
63
+
64
+ class SnakeBeta(nn.Module):
65
+ r"""A modified Snake function which uses separate parameters for the magnitude
66
+ of the periodic components. Alpha is initialized to 1 by default,
67
+ higher values means higher frequency. Beta is initialized to 1 by default,
68
+ higher values means higher magnitude. Both will be trained along with the
69
+ rest of your model.
70
+
71
+ Args:
72
+ in_features: shape of the input
73
+ alpha: trainable parameter that controls frequency
74
+ beta: trainable parameter that controls magnitude
75
+
76
+ Shape:
77
+ - Input: (B, C, T)
78
+ - Output: (B, C, T), same shape as the input
79
+
80
+ References:
81
+ This activation function is a modified version based on this paper by Liu Ziyin,
82
+ Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195
83
+
84
+ Examples:
85
+ >>> a1 = SnakeBeta(256)
86
+ >>> x = torch.randn(256)
87
+ >>> x = a1(x)
88
+ """
89
+
90
+ def __init__(
91
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
92
+ ):
93
+ super(SnakeBeta, self).__init__()
94
+ self.in_features = in_features
95
+
96
+ # initialize alpha
97
+ self.alpha_logscale = alpha_logscale
98
+ if self.alpha_logscale: # log scale alphas initialized to zeros
99
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
100
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
101
+ else: # linear scale alphas initialized to ones
102
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
103
+ self.beta = Parameter(torch.ones(in_features) * alpha)
104
+
105
+ self.alpha.requires_grad = alpha_trainable
106
+ self.beta.requires_grad = alpha_trainable
107
+
108
+ self.no_div_by_zero = 0.000000001
109
+
110
+ def forward(self, x):
111
+ r"""Forward pass of the function. Applies the function to the input elementwise.
112
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
113
+ """
114
+
115
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
116
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
117
+ if self.alpha_logscale:
118
+ alpha = torch.exp(alpha)
119
+ beta = torch.exp(beta)
120
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
121
+
122
+ return x
Amphion/modules/diffusion/bidilconv/bidilated_conv.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch.nn as nn
9
+
10
+ from modules.general.utils import Conv1d, zero_module
11
+ from .residual_block import ResidualBlock
12
+
13
+
14
+ class BiDilConv(nn.Module):
15
+ r"""Dilated CNN architecture with residual connections, default diffusion decoder.
16
+
17
+ Args:
18
+ input_channel: The number of input channels.
19
+ base_channel: The number of base channels.
20
+ n_res_block: The number of residual blocks.
21
+ conv_kernel_size: The kernel size of convolutional layers.
22
+ dilation_cycle_length: The cycle length of dilation.
23
+ conditioner_size: The size of conditioner.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ input_channel,
29
+ base_channel,
30
+ n_res_block,
31
+ conv_kernel_size,
32
+ dilation_cycle_length,
33
+ conditioner_size,
34
+ output_channel: int = -1,
35
+ ):
36
+ super().__init__()
37
+
38
+ self.input_channel = input_channel
39
+ self.base_channel = base_channel
40
+ self.n_res_block = n_res_block
41
+ self.conv_kernel_size = conv_kernel_size
42
+ self.dilation_cycle_length = dilation_cycle_length
43
+ self.conditioner_size = conditioner_size
44
+ self.output_channel = output_channel if output_channel > 0 else input_channel
45
+
46
+ self.input = nn.Sequential(
47
+ Conv1d(
48
+ input_channel,
49
+ base_channel,
50
+ 1,
51
+ ),
52
+ nn.ReLU(),
53
+ )
54
+
55
+ self.residual_blocks = nn.ModuleList(
56
+ [
57
+ ResidualBlock(
58
+ channels=base_channel,
59
+ kernel_size=conv_kernel_size,
60
+ dilation=2 ** (i % dilation_cycle_length),
61
+ d_context=conditioner_size,
62
+ )
63
+ for i in range(n_res_block)
64
+ ]
65
+ )
66
+
67
+ self.out_proj = nn.Sequential(
68
+ Conv1d(
69
+ base_channel,
70
+ base_channel,
71
+ 1,
72
+ ),
73
+ nn.ReLU(),
74
+ zero_module(
75
+ Conv1d(
76
+ base_channel,
77
+ self.output_channel,
78
+ 1,
79
+ ),
80
+ ),
81
+ )
82
+
83
+ def forward(self, x, y, context=None):
84
+ """
85
+ Args:
86
+ x: Noisy mel-spectrogram [B x ``n_mel`` x L]
87
+ y: FILM embeddings with the shape of (B, ``base_channel``)
88
+ context: Context with the shape of [B x ``d_context`` x L], default to None.
89
+ """
90
+
91
+ h = self.input(x)
92
+
93
+ skip = None
94
+ for i in range(self.n_res_block):
95
+ h, skip_connection = self.residual_blocks[i](h, y, context)
96
+ skip = skip_connection if skip is None else skip_connection + skip
97
+
98
+ out = skip / math.sqrt(self.n_res_block)
99
+
100
+ out = self.out_proj(out)
101
+
102
+ return out
Amphion/modules/diffusion/karras/sample.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ from scipy.stats import norm
11
+ import torch.distributed as dist
12
+
13
+
14
+ def create_named_schedule_sampler(name, diffusion):
15
+ """
16
+ Create a ScheduleSampler from a library of pre-defined samplers.
17
+
18
+ :param name: the name of the sampler.
19
+ :param diffusion: the diffusion object to sample for.
20
+ """
21
+ if name == "uniform":
22
+ return UniformSampler(diffusion)
23
+ elif name == "loss-second-moment":
24
+ return LossSecondMomentResampler(diffusion)
25
+ elif name == "lognormal":
26
+ return LogNormalSampler()
27
+ else:
28
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
29
+
30
+
31
+ class ScheduleSampler(ABC):
32
+ """
33
+ A distribution over timesteps in the diffusion process, intended to reduce
34
+ variance of the objective.
35
+
36
+ By default, samplers perform unbiased importance sampling, in which the
37
+ objective's mean is unchanged.
38
+ However, subclasses may override sample() to change how the resampled
39
+ terms are reweighted, allowing for actual changes in the objective.
40
+ """
41
+
42
+ @abstractmethod
43
+ def weights(self):
44
+ """
45
+ Get a numpy array of weights, one per diffusion step.
46
+
47
+ The weights needn't be normalized, but must be positive.
48
+ """
49
+
50
+ def sample(self, batch_size, device):
51
+ """
52
+ Importance-sample timesteps for a batch.
53
+
54
+ :param batch_size: the number of timesteps.
55
+ :param device: the torch device to save to.
56
+ :return: a tuple (timesteps, weights):
57
+ - timesteps: a tensor of timestep indices.
58
+ - weights: a tensor of weights to scale the resulting losses.
59
+ """
60
+ w = self.weights()
61
+ p = w / np.sum(w)
62
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
63
+ indices = th.from_numpy(indices_np).long().to(device)
64
+ weights_np = 1 / (len(p) * p[indices_np])
65
+ weights = th.from_numpy(weights_np).float().to(device)
66
+ return indices, weights
67
+
68
+
69
+ class UniformSampler(ScheduleSampler):
70
+ def __init__(self, diffusion):
71
+ self.diffusion = diffusion
72
+ self._weights = np.ones([diffusion.num_timesteps])
73
+
74
+ def weights(self):
75
+ return self._weights
76
+
77
+
78
+ class LossAwareSampler(ScheduleSampler):
79
+ def update_with_local_losses(self, local_ts, local_losses):
80
+ """
81
+ Update the reweighting using losses from a model.
82
+
83
+ Call this method from each rank with a batch of timesteps and the
84
+ corresponding losses for each of those timesteps.
85
+ This method will perform synchronization to make sure all of the ranks
86
+ maintain the exact same reweighting.
87
+
88
+ :param local_ts: an integer Tensor of timesteps.
89
+ :param local_losses: a 1D Tensor of losses.
90
+ """
91
+ batch_sizes = [
92
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
93
+ for _ in range(dist.get_world_size())
94
+ ]
95
+ dist.all_gather(
96
+ batch_sizes,
97
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
98
+ )
99
+
100
+ # Pad all_gather batches to be the maximum batch size.
101
+ batch_sizes = [x.item() for x in batch_sizes]
102
+ max_bs = max(batch_sizes)
103
+
104
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
105
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
106
+ dist.all_gather(timestep_batches, local_ts)
107
+ dist.all_gather(loss_batches, local_losses)
108
+ timesteps = [
109
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
110
+ ]
111
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
112
+ self.update_with_all_losses(timesteps, losses)
113
+
114
+ @abstractmethod
115
+ def update_with_all_losses(self, ts, losses):
116
+ """
117
+ Update the reweighting using losses from a model.
118
+
119
+ Sub-classes should override this method to update the reweighting
120
+ using losses from the model.
121
+
122
+ This method directly updates the reweighting without synchronizing
123
+ between workers. It is called by update_with_local_losses from all
124
+ ranks with identical arguments. Thus, it should have deterministic
125
+ behavior to maintain state across workers.
126
+
127
+ :param ts: a list of int timesteps.
128
+ :param losses: a list of float losses, one per timestep.
129
+ """
130
+
131
+
132
+ class LossSecondMomentResampler(LossAwareSampler):
133
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
134
+ self.diffusion = diffusion
135
+ self.history_per_term = history_per_term
136
+ self.uniform_prob = uniform_prob
137
+ self._loss_history = np.zeros(
138
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
139
+ )
140
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
141
+
142
+ def weights(self):
143
+ if not self._warmed_up():
144
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
145
+ weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
146
+ weights /= np.sum(weights)
147
+ weights *= 1 - self.uniform_prob
148
+ weights += self.uniform_prob / len(weights)
149
+ return weights
150
+
151
+ def update_with_all_losses(self, ts, losses):
152
+ for t, loss in zip(ts, losses):
153
+ if self._loss_counts[t] == self.history_per_term:
154
+ # Shift out the oldest loss term.
155
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
156
+ self._loss_history[t, -1] = loss
157
+ else:
158
+ self._loss_history[t, self._loss_counts[t]] = loss
159
+ self._loss_counts[t] += 1
160
+
161
+ def _warmed_up(self):
162
+ return (self._loss_counts == self.history_per_term).all()
163
+
164
+
165
+ class LogNormalSampler:
166
+ def __init__(self, p_mean=-1.2, p_std=1.2, even=False):
167
+ self.p_mean = p_mean
168
+ self.p_std = p_std
169
+ self.even = even
170
+ if self.even:
171
+ self.inv_cdf = lambda x: norm.ppf(x, loc=p_mean, scale=p_std)
172
+ self.rank, self.size = dist.get_rank(), dist.get_world_size()
173
+
174
+ def sample(self, bs, device):
175
+ if self.even:
176
+ # buckets = [1/G]
177
+ start_i, end_i = self.rank * bs, (self.rank + 1) * bs
178
+ global_batch_size = self.size * bs
179
+ locs = (th.arange(start_i, end_i) + th.rand(bs)) / global_batch_size
180
+ log_sigmas = th.tensor(self.inv_cdf(locs), dtype=th.float32, device=device)
181
+ else:
182
+ log_sigmas = self.p_mean + self.p_std * th.randn(bs, device=device)
183
+ sigmas = th.exp(log_sigmas)
184
+ weights = th.ones_like(sigmas)
185
+ return sigmas, weights
Amphion/modules/diffusion/unet/attention.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from modules.general.utils import Conv1d, normalization, zero_module
11
+ from .basic import UNetBlock
12
+
13
+
14
+ class AttentionBlock(UNetBlock):
15
+ r"""A spatial transformer encoder block that allows spatial positions to attend
16
+ to each other. Reference from `latent diffusion repo
17
+ <https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/attention.py#L531>`_.
18
+
19
+ Args:
20
+ channels: Number of channels in the input.
21
+ num_head_channels: Number of channels per attention head.
22
+ num_heads: Number of attention heads. Overrides ``num_head_channels`` if set.
23
+ encoder_channels: Number of channels in the encoder output for cross-attention.
24
+ If ``None``, then self-attention is performed.
25
+ use_self_attention: Whether to use self-attention before cross-attention, only applicable if encoder_channels is set.
26
+ dims: Number of spatial dimensions, i.e. 1 for temporal signals, 2 for images.
27
+ h_dim: The dimension of the height, would be applied if ``dims`` is 2.
28
+ encoder_hdim: The dimension of the height of the encoder output, would be applied if ``dims`` is 2.
29
+ p_dropout: Dropout probability.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ channels: int,
35
+ num_head_channels: int = 32,
36
+ num_heads: int = -1,
37
+ encoder_channels: int = None,
38
+ use_self_attention: bool = False,
39
+ dims: int = 1,
40
+ h_dim: int = 100,
41
+ encoder_hdim: int = 384,
42
+ p_dropout: float = 0.0,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.channels = channels
47
+ self.p_dropout = p_dropout
48
+ self.dims = dims
49
+
50
+ if dims == 1:
51
+ self.channels = channels
52
+ elif dims == 2:
53
+ # We consider the channel as product of channel and height, i.e. C x H
54
+ # This is because we want to apply attention on the audio signal, which is 1D
55
+ self.channels = channels * h_dim
56
+ else:
57
+ raise ValueError(f"invalid number of dimensions: {dims}")
58
+
59
+ if num_head_channels == -1:
60
+ assert (
61
+ self.channels % num_heads == 0
62
+ ), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}"
63
+ self.num_heads = num_heads
64
+ self.num_head_channels = self.channels // num_heads
65
+ else:
66
+ assert (
67
+ self.channels % num_head_channels == 0
68
+ ), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}"
69
+ self.num_heads = self.channels // num_head_channels
70
+ self.num_head_channels = num_head_channels
71
+
72
+ if encoder_channels is not None:
73
+ self.use_self_attention = use_self_attention
74
+
75
+ if dims == 1:
76
+ self.encoder_channels = encoder_channels
77
+ elif dims == 2:
78
+ self.encoder_channels = encoder_channels * encoder_hdim
79
+ else:
80
+ raise ValueError(f"invalid number of dimensions: {dims}")
81
+
82
+ if use_self_attention:
83
+ self.self_attention = BasicAttentionBlock(
84
+ self.channels,
85
+ self.num_head_channels,
86
+ self.num_heads,
87
+ p_dropout=self.p_dropout,
88
+ )
89
+ self.cross_attention = BasicAttentionBlock(
90
+ self.channels,
91
+ self.num_head_channels,
92
+ self.num_heads,
93
+ self.encoder_channels,
94
+ p_dropout=self.p_dropout,
95
+ )
96
+ else:
97
+ self.encoder_channels = None
98
+ self.self_attention = BasicAttentionBlock(
99
+ self.channels,
100
+ self.num_head_channels,
101
+ self.num_heads,
102
+ p_dropout=self.p_dropout,
103
+ )
104
+
105
+ def forward(self, x: torch.Tensor, encoder_output: torch.Tensor = None):
106
+ r"""
107
+ Args:
108
+ x: input tensor with shape [B x ``channels`` x ...]
109
+ encoder_output: feature tensor with shape [B x ``encoder_channels`` x ...], if ``None``, then self-attention is performed.
110
+
111
+ Returns:
112
+ output tensor with shape [B x ``channels`` x ...]
113
+ """
114
+ shape = x.size()
115
+ x = x.reshape(shape[0], self.channels, -1).contiguous()
116
+
117
+ if self.encoder_channels is None:
118
+ assert (
119
+ encoder_output is None
120
+ ), "encoder_output must be None for self-attention."
121
+ h = self.self_attention(x)
122
+
123
+ else:
124
+ assert (
125
+ encoder_output is not None
126
+ ), "encoder_output must be given for cross-attention."
127
+ encoder_output = encoder_output.reshape(
128
+ shape[0], self.encoder_channels, -1
129
+ ).contiguous()
130
+
131
+ if self.use_self_attention:
132
+ x = self.self_attention(x)
133
+ h = self.cross_attention(x, encoder_output)
134
+
135
+ return h.reshape(*shape).contiguous()
136
+
137
+
138
+ class BasicAttentionBlock(nn.Module):
139
+ def __init__(
140
+ self,
141
+ channels: int,
142
+ num_head_channels: int = 32,
143
+ num_heads: int = -1,
144
+ context_channels: int = None,
145
+ p_dropout: float = 0.0,
146
+ ):
147
+ super().__init__()
148
+
149
+ self.channels = channels
150
+ self.p_dropout = p_dropout
151
+ self.context_channels = context_channels
152
+
153
+ if num_head_channels == -1:
154
+ assert (
155
+ self.channels % num_heads == 0
156
+ ), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}"
157
+ self.num_heads = num_heads
158
+ self.num_head_channels = self.channels // num_heads
159
+ else:
160
+ assert (
161
+ self.channels % num_head_channels == 0
162
+ ), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}"
163
+ self.num_heads = self.channels // num_head_channels
164
+ self.num_head_channels = num_head_channels
165
+
166
+ if context_channels is not None:
167
+ self.to_q = nn.Sequential(
168
+ normalization(self.channels),
169
+ Conv1d(self.channels, self.channels, 1),
170
+ )
171
+ self.to_kv = Conv1d(context_channels, 2 * self.channels, 1)
172
+ else:
173
+ self.to_qkv = nn.Sequential(
174
+ normalization(self.channels),
175
+ Conv1d(self.channels, 3 * self.channels, 1),
176
+ )
177
+
178
+ self.linear = Conv1d(self.channels, self.channels)
179
+
180
+ self.proj_out = nn.Sequential(
181
+ normalization(self.channels),
182
+ Conv1d(self.channels, self.channels, 1),
183
+ nn.GELU(),
184
+ nn.Dropout(p=self.p_dropout),
185
+ zero_module(Conv1d(self.channels, self.channels, 1)),
186
+ )
187
+
188
+ def forward(self, q: torch.Tensor, kv: torch.Tensor = None):
189
+ r"""
190
+ Args:
191
+ q: input tensor with shape [B, ``channels``, L]
192
+ kv: feature tensor with shape [B, ``context_channels``, T], if ``None``, then self-attention is performed.
193
+
194
+ Returns:
195
+ output tensor with shape [B, ``channels``, L]
196
+ """
197
+ N, C, L = q.size()
198
+
199
+ if self.context_channels is not None:
200
+ assert kv is not None, "kv must be given for cross-attention."
201
+
202
+ q = (
203
+ self.to_q(q)
204
+ .reshape(self.num_heads, self.num_head_channels, -1)
205
+ .transpose(-1, -2)
206
+ .contiguous()
207
+ )
208
+ kv = (
209
+ self.to_kv(kv)
210
+ .reshape(2, self.num_heads, self.num_head_channels, -1)
211
+ .transpose(-1, -2)
212
+ .chunk(2)
213
+ )
214
+ k, v = (
215
+ kv[0].squeeze(0).contiguous(),
216
+ kv[1].squeeze(0).contiguous(),
217
+ )
218
+
219
+ else:
220
+ qkv = (
221
+ self.to_qkv(q)
222
+ .reshape(3, self.num_heads, self.num_head_channels, -1)
223
+ .transpose(-1, -2)
224
+ .chunk(3)
225
+ )
226
+ q, k, v = (
227
+ qkv[0].squeeze(0).contiguous(),
228
+ qkv[1].squeeze(0).contiguous(),
229
+ qkv[2].squeeze(0).contiguous(),
230
+ )
231
+
232
+ h = F.scaled_dot_product_attention(q, k, v, dropout_p=self.p_dropout).transpose(
233
+ -1, -2
234
+ )
235
+ h = h.reshape(N, -1, L).contiguous()
236
+ h = self.linear(h)
237
+
238
+ x = q + h
239
+ h = self.proj_out(x)
240
+
241
+ return x + h
Amphion/modules/diffusion/unet/resblock.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from .basic import UNetBlock
10
+ from modules.general.utils import (
11
+ append_dims,
12
+ ConvNd,
13
+ normalization,
14
+ zero_module,
15
+ )
16
+
17
+
18
+ class ResBlock(UNetBlock):
19
+ r"""A residual block that can optionally change the number of channels.
20
+
21
+ Args:
22
+ channels: the number of input channels.
23
+ emb_channels: the number of timestep embedding channels.
24
+ dropout: the rate of dropout.
25
+ out_channels: if specified, the number of out channels.
26
+ use_conv: if True and out_channels is specified, use a spatial
27
+ convolution instead of a smaller 1x1 convolution to change the
28
+ channels in the skip connection.
29
+ dims: determines if the signal is 1D, 2D, or 3D.
30
+ up: if True, use this block for upsampling.
31
+ down: if True, use this block for downsampling.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ channels,
37
+ emb_channels,
38
+ dropout: float = 0.0,
39
+ out_channels=None,
40
+ use_conv=False,
41
+ use_scale_shift_norm=False,
42
+ dims=2,
43
+ up=False,
44
+ down=False,
45
+ ):
46
+ super().__init__()
47
+ self.channels = channels
48
+ self.emb_channels = emb_channels
49
+ self.dropout = dropout
50
+ self.out_channels = out_channels or channels
51
+ self.use_conv = use_conv
52
+ self.use_scale_shift_norm = use_scale_shift_norm
53
+
54
+ self.in_layers = nn.Sequential(
55
+ normalization(channels),
56
+ nn.SiLU(),
57
+ ConvNd(dims, channels, self.out_channels, 3, padding=1),
58
+ )
59
+
60
+ self.updown = up or down
61
+
62
+ if up:
63
+ self.h_upd = Upsample(channels, False, dims)
64
+ self.x_upd = Upsample(channels, False, dims)
65
+ elif down:
66
+ self.h_upd = Downsample(channels, False, dims)
67
+ self.x_upd = Downsample(channels, False, dims)
68
+ else:
69
+ self.h_upd = self.x_upd = nn.Identity()
70
+
71
+ self.emb_layers = nn.Sequential(
72
+ nn.SiLU(),
73
+ ConvNd(
74
+ dims,
75
+ emb_channels,
76
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
77
+ 1,
78
+ ),
79
+ )
80
+ self.out_layers = nn.Sequential(
81
+ normalization(self.out_channels),
82
+ nn.SiLU(),
83
+ nn.Dropout(p=dropout),
84
+ zero_module(
85
+ ConvNd(dims, self.out_channels, self.out_channels, 3, padding=1)
86
+ ),
87
+ )
88
+
89
+ if self.out_channels == channels:
90
+ self.skip_connection = nn.Identity()
91
+ elif use_conv:
92
+ self.skip_connection = ConvNd(
93
+ dims, channels, self.out_channels, 3, padding=1
94
+ )
95
+ else:
96
+ self.skip_connection = ConvNd(dims, channels, self.out_channels, 1)
97
+
98
+ def forward(self, x, emb):
99
+ """
100
+ Apply the block to a Tensor, conditioned on a timestep embedding.
101
+
102
+ x: an [N x C x ...] Tensor of features.
103
+ emb: an [N x emb_channels x ...] Tensor of timestep embeddings.
104
+ :return: an [N x C x ...] Tensor of outputs.
105
+ """
106
+ if self.updown:
107
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
108
+ h = in_rest(x)
109
+ h = self.h_upd(h)
110
+ x = self.x_upd(x)
111
+ h = in_conv(h)
112
+ else:
113
+ h = self.in_layers(x)
114
+ emb_out = self.emb_layers(emb)
115
+ emb_out = append_dims(emb_out, h.dim())
116
+ if self.use_scale_shift_norm:
117
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
118
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
119
+ h = out_norm(h) * (1 + scale) + shift
120
+ h = out_rest(h)
121
+ else:
122
+ h = h + emb_out
123
+ h = self.out_layers(h)
124
+ return self.skip_connection(x) + h
125
+
126
+
127
+ class Upsample(nn.Module):
128
+ r"""An upsampling layer with an optional convolution.
129
+
130
+ Args:
131
+ channels: channels in the inputs and outputs.
132
+ dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
133
+ upsampling occurs in the inner-two dimensions.
134
+ out_channels: if specified, the number of out channels.
135
+ """
136
+
137
+ def __init__(self, channels, dims=2, out_channels=None):
138
+ super().__init__()
139
+ self.channels = channels
140
+ self.out_channels = out_channels or channels
141
+ self.dims = dims
142
+ self.conv = ConvNd(dims, self.channels, self.out_channels, 3, padding=1)
143
+
144
+ def forward(self, x):
145
+ assert x.shape[1] == self.channels
146
+ if self.dims == 3:
147
+ x = F.interpolate(
148
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
149
+ )
150
+ else:
151
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
152
+ x = self.conv(x)
153
+ return x
154
+
155
+
156
+ class Downsample(nn.Module):
157
+ r"""A downsampling layer with an optional convolution.
158
+
159
+ Args:
160
+ channels: channels in the inputs and outputs.
161
+ dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
162
+ downsampling occurs in the inner-two dimensions.
163
+ out_channels: if specified, the number of output channels.
164
+ """
165
+
166
+ def __init__(self, channels, dims=2, out_channels=None):
167
+ super().__init__()
168
+ self.channels = channels
169
+ self.out_channels = out_channels or channels
170
+ self.dims = dims
171
+ stride = 2 if dims != 3 else (1, 2, 2)
172
+ self.op = ConvNd(
173
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
174
+ )
175
+
176
+ def forward(self, x):
177
+ assert x.shape[1] == self.channels
178
+ return self.op(x)
Amphion/modules/diffusion/unet/unet.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from modules.encoder.position_encoder import PositionEncoder
10
+ from modules.general.utils import append_dims, ConvNd, normalization, zero_module
11
+ from .attention import AttentionBlock
12
+ from .resblock import Downsample, ResBlock, Upsample
13
+
14
+
15
+ class UNet(nn.Module):
16
+ r"""The full UNet model with attention and timestep embedding.
17
+
18
+ Args:
19
+ dims: determines if the signal is 1D (temporal), 2D(spatial).
20
+ in_channels: channels in the input Tensor.
21
+ model_channels: base channel count for the model.
22
+ out_channels: channels in the output Tensor.
23
+ num_res_blocks: number of residual blocks per downsample.
24
+ channel_mult: channel multiplier for each level of the UNet.
25
+ num_attn_blocks: number of attention blocks at place.
26
+ attention_resolutions: a collection of downsample rates at which attention will
27
+ take place. May be a set, list, or tuple. For example, if this contains 4,
28
+ then at 4x downsampling, attention will be used.
29
+ num_heads: the number of attention heads in each attention layer.
30
+ num_head_channels: if specified, ignore num_heads and instead use a fixed
31
+ channel width per attention head.
32
+ d_context: if specified, use for cross-attention channel project.
33
+ p_dropout: the dropout probability.
34
+ use_self_attention: Apply self attention before cross attention.
35
+ num_classes: if specified (as an int), then this model will be class-conditional
36
+ with ``num_classes`` classes.
37
+ use_extra_film: if specified, use an extra FiLM-like conditioning mechanism.
38
+ d_emb: if specified, use for FiLM-like conditioning.
39
+ use_scale_shift_norm: use a FiLM-like conditioning mechanism.
40
+ resblock_updown: use residual blocks for up/downsampling.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dims: int = 1,
46
+ in_channels: int = 100,
47
+ model_channels: int = 128,
48
+ out_channels: int = 100,
49
+ h_dim: int = 128,
50
+ num_res_blocks: int = 1,
51
+ channel_mult: tuple = (1, 2, 4),
52
+ num_attn_blocks: int = 1,
53
+ attention_resolutions: tuple = (1, 2, 4),
54
+ num_heads: int = 1,
55
+ num_head_channels: int = -1,
56
+ d_context: int = None,
57
+ context_hdim: int = 128,
58
+ p_dropout: float = 0.0,
59
+ num_classes: int = -1,
60
+ use_extra_film: str = None,
61
+ d_emb: int = None,
62
+ use_scale_shift_norm: bool = True,
63
+ resblock_updown: bool = False,
64
+ ):
65
+ super().__init__()
66
+
67
+ self.dims = dims
68
+ self.in_channels = in_channels
69
+ self.model_channels = model_channels
70
+ self.out_channels = out_channels
71
+ self.num_res_blocks = num_res_blocks
72
+ self.channel_mult = channel_mult
73
+ self.num_attn_blocks = num_attn_blocks
74
+ self.attention_resolutions = attention_resolutions
75
+ self.num_heads = num_heads
76
+ self.num_head_channels = num_head_channels
77
+ self.d_context = d_context
78
+ self.p_dropout = p_dropout
79
+ self.num_classes = num_classes
80
+ self.use_extra_film = use_extra_film
81
+ self.d_emb = d_emb
82
+ self.use_scale_shift_norm = use_scale_shift_norm
83
+ self.resblock_updown = resblock_updown
84
+
85
+ time_embed_dim = model_channels * 4
86
+ self.pos_enc = PositionEncoder(model_channels, time_embed_dim)
87
+
88
+ assert (
89
+ num_classes == -1 or use_extra_film is None
90
+ ), "You cannot set both num_classes and use_extra_film."
91
+
92
+ if self.num_classes > 0:
93
+ # TODO: if used for singer, norm should be 1, correct?
94
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim, max_norm=1.0)
95
+ elif use_extra_film is not None:
96
+ assert (
97
+ d_emb is not None
98
+ ), "d_emb must be specified if use_extra_film is not None"
99
+ assert use_extra_film in [
100
+ "add",
101
+ "concat",
102
+ ], f"use_extra_film only supported by add or concat. Your input is {use_extra_film}"
103
+ self.use_extra_film = use_extra_film
104
+ self.film_emb = ConvNd(dims, d_emb, time_embed_dim, 1)
105
+ if use_extra_film == "concat":
106
+ time_embed_dim *= 2
107
+
108
+ # Input blocks
109
+ ch = input_ch = int(channel_mult[0] * model_channels)
110
+ self.input_blocks = nn.ModuleList(
111
+ [UNetSequential(ConvNd(dims, in_channels, ch, 3, padding=1))]
112
+ )
113
+ self._feature_size = ch
114
+ input_block_chans = [ch]
115
+ ds = 1
116
+ for level, mult in enumerate(channel_mult):
117
+ for _ in range(num_res_blocks):
118
+ layers = [
119
+ ResBlock(
120
+ ch,
121
+ time_embed_dim,
122
+ p_dropout,
123
+ out_channels=int(mult * model_channels),
124
+ dims=dims,
125
+ use_scale_shift_norm=use_scale_shift_norm,
126
+ )
127
+ ]
128
+ ch = int(mult * model_channels)
129
+ if ds in attention_resolutions:
130
+ for _ in range(num_attn_blocks):
131
+ layers.append(
132
+ AttentionBlock(
133
+ ch,
134
+ num_heads=num_heads,
135
+ num_head_channels=num_head_channels,
136
+ encoder_channels=d_context,
137
+ dims=dims,
138
+ h_dim=h_dim // (level + 1),
139
+ encoder_hdim=context_hdim,
140
+ p_dropout=p_dropout,
141
+ )
142
+ )
143
+ self.input_blocks.append(UNetSequential(*layers))
144
+ self._feature_size += ch
145
+ input_block_chans.append(ch)
146
+ if level != len(channel_mult) - 1:
147
+ out_ch = ch
148
+ self.input_blocks.append(
149
+ UNetSequential(
150
+ ResBlock(
151
+ ch,
152
+ time_embed_dim,
153
+ p_dropout,
154
+ out_channels=out_ch,
155
+ dims=dims,
156
+ use_scale_shift_norm=use_scale_shift_norm,
157
+ down=True,
158
+ )
159
+ if resblock_updown
160
+ else Downsample(ch, dims=dims, out_channels=out_ch)
161
+ )
162
+ )
163
+ ch = out_ch
164
+ input_block_chans.append(ch)
165
+ ds *= 2
166
+ self._feature_size += ch
167
+
168
+ # Middle blocks
169
+ self.middle_block = UNetSequential(
170
+ ResBlock(
171
+ ch,
172
+ time_embed_dim,
173
+ p_dropout,
174
+ dims=dims,
175
+ use_scale_shift_norm=use_scale_shift_norm,
176
+ ),
177
+ AttentionBlock(
178
+ ch,
179
+ num_heads=num_heads,
180
+ num_head_channels=num_head_channels,
181
+ encoder_channels=d_context,
182
+ dims=dims,
183
+ h_dim=h_dim // (level + 1),
184
+ encoder_hdim=context_hdim,
185
+ p_dropout=p_dropout,
186
+ ),
187
+ ResBlock(
188
+ ch,
189
+ time_embed_dim,
190
+ p_dropout,
191
+ dims=dims,
192
+ use_scale_shift_norm=use_scale_shift_norm,
193
+ ),
194
+ )
195
+ self._feature_size += ch
196
+
197
+ # Output blocks
198
+ self.output_blocks = nn.ModuleList([])
199
+ for level, mult in tuple(enumerate(channel_mult))[::-1]:
200
+ for i in range(num_res_blocks + 1):
201
+ ich = input_block_chans.pop()
202
+ layers = [
203
+ ResBlock(
204
+ ch + ich,
205
+ time_embed_dim,
206
+ p_dropout,
207
+ out_channels=int(model_channels * mult),
208
+ dims=dims,
209
+ use_scale_shift_norm=use_scale_shift_norm,
210
+ )
211
+ ]
212
+ ch = int(model_channels * mult)
213
+ if ds in attention_resolutions:
214
+ for _ in range(num_attn_blocks):
215
+ layers.append(
216
+ AttentionBlock(
217
+ ch,
218
+ num_heads=num_heads,
219
+ num_head_channels=num_head_channels,
220
+ encoder_channels=d_context,
221
+ dims=dims,
222
+ h_dim=h_dim // (level + 1),
223
+ encoder_hdim=context_hdim,
224
+ p_dropout=p_dropout,
225
+ )
226
+ )
227
+ if level and i == num_res_blocks:
228
+ out_ch = ch
229
+ layers.append(
230
+ ResBlock(
231
+ ch,
232
+ time_embed_dim,
233
+ p_dropout,
234
+ out_channels=out_ch,
235
+ dims=dims,
236
+ use_scale_shift_norm=use_scale_shift_norm,
237
+ up=True,
238
+ )
239
+ if resblock_updown
240
+ else Upsample(ch, dims=dims, out_channels=out_ch)
241
+ )
242
+ ds //= 2
243
+ self.output_blocks.append(UNetSequential(*layers))
244
+ self._feature_size += ch
245
+
246
+ # Final proj out
247
+ self.out = nn.Sequential(
248
+ normalization(ch),
249
+ nn.SiLU(),
250
+ zero_module(ConvNd(dims, input_ch, out_channels, 3, padding=1)),
251
+ )
252
+
253
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
254
+ r"""Apply the model to an input batch.
255
+
256
+ Args:
257
+ x: an [N x C x ...] Tensor of inputs.
258
+ timesteps: a 1-D batch of timesteps, i.e. [N].
259
+ context: conditioning Tensor with shape of [N x ``d_context`` x ...] plugged
260
+ in via cross attention.
261
+ y: an [N] Tensor of labels, if **class-conditional**.
262
+ an [N x ``d_emb`` x ...] Tensor if **film-embed conditional**.
263
+
264
+ Returns:
265
+ an [N x C x ...] Tensor of outputs.
266
+ """
267
+ assert (y is None) or (
268
+ (y is not None)
269
+ and ((self.num_classes > 0) or (self.use_extra_film is not None))
270
+ ), f"y must be specified if num_classes or use_extra_film is not None. \nGot num_classes: {self.num_classes}\t\nuse_extra_film: {self.use_extra_film}\t\n"
271
+
272
+ hs = []
273
+ emb = self.pos_enc(timesteps)
274
+ emb = append_dims(emb, x.dim())
275
+
276
+ if self.num_classes > 0:
277
+ assert y.size() == (x.size(0),)
278
+ emb = emb + self.label_emb(y)
279
+ elif self.use_extra_film is not None:
280
+ assert y.size() == (x.size(0), self.d_emb, *x.size()[2:])
281
+ y = self.film_emb(y)
282
+ if self.use_extra_film == "add":
283
+ emb = emb + y
284
+ elif self.use_extra_film == "concat":
285
+ emb = torch.cat([emb, y], dim=1)
286
+
287
+ h = x
288
+ for module in self.input_blocks:
289
+ h = module(h, emb, context)
290
+ hs.append(h)
291
+ h = self.middle_block(h, emb, context)
292
+ for module in self.output_blocks:
293
+ h = torch.cat([h, hs.pop()], dim=1)
294
+ h = module(h, emb, context)
295
+
296
+ return self.out(h)
297
+
298
+
299
+ class UNetSequential(nn.Sequential):
300
+ r"""A sequential module that passes embeddings to the children that support it."""
301
+
302
+ def forward(self, x, emb=None, context=None):
303
+ for layer in self:
304
+ if isinstance(layer, ResBlock):
305
+ x = layer(x, emb)
306
+ elif isinstance(layer, AttentionBlock):
307
+ x = layer(x, context)
308
+ else:
309
+ x = layer(x)
310
+ return x
Amphion/modules/encoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .token_encoder import TokenEmbedding
Amphion/modules/general/utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ def normalization(channels: int, groups: int = 32):
11
+ r"""Make a standard normalization layer, i.e. GroupNorm.
12
+
13
+ Args:
14
+ channels: number of input channels.
15
+ groups: number of groups for group normalization.
16
+
17
+ Returns:
18
+ a ``nn.Module`` for normalization.
19
+ """
20
+ assert groups > 0, f"invalid number of groups: {groups}"
21
+ return nn.GroupNorm(groups, channels)
22
+
23
+
24
+ def Linear(*args, **kwargs):
25
+ r"""Wrapper of ``nn.Linear`` with kaiming_normal_ initialization."""
26
+ layer = nn.Linear(*args, **kwargs)
27
+ nn.init.kaiming_normal_(layer.weight)
28
+ return layer
29
+
30
+
31
+ def Conv1d(*args, **kwargs):
32
+ r"""Wrapper of ``nn.Conv1d`` with kaiming_normal_ initialization."""
33
+ layer = nn.Conv1d(*args, **kwargs)
34
+ nn.init.kaiming_normal_(layer.weight)
35
+ return layer
36
+
37
+
38
+ def Conv2d(*args, **kwargs):
39
+ r"""Wrapper of ``nn.Conv2d`` with kaiming_normal_ initialization."""
40
+ layer = nn.Conv2d(*args, **kwargs)
41
+ nn.init.kaiming_normal_(layer.weight)
42
+ return layer
43
+
44
+
45
+ def ConvNd(dims: int = 1, *args, **kwargs):
46
+ r"""Wrapper of N-dimension convolution with kaiming_normal_ initialization.
47
+
48
+ Args:
49
+ dims: number of dimensions of the convolution.
50
+ """
51
+ if dims == 1:
52
+ return Conv1d(*args, **kwargs)
53
+ elif dims == 2:
54
+ return Conv2d(*args, **kwargs)
55
+ else:
56
+ raise ValueError(f"invalid number of dimensions: {dims}")
57
+
58
+
59
+ def zero_module(module: nn.Module):
60
+ r"""Zero out the parameters of a module and return it."""
61
+ nn.init.zeros_(module.weight)
62
+ nn.init.zeros_(module.bias)
63
+ return module
64
+
65
+
66
+ def scale_module(module: nn.Module, scale):
67
+ r"""Scale the parameters of a module and return it."""
68
+ for p in module.parameters():
69
+ p.detach().mul_(scale)
70
+ return module
71
+
72
+
73
+ def mean_flat(tensor: torch.Tensor):
74
+ r"""Take the mean over all non-batch dimensions."""
75
+ return tensor.mean(dim=tuple(range(1, tensor.dim())))
76
+
77
+
78
+ def append_dims(x, target_dims):
79
+ r"""Appends dimensions to the end of a tensor until
80
+ it has target_dims dimensions.
81
+ """
82
+ dims_to_append = target_dims - x.dim()
83
+ if dims_to_append < 0:
84
+ raise ValueError(
85
+ f"input has {x.dim()} dims but target_dims is {target_dims}, which is less"
86
+ )
87
+ return x[(...,) + (None,) * dims_to_append]
88
+
89
+
90
+ def append_zero(x, count=1):
91
+ r"""Appends ``count`` zeros to the end of a tensor along the last dimension."""
92
+ assert count > 0, f"invalid count: {count}"
93
+ return torch.cat([x, x.new_zeros((*x.size()[:-1], count))], dim=-1)
94
+
95
+
96
+ class Transpose(nn.Identity):
97
+ """(N, T, D) -> (N, D, T)"""
98
+
99
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
100
+ return input.transpose(1, 2)
Amphion/modules/norms/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .norm import AdaptiveLayerNorm, LayerNorm, BalancedBasicNorm, IdentityNorm