Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Amphion/egs/tts/VITS/README.md +221 -0
- Amphion/egs/vocoder/diffusion/exp_config_base.json +71 -0
- Amphion/egs/vocoder/gan/bigvgan_large/run.sh +141 -0
- Amphion/evaluation/metrics/similarity/__init__.py +0 -0
- Amphion/models/base/base_dataset.py +464 -0
- Amphion/models/base/base_inference.py +220 -0
- Amphion/models/base/new_inference.py +253 -0
- Amphion/models/codec/ns3_codec/alias_free_torch/__init__.py +5 -0
- Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/act.cpython-310.pyc +0 -0
- Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/filter.cpython-310.pyc +0 -0
- Amphion/models/codec/ns3_codec/alias_free_torch/act.py +29 -0
- Amphion/models/codec/ns3_codec/facodec.py +1163 -0
- Amphion/models/codec/ns3_codec/gradient_reversal.py +35 -0
- Amphion/models/codec/ns3_codec/melspec.py +102 -0
- Amphion/models/codec/ns3_codec/quantize/__pycache__/__init__.cpython-310.pyc +0 -0
- Amphion/models/codec/ns3_codec/quantize/fvq.py +116 -0
- Amphion/models/svc/transformer/transformer_inference.py +45 -0
- Amphion/models/svc/vits/vits_trainer.py +704 -0
- Amphion/models/tta/autoencoder/autoencoder_dataset.py +112 -0
- Amphion/models/tta/ldm/__init__.py +0 -0
- Amphion/models/tta/ldm/audioldm_dataset.py +151 -0
- Amphion/models/tta/ldm/audioldm_trainer.py +251 -0
- Amphion/models/tta/ldm/inference_utils/vocoder.py +408 -0
- Amphion/models/tts/base/__init__.py +7 -0
- Amphion/models/tts/base/tts_trainer.py +721 -0
- Amphion/models/tts/fastspeech2/fs2_trainer.py +155 -0
- Amphion/models/tts/naturalspeech2/ns2.py +259 -0
- Amphion/models/tts/naturalspeech2/ns2_dataset.py +524 -0
- Amphion/models/tts/naturalspeech2/ns2_inference.py +128 -0
- Amphion/models/tts/naturalspeech2/ns2_trainer.py +798 -0
- Amphion/models/tts/valle/__init__.py +0 -0
- Amphion/models/vocoders/autoregressive/autoregressive_vocoder_inference.py +0 -0
- Amphion/models/vocoders/autoregressive/wavenet/conv.py +66 -0
- Amphion/models/vocoders/autoregressive/wavenet/wavenet.py +170 -0
- Amphion/models/vocoders/diffusion/diffusion_vocoder_inference.py +131 -0
- Amphion/models/vocoders/flow/flow_vocoder_dataset.py +0 -0
- Amphion/models/vocoders/flow/flow_vocoder_inference.py +0 -0
- Amphion/models/vocoders/gan/discriminator/msd.py +88 -0
- Amphion/models/vocoders/gan/gan_vocoder_inference.py +96 -0
- Amphion/models/vocoders/vocoder_dataset.py +264 -0
- Amphion/models/vocoders/vocoder_sampler.py +126 -0
- Amphion/modules/activation_functions/snake.py +122 -0
- Amphion/modules/diffusion/bidilconv/bidilated_conv.py +102 -0
- Amphion/modules/diffusion/karras/sample.py +185 -0
- Amphion/modules/diffusion/unet/attention.py +241 -0
- Amphion/modules/diffusion/unet/resblock.py +178 -0
- Amphion/modules/diffusion/unet/unet.py +310 -0
- Amphion/modules/encoder/__init__.py +1 -0
- Amphion/modules/general/utils.py +100 -0
- Amphion/modules/norms/__init__.py +1 -0
Amphion/egs/tts/VITS/README.md
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VITS Recipe
|
| 2 |
+
|
| 3 |
+
[](https://huggingface.co/spaces/amphion/Text-to-Speech)
|
| 4 |
+
[](https://openxlab.org.cn/apps/detail/Amphion/Text-to-Speech)
|
| 5 |
+
|
| 6 |
+
In this recipe, we will show how to train VITS using Amphion's infrastructure. [VITS](https://arxiv.org/abs/2106.06103) is an end-to-end TTS architecture that utilizes a conditional variational autoencoder with adversarial learning.
|
| 7 |
+
|
| 8 |
+
There are four stages in total:
|
| 9 |
+
|
| 10 |
+
1. Data preparation
|
| 11 |
+
2. Features extraction
|
| 12 |
+
3. Training
|
| 13 |
+
4. Inference
|
| 14 |
+
|
| 15 |
+
> **NOTE:** You need to run every command of this recipe in the `Amphion` root path:
|
| 16 |
+
> ```bash
|
| 17 |
+
> cd Amphion
|
| 18 |
+
> ```
|
| 19 |
+
|
| 20 |
+
## 1. Data Preparation
|
| 21 |
+
|
| 22 |
+
### Dataset Download
|
| 23 |
+
You can use the commonly used TTS dataset to train the TTS model, e.g., LJSpeech, VCTK, Hi-Fi TTS, LibriTTS, etc. We strongly recommend using LJSpeech to train the single-speaker TTS model for the first time. While training the multi-speaker TTS model for the first time, we recommend using Hi-Fi TTS. The process of downloading the dataset has been detailed [here](../../datasets/README.md).
|
| 24 |
+
|
| 25 |
+
### Configuration
|
| 26 |
+
|
| 27 |
+
After downloading the dataset, you can set the dataset paths in `exp_config.json`. Note that you can change the `dataset` list to use your preferred datasets.
|
| 28 |
+
|
| 29 |
+
```json
|
| 30 |
+
"dataset": [
|
| 31 |
+
"LJSpeech",
|
| 32 |
+
//"hifitts"
|
| 33 |
+
],
|
| 34 |
+
"dataset_path": {
|
| 35 |
+
// TODO: Fill in your dataset path
|
| 36 |
+
"LJSpeech": "[LJSpeech dataset path]",
|
| 37 |
+
//"hifitts": "[Hi-Fi TTS dataset path]
|
| 38 |
+
},
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## 2. Features Extraction
|
| 42 |
+
|
| 43 |
+
### Configuration
|
| 44 |
+
|
| 45 |
+
In `exp_config.json`, specify the `log_dir` for saving the checkpoints and logs, and specify the `processed_dir` for saving processed data. For preprocessing the multi-speaker TTS dataset, set `extract_audio` and `use_spkid` to `true`:
|
| 46 |
+
|
| 47 |
+
```json
|
| 48 |
+
// TODO: Fill in the output log path. The default value is "Amphion/ckpts/tts"
|
| 49 |
+
"log_dir": "ckpts/tts",
|
| 50 |
+
"preprocess": {
|
| 51 |
+
//"extract_audio": true,
|
| 52 |
+
"use_phone": true,
|
| 53 |
+
// linguistic features
|
| 54 |
+
"extract_phone": true,
|
| 55 |
+
"phone_extractor": "espeak", // "espeak, pypinyin, pypinyin_initials_finals, lexicon (only for language=en-us right now)"
|
| 56 |
+
// TODO: Fill in the output data path. The default value is "Amphion/data"
|
| 57 |
+
"processed_dir": "data",
|
| 58 |
+
"sample_rate": 22050, //target sampling rate
|
| 59 |
+
"valid_file": "valid.json", //validation set
|
| 60 |
+
//"use_spkid": true, //use speaker ID to train multi-speaker TTS model
|
| 61 |
+
},
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
### Run
|
| 65 |
+
|
| 66 |
+
Run the `run.sh` as the preprocess stage (set `--stage 1`):
|
| 67 |
+
|
| 68 |
+
```bash
|
| 69 |
+
sh egs/tts/VITS/run.sh --stage 1
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "1"`.
|
| 73 |
+
|
| 74 |
+
## 3. Training
|
| 75 |
+
|
| 76 |
+
### Configuration
|
| 77 |
+
|
| 78 |
+
We provide the default hyperparameters in the `exp_config.json`. They can work on a single NVIDIA-24g GPU. You can adjust them based on your GPU machines.
|
| 79 |
+
For training the multi-speaker TTS model, specify the `n_speakers` value to be greater (used for new speaker fine-tuning) than or equal to the number of speakers in your dataset(s) and set `multi_speaker_training` to `true`.
|
| 80 |
+
|
| 81 |
+
```json
|
| 82 |
+
"model": {
|
| 83 |
+
//"n_speakers": 10 //Number of speakers in the dataset(s) used. The default value is 0 if not specified.
|
| 84 |
+
},
|
| 85 |
+
"train": {
|
| 86 |
+
"batch_size": 16,
|
| 87 |
+
//"multi_speaker_training": true,
|
| 88 |
+
}
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
### Train From Scratch
|
| 92 |
+
|
| 93 |
+
Run the `run.sh` as the training stage (set `--stage 2`). Specify an experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/tts/[YourExptName]`.
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName]
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
### Train From Existing Source
|
| 100 |
+
|
| 101 |
+
We support training from existing sources for various purposes. You can resume training the model from a checkpoint or fine-tune a model from another checkpoint.
|
| 102 |
+
|
| 103 |
+
By setting `--resume true`, the training will resume from the **latest checkpoint** from the current `[YourExptName]` by default. For example, if you want to resume training from the latest checkpoint in `Amphion/ckpts/tts/[YourExptName]/checkpoint`, run:
|
| 104 |
+
|
| 105 |
+
```bash
|
| 106 |
+
sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName] \
|
| 107 |
+
--resume true
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
You can also choose a **specific checkpoint** for retraining by `--resume_from_ckpt_path` argument. For example, if you want to resume training from the checkpoint `Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificCheckpoint]`, run:
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName] \
|
| 114 |
+
--resume true \
|
| 115 |
+
--resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificCheckpoint]"
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
If you want to **fine-tune from another checkpoint**, just use `--resume_type` and set it to `"finetune"`. For example, If you want to fine-tune the model from the checkpoint `Amphion/ckpts/tts/[AnotherExperiment]/checkpoint/[SpecificCheckpoint]`, run:
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
```bash
|
| 122 |
+
sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName] \
|
| 123 |
+
--resume true \
|
| 124 |
+
--resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificCheckpoint]" \
|
| 125 |
+
--resume_type "finetune"
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
> **NOTE:** The `--resume_type` is set as `"resume"` in default. It's not necessary to specify it when resuming training.
|
| 129 |
+
>
|
| 130 |
+
> The difference between `"resume"` and `"finetune"` is that the `"finetune"` will **only** load the pretrained model weights from the checkpoint, while the `"resume"` will load all the training states (including optimizer, scheduler, etc.) from the checkpoint.
|
| 131 |
+
|
| 132 |
+
Here are some example scenarios to better understand how to use these arguments:
|
| 133 |
+
| Scenario | `--resume` | `--resume_from_ckpt_path` | `--resume_type` |
|
| 134 |
+
| ------ | -------- | ----------------------- | ------------- |
|
| 135 |
+
| You want to train from scratch | no | no | no |
|
| 136 |
+
| The machine breaks down during training and you want to resume training from the latest checkpoint | `true` | no | no |
|
| 137 |
+
| You find the latest model is overfitting and you want to re-train from the checkpoint before | `true` | `SpecificCheckpoint Path` | no |
|
| 138 |
+
| You want to fine-tune a model from another checkpoint | `true` | `SpecificCheckpoint Path` | `"finetune"` |
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`.
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
## 4. Inference
|
| 145 |
+
|
| 146 |
+
### Pre-trained Model Download
|
| 147 |
+
|
| 148 |
+
We released a pre-trained Amphion VITS model trained on LJSpeech. So you can download the pre-trained model [here](https://huggingface.co/amphion/vits-ljspeech) and generate speech according to the following inference instruction.
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
### Configuration
|
| 152 |
+
|
| 153 |
+
For inference, you need to specify the following configurations when running `run.sh`:
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
| Parameters | Description | Example |
|
| 157 |
+
| --------------------- | -------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| 158 |
+
| `--infer_expt_dir` | The experimental directory which contains `checkpoint` | `Amphion/ckpts/tts/[YourExptName]` |
|
| 159 |
+
| `--infer_output_dir` | The output directory to save inferred audios. | `Amphion/ckpts/tts/[YourExptName]/result` |
|
| 160 |
+
| `--infer_mode` | The inference mode, e.g., "`single`", "`batch`". | "`single`" to generate a clip of speech, "`batch`" to generate a batch of speech at a time. |
|
| 161 |
+
| `--infer_dataset` | The dataset used for inference. | For LJSpeech dataset, the inference dataset would be `LJSpeech`.<br> For Hi-Fi TTS dataset, the inference dataset would be `hifitts`. |
|
| 162 |
+
| `--infer_testing_set` | The subset of the inference dataset used for inference, e.g., train, test, golden_test | For LJSpeech dataset, the testing set would be "`test`" split from LJSpeech at the feature extraction, or "`golden_test`" cherry-picked from the test set as template testing set.<br>For Hi-Fi TTS dataset, the testing set would be "`test`" split from Hi-Fi TTS during the feature extraction process. |
|
| 163 |
+
| `--infer_text` | The text to be synthesized. | "`This is a clip of generated speech with the given text from a TTS model.`" |
|
| 164 |
+
| `--infer_speaker_name` | The target speaker's voice is to be synthesized.<br> (***Note: only applicable to multi-speaker TTS model***) | For Hi-Fi TTS dataset, the list of available speakers includes: "`hifitts_11614`", "`hifitts_11697`", "`hifitts_12787`", "`hifitts_6097`", "`hifitts_6670`", "`hifitts_6671`", "`hifitts_8051`", "`hifitts_9017`", "`hifitts_9136`", "`hifitts_92`". <br> You may find the list of available speakers from `spk2id.json` file generated in ```log_dir/[YourExptName]``` that you have specified in `exp_config.json`. |
|
| 165 |
+
|
| 166 |
+
### Run
|
| 167 |
+
#### Single text inference:
|
| 168 |
+
For the single-speaker TTS model, if you want to generate a single clip of speech from a given text, just run:
|
| 169 |
+
|
| 170 |
+
```bash
|
| 171 |
+
sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \
|
| 172 |
+
--infer_expt_dir Amphion/ckpts/tts/[YourExptName] \
|
| 173 |
+
--infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \
|
| 174 |
+
--infer_mode "single" \
|
| 175 |
+
--infer_text "This is a clip of generated speech with the given text from a TTS model."
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
For the multi-speaker TTS model, in addition to the above-mentioned arguments, you need to add ```infer_speaker_name``` argument, and run:
|
| 179 |
+
```bash
|
| 180 |
+
sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \
|
| 181 |
+
--infer_expt_dir Amphion/ckpts/tts/[YourExptName] \
|
| 182 |
+
--infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \
|
| 183 |
+
--infer_mode "single" \
|
| 184 |
+
--infer_text "This is a clip of generated speech with the given text from a TTS model." \
|
| 185 |
+
--infer_speaker_name "hifitts_92"
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
#### Batch inference:
|
| 189 |
+
For the single-speaker TTS model, if you want to generate speech of all testing sets split from LJSpeech, just run:
|
| 190 |
+
|
| 191 |
+
```bash
|
| 192 |
+
sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \
|
| 193 |
+
--infer_expt_dir Amphion/ckpts/tts/[YourExptName] \
|
| 194 |
+
--infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \
|
| 195 |
+
--infer_mode "batch" \
|
| 196 |
+
--infer_dataset "LJSpeech" \
|
| 197 |
+
--infer_testing_set "test"
|
| 198 |
+
```
|
| 199 |
+
For the multi-speaker TTS model, if you want to generate speech of all testing sets split from Hi-Fi TTS, the same procedure follows from above, with ```LJSpeech``` replaced by ```hifitts```.
|
| 200 |
+
```bash
|
| 201 |
+
sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \
|
| 202 |
+
--infer_expt_dir Amphion/ckpts/tts/[YourExptName] \
|
| 203 |
+
--infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \
|
| 204 |
+
--infer_mode "batch" \
|
| 205 |
+
--infer_dataset "hifitts" \
|
| 206 |
+
--infer_testing_set "test"
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
We released a pre-trained Amphion VITS model trained on LJSpeech. So, you can download the pre-trained model [here](https://huggingface.co/amphion/vits-ljspeech) and generate speech following the above inference instructions. Meanwhile, the pre-trained multi-speaker VITS model trained on Hi-Fi TTS will be released soon. Stay tuned.
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
```bibtex
|
| 214 |
+
@inproceedings{kim2021conditional,
|
| 215 |
+
title={Conditional variational autoencoder with adversarial learning for end-to-end text-to-speech},
|
| 216 |
+
author={Kim, Jaehyeon and Kong, Jungil and Son, Juhee},
|
| 217 |
+
booktitle={International Conference on Machine Learning},
|
| 218 |
+
pages={5530--5540},
|
| 219 |
+
year={2021},
|
| 220 |
+
}
|
| 221 |
+
```
|
Amphion/egs/vocoder/diffusion/exp_config_base.json
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_config": "config/vocoder.json",
|
| 3 |
+
"model_type": "DiffusionVocoder",
|
| 4 |
+
// TODO: Choose your needed datasets
|
| 5 |
+
"dataset": [
|
| 6 |
+
"csd",
|
| 7 |
+
"kising",
|
| 8 |
+
"m4singer",
|
| 9 |
+
"nus48e",
|
| 10 |
+
"opencpop",
|
| 11 |
+
"opensinger",
|
| 12 |
+
"opera",
|
| 13 |
+
"pjs",
|
| 14 |
+
"popbutfy",
|
| 15 |
+
"popcs",
|
| 16 |
+
"ljspeech",
|
| 17 |
+
"vctk",
|
| 18 |
+
"libritts",
|
| 19 |
+
],
|
| 20 |
+
"dataset_path": {
|
| 21 |
+
// TODO: Fill in your dataset path
|
| 22 |
+
"csd": "[dataset path]",
|
| 23 |
+
"kising": "[dataset path]",
|
| 24 |
+
"m4singer": "[dataset path]",
|
| 25 |
+
"nus48e": "[dataset path]",
|
| 26 |
+
"opencpop": "[dataset path]",
|
| 27 |
+
"opensinger": "[dataset path]",
|
| 28 |
+
"opera": "[dataset path]",
|
| 29 |
+
"pjs": "[dataset path]",
|
| 30 |
+
"popbutfy": "[dataset path]",
|
| 31 |
+
"popcs": "[dataset path]",
|
| 32 |
+
"ljspeech": "[dataset path]",
|
| 33 |
+
"vctk": "[dataset path]",
|
| 34 |
+
"libritts": "[dataset path]",
|
| 35 |
+
},
|
| 36 |
+
// TODO: Fill in the output log path
|
| 37 |
+
"log_dir": "ckpts/vocoder",
|
| 38 |
+
"preprocess": {
|
| 39 |
+
// Acoustic features
|
| 40 |
+
"extract_mel": true,
|
| 41 |
+
"extract_audio": true,
|
| 42 |
+
"extract_pitch": false,
|
| 43 |
+
"extract_uv": false,
|
| 44 |
+
"pitch_extractor": "parselmouth",
|
| 45 |
+
|
| 46 |
+
// Features used for model training
|
| 47 |
+
"use_mel": true,
|
| 48 |
+
"use_frame_pitch": false,
|
| 49 |
+
"use_uv": false,
|
| 50 |
+
"use_audio": true,
|
| 51 |
+
|
| 52 |
+
// TODO: Fill in the output data path
|
| 53 |
+
"processed_dir": "data/",
|
| 54 |
+
"n_mel": 100,
|
| 55 |
+
"sample_rate": 24000
|
| 56 |
+
},
|
| 57 |
+
"train": {
|
| 58 |
+
// TODO: Choose a suitable batch size, training epoch, and save stride
|
| 59 |
+
"batch_size": 32,
|
| 60 |
+
"max_epoch": 1000000,
|
| 61 |
+
"save_checkpoint_stride": [20],
|
| 62 |
+
"adamw": {
|
| 63 |
+
"lr": 2.0e-4,
|
| 64 |
+
"adam_b1": 0.8,
|
| 65 |
+
"adam_b2": 0.99
|
| 66 |
+
},
|
| 67 |
+
"exponential_lr": {
|
| 68 |
+
"lr_decay": 0.999
|
| 69 |
+
},
|
| 70 |
+
}
|
| 71 |
+
}
|
Amphion/egs/vocoder/gan/bigvgan_large/run.sh
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
######## Build Experiment Environment ###########
|
| 7 |
+
exp_dir=$(cd `dirname $0`; pwd)
|
| 8 |
+
work_dir=$(dirname $(dirname $(dirname $(dirname $exp_dir))))
|
| 9 |
+
|
| 10 |
+
export WORK_DIR=$work_dir
|
| 11 |
+
export PYTHONPATH=$work_dir
|
| 12 |
+
export PYTHONIOENCODING=UTF-8
|
| 13 |
+
|
| 14 |
+
######## Parse the Given Parameters from the Commond ###########
|
| 15 |
+
options=$(getopt -o c:n:s --long gpu:,config:,name:,stage:,checkpoint:,resume_type:,main_process_port:,infer_mode:,infer_datasets:,infer_feature_dir:,infer_audio_dir:,infer_expt_dir:,infer_output_dir: -- "$@")
|
| 16 |
+
eval set -- "$options"
|
| 17 |
+
|
| 18 |
+
while true; do
|
| 19 |
+
case $1 in
|
| 20 |
+
# Experimental Configuration File
|
| 21 |
+
-c | --config) shift; exp_config=$1 ; shift ;;
|
| 22 |
+
# Experimental Name
|
| 23 |
+
-n | --name) shift; exp_name=$1 ; shift ;;
|
| 24 |
+
# Running Stage
|
| 25 |
+
-s | --stage) shift; running_stage=$1 ; shift ;;
|
| 26 |
+
# Visible GPU machines. The default value is "0".
|
| 27 |
+
--gpu) shift; gpu=$1 ; shift ;;
|
| 28 |
+
|
| 29 |
+
# [Only for Training] The specific checkpoint path that you want to resume from.
|
| 30 |
+
--checkpoint) shift; checkpoint=$1 ; shift ;;
|
| 31 |
+
# [Only for Training] `resume` for loading all the things (including model weights, optimizer, scheduler, and random states). `finetune` for loading only the model weights.
|
| 32 |
+
--resume_type) shift; resume_type=$1 ; shift ;;
|
| 33 |
+
# [Only for Traiing] `main_process_port` for multi gpu training
|
| 34 |
+
--main_process_port) shift; main_process_port=$1 ; shift ;;
|
| 35 |
+
|
| 36 |
+
# [Only for Inference] The inference mode
|
| 37 |
+
--infer_mode) shift; infer_mode=$1 ; shift ;;
|
| 38 |
+
# [Only for Inference] The inferenced datasets
|
| 39 |
+
--infer_datasets) shift; infer_datasets=$1 ; shift ;;
|
| 40 |
+
# [Only for Inference] The feature dir for inference
|
| 41 |
+
--infer_feature_dir) shift; infer_feature_dir=$1 ; shift ;;
|
| 42 |
+
# [Only for Inference] The audio dir for inference
|
| 43 |
+
--infer_audio_dir) shift; infer_audio_dir=$1 ; shift ;;
|
| 44 |
+
# [Only for Inference] The experiment dir. The value is like "[Your path to save logs and checkpoints]/[YourExptName]"
|
| 45 |
+
--infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;;
|
| 46 |
+
# [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result"
|
| 47 |
+
--infer_output_dir) shift; infer_output_dir=$1 ; shift ;;
|
| 48 |
+
|
| 49 |
+
--) shift ; break ;;
|
| 50 |
+
*) echo "Invalid option: $1" exit 1 ;;
|
| 51 |
+
esac
|
| 52 |
+
done
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
### Value check ###
|
| 56 |
+
if [ -z "$running_stage" ]; then
|
| 57 |
+
echo "[Error] Please specify the running stage"
|
| 58 |
+
exit 1
|
| 59 |
+
fi
|
| 60 |
+
|
| 61 |
+
if [ -z "$exp_config" ]; then
|
| 62 |
+
exp_config="${exp_dir}"/exp_config.json
|
| 63 |
+
fi
|
| 64 |
+
echo "Exprimental Configuration File: $exp_config"
|
| 65 |
+
|
| 66 |
+
if [ -z "$gpu" ]; then
|
| 67 |
+
gpu="0"
|
| 68 |
+
fi
|
| 69 |
+
|
| 70 |
+
if [ -z "$main_process_port" ]; then
|
| 71 |
+
main_process_port=29500
|
| 72 |
+
fi
|
| 73 |
+
echo "Main Process Port: $main_process_port"
|
| 74 |
+
|
| 75 |
+
######## Features Extraction ###########
|
| 76 |
+
if [ $running_stage -eq 1 ]; then
|
| 77 |
+
CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/vocoder/preprocess.py \
|
| 78 |
+
--config $exp_config \
|
| 79 |
+
--num_workers 8
|
| 80 |
+
fi
|
| 81 |
+
|
| 82 |
+
######## Training ###########
|
| 83 |
+
if [ $running_stage -eq 2 ]; then
|
| 84 |
+
if [ -z "$exp_name" ]; then
|
| 85 |
+
echo "[Error] Please specify the experiments name"
|
| 86 |
+
exit 1
|
| 87 |
+
fi
|
| 88 |
+
echo "Exprimental Name: $exp_name"
|
| 89 |
+
|
| 90 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch \
|
| 91 |
+
--main_process_port "$main_process_port" \
|
| 92 |
+
"${work_dir}"/bins/vocoder/train.py \
|
| 93 |
+
--config "$exp_config" \
|
| 94 |
+
--exp_name "$exp_name" \
|
| 95 |
+
--log_level info \
|
| 96 |
+
--checkpoint "$checkpoint" \
|
| 97 |
+
--resume_type "$resume_type"
|
| 98 |
+
fi
|
| 99 |
+
|
| 100 |
+
######## Inference/Conversion ###########
|
| 101 |
+
if [ $running_stage -eq 3 ]; then
|
| 102 |
+
if [ -z "$infer_expt_dir" ]; then
|
| 103 |
+
echo "[Error] Please specify the experimental directionary. The value is like [Your path to save logs and checkpoints]/[YourExptName]"
|
| 104 |
+
exit 1
|
| 105 |
+
fi
|
| 106 |
+
|
| 107 |
+
if [ -z "$infer_output_dir" ]; then
|
| 108 |
+
infer_output_dir="$infer_expt_dir/result"
|
| 109 |
+
fi
|
| 110 |
+
|
| 111 |
+
if [ $infer_mode = "infer_from_dataset" ]; then
|
| 112 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
|
| 113 |
+
--config $exp_config \
|
| 114 |
+
--infer_mode $infer_mode \
|
| 115 |
+
--infer_datasets $infer_datasets \
|
| 116 |
+
--vocoder_dir $infer_expt_dir \
|
| 117 |
+
--output_dir $infer_output_dir \
|
| 118 |
+
--log_level debug
|
| 119 |
+
fi
|
| 120 |
+
|
| 121 |
+
if [ $infer_mode = "infer_from_feature" ]; then
|
| 122 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
|
| 123 |
+
--config $exp_config \
|
| 124 |
+
--infer_mode $infer_mode \
|
| 125 |
+
--feature_folder $infer_feature_dir \
|
| 126 |
+
--vocoder_dir $infer_expt_dir \
|
| 127 |
+
--output_dir $infer_output_dir \
|
| 128 |
+
--log_level debug
|
| 129 |
+
fi
|
| 130 |
+
|
| 131 |
+
if [ $infer_mode = "infer_from_audio" ]; then
|
| 132 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
|
| 133 |
+
--config $exp_config \
|
| 134 |
+
--infer_mode $infer_mode \
|
| 135 |
+
--audio_folder $infer_audio_dir \
|
| 136 |
+
--vocoder_dir $infer_expt_dir \
|
| 137 |
+
--output_dir $infer_output_dir \
|
| 138 |
+
--log_level debug
|
| 139 |
+
fi
|
| 140 |
+
|
| 141 |
+
fi
|
Amphion/evaluation/metrics/similarity/__init__.py
ADDED
|
File without changes
|
Amphion/models/base/base_dataset.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch.utils.data
|
| 9 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
+
import librosa
|
| 11 |
+
|
| 12 |
+
from utils.data_utils import *
|
| 13 |
+
from processors.acoustic_extractor import cal_normalized_mel
|
| 14 |
+
from text import text_to_sequence
|
| 15 |
+
from text.text_token_collation import phoneIDCollation
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class BaseOfflineDataset(torch.utils.data.Dataset):
|
| 19 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
| 20 |
+
"""
|
| 21 |
+
Args:
|
| 22 |
+
cfg: config
|
| 23 |
+
dataset: dataset name
|
| 24 |
+
is_valid: whether to use train or valid dataset
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
assert isinstance(dataset, str)
|
| 28 |
+
|
| 29 |
+
# self.data_root = processed_data_dir
|
| 30 |
+
self.cfg = cfg
|
| 31 |
+
|
| 32 |
+
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
|
| 33 |
+
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
|
| 34 |
+
self.metafile_path = os.path.join(processed_data_dir, meta_file)
|
| 35 |
+
self.metadata = self.get_metadata()
|
| 36 |
+
|
| 37 |
+
"""
|
| 38 |
+
load spk2id and utt2spk from json file
|
| 39 |
+
spk2id: {spk1: 0, spk2: 1, ...}
|
| 40 |
+
utt2spk: {dataset_uid: spk1, ...}
|
| 41 |
+
"""
|
| 42 |
+
if cfg.preprocess.use_spkid:
|
| 43 |
+
spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
|
| 44 |
+
with open(spk2id_path, "r") as f:
|
| 45 |
+
self.spk2id = json.load(f)
|
| 46 |
+
|
| 47 |
+
utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
|
| 48 |
+
self.utt2spk = dict()
|
| 49 |
+
with open(utt2spk_path, "r") as f:
|
| 50 |
+
for line in f.readlines():
|
| 51 |
+
utt, spk = line.strip().split("\t")
|
| 52 |
+
self.utt2spk[utt] = spk
|
| 53 |
+
|
| 54 |
+
if cfg.preprocess.use_uv:
|
| 55 |
+
self.utt2uv_path = {}
|
| 56 |
+
for utt_info in self.metadata:
|
| 57 |
+
dataset = utt_info["Dataset"]
|
| 58 |
+
uid = utt_info["Uid"]
|
| 59 |
+
utt = "{}_{}".format(dataset, uid)
|
| 60 |
+
self.utt2uv_path[utt] = os.path.join(
|
| 61 |
+
cfg.preprocess.processed_dir,
|
| 62 |
+
dataset,
|
| 63 |
+
cfg.preprocess.uv_dir,
|
| 64 |
+
uid + ".npy",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
if cfg.preprocess.use_frame_pitch:
|
| 68 |
+
self.utt2frame_pitch_path = {}
|
| 69 |
+
for utt_info in self.metadata:
|
| 70 |
+
dataset = utt_info["Dataset"]
|
| 71 |
+
uid = utt_info["Uid"]
|
| 72 |
+
utt = "{}_{}".format(dataset, uid)
|
| 73 |
+
|
| 74 |
+
self.utt2frame_pitch_path[utt] = os.path.join(
|
| 75 |
+
cfg.preprocess.processed_dir,
|
| 76 |
+
dataset,
|
| 77 |
+
cfg.preprocess.pitch_dir,
|
| 78 |
+
uid + ".npy",
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
if cfg.preprocess.use_frame_energy:
|
| 82 |
+
self.utt2frame_energy_path = {}
|
| 83 |
+
for utt_info in self.metadata:
|
| 84 |
+
dataset = utt_info["Dataset"]
|
| 85 |
+
uid = utt_info["Uid"]
|
| 86 |
+
utt = "{}_{}".format(dataset, uid)
|
| 87 |
+
|
| 88 |
+
self.utt2frame_energy_path[utt] = os.path.join(
|
| 89 |
+
cfg.preprocess.processed_dir,
|
| 90 |
+
dataset,
|
| 91 |
+
cfg.preprocess.energy_dir,
|
| 92 |
+
uid + ".npy",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
if cfg.preprocess.use_mel:
|
| 96 |
+
self.utt2mel_path = {}
|
| 97 |
+
for utt_info in self.metadata:
|
| 98 |
+
dataset = utt_info["Dataset"]
|
| 99 |
+
uid = utt_info["Uid"]
|
| 100 |
+
utt = "{}_{}".format(dataset, uid)
|
| 101 |
+
|
| 102 |
+
self.utt2mel_path[utt] = os.path.join(
|
| 103 |
+
cfg.preprocess.processed_dir,
|
| 104 |
+
dataset,
|
| 105 |
+
cfg.preprocess.mel_dir,
|
| 106 |
+
uid + ".npy",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if cfg.preprocess.use_linear:
|
| 110 |
+
self.utt2linear_path = {}
|
| 111 |
+
for utt_info in self.metadata:
|
| 112 |
+
dataset = utt_info["Dataset"]
|
| 113 |
+
uid = utt_info["Uid"]
|
| 114 |
+
utt = "{}_{}".format(dataset, uid)
|
| 115 |
+
|
| 116 |
+
self.utt2linear_path[utt] = os.path.join(
|
| 117 |
+
cfg.preprocess.processed_dir,
|
| 118 |
+
dataset,
|
| 119 |
+
cfg.preprocess.linear_dir,
|
| 120 |
+
uid + ".npy",
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if cfg.preprocess.use_audio:
|
| 124 |
+
self.utt2audio_path = {}
|
| 125 |
+
for utt_info in self.metadata:
|
| 126 |
+
dataset = utt_info["Dataset"]
|
| 127 |
+
uid = utt_info["Uid"]
|
| 128 |
+
utt = "{}_{}".format(dataset, uid)
|
| 129 |
+
|
| 130 |
+
self.utt2audio_path[utt] = os.path.join(
|
| 131 |
+
cfg.preprocess.processed_dir,
|
| 132 |
+
dataset,
|
| 133 |
+
cfg.preprocess.audio_dir,
|
| 134 |
+
uid + ".npy",
|
| 135 |
+
)
|
| 136 |
+
elif cfg.preprocess.use_label:
|
| 137 |
+
self.utt2label_path = {}
|
| 138 |
+
for utt_info in self.metadata:
|
| 139 |
+
dataset = utt_info["Dataset"]
|
| 140 |
+
uid = utt_info["Uid"]
|
| 141 |
+
utt = "{}_{}".format(dataset, uid)
|
| 142 |
+
|
| 143 |
+
self.utt2label_path[utt] = os.path.join(
|
| 144 |
+
cfg.preprocess.processed_dir,
|
| 145 |
+
dataset,
|
| 146 |
+
cfg.preprocess.label_dir,
|
| 147 |
+
uid + ".npy",
|
| 148 |
+
)
|
| 149 |
+
elif cfg.preprocess.use_one_hot:
|
| 150 |
+
self.utt2one_hot_path = {}
|
| 151 |
+
for utt_info in self.metadata:
|
| 152 |
+
dataset = utt_info["Dataset"]
|
| 153 |
+
uid = utt_info["Uid"]
|
| 154 |
+
utt = "{}_{}".format(dataset, uid)
|
| 155 |
+
|
| 156 |
+
self.utt2one_hot_path[utt] = os.path.join(
|
| 157 |
+
cfg.preprocess.processed_dir,
|
| 158 |
+
dataset,
|
| 159 |
+
cfg.preprocess.one_hot_dir,
|
| 160 |
+
uid + ".npy",
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if cfg.preprocess.use_text or cfg.preprocess.use_phone:
|
| 164 |
+
self.utt2seq = {}
|
| 165 |
+
for utt_info in self.metadata:
|
| 166 |
+
dataset = utt_info["Dataset"]
|
| 167 |
+
uid = utt_info["Uid"]
|
| 168 |
+
utt = "{}_{}".format(dataset, uid)
|
| 169 |
+
|
| 170 |
+
if cfg.preprocess.use_text:
|
| 171 |
+
text = utt_info["Text"]
|
| 172 |
+
sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
|
| 173 |
+
elif cfg.preprocess.use_phone:
|
| 174 |
+
# load phoneme squence from phone file
|
| 175 |
+
phone_path = os.path.join(
|
| 176 |
+
processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
|
| 177 |
+
)
|
| 178 |
+
with open(phone_path, "r") as fin:
|
| 179 |
+
phones = fin.readlines()
|
| 180 |
+
assert len(phones) == 1
|
| 181 |
+
phones = phones[0].strip()
|
| 182 |
+
phones_seq = phones.split(" ")
|
| 183 |
+
|
| 184 |
+
phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
|
| 185 |
+
sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
|
| 186 |
+
|
| 187 |
+
self.utt2seq[utt] = sequence
|
| 188 |
+
|
| 189 |
+
def get_metadata(self):
|
| 190 |
+
with open(self.metafile_path, "r", encoding="utf-8") as f:
|
| 191 |
+
metadata = json.load(f)
|
| 192 |
+
|
| 193 |
+
return metadata
|
| 194 |
+
|
| 195 |
+
def get_dataset_name(self):
|
| 196 |
+
return self.metadata[0]["Dataset"]
|
| 197 |
+
|
| 198 |
+
def __getitem__(self, index):
|
| 199 |
+
utt_info = self.metadata[index]
|
| 200 |
+
|
| 201 |
+
dataset = utt_info["Dataset"]
|
| 202 |
+
uid = utt_info["Uid"]
|
| 203 |
+
utt = "{}_{}".format(dataset, uid)
|
| 204 |
+
|
| 205 |
+
single_feature = dict()
|
| 206 |
+
|
| 207 |
+
if self.cfg.preprocess.use_spkid:
|
| 208 |
+
single_feature["spk_id"] = np.array(
|
| 209 |
+
[self.spk2id[self.utt2spk[utt]]], dtype=np.int32
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
if self.cfg.preprocess.use_mel:
|
| 213 |
+
mel = np.load(self.utt2mel_path[utt])
|
| 214 |
+
assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
|
| 215 |
+
if self.cfg.preprocess.use_min_max_norm_mel:
|
| 216 |
+
# do mel norm
|
| 217 |
+
mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)
|
| 218 |
+
|
| 219 |
+
if "target_len" not in single_feature.keys():
|
| 220 |
+
single_feature["target_len"] = mel.shape[1]
|
| 221 |
+
single_feature["mel"] = mel.T # [T, n_mels]
|
| 222 |
+
|
| 223 |
+
if self.cfg.preprocess.use_linear:
|
| 224 |
+
linear = np.load(self.utt2linear_path[utt])
|
| 225 |
+
if "target_len" not in single_feature.keys():
|
| 226 |
+
single_feature["target_len"] = linear.shape[1]
|
| 227 |
+
single_feature["linear"] = linear.T # [T, n_linear]
|
| 228 |
+
|
| 229 |
+
if self.cfg.preprocess.use_frame_pitch:
|
| 230 |
+
frame_pitch_path = self.utt2frame_pitch_path[utt]
|
| 231 |
+
frame_pitch = np.load(frame_pitch_path)
|
| 232 |
+
if "target_len" not in single_feature.keys():
|
| 233 |
+
single_feature["target_len"] = len(frame_pitch)
|
| 234 |
+
aligned_frame_pitch = align_length(
|
| 235 |
+
frame_pitch, single_feature["target_len"]
|
| 236 |
+
)
|
| 237 |
+
single_feature["frame_pitch"] = aligned_frame_pitch
|
| 238 |
+
|
| 239 |
+
if self.cfg.preprocess.use_uv:
|
| 240 |
+
frame_uv_path = self.utt2uv_path[utt]
|
| 241 |
+
frame_uv = np.load(frame_uv_path)
|
| 242 |
+
aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
|
| 243 |
+
aligned_frame_uv = [
|
| 244 |
+
0 if frame_uv else 1 for frame_uv in aligned_frame_uv
|
| 245 |
+
]
|
| 246 |
+
aligned_frame_uv = np.array(aligned_frame_uv)
|
| 247 |
+
single_feature["frame_uv"] = aligned_frame_uv
|
| 248 |
+
|
| 249 |
+
if self.cfg.preprocess.use_frame_energy:
|
| 250 |
+
frame_energy_path = self.utt2frame_energy_path[utt]
|
| 251 |
+
frame_energy = np.load(frame_energy_path)
|
| 252 |
+
if "target_len" not in single_feature.keys():
|
| 253 |
+
single_feature["target_len"] = len(frame_energy)
|
| 254 |
+
aligned_frame_energy = align_length(
|
| 255 |
+
frame_energy, single_feature["target_len"]
|
| 256 |
+
)
|
| 257 |
+
single_feature["frame_energy"] = aligned_frame_energy
|
| 258 |
+
|
| 259 |
+
if self.cfg.preprocess.use_audio:
|
| 260 |
+
audio = np.load(self.utt2audio_path[utt])
|
| 261 |
+
single_feature["audio"] = audio
|
| 262 |
+
single_feature["audio_len"] = audio.shape[0]
|
| 263 |
+
|
| 264 |
+
if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
|
| 265 |
+
single_feature["phone_seq"] = np.array(self.utt2seq[utt])
|
| 266 |
+
single_feature["phone_len"] = len(self.utt2seq[utt])
|
| 267 |
+
|
| 268 |
+
return single_feature
|
| 269 |
+
|
| 270 |
+
def __len__(self):
|
| 271 |
+
return len(self.metadata)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class BaseOfflineCollator(object):
|
| 275 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
| 276 |
+
|
| 277 |
+
def __init__(self, cfg):
|
| 278 |
+
self.cfg = cfg
|
| 279 |
+
|
| 280 |
+
def __call__(self, batch):
|
| 281 |
+
packed_batch_features = dict()
|
| 282 |
+
|
| 283 |
+
# mel: [b, T, n_mels]
|
| 284 |
+
# frame_pitch, frame_energy: [1, T]
|
| 285 |
+
# target_len: [b]
|
| 286 |
+
# spk_id: [b, 1]
|
| 287 |
+
# mask: [b, T, 1]
|
| 288 |
+
|
| 289 |
+
for key in batch[0].keys():
|
| 290 |
+
if key == "target_len":
|
| 291 |
+
packed_batch_features["target_len"] = torch.LongTensor(
|
| 292 |
+
[b["target_len"] for b in batch]
|
| 293 |
+
)
|
| 294 |
+
masks = [
|
| 295 |
+
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
|
| 296 |
+
]
|
| 297 |
+
packed_batch_features["mask"] = pad_sequence(
|
| 298 |
+
masks, batch_first=True, padding_value=0
|
| 299 |
+
)
|
| 300 |
+
elif key == "phone_len":
|
| 301 |
+
packed_batch_features["phone_len"] = torch.LongTensor(
|
| 302 |
+
[b["phone_len"] for b in batch]
|
| 303 |
+
)
|
| 304 |
+
masks = [
|
| 305 |
+
torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
|
| 306 |
+
]
|
| 307 |
+
packed_batch_features["phn_mask"] = pad_sequence(
|
| 308 |
+
masks, batch_first=True, padding_value=0
|
| 309 |
+
)
|
| 310 |
+
elif key == "audio_len":
|
| 311 |
+
packed_batch_features["audio_len"] = torch.LongTensor(
|
| 312 |
+
[b["audio_len"] for b in batch]
|
| 313 |
+
)
|
| 314 |
+
masks = [
|
| 315 |
+
torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
|
| 316 |
+
]
|
| 317 |
+
else:
|
| 318 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
| 319 |
+
packed_batch_features[key] = pad_sequence(
|
| 320 |
+
values, batch_first=True, padding_value=0
|
| 321 |
+
)
|
| 322 |
+
return packed_batch_features
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class BaseOnlineDataset(torch.utils.data.Dataset):
|
| 326 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
| 327 |
+
"""
|
| 328 |
+
Args:
|
| 329 |
+
cfg: config
|
| 330 |
+
dataset: dataset name
|
| 331 |
+
is_valid: whether to use train or valid dataset
|
| 332 |
+
"""
|
| 333 |
+
assert isinstance(dataset, str)
|
| 334 |
+
|
| 335 |
+
self.cfg = cfg
|
| 336 |
+
self.sample_rate = cfg.preprocess.sample_rate
|
| 337 |
+
self.hop_size = self.cfg.preprocess.hop_size
|
| 338 |
+
|
| 339 |
+
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
|
| 340 |
+
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
|
| 341 |
+
self.metafile_path = os.path.join(processed_data_dir, meta_file)
|
| 342 |
+
self.metadata = self.get_metadata()
|
| 343 |
+
|
| 344 |
+
"""
|
| 345 |
+
load spk2id and utt2spk from json file
|
| 346 |
+
spk2id: {spk1: 0, spk2: 1, ...}
|
| 347 |
+
utt2spk: {dataset_uid: spk1, ...}
|
| 348 |
+
"""
|
| 349 |
+
if cfg.preprocess.use_spkid:
|
| 350 |
+
spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
|
| 351 |
+
with open(spk2id_path, "r") as f:
|
| 352 |
+
self.spk2id = json.load(f)
|
| 353 |
+
|
| 354 |
+
utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
|
| 355 |
+
self.utt2spk = dict()
|
| 356 |
+
with open(utt2spk_path, "r") as f:
|
| 357 |
+
for line in f.readlines():
|
| 358 |
+
utt, spk = line.strip().split("\t")
|
| 359 |
+
self.utt2spk[utt] = spk
|
| 360 |
+
|
| 361 |
+
def get_metadata(self):
|
| 362 |
+
with open(self.metafile_path, "r", encoding="utf-8") as f:
|
| 363 |
+
metadata = json.load(f)
|
| 364 |
+
|
| 365 |
+
return metadata
|
| 366 |
+
|
| 367 |
+
def get_dataset_name(self):
|
| 368 |
+
return self.metadata[0]["Dataset"]
|
| 369 |
+
|
| 370 |
+
def __getitem__(self, index):
|
| 371 |
+
"""
|
| 372 |
+
single_feature:
|
| 373 |
+
wav: (T)
|
| 374 |
+
wav_len: int
|
| 375 |
+
target_len: int
|
| 376 |
+
mask: (n_frames, 1)
|
| 377 |
+
spk_id: (1)
|
| 378 |
+
"""
|
| 379 |
+
utt_item = self.metadata[index]
|
| 380 |
+
|
| 381 |
+
wav_path = utt_item["Path"]
|
| 382 |
+
wav, _ = librosa.load(wav_path, sr=self.sample_rate)
|
| 383 |
+
# wav: (T)
|
| 384 |
+
wav = torch.as_tensor(wav, dtype=torch.float32)
|
| 385 |
+
wav_len = len(wav)
|
| 386 |
+
# mask: (n_frames, 1)
|
| 387 |
+
frame_len = wav_len // self.hop_size
|
| 388 |
+
mask = torch.ones(frame_len, 1, dtype=torch.long)
|
| 389 |
+
|
| 390 |
+
single_feature = {
|
| 391 |
+
"wav": wav,
|
| 392 |
+
"wav_len": wav_len,
|
| 393 |
+
"target_len": frame_len,
|
| 394 |
+
"mask": mask,
|
| 395 |
+
}
|
| 396 |
+
|
| 397 |
+
if self.cfg.preprocess.use_spkid:
|
| 398 |
+
utt = "{}_{}".format(utt_item["Dataset"], utt_item["Uid"])
|
| 399 |
+
single_feature["spk_id"] = torch.tensor(
|
| 400 |
+
[self.spk2id[self.utt2spk[utt]]], dtype=torch.int32
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
return single_feature
|
| 404 |
+
|
| 405 |
+
def __len__(self):
|
| 406 |
+
return len(self.metadata)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
class BaseOnlineCollator(object):
|
| 410 |
+
"""Zero-pads model inputs and targets based on number of frames per step (For on-the-fly features extraction, whose iterative item contains only wavs)"""
|
| 411 |
+
|
| 412 |
+
def __init__(self, cfg):
|
| 413 |
+
self.cfg = cfg
|
| 414 |
+
|
| 415 |
+
def __call__(self, batch):
|
| 416 |
+
"""
|
| 417 |
+
BaseOnlineDataset.__getitem__:
|
| 418 |
+
wav: (T,)
|
| 419 |
+
wav_len: int
|
| 420 |
+
target_len: int
|
| 421 |
+
mask: (n_frames, 1)
|
| 422 |
+
spk_id: (1)
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
wav: (B, T), torch.float32
|
| 426 |
+
wav_len: (B), torch.long
|
| 427 |
+
target_len: (B), torch.long
|
| 428 |
+
mask: (B, n_frames, 1), torch.long
|
| 429 |
+
spk_id: (B, 1), torch.int32
|
| 430 |
+
"""
|
| 431 |
+
packed_batch_features = dict()
|
| 432 |
+
|
| 433 |
+
for key in batch[0].keys():
|
| 434 |
+
if key in ["wav_len", "target_len"]:
|
| 435 |
+
packed_batch_features[key] = torch.LongTensor([b[key] for b in batch])
|
| 436 |
+
else:
|
| 437 |
+
packed_batch_features[key] = pad_sequence(
|
| 438 |
+
[b[key] for b in batch], batch_first=True, padding_value=0
|
| 439 |
+
)
|
| 440 |
+
return packed_batch_features
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
class BaseTestDataset(torch.utils.data.Dataset):
|
| 444 |
+
def __init__(self, cfg, args):
|
| 445 |
+
raise NotImplementedError
|
| 446 |
+
|
| 447 |
+
def get_metadata(self):
|
| 448 |
+
raise NotImplementedError
|
| 449 |
+
|
| 450 |
+
def __getitem__(self, index):
|
| 451 |
+
raise NotImplementedError
|
| 452 |
+
|
| 453 |
+
def __len__(self):
|
| 454 |
+
return len(self.metadata)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
class BaseTestCollator(object):
|
| 458 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
| 459 |
+
|
| 460 |
+
def __init__(self, cfg):
|
| 461 |
+
raise NotImplementedError
|
| 462 |
+
|
| 463 |
+
def __call__(self, batch):
|
| 464 |
+
raise NotImplementedError
|
Amphion/models/base/base_inference.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
import time
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch.utils.data import DataLoader
|
| 14 |
+
from tqdm import tqdm
|
| 15 |
+
|
| 16 |
+
from models.vocoders.vocoder_inference import synthesis
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
from utils.util import set_all_random_seed
|
| 19 |
+
from utils.util import load_config
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def parse_vocoder(vocoder_dir):
|
| 23 |
+
r"""Parse vocoder config"""
|
| 24 |
+
vocoder_dir = os.path.abspath(vocoder_dir)
|
| 25 |
+
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
|
| 26 |
+
ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
|
| 27 |
+
ckpt_path = str(ckpt_list[0])
|
| 28 |
+
vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True)
|
| 29 |
+
vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder
|
| 30 |
+
return vocoder_cfg, ckpt_path
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class BaseInference(object):
|
| 34 |
+
def __init__(self, cfg, args):
|
| 35 |
+
self.cfg = cfg
|
| 36 |
+
self.args = args
|
| 37 |
+
self.model_type = cfg.model_type
|
| 38 |
+
self.avg_rtf = list()
|
| 39 |
+
set_all_random_seed(10086)
|
| 40 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
if torch.cuda.is_available():
|
| 43 |
+
self.device = torch.device("cuda")
|
| 44 |
+
else:
|
| 45 |
+
self.device = torch.device("cpu")
|
| 46 |
+
torch.set_num_threads(10) # inference on 1 core cpu.
|
| 47 |
+
|
| 48 |
+
# Load acoustic model
|
| 49 |
+
self.model = self.create_model().to(self.device)
|
| 50 |
+
state_dict = self.load_state_dict()
|
| 51 |
+
self.load_model(state_dict)
|
| 52 |
+
self.model.eval()
|
| 53 |
+
|
| 54 |
+
# Load vocoder model if necessary
|
| 55 |
+
if self.args.checkpoint_dir_vocoder is not None:
|
| 56 |
+
self.get_vocoder_info()
|
| 57 |
+
|
| 58 |
+
def create_model(self):
|
| 59 |
+
raise NotImplementedError
|
| 60 |
+
|
| 61 |
+
def load_state_dict(self):
|
| 62 |
+
self.checkpoint_file = self.args.checkpoint_file
|
| 63 |
+
if self.checkpoint_file is None:
|
| 64 |
+
assert self.args.checkpoint_dir is not None
|
| 65 |
+
checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint")
|
| 66 |
+
checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
|
| 67 |
+
self.checkpoint_file = os.path.join(
|
| 68 |
+
self.args.checkpoint_dir, checkpoint_filename
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.checkpoint_dir = os.path.split(self.checkpoint_file)[0]
|
| 72 |
+
|
| 73 |
+
print("Restore acoustic model from {}".format(self.checkpoint_file))
|
| 74 |
+
raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device)
|
| 75 |
+
self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0]
|
| 76 |
+
|
| 77 |
+
return raw_state_dict
|
| 78 |
+
|
| 79 |
+
def load_model(self, model):
|
| 80 |
+
raise NotImplementedError
|
| 81 |
+
|
| 82 |
+
def get_vocoder_info(self):
|
| 83 |
+
self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder
|
| 84 |
+
self.vocoder_cfg = os.path.join(
|
| 85 |
+
os.path.dirname(self.checkpoint_dir_vocoder), "args.json"
|
| 86 |
+
)
|
| 87 |
+
self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True)
|
| 88 |
+
self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1]
|
| 89 |
+
self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0]
|
| 90 |
+
|
| 91 |
+
def build_test_utt_data(self):
|
| 92 |
+
raise NotImplementedError
|
| 93 |
+
|
| 94 |
+
def build_testdata_loader(self, args, target_speaker=None):
|
| 95 |
+
datasets, collate = self.build_test_dataset()
|
| 96 |
+
self.test_dataset = datasets(self.cfg, args, target_speaker)
|
| 97 |
+
self.test_collate = collate(self.cfg)
|
| 98 |
+
self.test_batch_size = min(
|
| 99 |
+
self.cfg.train.batch_size, len(self.test_dataset.metadata)
|
| 100 |
+
)
|
| 101 |
+
test_loader = DataLoader(
|
| 102 |
+
self.test_dataset,
|
| 103 |
+
collate_fn=self.test_collate,
|
| 104 |
+
num_workers=self.args.num_workers,
|
| 105 |
+
batch_size=self.test_batch_size,
|
| 106 |
+
shuffle=False,
|
| 107 |
+
)
|
| 108 |
+
return test_loader
|
| 109 |
+
|
| 110 |
+
def inference_each_batch(self, batch_data):
|
| 111 |
+
raise NotImplementedError
|
| 112 |
+
|
| 113 |
+
def inference_for_batches(self, args, target_speaker=None):
|
| 114 |
+
###### Construct test_batch ######
|
| 115 |
+
loader = self.build_testdata_loader(args, target_speaker)
|
| 116 |
+
|
| 117 |
+
n_batch = len(loader)
|
| 118 |
+
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
|
| 119 |
+
print(
|
| 120 |
+
"Model eval time: {}, batch_size = {}, n_batch = {}".format(
|
| 121 |
+
now, self.test_batch_size, n_batch
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
self.model.eval()
|
| 125 |
+
|
| 126 |
+
###### Inference for each batch ######
|
| 127 |
+
pred_res = []
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)):
|
| 130 |
+
# Put the data to device
|
| 131 |
+
for k, v in batch_data.items():
|
| 132 |
+
batch_data[k] = batch_data[k].to(self.device)
|
| 133 |
+
|
| 134 |
+
y_pred, stats = self.inference_each_batch(batch_data)
|
| 135 |
+
|
| 136 |
+
pred_res += y_pred
|
| 137 |
+
|
| 138 |
+
return pred_res
|
| 139 |
+
|
| 140 |
+
def inference(self, feature):
|
| 141 |
+
raise NotImplementedError
|
| 142 |
+
|
| 143 |
+
def synthesis_by_vocoder(self, pred):
|
| 144 |
+
audios_pred = synthesis(
|
| 145 |
+
self.vocoder_cfg,
|
| 146 |
+
self.checkpoint_dir_vocoder,
|
| 147 |
+
len(pred),
|
| 148 |
+
pred,
|
| 149 |
+
)
|
| 150 |
+
return audios_pred
|
| 151 |
+
|
| 152 |
+
def __call__(self, utt):
|
| 153 |
+
feature = self.build_test_utt_data(utt)
|
| 154 |
+
start_time = time.time()
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
outputs = self.inference(feature)[0]
|
| 157 |
+
time_used = time.time() - start_time
|
| 158 |
+
rtf = time_used / (
|
| 159 |
+
outputs.shape[1]
|
| 160 |
+
* self.cfg.preprocess.hop_size
|
| 161 |
+
/ self.cfg.preprocess.sample_rate
|
| 162 |
+
)
|
| 163 |
+
print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf))
|
| 164 |
+
self.avg_rtf.append(rtf)
|
| 165 |
+
audios = outputs.cpu().squeeze().numpy().reshape(-1, 1)
|
| 166 |
+
return audios
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def base_parser():
|
| 170 |
+
parser = argparse.ArgumentParser()
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"--config", default="config.json", help="json files for configurations."
|
| 173 |
+
)
|
| 174 |
+
parser.add_argument("--use_ddp_inference", default=False)
|
| 175 |
+
parser.add_argument("--n_workers", default=1, type=int)
|
| 176 |
+
parser.add_argument("--local_rank", default=-1, type=int)
|
| 177 |
+
parser.add_argument(
|
| 178 |
+
"--batch_size", default=1, type=int, help="Batch size for inference"
|
| 179 |
+
)
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"--num_workers",
|
| 182 |
+
default=1,
|
| 183 |
+
type=int,
|
| 184 |
+
help="Worker number for inference dataloader",
|
| 185 |
+
)
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--checkpoint_dir",
|
| 188 |
+
type=str,
|
| 189 |
+
default=None,
|
| 190 |
+
help="Checkpoint dir including model file and configuration",
|
| 191 |
+
)
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--checkpoint_file", help="checkpoint file", type=str, default=None
|
| 194 |
+
)
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--test_list", help="test utterance list for testing", type=str, default=None
|
| 197 |
+
)
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--checkpoint_dir_vocoder",
|
| 200 |
+
help="Vocoder's checkpoint dir including model file and configuration",
|
| 201 |
+
type=str,
|
| 202 |
+
default=None,
|
| 203 |
+
)
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--output_dir",
|
| 206 |
+
type=str,
|
| 207 |
+
default=None,
|
| 208 |
+
help="Output dir for saving generated results",
|
| 209 |
+
)
|
| 210 |
+
return parser
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
parser = base_parser()
|
| 215 |
+
args = parser.parse_args()
|
| 216 |
+
cfg = load_config(args.config)
|
| 217 |
+
|
| 218 |
+
# Build inference
|
| 219 |
+
inference = BaseInference(cfg, args)
|
| 220 |
+
inference()
|
Amphion/models/base/new_inference.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import re
|
| 9 |
+
import time
|
| 10 |
+
from abc import abstractmethod
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import accelerate
|
| 14 |
+
import json5
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
from accelerate.logging import get_logger
|
| 18 |
+
from torch.utils.data import DataLoader
|
| 19 |
+
|
| 20 |
+
from models.vocoders.vocoder_inference import synthesis
|
| 21 |
+
from utils.io import save_audio
|
| 22 |
+
from utils.util import load_config
|
| 23 |
+
from utils.audio_slicer import is_silence
|
| 24 |
+
|
| 25 |
+
EPS = 1.0e-12
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BaseInference(object):
|
| 29 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
start = time.monotonic_ns()
|
| 33 |
+
self.args = args
|
| 34 |
+
self.cfg = cfg
|
| 35 |
+
|
| 36 |
+
assert infer_type in ["from_dataset", "from_file"]
|
| 37 |
+
self.infer_type = infer_type
|
| 38 |
+
|
| 39 |
+
# init with accelerate
|
| 40 |
+
self.accelerator = accelerate.Accelerator()
|
| 41 |
+
self.accelerator.wait_for_everyone()
|
| 42 |
+
|
| 43 |
+
# Use accelerate logger for distributed inference
|
| 44 |
+
with self.accelerator.main_process_first():
|
| 45 |
+
self.logger = get_logger("inference", log_level=args.log_level)
|
| 46 |
+
|
| 47 |
+
# Log some info
|
| 48 |
+
self.logger.info("=" * 56)
|
| 49 |
+
self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
|
| 50 |
+
self.logger.info("=" * 56)
|
| 51 |
+
self.logger.info("\n")
|
| 52 |
+
self.logger.debug(f"Using {args.log_level.upper()} logging level.")
|
| 53 |
+
|
| 54 |
+
self.acoustics_dir = args.acoustics_dir
|
| 55 |
+
self.logger.debug(f"Acoustic dir: {args.acoustics_dir}")
|
| 56 |
+
self.vocoder_dir = args.vocoder_dir
|
| 57 |
+
self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
|
| 58 |
+
# should be in svc inferencer
|
| 59 |
+
# self.target_singer = args.target_singer
|
| 60 |
+
# self.logger.info(f"Target singers: {args.target_singer}")
|
| 61 |
+
# self.trans_key = args.trans_key
|
| 62 |
+
# self.logger.info(f"Trans key: {args.trans_key}")
|
| 63 |
+
|
| 64 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 65 |
+
|
| 66 |
+
# set random seed
|
| 67 |
+
with self.accelerator.main_process_first():
|
| 68 |
+
start = time.monotonic_ns()
|
| 69 |
+
self._set_random_seed(self.cfg.train.random_seed)
|
| 70 |
+
end = time.monotonic_ns()
|
| 71 |
+
self.logger.debug(
|
| 72 |
+
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
|
| 73 |
+
)
|
| 74 |
+
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
|
| 75 |
+
|
| 76 |
+
# setup data_loader
|
| 77 |
+
with self.accelerator.main_process_first():
|
| 78 |
+
self.logger.info("Building dataset...")
|
| 79 |
+
start = time.monotonic_ns()
|
| 80 |
+
self.test_dataloader = self._build_dataloader()
|
| 81 |
+
end = time.monotonic_ns()
|
| 82 |
+
self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
|
| 83 |
+
|
| 84 |
+
# setup model
|
| 85 |
+
with self.accelerator.main_process_first():
|
| 86 |
+
self.logger.info("Building model...")
|
| 87 |
+
start = time.monotonic_ns()
|
| 88 |
+
self.model = self._build_model()
|
| 89 |
+
end = time.monotonic_ns()
|
| 90 |
+
# self.logger.debug(self.model)
|
| 91 |
+
self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
|
| 92 |
+
|
| 93 |
+
# init with accelerate
|
| 94 |
+
self.logger.info("Initializing accelerate...")
|
| 95 |
+
start = time.monotonic_ns()
|
| 96 |
+
self.accelerator = accelerate.Accelerator()
|
| 97 |
+
self.model = self.accelerator.prepare(self.model)
|
| 98 |
+
end = time.monotonic_ns()
|
| 99 |
+
self.accelerator.wait_for_everyone()
|
| 100 |
+
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
|
| 101 |
+
|
| 102 |
+
with self.accelerator.main_process_first():
|
| 103 |
+
self.logger.info("Loading checkpoint...")
|
| 104 |
+
start = time.monotonic_ns()
|
| 105 |
+
# TODO: Also, suppose only use latest one yet
|
| 106 |
+
self.__load_model(os.path.join(args.acoustics_dir, "checkpoint"))
|
| 107 |
+
end = time.monotonic_ns()
|
| 108 |
+
self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
|
| 109 |
+
|
| 110 |
+
self.model.eval()
|
| 111 |
+
self.accelerator.wait_for_everyone()
|
| 112 |
+
|
| 113 |
+
### Abstract methods ###
|
| 114 |
+
@abstractmethod
|
| 115 |
+
def _build_test_dataset(self):
|
| 116 |
+
pass
|
| 117 |
+
|
| 118 |
+
@abstractmethod
|
| 119 |
+
def _build_model(self):
|
| 120 |
+
pass
|
| 121 |
+
|
| 122 |
+
@abstractmethod
|
| 123 |
+
@torch.inference_mode()
|
| 124 |
+
def _inference_each_batch(self, batch_data):
|
| 125 |
+
pass
|
| 126 |
+
|
| 127 |
+
### Abstract methods end ###
|
| 128 |
+
|
| 129 |
+
@torch.inference_mode()
|
| 130 |
+
def inference(self):
|
| 131 |
+
for i, batch in enumerate(self.test_dataloader):
|
| 132 |
+
y_pred = self._inference_each_batch(batch).cpu()
|
| 133 |
+
|
| 134 |
+
# Judge whether the min-max normliazation is used
|
| 135 |
+
if self.cfg.preprocess.use_min_max_norm_mel:
|
| 136 |
+
mel_min, mel_max = self.test_dataset.target_mel_extrema
|
| 137 |
+
y_pred = (y_pred + 1.0) / 2.0 * (mel_max - mel_min + EPS) + mel_min
|
| 138 |
+
|
| 139 |
+
y_ls = y_pred.chunk(self.test_batch_size)
|
| 140 |
+
tgt_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
|
| 141 |
+
j = 0
|
| 142 |
+
for it, l in zip(y_ls, tgt_ls):
|
| 143 |
+
l = l.item()
|
| 144 |
+
it = it.squeeze(0)[:l]
|
| 145 |
+
uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
|
| 146 |
+
torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
|
| 147 |
+
j += 1
|
| 148 |
+
|
| 149 |
+
vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
|
| 150 |
+
|
| 151 |
+
res = synthesis(
|
| 152 |
+
cfg=vocoder_cfg,
|
| 153 |
+
vocoder_weight_file=vocoder_ckpt,
|
| 154 |
+
n_samples=None,
|
| 155 |
+
pred=[
|
| 156 |
+
torch.load(
|
| 157 |
+
os.path.join(self.args.output_dir, "{}.pt".format(i["Uid"]))
|
| 158 |
+
).numpy(force=True)
|
| 159 |
+
for i in self.test_dataset.metadata
|
| 160 |
+
],
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
output_audio_files = []
|
| 164 |
+
for it, wav in zip(self.test_dataset.metadata, res):
|
| 165 |
+
uid = it["Uid"]
|
| 166 |
+
file = os.path.join(self.args.output_dir, f"{uid}.wav")
|
| 167 |
+
output_audio_files.append(file)
|
| 168 |
+
|
| 169 |
+
wav = wav.numpy(force=True)
|
| 170 |
+
save_audio(
|
| 171 |
+
file,
|
| 172 |
+
wav,
|
| 173 |
+
self.cfg.preprocess.sample_rate,
|
| 174 |
+
add_silence=False,
|
| 175 |
+
turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
|
| 176 |
+
)
|
| 177 |
+
os.remove(os.path.join(self.args.output_dir, f"{uid}.pt"))
|
| 178 |
+
|
| 179 |
+
return sorted(output_audio_files)
|
| 180 |
+
|
| 181 |
+
# TODO: LEGACY CODE
|
| 182 |
+
def _build_dataloader(self):
|
| 183 |
+
datasets, collate = self._build_test_dataset()
|
| 184 |
+
self.test_dataset = datasets(self.args, self.cfg, self.infer_type)
|
| 185 |
+
self.test_collate = collate(self.cfg)
|
| 186 |
+
self.test_batch_size = min(
|
| 187 |
+
self.cfg.train.batch_size, len(self.test_dataset.metadata)
|
| 188 |
+
)
|
| 189 |
+
test_dataloader = DataLoader(
|
| 190 |
+
self.test_dataset,
|
| 191 |
+
collate_fn=self.test_collate,
|
| 192 |
+
num_workers=1,
|
| 193 |
+
batch_size=self.test_batch_size,
|
| 194 |
+
shuffle=False,
|
| 195 |
+
)
|
| 196 |
+
return test_dataloader
|
| 197 |
+
|
| 198 |
+
def __load_model(self, checkpoint_dir: str = None, checkpoint_path: str = None):
|
| 199 |
+
r"""Load model from checkpoint. If checkpoint_path is None, it will
|
| 200 |
+
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
|
| 201 |
+
None, it will load the checkpoint specified by checkpoint_path. **Only use this
|
| 202 |
+
method after** ``accelerator.prepare()``.
|
| 203 |
+
"""
|
| 204 |
+
if checkpoint_path is None:
|
| 205 |
+
ls = []
|
| 206 |
+
for i in Path(checkpoint_dir).iterdir():
|
| 207 |
+
if re.match(r"epoch-\d+_step-\d+_loss-[\d.]+", str(i.stem)):
|
| 208 |
+
ls.append(i)
|
| 209 |
+
ls.sort(
|
| 210 |
+
key=lambda x: int(x.stem.split("_")[-3].split("-")[-1]), reverse=True
|
| 211 |
+
)
|
| 212 |
+
checkpoint_path = ls[0]
|
| 213 |
+
else:
|
| 214 |
+
checkpoint_path = Path(checkpoint_path)
|
| 215 |
+
self.accelerator.load_state(str(checkpoint_path))
|
| 216 |
+
# set epoch and step
|
| 217 |
+
self.epoch = int(checkpoint_path.stem.split("_")[-3].split("-")[-1])
|
| 218 |
+
self.step = int(checkpoint_path.stem.split("_")[-2].split("-")[-1])
|
| 219 |
+
return str(checkpoint_path)
|
| 220 |
+
|
| 221 |
+
@staticmethod
|
| 222 |
+
def _set_random_seed(seed):
|
| 223 |
+
r"""Set random seed for all possible random modules."""
|
| 224 |
+
random.seed(seed)
|
| 225 |
+
np.random.seed(seed)
|
| 226 |
+
torch.random.manual_seed(seed)
|
| 227 |
+
|
| 228 |
+
@staticmethod
|
| 229 |
+
def _parse_vocoder(vocoder_dir):
|
| 230 |
+
r"""Parse vocoder config"""
|
| 231 |
+
vocoder_dir = os.path.abspath(vocoder_dir)
|
| 232 |
+
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
|
| 233 |
+
ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
|
| 234 |
+
ckpt_path = str(ckpt_list[0])
|
| 235 |
+
vocoder_cfg = load_config(
|
| 236 |
+
os.path.join(vocoder_dir, "args.json"), lowercase=True
|
| 237 |
+
)
|
| 238 |
+
return vocoder_cfg, ckpt_path
|
| 239 |
+
|
| 240 |
+
@staticmethod
|
| 241 |
+
def __count_parameters(model):
|
| 242 |
+
return sum(p.numel() for p in model.parameters())
|
| 243 |
+
|
| 244 |
+
def __dump_cfg(self, path):
|
| 245 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 246 |
+
json5.dump(
|
| 247 |
+
self.cfg,
|
| 248 |
+
open(path, "w"),
|
| 249 |
+
indent=4,
|
| 250 |
+
sort_keys=True,
|
| 251 |
+
ensure_ascii=False,
|
| 252 |
+
quote_keys=True,
|
| 253 |
+
)
|
Amphion/models/codec/ns3_codec/alias_free_torch/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
|
| 3 |
+
from .filter import *
|
| 4 |
+
from .resample import *
|
| 5 |
+
from .act import *
|
Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/act.cpython-310.pyc
ADDED
|
Binary file (1.11 kB). View file
|
|
|
Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/filter.cpython-310.pyc
ADDED
|
Binary file (2.69 kB). View file
|
|
|
Amphion/models/codec/ns3_codec/alias_free_torch/act.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from .resample import UpSample1d, DownSample1d
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Activation1d(nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
activation,
|
| 11 |
+
up_ratio: int = 2,
|
| 12 |
+
down_ratio: int = 2,
|
| 13 |
+
up_kernel_size: int = 12,
|
| 14 |
+
down_kernel_size: int = 12,
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.up_ratio = up_ratio
|
| 18 |
+
self.down_ratio = down_ratio
|
| 19 |
+
self.act = activation
|
| 20 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
| 21 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
| 22 |
+
|
| 23 |
+
# x: [B,C,T]
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
x = self.upsample(x)
|
| 26 |
+
x = self.act(x)
|
| 27 |
+
x = self.downsample(x)
|
| 28 |
+
|
| 29 |
+
return x
|
Amphion/models/codec/ns3_codec/facodec.py
ADDED
|
@@ -0,0 +1,1163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from einops.layers.torch import Rearrange
|
| 10 |
+
from torch import nn, pow, sin
|
| 11 |
+
from torch.nn import Parameter
|
| 12 |
+
from torch.nn.utils import weight_norm
|
| 13 |
+
|
| 14 |
+
from .alias_free_torch import *
|
| 15 |
+
from .gradient_reversal import GradientReversal
|
| 16 |
+
from .melspec import MelSpectrogram
|
| 17 |
+
from .quantize import *
|
| 18 |
+
from .transformer import TransformerEncoder
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def init_weights(m):
|
| 22 |
+
if isinstance(m, nn.Conv1d):
|
| 23 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 24 |
+
nn.init.constant_(m.bias, 0)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def WNConv1d(*args, **kwargs):
|
| 28 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def WNConvTranspose1d(*args, **kwargs):
|
| 32 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class CNNLSTM(nn.Module):
|
| 36 |
+
def __init__(self, indim, outdim, head, global_pred=False):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.global_pred = global_pred
|
| 39 |
+
self.model = nn.Sequential(
|
| 40 |
+
ResidualUnit(indim, dilation=1),
|
| 41 |
+
ResidualUnit(indim, dilation=2),
|
| 42 |
+
ResidualUnit(indim, dilation=3),
|
| 43 |
+
Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
|
| 44 |
+
Rearrange("b c t -> b t c"),
|
| 45 |
+
)
|
| 46 |
+
self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
# x: [B, C, T]
|
| 50 |
+
x = self.model(x)
|
| 51 |
+
if self.global_pred:
|
| 52 |
+
x = torch.mean(x, dim=1, keepdim=False)
|
| 53 |
+
outs = [head(x) for head in self.heads]
|
| 54 |
+
return outs
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SnakeBeta(nn.Module):
|
| 58 |
+
"""
|
| 59 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
| 60 |
+
Shape:
|
| 61 |
+
- Input: (B, C, T)
|
| 62 |
+
- Output: (B, C, T), same shape as the input
|
| 63 |
+
Parameters:
|
| 64 |
+
- alpha - trainable parameter that controls frequency
|
| 65 |
+
- beta - trainable parameter that controls magnitude
|
| 66 |
+
References:
|
| 67 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 68 |
+
https://arxiv.org/abs/2006.08195
|
| 69 |
+
Examples:
|
| 70 |
+
>>> a1 = snakebeta(256)
|
| 71 |
+
>>> x = torch.randn(256)
|
| 72 |
+
>>> x = a1(x)
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
| 76 |
+
"""
|
| 77 |
+
Initialization.
|
| 78 |
+
INPUT:
|
| 79 |
+
- in_features: shape of the input
|
| 80 |
+
- alpha - trainable parameter that controls frequency
|
| 81 |
+
- beta - trainable parameter that controls magnitude
|
| 82 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 83 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
| 84 |
+
alpha will be trained along with the rest of your model.
|
| 85 |
+
"""
|
| 86 |
+
super(SnakeBeta, self).__init__()
|
| 87 |
+
self.in_features = in_features
|
| 88 |
+
|
| 89 |
+
# initialize alpha
|
| 90 |
+
self.alpha_logscale = alpha_logscale
|
| 91 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 92 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 93 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
| 94 |
+
else: # linear scale alphas initialized to ones
|
| 95 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 96 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
| 97 |
+
|
| 98 |
+
self.alpha.requires_grad = alpha_trainable
|
| 99 |
+
self.beta.requires_grad = alpha_trainable
|
| 100 |
+
|
| 101 |
+
self.no_div_by_zero = 0.000000001
|
| 102 |
+
|
| 103 |
+
def forward(self, x):
|
| 104 |
+
"""
|
| 105 |
+
Forward pass of the function.
|
| 106 |
+
Applies the function to the input elementwise.
|
| 107 |
+
SnakeBeta := x + 1/b * sin^2 (xa)
|
| 108 |
+
"""
|
| 109 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 110 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 111 |
+
if self.alpha_logscale:
|
| 112 |
+
alpha = torch.exp(alpha)
|
| 113 |
+
beta = torch.exp(beta)
|
| 114 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 115 |
+
|
| 116 |
+
return x
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class ResidualUnit(nn.Module):
|
| 120 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
| 121 |
+
super().__init__()
|
| 122 |
+
pad = ((7 - 1) * dilation) // 2
|
| 123 |
+
self.block = nn.Sequential(
|
| 124 |
+
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
|
| 125 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
| 126 |
+
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
|
| 127 |
+
WNConv1d(dim, dim, kernel_size=1),
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def forward(self, x):
|
| 131 |
+
return x + self.block(x)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class EncoderBlock(nn.Module):
|
| 135 |
+
def __init__(self, dim: int = 16, stride: int = 1):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.block = nn.Sequential(
|
| 138 |
+
ResidualUnit(dim // 2, dilation=1),
|
| 139 |
+
ResidualUnit(dim // 2, dilation=3),
|
| 140 |
+
ResidualUnit(dim // 2, dilation=9),
|
| 141 |
+
Activation1d(activation=SnakeBeta(dim // 2, alpha_logscale=True)),
|
| 142 |
+
WNConv1d(
|
| 143 |
+
dim // 2,
|
| 144 |
+
dim,
|
| 145 |
+
kernel_size=2 * stride,
|
| 146 |
+
stride=stride,
|
| 147 |
+
padding=stride // 2 + stride % 2,
|
| 148 |
+
),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
return self.block(x)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class FACodecEncoder(nn.Module):
|
| 156 |
+
def __init__(
|
| 157 |
+
self,
|
| 158 |
+
ngf=32,
|
| 159 |
+
up_ratios=(2, 4, 5, 5),
|
| 160 |
+
out_channels=1024,
|
| 161 |
+
):
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.hop_length = np.prod(up_ratios)
|
| 164 |
+
self.up_ratios = up_ratios
|
| 165 |
+
|
| 166 |
+
# Create first convolution
|
| 167 |
+
d_model = ngf
|
| 168 |
+
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
| 169 |
+
|
| 170 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
| 171 |
+
for stride in up_ratios:
|
| 172 |
+
d_model *= 2
|
| 173 |
+
self.block += [EncoderBlock(d_model, stride=stride)]
|
| 174 |
+
|
| 175 |
+
# Create last convolution
|
| 176 |
+
self.block += [
|
| 177 |
+
Activation1d(activation=SnakeBeta(d_model, alpha_logscale=True)),
|
| 178 |
+
WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
# Wrap black into nn.Sequential
|
| 182 |
+
self.block = nn.Sequential(*self.block)
|
| 183 |
+
self.enc_dim = d_model
|
| 184 |
+
|
| 185 |
+
self.reset_parameters()
|
| 186 |
+
|
| 187 |
+
def forward(self, x):
|
| 188 |
+
out = self.block(x)
|
| 189 |
+
return out
|
| 190 |
+
|
| 191 |
+
def inference(self, x):
|
| 192 |
+
return self.block(x)
|
| 193 |
+
|
| 194 |
+
def remove_weight_norm(self):
|
| 195 |
+
"""Remove weight normalization module from all of the layers."""
|
| 196 |
+
|
| 197 |
+
def _remove_weight_norm(m):
|
| 198 |
+
try:
|
| 199 |
+
torch.nn.utils.remove_weight_norm(m)
|
| 200 |
+
except ValueError: # this module didn't have weight norm
|
| 201 |
+
return
|
| 202 |
+
|
| 203 |
+
self.apply(_remove_weight_norm)
|
| 204 |
+
|
| 205 |
+
def apply_weight_norm(self):
|
| 206 |
+
"""Apply weight normalization module from all of the layers."""
|
| 207 |
+
|
| 208 |
+
def _apply_weight_norm(m):
|
| 209 |
+
if isinstance(m, nn.Conv1d):
|
| 210 |
+
torch.nn.utils.weight_norm(m)
|
| 211 |
+
|
| 212 |
+
self.apply(_apply_weight_norm)
|
| 213 |
+
|
| 214 |
+
def reset_parameters(self):
|
| 215 |
+
self.apply(init_weights)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class DecoderBlock(nn.Module):
|
| 219 |
+
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.block = nn.Sequential(
|
| 222 |
+
Activation1d(activation=SnakeBeta(input_dim, alpha_logscale=True)),
|
| 223 |
+
WNConvTranspose1d(
|
| 224 |
+
input_dim,
|
| 225 |
+
output_dim,
|
| 226 |
+
kernel_size=2 * stride,
|
| 227 |
+
stride=stride,
|
| 228 |
+
padding=stride // 2 + stride % 2,
|
| 229 |
+
output_padding=stride % 2,
|
| 230 |
+
),
|
| 231 |
+
ResidualUnit(output_dim, dilation=1),
|
| 232 |
+
ResidualUnit(output_dim, dilation=3),
|
| 233 |
+
ResidualUnit(output_dim, dilation=9),
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
def forward(self, x):
|
| 237 |
+
return self.block(x)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class FACodecDecoder(nn.Module):
|
| 241 |
+
def __init__(
|
| 242 |
+
self,
|
| 243 |
+
in_channels=256,
|
| 244 |
+
upsample_initial_channel=1536,
|
| 245 |
+
ngf=32,
|
| 246 |
+
up_ratios=(5, 5, 4, 2),
|
| 247 |
+
vq_num_q_c=2,
|
| 248 |
+
vq_num_q_p=1,
|
| 249 |
+
vq_num_q_r=3,
|
| 250 |
+
vq_dim=1024,
|
| 251 |
+
vq_commit_weight=0.005,
|
| 252 |
+
vq_weight_init=False,
|
| 253 |
+
vq_full_commit_loss=False,
|
| 254 |
+
codebook_dim=8,
|
| 255 |
+
codebook_size_prosody=10, # true codebook size is equal to 2^codebook_size
|
| 256 |
+
codebook_size_content=10,
|
| 257 |
+
codebook_size_residual=10,
|
| 258 |
+
quantizer_dropout=0.0,
|
| 259 |
+
dropout_type="linear",
|
| 260 |
+
use_gr_content_f0=False,
|
| 261 |
+
use_gr_prosody_phone=False,
|
| 262 |
+
use_gr_residual_f0=False,
|
| 263 |
+
use_gr_residual_phone=False,
|
| 264 |
+
use_gr_x_timbre=False,
|
| 265 |
+
use_random_mask_residual=True,
|
| 266 |
+
prob_random_mask_residual=0.75,
|
| 267 |
+
):
|
| 268 |
+
super().__init__()
|
| 269 |
+
self.hop_length = np.prod(up_ratios)
|
| 270 |
+
self.ngf = ngf
|
| 271 |
+
self.up_ratios = up_ratios
|
| 272 |
+
|
| 273 |
+
self.use_random_mask_residual = use_random_mask_residual
|
| 274 |
+
self.prob_random_mask_residual = prob_random_mask_residual
|
| 275 |
+
|
| 276 |
+
self.vq_num_q_p = vq_num_q_p
|
| 277 |
+
self.vq_num_q_c = vq_num_q_c
|
| 278 |
+
self.vq_num_q_r = vq_num_q_r
|
| 279 |
+
|
| 280 |
+
self.codebook_size_prosody = codebook_size_prosody
|
| 281 |
+
self.codebook_size_content = codebook_size_content
|
| 282 |
+
self.codebook_size_residual = codebook_size_residual
|
| 283 |
+
|
| 284 |
+
quantizer_class = ResidualVQ
|
| 285 |
+
|
| 286 |
+
self.quantizer = nn.ModuleList()
|
| 287 |
+
|
| 288 |
+
# prosody
|
| 289 |
+
quantizer = quantizer_class(
|
| 290 |
+
num_quantizers=vq_num_q_p,
|
| 291 |
+
dim=vq_dim,
|
| 292 |
+
codebook_size=codebook_size_prosody,
|
| 293 |
+
codebook_dim=codebook_dim,
|
| 294 |
+
threshold_ema_dead_code=2,
|
| 295 |
+
commitment=vq_commit_weight,
|
| 296 |
+
weight_init=vq_weight_init,
|
| 297 |
+
full_commit_loss=vq_full_commit_loss,
|
| 298 |
+
quantizer_dropout=quantizer_dropout,
|
| 299 |
+
dropout_type=dropout_type,
|
| 300 |
+
)
|
| 301 |
+
self.quantizer.append(quantizer)
|
| 302 |
+
|
| 303 |
+
# phone
|
| 304 |
+
quantizer = quantizer_class(
|
| 305 |
+
num_quantizers=vq_num_q_c,
|
| 306 |
+
dim=vq_dim,
|
| 307 |
+
codebook_size=codebook_size_content,
|
| 308 |
+
codebook_dim=codebook_dim,
|
| 309 |
+
threshold_ema_dead_code=2,
|
| 310 |
+
commitment=vq_commit_weight,
|
| 311 |
+
weight_init=vq_weight_init,
|
| 312 |
+
full_commit_loss=vq_full_commit_loss,
|
| 313 |
+
quantizer_dropout=quantizer_dropout,
|
| 314 |
+
dropout_type=dropout_type,
|
| 315 |
+
)
|
| 316 |
+
self.quantizer.append(quantizer)
|
| 317 |
+
|
| 318 |
+
# residual
|
| 319 |
+
if self.vq_num_q_r > 0:
|
| 320 |
+
quantizer = quantizer_class(
|
| 321 |
+
num_quantizers=vq_num_q_r,
|
| 322 |
+
dim=vq_dim,
|
| 323 |
+
codebook_size=codebook_size_residual,
|
| 324 |
+
codebook_dim=codebook_dim,
|
| 325 |
+
threshold_ema_dead_code=2,
|
| 326 |
+
commitment=vq_commit_weight,
|
| 327 |
+
weight_init=vq_weight_init,
|
| 328 |
+
full_commit_loss=vq_full_commit_loss,
|
| 329 |
+
quantizer_dropout=quantizer_dropout,
|
| 330 |
+
dropout_type=dropout_type,
|
| 331 |
+
)
|
| 332 |
+
self.quantizer.append(quantizer)
|
| 333 |
+
|
| 334 |
+
# Add first conv layer
|
| 335 |
+
channels = upsample_initial_channel
|
| 336 |
+
layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
|
| 337 |
+
|
| 338 |
+
# Add upsampling + MRF blocks
|
| 339 |
+
for i, stride in enumerate(up_ratios):
|
| 340 |
+
input_dim = channels // 2**i
|
| 341 |
+
output_dim = channels // 2 ** (i + 1)
|
| 342 |
+
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
| 343 |
+
|
| 344 |
+
# Add final conv layer
|
| 345 |
+
layers += [
|
| 346 |
+
Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
|
| 347 |
+
WNConv1d(output_dim, 1, kernel_size=7, padding=3),
|
| 348 |
+
nn.Tanh(),
|
| 349 |
+
]
|
| 350 |
+
|
| 351 |
+
self.model = nn.Sequential(*layers)
|
| 352 |
+
|
| 353 |
+
self.timbre_encoder = TransformerEncoder(
|
| 354 |
+
enc_emb_tokens=None,
|
| 355 |
+
encoder_layer=4,
|
| 356 |
+
encoder_hidden=256,
|
| 357 |
+
encoder_head=4,
|
| 358 |
+
conv_filter_size=1024,
|
| 359 |
+
conv_kernel_size=5,
|
| 360 |
+
encoder_dropout=0.1,
|
| 361 |
+
use_cln=False,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
|
| 365 |
+
self.timbre_linear.bias.data[:in_channels] = 1
|
| 366 |
+
self.timbre_linear.bias.data[in_channels:] = 0
|
| 367 |
+
self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
|
| 368 |
+
|
| 369 |
+
self.f0_predictor = CNNLSTM(in_channels, 1, 2)
|
| 370 |
+
self.phone_predictor = CNNLSTM(in_channels, 5003, 1)
|
| 371 |
+
|
| 372 |
+
self.use_gr_content_f0 = use_gr_content_f0
|
| 373 |
+
self.use_gr_prosody_phone = use_gr_prosody_phone
|
| 374 |
+
self.use_gr_residual_f0 = use_gr_residual_f0
|
| 375 |
+
self.use_gr_residual_phone = use_gr_residual_phone
|
| 376 |
+
self.use_gr_x_timbre = use_gr_x_timbre
|
| 377 |
+
|
| 378 |
+
if self.vq_num_q_r > 0 and self.use_gr_residual_f0:
|
| 379 |
+
self.res_f0_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2))
|
| 380 |
+
|
| 381 |
+
if self.vq_num_q_r > 0 and self.use_gr_residual_phone > 0:
|
| 382 |
+
self.res_phone_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1))
|
| 383 |
+
|
| 384 |
+
if self.use_gr_content_f0:
|
| 385 |
+
self.content_f0_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2))
|
| 386 |
+
|
| 387 |
+
if self.use_gr_prosody_phone:
|
| 388 |
+
self.prosody_phone_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1))
|
| 389 |
+
|
| 390 |
+
if self.use_gr_x_timbre:
|
| 391 |
+
self.x_timbre_predictor = nn.Sequential(
|
| 392 |
+
GradientReversal(alpha=1),
|
| 393 |
+
CNNLSTM(in_channels, 245200, 1, global_pred=True),
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
self.reset_parameters()
|
| 397 |
+
|
| 398 |
+
def quantize(self, x, n_quantizers=None):
|
| 399 |
+
outs, qs, commit_loss, quantized_buf = 0, [], [], []
|
| 400 |
+
|
| 401 |
+
# prosody
|
| 402 |
+
f0_input = x # (B, d, T)
|
| 403 |
+
f0_quantizer = self.quantizer[0]
|
| 404 |
+
out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers)
|
| 405 |
+
outs += out
|
| 406 |
+
qs.append(q)
|
| 407 |
+
quantized_buf.append(quantized.sum(0))
|
| 408 |
+
commit_loss.append(commit)
|
| 409 |
+
|
| 410 |
+
# phone
|
| 411 |
+
phone_input = x
|
| 412 |
+
phone_quantizer = self.quantizer[1]
|
| 413 |
+
out, q, commit, quantized = phone_quantizer(phone_input, n_quantizers=n_quantizers)
|
| 414 |
+
outs += out
|
| 415 |
+
qs.append(q)
|
| 416 |
+
quantized_buf.append(quantized.sum(0))
|
| 417 |
+
commit_loss.append(commit)
|
| 418 |
+
|
| 419 |
+
# residual
|
| 420 |
+
if self.vq_num_q_r > 0:
|
| 421 |
+
residual_quantizer = self.quantizer[2]
|
| 422 |
+
residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach()
|
| 423 |
+
out, q, commit, quantized = residual_quantizer(residual_input, n_quantizers=n_quantizers)
|
| 424 |
+
outs += out
|
| 425 |
+
qs.append(q)
|
| 426 |
+
quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T]
|
| 427 |
+
commit_loss.append(commit)
|
| 428 |
+
|
| 429 |
+
qs = torch.cat(qs, dim=0)
|
| 430 |
+
commit_loss = torch.cat(commit_loss, dim=0)
|
| 431 |
+
return outs, qs, commit_loss, quantized_buf
|
| 432 |
+
|
| 433 |
+
def forward(
|
| 434 |
+
self,
|
| 435 |
+
x,
|
| 436 |
+
vq=True,
|
| 437 |
+
get_vq=False,
|
| 438 |
+
eval_vq=True,
|
| 439 |
+
speaker_embedding=None,
|
| 440 |
+
n_quantizers=None,
|
| 441 |
+
quantized=None,
|
| 442 |
+
):
|
| 443 |
+
if get_vq:
|
| 444 |
+
return self.quantizer.get_emb()
|
| 445 |
+
if vq is True:
|
| 446 |
+
if eval_vq:
|
| 447 |
+
self.quantizer.eval()
|
| 448 |
+
x_timbre = x
|
| 449 |
+
outs, qs, commit_loss, quantized_buf = self.quantize(x, n_quantizers=n_quantizers)
|
| 450 |
+
|
| 451 |
+
x_timbre = x_timbre.transpose(1, 2)
|
| 452 |
+
x_timbre = self.timbre_encoder(x_timbre, None, None)
|
| 453 |
+
x_timbre = x_timbre.transpose(1, 2)
|
| 454 |
+
spk_embs = torch.mean(x_timbre, dim=2)
|
| 455 |
+
return outs, qs, commit_loss, quantized_buf, spk_embs
|
| 456 |
+
|
| 457 |
+
out = {}
|
| 458 |
+
|
| 459 |
+
layer_0 = quantized[0]
|
| 460 |
+
f0, uv = self.f0_predictor(layer_0)
|
| 461 |
+
f0 = rearrange(f0, "... 1 -> ...")
|
| 462 |
+
uv = rearrange(uv, "... 1 -> ...")
|
| 463 |
+
|
| 464 |
+
layer_1 = quantized[1]
|
| 465 |
+
(phone,) = self.phone_predictor(layer_1)
|
| 466 |
+
|
| 467 |
+
out = {"f0": f0, "uv": uv, "phone": phone}
|
| 468 |
+
|
| 469 |
+
if self.use_gr_prosody_phone:
|
| 470 |
+
(prosody_phone,) = self.prosody_phone_predictor(layer_0)
|
| 471 |
+
out["prosody_phone"] = prosody_phone
|
| 472 |
+
|
| 473 |
+
if self.use_gr_content_f0:
|
| 474 |
+
content_f0, content_uv = self.content_f0_predictor(layer_1)
|
| 475 |
+
content_f0 = rearrange(content_f0, "... 1 -> ...")
|
| 476 |
+
content_uv = rearrange(content_uv, "... 1 -> ...")
|
| 477 |
+
out["content_f0"] = content_f0
|
| 478 |
+
out["content_uv"] = content_uv
|
| 479 |
+
|
| 480 |
+
if self.vq_num_q_r > 0:
|
| 481 |
+
layer_2 = quantized[2]
|
| 482 |
+
|
| 483 |
+
if self.use_gr_residual_f0:
|
| 484 |
+
res_f0, res_uv = self.res_f0_predictor(layer_2)
|
| 485 |
+
res_f0 = rearrange(res_f0, "... 1 -> ...")
|
| 486 |
+
res_uv = rearrange(res_uv, "... 1 -> ...")
|
| 487 |
+
out["res_f0"] = res_f0
|
| 488 |
+
out["res_uv"] = res_uv
|
| 489 |
+
|
| 490 |
+
if self.use_gr_residual_phone:
|
| 491 |
+
(res_phone,) = self.res_phone_predictor(layer_2)
|
| 492 |
+
out["res_phone"] = res_phone
|
| 493 |
+
|
| 494 |
+
style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
|
| 495 |
+
gamma, beta = style.chunk(2, 1) # (B, d, 1)
|
| 496 |
+
if self.vq_num_q_r > 0:
|
| 497 |
+
if self.use_random_mask_residual:
|
| 498 |
+
bsz = quantized[2].shape[0]
|
| 499 |
+
res_mask = np.random.choice(
|
| 500 |
+
[0, 1],
|
| 501 |
+
size=bsz,
|
| 502 |
+
p=[
|
| 503 |
+
self.prob_random_mask_residual,
|
| 504 |
+
1 - self.prob_random_mask_residual,
|
| 505 |
+
],
|
| 506 |
+
)
|
| 507 |
+
res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
|
| 508 |
+
res_mask = res_mask.to(device=quantized[2].device, dtype=quantized[2].dtype)
|
| 509 |
+
x = quantized[0].detach() + quantized[1].detach() + quantized[2] * res_mask
|
| 510 |
+
# x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] * res_mask
|
| 511 |
+
else:
|
| 512 |
+
x = quantized[0].detach() + quantized[1].detach() + quantized[2]
|
| 513 |
+
# x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2]
|
| 514 |
+
else:
|
| 515 |
+
x = quantized[0].detach() + quantized[1].detach()
|
| 516 |
+
# x = quantized_perturbe[0].detach() + quantized[1].detach()
|
| 517 |
+
|
| 518 |
+
if self.use_gr_x_timbre:
|
| 519 |
+
(x_timbre,) = self.x_timbre_predictor(x)
|
| 520 |
+
out["x_timbre"] = x_timbre
|
| 521 |
+
|
| 522 |
+
x = x.transpose(1, 2)
|
| 523 |
+
x = self.timbre_norm(x)
|
| 524 |
+
x = x.transpose(1, 2)
|
| 525 |
+
x = x * gamma + beta
|
| 526 |
+
|
| 527 |
+
x = self.model(x)
|
| 528 |
+
out["audio"] = x
|
| 529 |
+
|
| 530 |
+
return out
|
| 531 |
+
|
| 532 |
+
def vq2emb(self, vq, use_residual_code=True):
|
| 533 |
+
# vq: [num_quantizer, B, T]
|
| 534 |
+
self.quantizer = self.quantizer.eval()
|
| 535 |
+
out = 0
|
| 536 |
+
out += self.quantizer[0].vq2emb(vq[0 : self.vq_num_q_p])
|
| 537 |
+
out += self.quantizer[1].vq2emb(vq[self.vq_num_q_p : self.vq_num_q_p + self.vq_num_q_c])
|
| 538 |
+
if self.vq_num_q_r > 0 and use_residual_code:
|
| 539 |
+
out += self.quantizer[2].vq2emb(vq[self.vq_num_q_p + self.vq_num_q_c :])
|
| 540 |
+
return out
|
| 541 |
+
|
| 542 |
+
def inference(self, x, speaker_embedding):
|
| 543 |
+
style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
|
| 544 |
+
gamma, beta = style.chunk(2, 1) # (B, d, 1)
|
| 545 |
+
x = x.transpose(1, 2)
|
| 546 |
+
x = self.timbre_norm(x)
|
| 547 |
+
x = x.transpose(1, 2)
|
| 548 |
+
x = x * gamma + beta
|
| 549 |
+
x = self.model(x)
|
| 550 |
+
return x
|
| 551 |
+
|
| 552 |
+
def remove_weight_norm(self):
|
| 553 |
+
"""Remove weight normalization module from all of the layers."""
|
| 554 |
+
|
| 555 |
+
def _remove_weight_norm(m):
|
| 556 |
+
try:
|
| 557 |
+
torch.nn.utils.remove_weight_norm(m)
|
| 558 |
+
except ValueError: # this module didn't have weight norm
|
| 559 |
+
return
|
| 560 |
+
|
| 561 |
+
self.apply(_remove_weight_norm)
|
| 562 |
+
|
| 563 |
+
def apply_weight_norm(self):
|
| 564 |
+
"""Apply weight normalization module from all of the layers."""
|
| 565 |
+
|
| 566 |
+
def _apply_weight_norm(m):
|
| 567 |
+
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
|
| 568 |
+
torch.nn.utils.weight_norm(m)
|
| 569 |
+
|
| 570 |
+
self.apply(_apply_weight_norm)
|
| 571 |
+
|
| 572 |
+
def reset_parameters(self):
|
| 573 |
+
self.apply(init_weights)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
class FACodecRedecoder(nn.Module):
|
| 577 |
+
def __init__(
|
| 578 |
+
self,
|
| 579 |
+
in_channels=256,
|
| 580 |
+
upsample_initial_channel=1280,
|
| 581 |
+
up_ratios=(5, 5, 4, 2),
|
| 582 |
+
vq_num_q_c=2,
|
| 583 |
+
vq_num_q_p=1,
|
| 584 |
+
vq_num_q_r=3,
|
| 585 |
+
vq_dim=256,
|
| 586 |
+
codebook_size_prosody=10,
|
| 587 |
+
codebook_size_content=10,
|
| 588 |
+
codebook_size_residual=10,
|
| 589 |
+
):
|
| 590 |
+
super().__init__()
|
| 591 |
+
self.hop_length = np.prod(up_ratios)
|
| 592 |
+
self.up_ratios = up_ratios
|
| 593 |
+
|
| 594 |
+
self.vq_num_q_p = vq_num_q_p
|
| 595 |
+
self.vq_num_q_c = vq_num_q_c
|
| 596 |
+
self.vq_num_q_r = vq_num_q_r
|
| 597 |
+
|
| 598 |
+
self.vq_dim = vq_dim
|
| 599 |
+
|
| 600 |
+
self.codebook_size_prosody = codebook_size_prosody
|
| 601 |
+
self.codebook_size_content = codebook_size_content
|
| 602 |
+
self.codebook_size_residual = codebook_size_residual
|
| 603 |
+
|
| 604 |
+
self.prosody_embs = nn.ModuleList()
|
| 605 |
+
for i in range(self.vq_num_q_p):
|
| 606 |
+
emb_tokens = nn.Embedding(
|
| 607 |
+
num_embeddings=2**self.codebook_size_prosody,
|
| 608 |
+
embedding_dim=self.vq_dim,
|
| 609 |
+
)
|
| 610 |
+
emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
|
| 611 |
+
self.prosody_embs.append(emb_tokens)
|
| 612 |
+
self.content_embs = nn.ModuleList()
|
| 613 |
+
for i in range(self.vq_num_q_c):
|
| 614 |
+
emb_tokens = nn.Embedding(
|
| 615 |
+
num_embeddings=2**self.codebook_size_content,
|
| 616 |
+
embedding_dim=self.vq_dim,
|
| 617 |
+
)
|
| 618 |
+
emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
|
| 619 |
+
self.content_embs.append(emb_tokens)
|
| 620 |
+
self.residual_embs = nn.ModuleList()
|
| 621 |
+
for i in range(self.vq_num_q_r):
|
| 622 |
+
emb_tokens = nn.Embedding(
|
| 623 |
+
num_embeddings=2**self.codebook_size_residual,
|
| 624 |
+
embedding_dim=self.vq_dim,
|
| 625 |
+
)
|
| 626 |
+
emb_tokens.weight.data.normal_(mean=0.0, std=1e-5)
|
| 627 |
+
self.residual_embs.append(emb_tokens)
|
| 628 |
+
|
| 629 |
+
# Add first conv layer
|
| 630 |
+
channels = upsample_initial_channel
|
| 631 |
+
layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
|
| 632 |
+
|
| 633 |
+
# Add upsampling + MRF blocks
|
| 634 |
+
for i, stride in enumerate(up_ratios):
|
| 635 |
+
input_dim = channels // 2**i
|
| 636 |
+
output_dim = channels // 2 ** (i + 1)
|
| 637 |
+
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
| 638 |
+
|
| 639 |
+
# Add final conv layer
|
| 640 |
+
layers += [
|
| 641 |
+
Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
|
| 642 |
+
WNConv1d(output_dim, 1, kernel_size=7, padding=3),
|
| 643 |
+
nn.Tanh(),
|
| 644 |
+
]
|
| 645 |
+
|
| 646 |
+
self.model = nn.Sequential(*layers)
|
| 647 |
+
|
| 648 |
+
self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
|
| 649 |
+
self.timbre_linear.bias.data[:in_channels] = 1
|
| 650 |
+
self.timbre_linear.bias.data[in_channels:] = 0
|
| 651 |
+
self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
|
| 652 |
+
|
| 653 |
+
self.timbre_cond_prosody_enc = TransformerEncoder(
|
| 654 |
+
enc_emb_tokens=None,
|
| 655 |
+
encoder_layer=4,
|
| 656 |
+
encoder_hidden=256,
|
| 657 |
+
encoder_head=4,
|
| 658 |
+
conv_filter_size=1024,
|
| 659 |
+
conv_kernel_size=5,
|
| 660 |
+
encoder_dropout=0.1,
|
| 661 |
+
use_cln=True,
|
| 662 |
+
cfg=None,
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
def forward(
|
| 666 |
+
self,
|
| 667 |
+
vq,
|
| 668 |
+
speaker_embedding,
|
| 669 |
+
use_residual_code=False,
|
| 670 |
+
):
|
| 671 |
+
x = 0
|
| 672 |
+
|
| 673 |
+
x_p = 0
|
| 674 |
+
for i in range(self.vq_num_q_p):
|
| 675 |
+
x_p = x_p + self.prosody_embs[i](vq[i]) # (B, T, d)
|
| 676 |
+
spk_cond = speaker_embedding.unsqueeze(1).expand(-1, x_p.shape[1], -1)
|
| 677 |
+
x_p = self.timbre_cond_prosody_enc(x_p, key_padding_mask=None, condition=spk_cond)
|
| 678 |
+
x = x + x_p
|
| 679 |
+
|
| 680 |
+
x_c = 0
|
| 681 |
+
for i in range(self.vq_num_q_c):
|
| 682 |
+
x_c = x_c + self.content_embs[i](vq[self.vq_num_q_p + i])
|
| 683 |
+
|
| 684 |
+
x = x + x_c
|
| 685 |
+
|
| 686 |
+
if use_residual_code:
|
| 687 |
+
x_r = 0
|
| 688 |
+
for i in range(self.vq_num_q_r):
|
| 689 |
+
x_r = x_r + self.residual_embs[i](vq[self.vq_num_q_p + self.vq_num_q_c + i])
|
| 690 |
+
x = x + x_r
|
| 691 |
+
|
| 692 |
+
style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
|
| 693 |
+
gamma, beta = style.chunk(2, 1) # (B, d, 1)
|
| 694 |
+
x = x.transpose(1, 2)
|
| 695 |
+
x = self.timbre_norm(x)
|
| 696 |
+
x = x.transpose(1, 2)
|
| 697 |
+
x = x * gamma + beta
|
| 698 |
+
x = self.model(x)
|
| 699 |
+
|
| 700 |
+
return x
|
| 701 |
+
|
| 702 |
+
def vq2emb(self, vq, speaker_embedding, use_residual=True):
|
| 703 |
+
out = 0
|
| 704 |
+
|
| 705 |
+
x_t = 0
|
| 706 |
+
for i in range(self.vq_num_q_p):
|
| 707 |
+
x_t += self.prosody_embs[i](vq[i]) # (B, T, d)
|
| 708 |
+
spk_cond = speaker_embedding.unsqueeze(1).expand(-1, x_t.shape[1], -1)
|
| 709 |
+
x_t = self.timbre_cond_prosody_enc(x_t, key_padding_mask=None, condition=spk_cond)
|
| 710 |
+
|
| 711 |
+
# prosody
|
| 712 |
+
out += x_t
|
| 713 |
+
|
| 714 |
+
# content
|
| 715 |
+
for i in range(self.vq_num_q_c):
|
| 716 |
+
out += self.content_embs[i](vq[self.vq_num_q_p + i])
|
| 717 |
+
|
| 718 |
+
# residual
|
| 719 |
+
if use_residual:
|
| 720 |
+
for i in range(self.vq_num_q_r):
|
| 721 |
+
out += self.residual_embs[i](vq[self.vq_num_q_p + self.vq_num_q_c + i])
|
| 722 |
+
|
| 723 |
+
out = out.transpose(1, 2) # (B, T, d) -> (B, d, T)
|
| 724 |
+
return out
|
| 725 |
+
|
| 726 |
+
def inference(self, x, speaker_embedding):
|
| 727 |
+
style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
|
| 728 |
+
gamma, beta = style.chunk(2, 1) # (B, d, 1)
|
| 729 |
+
x = x.transpose(1, 2)
|
| 730 |
+
x = self.timbre_norm(x)
|
| 731 |
+
x = x.transpose(1, 2)
|
| 732 |
+
x = x * gamma + beta
|
| 733 |
+
x = self.model(x)
|
| 734 |
+
return x
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
class FACodecEncoderV2(nn.Module):
|
| 738 |
+
def __init__(
|
| 739 |
+
self,
|
| 740 |
+
ngf=32,
|
| 741 |
+
up_ratios=(2, 4, 5, 5),
|
| 742 |
+
out_channels=1024,
|
| 743 |
+
):
|
| 744 |
+
super().__init__()
|
| 745 |
+
self.hop_length = np.prod(up_ratios)
|
| 746 |
+
self.up_ratios = up_ratios
|
| 747 |
+
|
| 748 |
+
# Create first convolution
|
| 749 |
+
d_model = ngf
|
| 750 |
+
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
|
| 751 |
+
|
| 752 |
+
# Create EncoderBlocks that double channels as they downsample by `stride`
|
| 753 |
+
for stride in up_ratios:
|
| 754 |
+
d_model *= 2
|
| 755 |
+
self.block += [EncoderBlock(d_model, stride=stride)]
|
| 756 |
+
|
| 757 |
+
# Create last convolution
|
| 758 |
+
self.block += [
|
| 759 |
+
Activation1d(activation=SnakeBeta(d_model, alpha_logscale=True)),
|
| 760 |
+
WNConv1d(d_model, out_channels, kernel_size=3, padding=1),
|
| 761 |
+
]
|
| 762 |
+
|
| 763 |
+
# Wrap black into nn.Sequential
|
| 764 |
+
self.block = nn.Sequential(*self.block)
|
| 765 |
+
self.enc_dim = d_model
|
| 766 |
+
|
| 767 |
+
self.mel_transform = MelSpectrogram(
|
| 768 |
+
n_fft=1024,
|
| 769 |
+
num_mels=80,
|
| 770 |
+
sampling_rate=16000,
|
| 771 |
+
hop_size=200,
|
| 772 |
+
win_size=800,
|
| 773 |
+
fmin=0,
|
| 774 |
+
fmax=8000,
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
self.reset_parameters()
|
| 778 |
+
|
| 779 |
+
def forward(self, x):
|
| 780 |
+
out = self.block(x)
|
| 781 |
+
return out
|
| 782 |
+
|
| 783 |
+
def inference(self, x):
|
| 784 |
+
return self.block(x)
|
| 785 |
+
|
| 786 |
+
def get_prosody_feature(self, x):
|
| 787 |
+
return self.mel_transform(x.squeeze(1))[:, :20, :]
|
| 788 |
+
|
| 789 |
+
def remove_weight_norm(self):
|
| 790 |
+
"""Remove weight normalization module from all of the layers."""
|
| 791 |
+
|
| 792 |
+
def _remove_weight_norm(m):
|
| 793 |
+
try:
|
| 794 |
+
torch.nn.utils.remove_weight_norm(m)
|
| 795 |
+
except ValueError: # this module didn't have weight norm
|
| 796 |
+
return
|
| 797 |
+
|
| 798 |
+
self.apply(_remove_weight_norm)
|
| 799 |
+
|
| 800 |
+
def apply_weight_norm(self):
|
| 801 |
+
"""Apply weight normalization module from all of the layers."""
|
| 802 |
+
|
| 803 |
+
def _apply_weight_norm(m):
|
| 804 |
+
if isinstance(m, nn.Conv1d):
|
| 805 |
+
torch.nn.utils.weight_norm(m)
|
| 806 |
+
|
| 807 |
+
self.apply(_apply_weight_norm)
|
| 808 |
+
|
| 809 |
+
def reset_parameters(self):
|
| 810 |
+
self.apply(init_weights)
|
| 811 |
+
|
| 812 |
+
|
| 813 |
+
class FACodecDecoderV2(nn.Module):
|
| 814 |
+
def __init__(
|
| 815 |
+
self,
|
| 816 |
+
in_channels=256,
|
| 817 |
+
upsample_initial_channel=1536,
|
| 818 |
+
ngf=32,
|
| 819 |
+
up_ratios=(5, 5, 4, 2),
|
| 820 |
+
vq_num_q_c=2,
|
| 821 |
+
vq_num_q_p=1,
|
| 822 |
+
vq_num_q_r=3,
|
| 823 |
+
vq_dim=1024,
|
| 824 |
+
vq_commit_weight=0.005,
|
| 825 |
+
vq_weight_init=False,
|
| 826 |
+
vq_full_commit_loss=False,
|
| 827 |
+
codebook_dim=8,
|
| 828 |
+
codebook_size_prosody=10, # true codebook size is equal to 2^codebook_size
|
| 829 |
+
codebook_size_content=10,
|
| 830 |
+
codebook_size_residual=10,
|
| 831 |
+
quantizer_dropout=0.0,
|
| 832 |
+
dropout_type="linear",
|
| 833 |
+
use_gr_content_f0=False,
|
| 834 |
+
use_gr_prosody_phone=False,
|
| 835 |
+
use_gr_residual_f0=False,
|
| 836 |
+
use_gr_residual_phone=False,
|
| 837 |
+
use_gr_x_timbre=False,
|
| 838 |
+
use_random_mask_residual=True,
|
| 839 |
+
prob_random_mask_residual=0.75,
|
| 840 |
+
):
|
| 841 |
+
super().__init__()
|
| 842 |
+
self.hop_length = np.prod(up_ratios)
|
| 843 |
+
self.ngf = ngf
|
| 844 |
+
self.up_ratios = up_ratios
|
| 845 |
+
|
| 846 |
+
self.use_random_mask_residual = use_random_mask_residual
|
| 847 |
+
self.prob_random_mask_residual = prob_random_mask_residual
|
| 848 |
+
|
| 849 |
+
self.vq_num_q_p = vq_num_q_p
|
| 850 |
+
self.vq_num_q_c = vq_num_q_c
|
| 851 |
+
self.vq_num_q_r = vq_num_q_r
|
| 852 |
+
|
| 853 |
+
self.codebook_size_prosody = codebook_size_prosody
|
| 854 |
+
self.codebook_size_content = codebook_size_content
|
| 855 |
+
self.codebook_size_residual = codebook_size_residual
|
| 856 |
+
|
| 857 |
+
quantizer_class = ResidualVQ
|
| 858 |
+
|
| 859 |
+
self.quantizer = nn.ModuleList()
|
| 860 |
+
|
| 861 |
+
# prosody
|
| 862 |
+
quantizer = quantizer_class(
|
| 863 |
+
num_quantizers=vq_num_q_p,
|
| 864 |
+
dim=vq_dim,
|
| 865 |
+
codebook_size=codebook_size_prosody,
|
| 866 |
+
codebook_dim=codebook_dim,
|
| 867 |
+
threshold_ema_dead_code=2,
|
| 868 |
+
commitment=vq_commit_weight,
|
| 869 |
+
weight_init=vq_weight_init,
|
| 870 |
+
full_commit_loss=vq_full_commit_loss,
|
| 871 |
+
quantizer_dropout=quantizer_dropout,
|
| 872 |
+
dropout_type=dropout_type,
|
| 873 |
+
)
|
| 874 |
+
self.quantizer.append(quantizer)
|
| 875 |
+
|
| 876 |
+
# phone
|
| 877 |
+
quantizer = quantizer_class(
|
| 878 |
+
num_quantizers=vq_num_q_c,
|
| 879 |
+
dim=vq_dim,
|
| 880 |
+
codebook_size=codebook_size_content,
|
| 881 |
+
codebook_dim=codebook_dim,
|
| 882 |
+
threshold_ema_dead_code=2,
|
| 883 |
+
commitment=vq_commit_weight,
|
| 884 |
+
weight_init=vq_weight_init,
|
| 885 |
+
full_commit_loss=vq_full_commit_loss,
|
| 886 |
+
quantizer_dropout=quantizer_dropout,
|
| 887 |
+
dropout_type=dropout_type,
|
| 888 |
+
)
|
| 889 |
+
self.quantizer.append(quantizer)
|
| 890 |
+
|
| 891 |
+
# residual
|
| 892 |
+
if self.vq_num_q_r > 0:
|
| 893 |
+
quantizer = quantizer_class(
|
| 894 |
+
num_quantizers=vq_num_q_r,
|
| 895 |
+
dim=vq_dim,
|
| 896 |
+
codebook_size=codebook_size_residual,
|
| 897 |
+
codebook_dim=codebook_dim,
|
| 898 |
+
threshold_ema_dead_code=2,
|
| 899 |
+
commitment=vq_commit_weight,
|
| 900 |
+
weight_init=vq_weight_init,
|
| 901 |
+
full_commit_loss=vq_full_commit_loss,
|
| 902 |
+
quantizer_dropout=quantizer_dropout,
|
| 903 |
+
dropout_type=dropout_type,
|
| 904 |
+
)
|
| 905 |
+
self.quantizer.append(quantizer)
|
| 906 |
+
|
| 907 |
+
# Add first conv layer
|
| 908 |
+
channels = upsample_initial_channel
|
| 909 |
+
layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)]
|
| 910 |
+
|
| 911 |
+
# Add upsampling + MRF blocks
|
| 912 |
+
for i, stride in enumerate(up_ratios):
|
| 913 |
+
input_dim = channels // 2**i
|
| 914 |
+
output_dim = channels // 2 ** (i + 1)
|
| 915 |
+
layers += [DecoderBlock(input_dim, output_dim, stride)]
|
| 916 |
+
|
| 917 |
+
# Add final conv layer
|
| 918 |
+
layers += [
|
| 919 |
+
Activation1d(activation=SnakeBeta(output_dim, alpha_logscale=True)),
|
| 920 |
+
WNConv1d(output_dim, 1, kernel_size=7, padding=3),
|
| 921 |
+
nn.Tanh(),
|
| 922 |
+
]
|
| 923 |
+
|
| 924 |
+
self.model = nn.Sequential(*layers)
|
| 925 |
+
|
| 926 |
+
self.timbre_encoder = TransformerEncoder(
|
| 927 |
+
enc_emb_tokens=None,
|
| 928 |
+
encoder_layer=4,
|
| 929 |
+
encoder_hidden=256,
|
| 930 |
+
encoder_head=4,
|
| 931 |
+
conv_filter_size=1024,
|
| 932 |
+
conv_kernel_size=5,
|
| 933 |
+
encoder_dropout=0.1,
|
| 934 |
+
use_cln=False,
|
| 935 |
+
)
|
| 936 |
+
|
| 937 |
+
self.timbre_linear = nn.Linear(in_channels, in_channels * 2)
|
| 938 |
+
self.timbre_linear.bias.data[:in_channels] = 1
|
| 939 |
+
self.timbre_linear.bias.data[in_channels:] = 0
|
| 940 |
+
self.timbre_norm = nn.LayerNorm(in_channels, elementwise_affine=False)
|
| 941 |
+
|
| 942 |
+
self.f0_predictor = CNNLSTM(in_channels, 1, 2)
|
| 943 |
+
self.phone_predictor = CNNLSTM(in_channels, 5003, 1)
|
| 944 |
+
|
| 945 |
+
self.use_gr_content_f0 = use_gr_content_f0
|
| 946 |
+
self.use_gr_prosody_phone = use_gr_prosody_phone
|
| 947 |
+
self.use_gr_residual_f0 = use_gr_residual_f0
|
| 948 |
+
self.use_gr_residual_phone = use_gr_residual_phone
|
| 949 |
+
self.use_gr_x_timbre = use_gr_x_timbre
|
| 950 |
+
|
| 951 |
+
if self.vq_num_q_r > 0 and self.use_gr_residual_f0:
|
| 952 |
+
self.res_f0_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2))
|
| 953 |
+
|
| 954 |
+
if self.vq_num_q_r > 0 and self.use_gr_residual_phone > 0:
|
| 955 |
+
self.res_phone_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1))
|
| 956 |
+
|
| 957 |
+
if self.use_gr_content_f0:
|
| 958 |
+
self.content_f0_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 1, 2))
|
| 959 |
+
|
| 960 |
+
if self.use_gr_prosody_phone:
|
| 961 |
+
self.prosody_phone_predictor = nn.Sequential(GradientReversal(alpha=1.0), CNNLSTM(in_channels, 5003, 1))
|
| 962 |
+
|
| 963 |
+
if self.use_gr_x_timbre:
|
| 964 |
+
self.x_timbre_predictor = nn.Sequential(
|
| 965 |
+
GradientReversal(alpha=1),
|
| 966 |
+
CNNLSTM(in_channels, 245200, 1, global_pred=True),
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
self.melspec_linear = nn.Linear(20, 256)
|
| 970 |
+
self.melspec_encoder = TransformerEncoder(
|
| 971 |
+
enc_emb_tokens=None,
|
| 972 |
+
encoder_layer=4,
|
| 973 |
+
encoder_hidden=256,
|
| 974 |
+
encoder_head=4,
|
| 975 |
+
conv_filter_size=1024,
|
| 976 |
+
conv_kernel_size=5,
|
| 977 |
+
encoder_dropout=0.1,
|
| 978 |
+
use_cln=False,
|
| 979 |
+
cfg=None,
|
| 980 |
+
)
|
| 981 |
+
|
| 982 |
+
self.reset_parameters()
|
| 983 |
+
|
| 984 |
+
def quantize(self, x, prosody_feature, n_quantizers=None):
|
| 985 |
+
outs, qs, commit_loss, quantized_buf = 0, [], [], []
|
| 986 |
+
|
| 987 |
+
# prosody
|
| 988 |
+
f0_input = prosody_feature.transpose(1, 2) # (B, T, 20)
|
| 989 |
+
f0_input = self.melspec_linear(f0_input)
|
| 990 |
+
f0_input = self.melspec_encoder(f0_input, None, None)
|
| 991 |
+
f0_input = f0_input.transpose(1, 2)
|
| 992 |
+
f0_quantizer = self.quantizer[0]
|
| 993 |
+
out, q, commit, quantized = f0_quantizer(f0_input, n_quantizers=n_quantizers)
|
| 994 |
+
outs += out
|
| 995 |
+
qs.append(q)
|
| 996 |
+
quantized_buf.append(quantized.sum(0))
|
| 997 |
+
commit_loss.append(commit)
|
| 998 |
+
|
| 999 |
+
# phone
|
| 1000 |
+
phone_input = x
|
| 1001 |
+
phone_quantizer = self.quantizer[1]
|
| 1002 |
+
out, q, commit, quantized = phone_quantizer(phone_input, n_quantizers=n_quantizers)
|
| 1003 |
+
outs += out
|
| 1004 |
+
qs.append(q)
|
| 1005 |
+
quantized_buf.append(quantized.sum(0))
|
| 1006 |
+
commit_loss.append(commit)
|
| 1007 |
+
|
| 1008 |
+
# residual
|
| 1009 |
+
if self.vq_num_q_r > 0:
|
| 1010 |
+
residual_quantizer = self.quantizer[2]
|
| 1011 |
+
residual_input = x - (quantized_buf[0] + quantized_buf[1]).detach()
|
| 1012 |
+
out, q, commit, quantized = residual_quantizer(residual_input, n_quantizers=n_quantizers)
|
| 1013 |
+
outs += out
|
| 1014 |
+
qs.append(q)
|
| 1015 |
+
quantized_buf.append(quantized.sum(0)) # [L, B, C, T] -> [B, C, T]
|
| 1016 |
+
commit_loss.append(commit)
|
| 1017 |
+
|
| 1018 |
+
qs = torch.cat(qs, dim=0)
|
| 1019 |
+
commit_loss = torch.cat(commit_loss, dim=0)
|
| 1020 |
+
return outs, qs, commit_loss, quantized_buf
|
| 1021 |
+
|
| 1022 |
+
def forward(
|
| 1023 |
+
self,
|
| 1024 |
+
x,
|
| 1025 |
+
prosody_feature,
|
| 1026 |
+
vq=True,
|
| 1027 |
+
get_vq=False,
|
| 1028 |
+
eval_vq=True,
|
| 1029 |
+
speaker_embedding=None,
|
| 1030 |
+
n_quantizers=None,
|
| 1031 |
+
quantized=None,
|
| 1032 |
+
):
|
| 1033 |
+
if get_vq:
|
| 1034 |
+
return self.quantizer.get_emb()
|
| 1035 |
+
if vq is True:
|
| 1036 |
+
if eval_vq:
|
| 1037 |
+
self.quantizer.eval()
|
| 1038 |
+
x_timbre = x
|
| 1039 |
+
outs, qs, commit_loss, quantized_buf = self.quantize(x, prosody_feature, n_quantizers=n_quantizers)
|
| 1040 |
+
|
| 1041 |
+
x_timbre = x_timbre.transpose(1, 2)
|
| 1042 |
+
x_timbre = self.timbre_encoder(x_timbre, None, None)
|
| 1043 |
+
x_timbre = x_timbre.transpose(1, 2)
|
| 1044 |
+
spk_embs = torch.mean(x_timbre, dim=2)
|
| 1045 |
+
return outs, qs, commit_loss, quantized_buf, spk_embs
|
| 1046 |
+
|
| 1047 |
+
out = {}
|
| 1048 |
+
|
| 1049 |
+
layer_0 = quantized[0]
|
| 1050 |
+
f0, uv = self.f0_predictor(layer_0)
|
| 1051 |
+
f0 = rearrange(f0, "... 1 -> ...")
|
| 1052 |
+
uv = rearrange(uv, "... 1 -> ...")
|
| 1053 |
+
|
| 1054 |
+
layer_1 = quantized[1]
|
| 1055 |
+
(phone,) = self.phone_predictor(layer_1)
|
| 1056 |
+
|
| 1057 |
+
out = {"f0": f0, "uv": uv, "phone": phone}
|
| 1058 |
+
|
| 1059 |
+
if self.use_gr_prosody_phone:
|
| 1060 |
+
(prosody_phone,) = self.prosody_phone_predictor(layer_0)
|
| 1061 |
+
out["prosody_phone"] = prosody_phone
|
| 1062 |
+
|
| 1063 |
+
if self.use_gr_content_f0:
|
| 1064 |
+
content_f0, content_uv = self.content_f0_predictor(layer_1)
|
| 1065 |
+
content_f0 = rearrange(content_f0, "... 1 -> ...")
|
| 1066 |
+
content_uv = rearrange(content_uv, "... 1 -> ...")
|
| 1067 |
+
out["content_f0"] = content_f0
|
| 1068 |
+
out["content_uv"] = content_uv
|
| 1069 |
+
|
| 1070 |
+
if self.vq_num_q_r > 0:
|
| 1071 |
+
layer_2 = quantized[2]
|
| 1072 |
+
|
| 1073 |
+
if self.use_gr_residual_f0:
|
| 1074 |
+
res_f0, res_uv = self.res_f0_predictor(layer_2)
|
| 1075 |
+
res_f0 = rearrange(res_f0, "... 1 -> ...")
|
| 1076 |
+
res_uv = rearrange(res_uv, "... 1 -> ...")
|
| 1077 |
+
out["res_f0"] = res_f0
|
| 1078 |
+
out["res_uv"] = res_uv
|
| 1079 |
+
|
| 1080 |
+
if self.use_gr_residual_phone:
|
| 1081 |
+
(res_phone,) = self.res_phone_predictor(layer_2)
|
| 1082 |
+
out["res_phone"] = res_phone
|
| 1083 |
+
|
| 1084 |
+
style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
|
| 1085 |
+
gamma, beta = style.chunk(2, 1) # (B, d, 1)
|
| 1086 |
+
if self.vq_num_q_r > 0:
|
| 1087 |
+
if self.use_random_mask_residual:
|
| 1088 |
+
bsz = quantized[2].shape[0]
|
| 1089 |
+
res_mask = np.random.choice(
|
| 1090 |
+
[0, 1],
|
| 1091 |
+
size=bsz,
|
| 1092 |
+
p=[
|
| 1093 |
+
self.prob_random_mask_residual,
|
| 1094 |
+
1 - self.prob_random_mask_residual,
|
| 1095 |
+
],
|
| 1096 |
+
)
|
| 1097 |
+
res_mask = torch.from_numpy(res_mask).unsqueeze(1).unsqueeze(1) # (B, 1, 1)
|
| 1098 |
+
res_mask = res_mask.to(device=quantized[2].device, dtype=quantized[2].dtype)
|
| 1099 |
+
x = quantized[0].detach() + quantized[1].detach() + quantized[2] * res_mask
|
| 1100 |
+
# x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2] * res_mask
|
| 1101 |
+
else:
|
| 1102 |
+
x = quantized[0].detach() + quantized[1].detach() + quantized[2]
|
| 1103 |
+
# x = quantized_perturbe[0].detach() + quantized[1].detach() + quantized[2]
|
| 1104 |
+
else:
|
| 1105 |
+
x = quantized[0].detach() + quantized[1].detach()
|
| 1106 |
+
# x = quantized_perturbe[0].detach() + quantized[1].detach()
|
| 1107 |
+
|
| 1108 |
+
if self.use_gr_x_timbre:
|
| 1109 |
+
(x_timbre,) = self.x_timbre_predictor(x)
|
| 1110 |
+
out["x_timbre"] = x_timbre
|
| 1111 |
+
|
| 1112 |
+
x = x.transpose(1, 2)
|
| 1113 |
+
x = self.timbre_norm(x)
|
| 1114 |
+
x = x.transpose(1, 2)
|
| 1115 |
+
x = x * gamma + beta
|
| 1116 |
+
|
| 1117 |
+
x = self.model(x)
|
| 1118 |
+
out["audio"] = x
|
| 1119 |
+
|
| 1120 |
+
return out
|
| 1121 |
+
|
| 1122 |
+
def vq2emb(self, vq, use_residual=True):
|
| 1123 |
+
# vq: [num_quantizer, B, T]
|
| 1124 |
+
self.quantizer = self.quantizer.eval()
|
| 1125 |
+
out = 0
|
| 1126 |
+
out += self.quantizer[0].vq2emb(vq[0 : self.vq_num_q_p])
|
| 1127 |
+
out += self.quantizer[1].vq2emb(vq[self.vq_num_q_p : self.vq_num_q_p + self.vq_num_q_c])
|
| 1128 |
+
if self.vq_num_q_r > 0 and use_residual:
|
| 1129 |
+
out += self.quantizer[2].vq2emb(vq[self.vq_num_q_p + self.vq_num_q_c :])
|
| 1130 |
+
return out
|
| 1131 |
+
|
| 1132 |
+
def inference(self, x, speaker_embedding):
|
| 1133 |
+
style = self.timbre_linear(speaker_embedding).unsqueeze(2) # (B, 2d, 1)
|
| 1134 |
+
gamma, beta = style.chunk(2, 1) # (B, d, 1)
|
| 1135 |
+
x = x.transpose(1, 2)
|
| 1136 |
+
x = self.timbre_norm(x)
|
| 1137 |
+
x = x.transpose(1, 2)
|
| 1138 |
+
x = x * gamma + beta
|
| 1139 |
+
x = self.model(x)
|
| 1140 |
+
return x
|
| 1141 |
+
|
| 1142 |
+
def remove_weight_norm(self):
|
| 1143 |
+
"""Remove weight normalization module from all of the layers."""
|
| 1144 |
+
|
| 1145 |
+
def _remove_weight_norm(m):
|
| 1146 |
+
try:
|
| 1147 |
+
torch.nn.utils.remove_weight_norm(m)
|
| 1148 |
+
except ValueError: # this module didn't have weight norm
|
| 1149 |
+
return
|
| 1150 |
+
|
| 1151 |
+
self.apply(_remove_weight_norm)
|
| 1152 |
+
|
| 1153 |
+
def apply_weight_norm(self):
|
| 1154 |
+
"""Apply weight normalization module from all of the layers."""
|
| 1155 |
+
|
| 1156 |
+
def _apply_weight_norm(m):
|
| 1157 |
+
if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
|
| 1158 |
+
torch.nn.utils.weight_norm(m)
|
| 1159 |
+
|
| 1160 |
+
self.apply(_apply_weight_norm)
|
| 1161 |
+
|
| 1162 |
+
def reset_parameters(self):
|
| 1163 |
+
self.apply(init_weights)
|
Amphion/models/codec/ns3_codec/gradient_reversal.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from torch.autograd import Function
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class GradientReversal(Function):
|
| 12 |
+
@staticmethod
|
| 13 |
+
def forward(ctx, x, alpha):
|
| 14 |
+
ctx.save_for_backward(x, alpha)
|
| 15 |
+
return x
|
| 16 |
+
|
| 17 |
+
@staticmethod
|
| 18 |
+
def backward(ctx, grad_output):
|
| 19 |
+
grad_input = None
|
| 20 |
+
_, alpha = ctx.saved_tensors
|
| 21 |
+
if ctx.needs_input_grad[0]:
|
| 22 |
+
grad_input = -alpha * grad_output
|
| 23 |
+
return grad_input, None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
revgrad = GradientReversal.apply
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class GradientReversal(nn.Module):
|
| 30 |
+
def __init__(self, alpha):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.alpha = torch.tensor(alpha, requires_grad=False)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
return revgrad(x, self.alpha)
|
Amphion/models/codec/ns3_codec/melspec.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pyworld as pw
|
| 3 |
+
import numpy as np
|
| 4 |
+
import soundfile as sf
|
| 5 |
+
import os
|
| 6 |
+
from torchaudio.functional import pitch_shift
|
| 7 |
+
import librosa
|
| 8 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 14 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def dynamic_range_decompression(x, C=1):
|
| 18 |
+
return np.exp(x) / C
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 22 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def dynamic_range_decompression_torch(x, C=1):
|
| 26 |
+
return torch.exp(x) / C
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def spectral_normalize_torch(magnitudes):
|
| 30 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 31 |
+
return output
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def spectral_de_normalize_torch(magnitudes):
|
| 35 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
| 36 |
+
return output
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MelSpectrogram(nn.Module):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
n_fft,
|
| 43 |
+
num_mels,
|
| 44 |
+
sampling_rate,
|
| 45 |
+
hop_size,
|
| 46 |
+
win_size,
|
| 47 |
+
fmin,
|
| 48 |
+
fmax,
|
| 49 |
+
center=False,
|
| 50 |
+
):
|
| 51 |
+
super(MelSpectrogram, self).__init__()
|
| 52 |
+
self.n_fft = n_fft
|
| 53 |
+
self.hop_size = hop_size
|
| 54 |
+
self.win_size = win_size
|
| 55 |
+
self.sampling_rate = sampling_rate
|
| 56 |
+
self.num_mels = num_mels
|
| 57 |
+
self.fmin = fmin
|
| 58 |
+
self.fmax = fmax
|
| 59 |
+
self.center = center
|
| 60 |
+
|
| 61 |
+
mel_basis = {}
|
| 62 |
+
hann_window = {}
|
| 63 |
+
|
| 64 |
+
mel = librosa_mel_fn(
|
| 65 |
+
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
| 66 |
+
)
|
| 67 |
+
mel_basis = torch.from_numpy(mel).float()
|
| 68 |
+
hann_window = torch.hann_window(win_size)
|
| 69 |
+
|
| 70 |
+
self.register_buffer("mel_basis", mel_basis)
|
| 71 |
+
self.register_buffer("hann_window", hann_window)
|
| 72 |
+
|
| 73 |
+
def forward(self, y):
|
| 74 |
+
y = torch.nn.functional.pad(
|
| 75 |
+
y.unsqueeze(1),
|
| 76 |
+
(
|
| 77 |
+
int((self.n_fft - self.hop_size) / 2),
|
| 78 |
+
int((self.n_fft - self.hop_size) / 2),
|
| 79 |
+
),
|
| 80 |
+
mode="reflect",
|
| 81 |
+
)
|
| 82 |
+
y = y.squeeze(1)
|
| 83 |
+
spec = torch.stft(
|
| 84 |
+
y,
|
| 85 |
+
self.n_fft,
|
| 86 |
+
hop_length=self.hop_size,
|
| 87 |
+
win_length=self.win_size,
|
| 88 |
+
window=self.hann_window,
|
| 89 |
+
center=self.center,
|
| 90 |
+
pad_mode="reflect",
|
| 91 |
+
normalized=False,
|
| 92 |
+
onesided=True,
|
| 93 |
+
return_complex=True,
|
| 94 |
+
)
|
| 95 |
+
spec = torch.view_as_real(spec)
|
| 96 |
+
|
| 97 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 98 |
+
|
| 99 |
+
spec = torch.matmul(self.mel_basis, spec)
|
| 100 |
+
spec = spectral_normalize_torch(spec)
|
| 101 |
+
|
| 102 |
+
return spec
|
Amphion/models/codec/ns3_codec/quantize/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (246 Bytes). View file
|
|
|
Amphion/models/codec/ns3_codec/quantize/fvq.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from torch.nn.utils import weight_norm
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class FactorizedVectorQuantize(nn.Module):
|
| 17 |
+
def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.codebook_size = codebook_size
|
| 20 |
+
self.codebook_dim = codebook_dim
|
| 21 |
+
self.commitment = commitment
|
| 22 |
+
|
| 23 |
+
if dim != self.codebook_dim:
|
| 24 |
+
self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim))
|
| 25 |
+
self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim))
|
| 26 |
+
else:
|
| 27 |
+
self.in_proj = nn.Identity()
|
| 28 |
+
self.out_proj = nn.Identity()
|
| 29 |
+
self._codebook = nn.Embedding(codebook_size, self.codebook_dim)
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def codebook(self):
|
| 33 |
+
return self._codebook
|
| 34 |
+
|
| 35 |
+
def forward(self, z):
|
| 36 |
+
"""Quantized the input tensor using a fixed codebook and returns
|
| 37 |
+
the corresponding codebook vectors
|
| 38 |
+
|
| 39 |
+
Parameters
|
| 40 |
+
----------
|
| 41 |
+
z : Tensor[B x D x T]
|
| 42 |
+
|
| 43 |
+
Returns
|
| 44 |
+
-------
|
| 45 |
+
Tensor[B x D x T]
|
| 46 |
+
Quantized continuous representation of input
|
| 47 |
+
Tensor[1]
|
| 48 |
+
Commitment loss to train encoder to predict vectors closer to codebook
|
| 49 |
+
entries
|
| 50 |
+
Tensor[1]
|
| 51 |
+
Codebook loss to update the codebook
|
| 52 |
+
Tensor[B x T]
|
| 53 |
+
Codebook indices (quantized discrete representation of input)
|
| 54 |
+
Tensor[B x D x T]
|
| 55 |
+
Projected latents (continuous representation of input before quantization)
|
| 56 |
+
"""
|
| 57 |
+
# transpose since we use linear
|
| 58 |
+
|
| 59 |
+
z = rearrange(z, "b d t -> b t d")
|
| 60 |
+
|
| 61 |
+
# Factorized codes project input into low-dimensional space
|
| 62 |
+
z_e = self.in_proj(z) # z_e : (B x T x D)
|
| 63 |
+
z_e = rearrange(z_e, "b t d -> b d t")
|
| 64 |
+
z_q, indices = self.decode_latents(z_e)
|
| 65 |
+
|
| 66 |
+
if self.training:
|
| 67 |
+
commitment_loss = (
|
| 68 |
+
F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
|
| 69 |
+
* self.commitment
|
| 70 |
+
)
|
| 71 |
+
codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
|
| 72 |
+
commit_loss = commitment_loss + codebook_loss
|
| 73 |
+
else:
|
| 74 |
+
commit_loss = torch.zeros(z.shape[0], device=z.device)
|
| 75 |
+
|
| 76 |
+
z_q = (
|
| 77 |
+
z_e + (z_q - z_e).detach()
|
| 78 |
+
) # noop in forward pass, straight-through gradient estimator in backward pass
|
| 79 |
+
|
| 80 |
+
z_q = rearrange(z_q, "b d t -> b t d")
|
| 81 |
+
z_q = self.out_proj(z_q)
|
| 82 |
+
z_q = rearrange(z_q, "b t d -> b d t")
|
| 83 |
+
|
| 84 |
+
return z_q, indices, commit_loss
|
| 85 |
+
|
| 86 |
+
def vq2emb(self, vq, proj=True):
|
| 87 |
+
emb = self.embed_code(vq)
|
| 88 |
+
if proj:
|
| 89 |
+
emb = self.out_proj(emb)
|
| 90 |
+
return emb.transpose(1, 2)
|
| 91 |
+
|
| 92 |
+
def get_emb(self):
|
| 93 |
+
return self.codebook.weight
|
| 94 |
+
|
| 95 |
+
def embed_code(self, embed_id):
|
| 96 |
+
return F.embedding(embed_id, self.codebook.weight)
|
| 97 |
+
|
| 98 |
+
def decode_code(self, embed_id):
|
| 99 |
+
return self.embed_code(embed_id).transpose(1, 2)
|
| 100 |
+
|
| 101 |
+
def decode_latents(self, latents):
|
| 102 |
+
encodings = rearrange(latents, "b d t -> (b t) d")
|
| 103 |
+
codebook = self.codebook.weight # codebook: (N x D)
|
| 104 |
+
# L2 normalize encodings and codebook
|
| 105 |
+
encodings = F.normalize(encodings)
|
| 106 |
+
codebook = F.normalize(codebook)
|
| 107 |
+
|
| 108 |
+
# Compute euclidean distance with codebook
|
| 109 |
+
dist = (
|
| 110 |
+
encodings.pow(2).sum(1, keepdim=True)
|
| 111 |
+
- 2 * encodings @ codebook.t()
|
| 112 |
+
+ codebook.pow(2).sum(1, keepdim=True).t()
|
| 113 |
+
)
|
| 114 |
+
indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
|
| 115 |
+
z_q = self.decode_code(indices)
|
| 116 |
+
return z_q, indices
|
Amphion/models/svc/transformer/transformer_inference.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from collections import OrderedDict
|
| 13 |
+
|
| 14 |
+
from models.svc.base import SVCInference
|
| 15 |
+
from modules.encoder.condition_encoder import ConditionEncoder
|
| 16 |
+
from models.svc.transformer.transformer import Transformer
|
| 17 |
+
from models.svc.transformer.conformer import Conformer
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TransformerInference(SVCInference):
|
| 21 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
| 22 |
+
SVCInference.__init__(self, args, cfg, infer_type)
|
| 23 |
+
|
| 24 |
+
def _build_model(self):
|
| 25 |
+
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
|
| 26 |
+
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
|
| 27 |
+
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
|
| 28 |
+
if self.cfg.model.transformer.type == "transformer":
|
| 29 |
+
self.acoustic_mapper = Transformer(self.cfg.model.transformer)
|
| 30 |
+
elif self.cfg.model.transformer.type == "conformer":
|
| 31 |
+
self.acoustic_mapper = Conformer(self.cfg.model.transformer)
|
| 32 |
+
else:
|
| 33 |
+
raise NotImplementedError
|
| 34 |
+
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
|
| 35 |
+
return model
|
| 36 |
+
|
| 37 |
+
def _inference_each_batch(self, batch_data):
|
| 38 |
+
device = self.accelerator.device
|
| 39 |
+
for k, v in batch_data.items():
|
| 40 |
+
batch_data[k] = v.to(device)
|
| 41 |
+
|
| 42 |
+
condition = self.condition_encoder(batch_data)
|
| 43 |
+
y_pred = self.acoustic_mapper(condition, batch_data["mask"])
|
| 44 |
+
|
| 45 |
+
return y_pred
|
Amphion/models/svc/vits/vits_trainer.py
ADDED
|
@@ -0,0 +1,704 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.optim.lr_scheduler import ExponentialLR
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import shutil
|
| 11 |
+
import accelerate
|
| 12 |
+
|
| 13 |
+
# from models.svc.base import SVCTrainer
|
| 14 |
+
from models.svc.base.svc_dataset import SVCOfflineCollator, SVCOfflineDataset
|
| 15 |
+
from models.svc.vits.vits import *
|
| 16 |
+
from models.svc.base import SVCTrainer
|
| 17 |
+
|
| 18 |
+
from utils.mel import mel_spectrogram_torch
|
| 19 |
+
import json
|
| 20 |
+
|
| 21 |
+
from models.vocoders.gan.discriminator.mpd import (
|
| 22 |
+
MultiPeriodDiscriminator_vits as MultiPeriodDiscriminator,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class VitsSVCTrainer(SVCTrainer):
|
| 27 |
+
def __init__(self, args, cfg):
|
| 28 |
+
self.args = args
|
| 29 |
+
self.cfg = cfg
|
| 30 |
+
SVCTrainer.__init__(self, args, cfg)
|
| 31 |
+
|
| 32 |
+
def _accelerator_prepare(self):
|
| 33 |
+
(
|
| 34 |
+
self.train_dataloader,
|
| 35 |
+
self.valid_dataloader,
|
| 36 |
+
) = self.accelerator.prepare(
|
| 37 |
+
self.train_dataloader,
|
| 38 |
+
self.valid_dataloader,
|
| 39 |
+
)
|
| 40 |
+
if isinstance(self.model, dict):
|
| 41 |
+
for key in self.model.keys():
|
| 42 |
+
self.model[key] = self.accelerator.prepare(self.model[key])
|
| 43 |
+
else:
|
| 44 |
+
self.model = self.accelerator.prepare(self.model)
|
| 45 |
+
|
| 46 |
+
if isinstance(self.optimizer, dict):
|
| 47 |
+
for key in self.optimizer.keys():
|
| 48 |
+
self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
|
| 49 |
+
else:
|
| 50 |
+
self.optimizer = self.accelerator.prepare(self.optimizer)
|
| 51 |
+
|
| 52 |
+
if isinstance(self.scheduler, dict):
|
| 53 |
+
for key in self.scheduler.keys():
|
| 54 |
+
self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
|
| 55 |
+
else:
|
| 56 |
+
self.scheduler = self.accelerator.prepare(self.scheduler)
|
| 57 |
+
|
| 58 |
+
def _load_model(
|
| 59 |
+
self,
|
| 60 |
+
checkpoint_dir: str = None,
|
| 61 |
+
checkpoint_path: str = None,
|
| 62 |
+
resume_type: str = "",
|
| 63 |
+
):
|
| 64 |
+
r"""Load model from checkpoint. If checkpoint_path is None, it will
|
| 65 |
+
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
|
| 66 |
+
None, it will load the checkpoint specified by checkpoint_path. **Only use this
|
| 67 |
+
method after** ``accelerator.prepare()``.
|
| 68 |
+
"""
|
| 69 |
+
if checkpoint_path is None:
|
| 70 |
+
ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
|
| 71 |
+
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
|
| 72 |
+
checkpoint_path = ls[0]
|
| 73 |
+
self.logger.info("Resume from {}...".format(checkpoint_path))
|
| 74 |
+
|
| 75 |
+
if resume_type in ["resume", ""]:
|
| 76 |
+
# Load all the things, including model weights, optimizer, scheduler, and random states.
|
| 77 |
+
self.accelerator.load_state(input_dir=checkpoint_path)
|
| 78 |
+
|
| 79 |
+
# set epoch and step
|
| 80 |
+
self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
|
| 81 |
+
self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
|
| 82 |
+
|
| 83 |
+
elif resume_type == "finetune":
|
| 84 |
+
# Load only the model weights
|
| 85 |
+
accelerate.load_checkpoint_and_dispatch(
|
| 86 |
+
self.accelerator.unwrap_model(self.model["generator"]),
|
| 87 |
+
os.path.join(checkpoint_path, "pytorch_model.bin"),
|
| 88 |
+
)
|
| 89 |
+
accelerate.load_checkpoint_and_dispatch(
|
| 90 |
+
self.accelerator.unwrap_model(self.model["discriminator"]),
|
| 91 |
+
os.path.join(checkpoint_path, "pytorch_model.bin"),
|
| 92 |
+
)
|
| 93 |
+
self.logger.info("Load model weights for finetune...")
|
| 94 |
+
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError("Resume_type must be `resume` or `finetune`.")
|
| 97 |
+
|
| 98 |
+
return checkpoint_path
|
| 99 |
+
|
| 100 |
+
def _build_model(self):
|
| 101 |
+
net_g = SynthesizerTrn(
|
| 102 |
+
self.cfg.preprocess.n_fft // 2 + 1,
|
| 103 |
+
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
|
| 104 |
+
# directly use cfg
|
| 105 |
+
self.cfg,
|
| 106 |
+
)
|
| 107 |
+
net_d = MultiPeriodDiscriminator(self.cfg.model.vits.use_spectral_norm)
|
| 108 |
+
model = {"generator": net_g, "discriminator": net_d}
|
| 109 |
+
|
| 110 |
+
return model
|
| 111 |
+
|
| 112 |
+
def _build_dataset(self):
|
| 113 |
+
return SVCOfflineDataset, SVCOfflineCollator
|
| 114 |
+
|
| 115 |
+
def _build_optimizer(self):
|
| 116 |
+
optimizer_g = torch.optim.AdamW(
|
| 117 |
+
self.model["generator"].parameters(),
|
| 118 |
+
self.cfg.train.learning_rate,
|
| 119 |
+
betas=self.cfg.train.AdamW.betas,
|
| 120 |
+
eps=self.cfg.train.AdamW.eps,
|
| 121 |
+
)
|
| 122 |
+
optimizer_d = torch.optim.AdamW(
|
| 123 |
+
self.model["discriminator"].parameters(),
|
| 124 |
+
self.cfg.train.learning_rate,
|
| 125 |
+
betas=self.cfg.train.AdamW.betas,
|
| 126 |
+
eps=self.cfg.train.AdamW.eps,
|
| 127 |
+
)
|
| 128 |
+
optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d}
|
| 129 |
+
|
| 130 |
+
return optimizer
|
| 131 |
+
|
| 132 |
+
def _build_scheduler(self):
|
| 133 |
+
scheduler_g = ExponentialLR(
|
| 134 |
+
self.optimizer["optimizer_g"],
|
| 135 |
+
gamma=self.cfg.train.lr_decay,
|
| 136 |
+
last_epoch=self.epoch - 1,
|
| 137 |
+
)
|
| 138 |
+
scheduler_d = ExponentialLR(
|
| 139 |
+
self.optimizer["optimizer_d"],
|
| 140 |
+
gamma=self.cfg.train.lr_decay,
|
| 141 |
+
last_epoch=self.epoch - 1,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d}
|
| 145 |
+
return scheduler
|
| 146 |
+
|
| 147 |
+
def _build_criterion(self):
|
| 148 |
+
class GeneratorLoss(nn.Module):
|
| 149 |
+
def __init__(self, cfg):
|
| 150 |
+
super(GeneratorLoss, self).__init__()
|
| 151 |
+
self.cfg = cfg
|
| 152 |
+
self.l1_loss = nn.L1Loss()
|
| 153 |
+
|
| 154 |
+
def generator_loss(self, disc_outputs):
|
| 155 |
+
loss = 0
|
| 156 |
+
gen_losses = []
|
| 157 |
+
for dg in disc_outputs:
|
| 158 |
+
dg = dg.float()
|
| 159 |
+
l = torch.mean((1 - dg) ** 2)
|
| 160 |
+
gen_losses.append(l)
|
| 161 |
+
loss += l
|
| 162 |
+
|
| 163 |
+
return loss, gen_losses
|
| 164 |
+
|
| 165 |
+
def feature_loss(self, fmap_r, fmap_g):
|
| 166 |
+
loss = 0
|
| 167 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
| 168 |
+
for rl, gl in zip(dr, dg):
|
| 169 |
+
rl = rl.float().detach()
|
| 170 |
+
gl = gl.float()
|
| 171 |
+
loss += torch.mean(torch.abs(rl - gl))
|
| 172 |
+
|
| 173 |
+
return loss * 2
|
| 174 |
+
|
| 175 |
+
def kl_loss(self, z_p, logs_q, m_p, logs_p, z_mask):
|
| 176 |
+
"""
|
| 177 |
+
z_p, logs_q: [b, h, t_t]
|
| 178 |
+
m_p, logs_p: [b, h, t_t]
|
| 179 |
+
"""
|
| 180 |
+
z_p = z_p.float()
|
| 181 |
+
logs_q = logs_q.float()
|
| 182 |
+
m_p = m_p.float()
|
| 183 |
+
logs_p = logs_p.float()
|
| 184 |
+
z_mask = z_mask.float()
|
| 185 |
+
|
| 186 |
+
kl = logs_p - logs_q - 0.5
|
| 187 |
+
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
| 188 |
+
kl = torch.sum(kl * z_mask)
|
| 189 |
+
l = kl / torch.sum(z_mask)
|
| 190 |
+
return l
|
| 191 |
+
|
| 192 |
+
def forward(
|
| 193 |
+
self,
|
| 194 |
+
outputs_g,
|
| 195 |
+
outputs_d,
|
| 196 |
+
y_mel,
|
| 197 |
+
y_hat_mel,
|
| 198 |
+
):
|
| 199 |
+
loss_g = {}
|
| 200 |
+
|
| 201 |
+
# mel loss
|
| 202 |
+
loss_mel = self.l1_loss(y_mel, y_hat_mel) * self.cfg.train.c_mel
|
| 203 |
+
loss_g["loss_mel"] = loss_mel
|
| 204 |
+
|
| 205 |
+
# kl loss
|
| 206 |
+
loss_kl = (
|
| 207 |
+
self.kl_loss(
|
| 208 |
+
outputs_g["z_p"],
|
| 209 |
+
outputs_g["logs_q"],
|
| 210 |
+
outputs_g["m_p"],
|
| 211 |
+
outputs_g["logs_p"],
|
| 212 |
+
outputs_g["z_mask"],
|
| 213 |
+
)
|
| 214 |
+
* self.cfg.train.c_kl
|
| 215 |
+
)
|
| 216 |
+
loss_g["loss_kl"] = loss_kl
|
| 217 |
+
|
| 218 |
+
# feature loss
|
| 219 |
+
loss_fm = self.feature_loss(outputs_d["fmap_rs"], outputs_d["fmap_gs"])
|
| 220 |
+
loss_g["loss_fm"] = loss_fm
|
| 221 |
+
|
| 222 |
+
# gan loss
|
| 223 |
+
loss_gen, losses_gen = self.generator_loss(outputs_d["y_d_hat_g"])
|
| 224 |
+
loss_g["loss_gen"] = loss_gen
|
| 225 |
+
loss_g["loss_gen_all"] = loss_mel + loss_kl + loss_fm + loss_gen
|
| 226 |
+
|
| 227 |
+
return loss_g
|
| 228 |
+
|
| 229 |
+
class DiscriminatorLoss(nn.Module):
|
| 230 |
+
def __init__(self, cfg):
|
| 231 |
+
super(DiscriminatorLoss, self).__init__()
|
| 232 |
+
self.cfg = cfg
|
| 233 |
+
self.l1Loss = torch.nn.L1Loss(reduction="mean")
|
| 234 |
+
|
| 235 |
+
def __call__(self, disc_real_outputs, disc_generated_outputs):
|
| 236 |
+
loss_d = {}
|
| 237 |
+
|
| 238 |
+
loss = 0
|
| 239 |
+
r_losses = []
|
| 240 |
+
g_losses = []
|
| 241 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 242 |
+
dr = dr.float()
|
| 243 |
+
dg = dg.float()
|
| 244 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
| 245 |
+
g_loss = torch.mean(dg**2)
|
| 246 |
+
loss += r_loss + g_loss
|
| 247 |
+
r_losses.append(r_loss.item())
|
| 248 |
+
g_losses.append(g_loss.item())
|
| 249 |
+
|
| 250 |
+
loss_d["loss_disc_all"] = loss
|
| 251 |
+
|
| 252 |
+
return loss_d
|
| 253 |
+
|
| 254 |
+
criterion = {
|
| 255 |
+
"generator": GeneratorLoss(self.cfg),
|
| 256 |
+
"discriminator": DiscriminatorLoss(self.cfg),
|
| 257 |
+
}
|
| 258 |
+
return criterion
|
| 259 |
+
|
| 260 |
+
# Keep legacy unchanged
|
| 261 |
+
def write_summary(
|
| 262 |
+
self,
|
| 263 |
+
losses,
|
| 264 |
+
stats,
|
| 265 |
+
images={},
|
| 266 |
+
audios={},
|
| 267 |
+
audio_sampling_rate=24000,
|
| 268 |
+
tag="train",
|
| 269 |
+
):
|
| 270 |
+
for key, value in losses.items():
|
| 271 |
+
self.sw.add_scalar(tag + "/" + key, value, self.step)
|
| 272 |
+
self.sw.add_scalar(
|
| 273 |
+
"learning_rate",
|
| 274 |
+
self.optimizer["optimizer_g"].param_groups[0]["lr"],
|
| 275 |
+
self.step,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
if len(images) != 0:
|
| 279 |
+
for key, value in images.items():
|
| 280 |
+
self.sw.add_image(key, value, self.global_step, batchformats="HWC")
|
| 281 |
+
if len(audios) != 0:
|
| 282 |
+
for key, value in audios.items():
|
| 283 |
+
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
|
| 284 |
+
|
| 285 |
+
def write_valid_summary(
|
| 286 |
+
self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val"
|
| 287 |
+
):
|
| 288 |
+
for key, value in losses.items():
|
| 289 |
+
self.sw.add_scalar(tag + "/" + key, value, self.step)
|
| 290 |
+
|
| 291 |
+
if len(images) != 0:
|
| 292 |
+
for key, value in images.items():
|
| 293 |
+
self.sw.add_image(key, value, self.global_step, batchformats="HWC")
|
| 294 |
+
if len(audios) != 0:
|
| 295 |
+
for key, value in audios.items():
|
| 296 |
+
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
|
| 297 |
+
|
| 298 |
+
def _get_state_dict(self):
|
| 299 |
+
state_dict = {
|
| 300 |
+
"generator": self.model["generator"].state_dict(),
|
| 301 |
+
"discriminator": self.model["discriminator"].state_dict(),
|
| 302 |
+
"optimizer_g": self.optimizer["optimizer_g"].state_dict(),
|
| 303 |
+
"optimizer_d": self.optimizer["optimizer_d"].state_dict(),
|
| 304 |
+
"scheduler_g": self.scheduler["scheduler_g"].state_dict(),
|
| 305 |
+
"scheduler_d": self.scheduler["scheduler_d"].state_dict(),
|
| 306 |
+
"step": self.step,
|
| 307 |
+
"epoch": self.epoch,
|
| 308 |
+
"batch_size": self.cfg.train.batch_size,
|
| 309 |
+
}
|
| 310 |
+
return state_dict
|
| 311 |
+
|
| 312 |
+
def get_state_dict(self):
|
| 313 |
+
state_dict = {
|
| 314 |
+
"generator": self.model["generator"].state_dict(),
|
| 315 |
+
"discriminator": self.model["discriminator"].state_dict(),
|
| 316 |
+
"optimizer_g": self.optimizer["optimizer_g"].state_dict(),
|
| 317 |
+
"optimizer_d": self.optimizer["optimizer_d"].state_dict(),
|
| 318 |
+
"scheduler_g": self.scheduler["scheduler_g"].state_dict(),
|
| 319 |
+
"scheduler_d": self.scheduler["scheduler_d"].state_dict(),
|
| 320 |
+
"step": self.step,
|
| 321 |
+
"epoch": self.epoch,
|
| 322 |
+
"batch_size": self.cfg.train.batch_size,
|
| 323 |
+
}
|
| 324 |
+
return state_dict
|
| 325 |
+
|
| 326 |
+
def load_model(self, checkpoint):
|
| 327 |
+
self.step = checkpoint["step"]
|
| 328 |
+
self.epoch = checkpoint["epoch"]
|
| 329 |
+
self.model["generator"].load_state_dict(checkpoint["generator"])
|
| 330 |
+
self.model["discriminator"].load_state_dict(checkpoint["discriminator"])
|
| 331 |
+
self.optimizer["optimizer_g"].load_state_dict(checkpoint["optimizer_g"])
|
| 332 |
+
self.optimizer["optimizer_d"].load_state_dict(checkpoint["optimizer_d"])
|
| 333 |
+
self.scheduler["scheduler_g"].load_state_dict(checkpoint["scheduler_g"])
|
| 334 |
+
self.scheduler["scheduler_d"].load_state_dict(checkpoint["scheduler_d"])
|
| 335 |
+
|
| 336 |
+
@torch.inference_mode()
|
| 337 |
+
def _valid_step(self, batch):
|
| 338 |
+
r"""Testing forward step. Should return average loss of a sample over
|
| 339 |
+
one batch. Provoke ``_forward_step`` is recommended except for special case.
|
| 340 |
+
See ``_test_epoch`` for usage.
|
| 341 |
+
"""
|
| 342 |
+
|
| 343 |
+
valid_losses = {}
|
| 344 |
+
total_loss = 0
|
| 345 |
+
valid_stats = {}
|
| 346 |
+
|
| 347 |
+
# Discriminator
|
| 348 |
+
# Generator output
|
| 349 |
+
outputs_g = self.model["generator"](batch)
|
| 350 |
+
|
| 351 |
+
y_mel = slice_segments(
|
| 352 |
+
batch["mel"].transpose(1, 2),
|
| 353 |
+
outputs_g["ids_slice"],
|
| 354 |
+
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
|
| 355 |
+
)
|
| 356 |
+
y_hat_mel = mel_spectrogram_torch(
|
| 357 |
+
outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
|
| 358 |
+
)
|
| 359 |
+
y = slice_segments(
|
| 360 |
+
batch["audio"].unsqueeze(1),
|
| 361 |
+
outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
|
| 362 |
+
self.cfg.preprocess.segment_size,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Discriminator output
|
| 366 |
+
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
|
| 367 |
+
## Discriminator loss
|
| 368 |
+
loss_d = self.criterion["discriminator"](
|
| 369 |
+
outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
|
| 370 |
+
)
|
| 371 |
+
valid_losses.update(loss_d)
|
| 372 |
+
|
| 373 |
+
## Generator
|
| 374 |
+
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
|
| 375 |
+
loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
|
| 376 |
+
valid_losses.update(loss_g)
|
| 377 |
+
|
| 378 |
+
for item in valid_losses:
|
| 379 |
+
valid_losses[item] = valid_losses[item].item()
|
| 380 |
+
|
| 381 |
+
total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
|
| 382 |
+
|
| 383 |
+
return (
|
| 384 |
+
total_loss.item(),
|
| 385 |
+
valid_losses,
|
| 386 |
+
valid_stats,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
@torch.inference_mode()
|
| 390 |
+
def _valid_epoch(self):
|
| 391 |
+
r"""Testing epoch. Should return average loss of a batch (sample) over
|
| 392 |
+
one epoch. See ``train_loop`` for usage.
|
| 393 |
+
"""
|
| 394 |
+
if isinstance(self.model, dict):
|
| 395 |
+
for key in self.model.keys():
|
| 396 |
+
self.model[key].eval()
|
| 397 |
+
else:
|
| 398 |
+
self.model.eval()
|
| 399 |
+
|
| 400 |
+
epoch_sum_loss = 0.0
|
| 401 |
+
epoch_losses = dict()
|
| 402 |
+
for batch in tqdm(
|
| 403 |
+
self.valid_dataloader,
|
| 404 |
+
desc=f"Validating Epoch {self.epoch}",
|
| 405 |
+
unit="batch",
|
| 406 |
+
colour="GREEN",
|
| 407 |
+
leave=False,
|
| 408 |
+
dynamic_ncols=True,
|
| 409 |
+
smoothing=0.04,
|
| 410 |
+
disable=not self.accelerator.is_main_process,
|
| 411 |
+
):
|
| 412 |
+
total_loss, valid_losses, valid_stats = self._valid_step(batch)
|
| 413 |
+
epoch_sum_loss += total_loss
|
| 414 |
+
if isinstance(valid_losses, dict):
|
| 415 |
+
for key, value in valid_losses.items():
|
| 416 |
+
if key not in epoch_losses.keys():
|
| 417 |
+
epoch_losses[key] = value
|
| 418 |
+
else:
|
| 419 |
+
epoch_losses[key] += value
|
| 420 |
+
|
| 421 |
+
epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader)
|
| 422 |
+
for key in epoch_losses.keys():
|
| 423 |
+
epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader)
|
| 424 |
+
|
| 425 |
+
self.accelerator.wait_for_everyone()
|
| 426 |
+
|
| 427 |
+
return epoch_sum_loss, epoch_losses
|
| 428 |
+
|
| 429 |
+
### THIS IS MAIN ENTRY ###
|
| 430 |
+
def train_loop(self):
|
| 431 |
+
r"""Training loop. The public entry of training process."""
|
| 432 |
+
# Wait everyone to prepare before we move on
|
| 433 |
+
self.accelerator.wait_for_everyone()
|
| 434 |
+
# dump config file
|
| 435 |
+
if self.accelerator.is_main_process:
|
| 436 |
+
self.__dump_cfg(self.config_save_path)
|
| 437 |
+
|
| 438 |
+
# self.optimizer.zero_grad()
|
| 439 |
+
# Wait to ensure good to go
|
| 440 |
+
|
| 441 |
+
self.accelerator.wait_for_everyone()
|
| 442 |
+
while self.epoch < self.max_epoch:
|
| 443 |
+
self.logger.info("\n")
|
| 444 |
+
self.logger.info("-" * 32)
|
| 445 |
+
self.logger.info("Epoch {}: ".format(self.epoch))
|
| 446 |
+
|
| 447 |
+
# Do training & validating epoch
|
| 448 |
+
train_total_loss, train_losses = self._train_epoch()
|
| 449 |
+
if isinstance(train_losses, dict):
|
| 450 |
+
for key, loss in train_losses.items():
|
| 451 |
+
self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
|
| 452 |
+
self.accelerator.log(
|
| 453 |
+
{"Epoch/Train {} Loss".format(key): loss},
|
| 454 |
+
step=self.epoch,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
valid_total_loss, valid_losses = self._valid_epoch()
|
| 458 |
+
if isinstance(valid_losses, dict):
|
| 459 |
+
for key, loss in valid_losses.items():
|
| 460 |
+
self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss))
|
| 461 |
+
self.accelerator.log(
|
| 462 |
+
{"Epoch/Train {} Loss".format(key): loss},
|
| 463 |
+
step=self.epoch,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
self.logger.info(" |- Train/Loss: {:.6f}".format(train_total_loss))
|
| 467 |
+
self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_total_loss))
|
| 468 |
+
self.accelerator.log(
|
| 469 |
+
{
|
| 470 |
+
"Epoch/Train Loss": train_total_loss,
|
| 471 |
+
"Epoch/Valid Loss": valid_total_loss,
|
| 472 |
+
},
|
| 473 |
+
step=self.epoch,
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
self.accelerator.wait_for_everyone()
|
| 477 |
+
|
| 478 |
+
# Check if hit save_checkpoint_stride and run_eval
|
| 479 |
+
run_eval = False
|
| 480 |
+
if self.accelerator.is_main_process:
|
| 481 |
+
save_checkpoint = False
|
| 482 |
+
hit_dix = []
|
| 483 |
+
for i, num in enumerate(self.save_checkpoint_stride):
|
| 484 |
+
if self.epoch % num == 0:
|
| 485 |
+
save_checkpoint = True
|
| 486 |
+
hit_dix.append(i)
|
| 487 |
+
run_eval |= self.run_eval[i]
|
| 488 |
+
|
| 489 |
+
self.accelerator.wait_for_everyone()
|
| 490 |
+
if self.accelerator.is_main_process and save_checkpoint:
|
| 491 |
+
path = os.path.join(
|
| 492 |
+
self.checkpoint_dir,
|
| 493 |
+
"epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
| 494 |
+
self.epoch, self.step, train_total_loss
|
| 495 |
+
),
|
| 496 |
+
)
|
| 497 |
+
self.tmp_checkpoint_save_path = path
|
| 498 |
+
self.accelerator.save_state(path)
|
| 499 |
+
|
| 500 |
+
json.dump(
|
| 501 |
+
self.checkpoints_path,
|
| 502 |
+
open(os.path.join(path, "ckpts.json"), "w"),
|
| 503 |
+
ensure_ascii=False,
|
| 504 |
+
indent=4,
|
| 505 |
+
)
|
| 506 |
+
self._save_auxiliary_states()
|
| 507 |
+
|
| 508 |
+
# Remove old checkpoints
|
| 509 |
+
to_remove = []
|
| 510 |
+
for idx in hit_dix:
|
| 511 |
+
self.checkpoints_path[idx].append(path)
|
| 512 |
+
while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
|
| 513 |
+
to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
|
| 514 |
+
|
| 515 |
+
# Search conflicts
|
| 516 |
+
total = set()
|
| 517 |
+
for i in self.checkpoints_path:
|
| 518 |
+
total |= set(i)
|
| 519 |
+
do_remove = set()
|
| 520 |
+
for idx, path in to_remove[::-1]:
|
| 521 |
+
if path in total:
|
| 522 |
+
self.checkpoints_path[idx].insert(0, path)
|
| 523 |
+
else:
|
| 524 |
+
do_remove.add(path)
|
| 525 |
+
|
| 526 |
+
# Remove old checkpoints
|
| 527 |
+
for path in do_remove:
|
| 528 |
+
shutil.rmtree(path, ignore_errors=True)
|
| 529 |
+
self.logger.debug(f"Remove old checkpoint: {path}")
|
| 530 |
+
|
| 531 |
+
self.accelerator.wait_for_everyone()
|
| 532 |
+
if run_eval:
|
| 533 |
+
# TODO: run evaluation
|
| 534 |
+
pass
|
| 535 |
+
|
| 536 |
+
# Update info for each epoch
|
| 537 |
+
self.epoch += 1
|
| 538 |
+
|
| 539 |
+
# Finish training and save final checkpoint
|
| 540 |
+
self.accelerator.wait_for_everyone()
|
| 541 |
+
if self.accelerator.is_main_process:
|
| 542 |
+
path = os.path.join(
|
| 543 |
+
self.checkpoint_dir,
|
| 544 |
+
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
| 545 |
+
self.epoch, self.step, valid_total_loss
|
| 546 |
+
),
|
| 547 |
+
)
|
| 548 |
+
self.tmp_checkpoint_save_path = path
|
| 549 |
+
self.accelerator.save_state(
|
| 550 |
+
os.path.join(
|
| 551 |
+
self.checkpoint_dir,
|
| 552 |
+
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
| 553 |
+
self.epoch, self.step, valid_total_loss
|
| 554 |
+
),
|
| 555 |
+
)
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
json.dump(
|
| 559 |
+
self.checkpoints_path,
|
| 560 |
+
open(os.path.join(path, "ckpts.json"), "w"),
|
| 561 |
+
ensure_ascii=False,
|
| 562 |
+
indent=4,
|
| 563 |
+
)
|
| 564 |
+
self._save_auxiliary_states()
|
| 565 |
+
|
| 566 |
+
self.accelerator.end_training()
|
| 567 |
+
|
| 568 |
+
def _train_step(self, batch):
|
| 569 |
+
r"""Forward step for training and inference. This function is called
|
| 570 |
+
in ``_train_step`` & ``_test_step`` function.
|
| 571 |
+
"""
|
| 572 |
+
|
| 573 |
+
train_losses = {}
|
| 574 |
+
total_loss = 0
|
| 575 |
+
training_stats = {}
|
| 576 |
+
|
| 577 |
+
## Train Discriminator
|
| 578 |
+
# Generator output
|
| 579 |
+
outputs_g = self.model["generator"](batch)
|
| 580 |
+
|
| 581 |
+
y_mel = slice_segments(
|
| 582 |
+
batch["mel"].transpose(1, 2),
|
| 583 |
+
outputs_g["ids_slice"],
|
| 584 |
+
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
|
| 585 |
+
)
|
| 586 |
+
y_hat_mel = mel_spectrogram_torch(
|
| 587 |
+
outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
y = slice_segments(
|
| 591 |
+
# [1, 168418] -> [1, 1, 168418]
|
| 592 |
+
batch["audio"].unsqueeze(1),
|
| 593 |
+
outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
|
| 594 |
+
self.cfg.preprocess.segment_size,
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
# Discriminator output
|
| 598 |
+
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
|
| 599 |
+
# Discriminator loss
|
| 600 |
+
loss_d = self.criterion["discriminator"](
|
| 601 |
+
outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
|
| 602 |
+
)
|
| 603 |
+
train_losses.update(loss_d)
|
| 604 |
+
|
| 605 |
+
# BP and Grad Updated
|
| 606 |
+
self.optimizer["optimizer_d"].zero_grad()
|
| 607 |
+
self.accelerator.backward(loss_d["loss_disc_all"])
|
| 608 |
+
self.optimizer["optimizer_d"].step()
|
| 609 |
+
|
| 610 |
+
## Train Generator
|
| 611 |
+
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
|
| 612 |
+
loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
|
| 613 |
+
train_losses.update(loss_g)
|
| 614 |
+
|
| 615 |
+
# BP and Grad Updated
|
| 616 |
+
self.optimizer["optimizer_g"].zero_grad()
|
| 617 |
+
self.accelerator.backward(loss_g["loss_gen_all"])
|
| 618 |
+
self.optimizer["optimizer_g"].step()
|
| 619 |
+
|
| 620 |
+
for item in train_losses:
|
| 621 |
+
train_losses[item] = train_losses[item].item()
|
| 622 |
+
|
| 623 |
+
total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
|
| 624 |
+
|
| 625 |
+
return (
|
| 626 |
+
total_loss.item(),
|
| 627 |
+
train_losses,
|
| 628 |
+
training_stats,
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
def _train_epoch(self):
|
| 632 |
+
r"""Training epoch. Should return average loss of a batch (sample) over
|
| 633 |
+
one epoch. See ``train_loop`` for usage.
|
| 634 |
+
"""
|
| 635 |
+
epoch_sum_loss: float = 0.0
|
| 636 |
+
epoch_losses: dict = {}
|
| 637 |
+
epoch_step: int = 0
|
| 638 |
+
for batch in tqdm(
|
| 639 |
+
self.train_dataloader,
|
| 640 |
+
desc=f"Training Epoch {self.epoch}",
|
| 641 |
+
unit="batch",
|
| 642 |
+
colour="GREEN",
|
| 643 |
+
leave=False,
|
| 644 |
+
dynamic_ncols=True,
|
| 645 |
+
smoothing=0.04,
|
| 646 |
+
disable=not self.accelerator.is_main_process,
|
| 647 |
+
):
|
| 648 |
+
# Do training step and BP
|
| 649 |
+
with self.accelerator.accumulate(self.model):
|
| 650 |
+
total_loss, train_losses, training_stats = self._train_step(batch)
|
| 651 |
+
self.batch_count += 1
|
| 652 |
+
|
| 653 |
+
# Update info for each step
|
| 654 |
+
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
|
| 655 |
+
epoch_sum_loss += total_loss
|
| 656 |
+
for key, value in train_losses.items():
|
| 657 |
+
if key not in epoch_losses.keys():
|
| 658 |
+
epoch_losses[key] = value
|
| 659 |
+
else:
|
| 660 |
+
epoch_losses[key] += value
|
| 661 |
+
|
| 662 |
+
self.accelerator.log(
|
| 663 |
+
{
|
| 664 |
+
"Step/Generator Loss": train_losses["loss_gen_all"],
|
| 665 |
+
"Step/Discriminator Loss": train_losses["loss_disc_all"],
|
| 666 |
+
"Step/Generator Learning Rate": self.optimizer[
|
| 667 |
+
"optimizer_d"
|
| 668 |
+
].param_groups[0]["lr"],
|
| 669 |
+
"Step/Discriminator Learning Rate": self.optimizer[
|
| 670 |
+
"optimizer_g"
|
| 671 |
+
].param_groups[0]["lr"],
|
| 672 |
+
},
|
| 673 |
+
step=self.step,
|
| 674 |
+
)
|
| 675 |
+
self.step += 1
|
| 676 |
+
epoch_step += 1
|
| 677 |
+
|
| 678 |
+
self.accelerator.wait_for_everyone()
|
| 679 |
+
|
| 680 |
+
epoch_sum_loss = (
|
| 681 |
+
epoch_sum_loss
|
| 682 |
+
/ len(self.train_dataloader)
|
| 683 |
+
* self.cfg.train.gradient_accumulation_step
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
for key in epoch_losses.keys():
|
| 687 |
+
epoch_losses[key] = (
|
| 688 |
+
epoch_losses[key]
|
| 689 |
+
/ len(self.train_dataloader)
|
| 690 |
+
* self.cfg.train.gradient_accumulation_step
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
return epoch_sum_loss, epoch_losses
|
| 694 |
+
|
| 695 |
+
def __dump_cfg(self, path):
|
| 696 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 697 |
+
json5.dump(
|
| 698 |
+
self.cfg,
|
| 699 |
+
open(path, "w"),
|
| 700 |
+
indent=4,
|
| 701 |
+
sort_keys=True,
|
| 702 |
+
ensure_ascii=False,
|
| 703 |
+
quote_keys=True,
|
| 704 |
+
)
|
Amphion/models/tta/autoencoder/autoencoder_dataset.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
import torch
|
| 8 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 9 |
+
from utils.data_utils import *
|
| 10 |
+
from models.base.base_dataset import (
|
| 11 |
+
BaseOfflineCollator,
|
| 12 |
+
BaseOfflineDataset,
|
| 13 |
+
BaseTestDataset,
|
| 14 |
+
BaseTestCollator,
|
| 15 |
+
)
|
| 16 |
+
import librosa
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AutoencoderKLDataset(BaseOfflineDataset):
|
| 20 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
| 21 |
+
BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid)
|
| 22 |
+
|
| 23 |
+
cfg = self.cfg
|
| 24 |
+
|
| 25 |
+
# utt2melspec
|
| 26 |
+
if cfg.preprocess.use_melspec:
|
| 27 |
+
self.utt2melspec_path = {}
|
| 28 |
+
for utt_info in self.metadata:
|
| 29 |
+
dataset = utt_info["Dataset"]
|
| 30 |
+
uid = utt_info["Uid"]
|
| 31 |
+
utt = "{}_{}".format(dataset, uid)
|
| 32 |
+
|
| 33 |
+
self.utt2melspec_path[utt] = os.path.join(
|
| 34 |
+
cfg.preprocess.processed_dir,
|
| 35 |
+
dataset,
|
| 36 |
+
cfg.preprocess.melspec_dir,
|
| 37 |
+
uid + ".npy",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# utt2wav
|
| 41 |
+
if cfg.preprocess.use_wav:
|
| 42 |
+
self.utt2wav_path = {}
|
| 43 |
+
for utt_info in self.metadata:
|
| 44 |
+
dataset = utt_info["Dataset"]
|
| 45 |
+
uid = utt_info["Uid"]
|
| 46 |
+
utt = "{}_{}".format(dataset, uid)
|
| 47 |
+
|
| 48 |
+
self.utt2wav_path[utt] = os.path.join(
|
| 49 |
+
cfg.preprocess.processed_dir,
|
| 50 |
+
dataset,
|
| 51 |
+
cfg.preprocess.wav_dir,
|
| 52 |
+
uid + ".wav",
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def __getitem__(self, index):
|
| 56 |
+
# melspec: (n_mels, T)
|
| 57 |
+
# wav: (T,)
|
| 58 |
+
|
| 59 |
+
single_feature = BaseOfflineDataset.__getitem__(self, index)
|
| 60 |
+
|
| 61 |
+
utt_info = self.metadata[index]
|
| 62 |
+
dataset = utt_info["Dataset"]
|
| 63 |
+
uid = utt_info["Uid"]
|
| 64 |
+
utt = "{}_{}".format(dataset, uid)
|
| 65 |
+
|
| 66 |
+
if self.cfg.preprocess.use_melspec:
|
| 67 |
+
single_feature["melspec"] = np.load(self.utt2melspec_path[utt])
|
| 68 |
+
|
| 69 |
+
if self.cfg.preprocess.use_wav:
|
| 70 |
+
wav, sr = librosa.load(
|
| 71 |
+
self.utt2wav_path[utt], sr=16000
|
| 72 |
+
) # hard coding for 16KHz...
|
| 73 |
+
single_feature["wav"] = wav
|
| 74 |
+
|
| 75 |
+
return single_feature
|
| 76 |
+
|
| 77 |
+
def __len__(self):
|
| 78 |
+
return len(self.metadata)
|
| 79 |
+
|
| 80 |
+
def __len__(self):
|
| 81 |
+
return len(self.metadata)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class AutoencoderKLCollator(BaseOfflineCollator):
|
| 85 |
+
def __init__(self, cfg):
|
| 86 |
+
BaseOfflineCollator.__init__(self, cfg)
|
| 87 |
+
|
| 88 |
+
def __call__(self, batch):
|
| 89 |
+
# mel: (B, n_mels, T)
|
| 90 |
+
# wav (option): (B, T)
|
| 91 |
+
|
| 92 |
+
packed_batch_features = dict()
|
| 93 |
+
|
| 94 |
+
for key in batch[0].keys():
|
| 95 |
+
if key == "melspec":
|
| 96 |
+
packed_batch_features["melspec"] = torch.from_numpy(
|
| 97 |
+
np.array([b["melspec"][:, :624] for b in batch])
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if key == "wav":
|
| 101 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
| 102 |
+
packed_batch_features[key] = pad_sequence(
|
| 103 |
+
values, batch_first=True, padding_value=0
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
return packed_batch_features
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class AutoencoderKLTestDataset(BaseTestDataset): ...
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class AutoencoderKLTestCollator(BaseTestCollator): ...
|
Amphion/models/tta/ldm/__init__.py
ADDED
|
File without changes
|
Amphion/models/tta/ldm/audioldm_dataset.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
import torch
|
| 8 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 9 |
+
from utils.data_utils import *
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from models.base.base_dataset import (
|
| 13 |
+
BaseOfflineCollator,
|
| 14 |
+
BaseOfflineDataset,
|
| 15 |
+
BaseTestDataset,
|
| 16 |
+
BaseTestCollator,
|
| 17 |
+
)
|
| 18 |
+
import librosa
|
| 19 |
+
|
| 20 |
+
from transformers import AutoTokenizer
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AudioLDMDataset(BaseOfflineDataset):
|
| 24 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
| 25 |
+
BaseOfflineDataset.__init__(self, cfg, dataset, is_valid=is_valid)
|
| 26 |
+
|
| 27 |
+
self.cfg = cfg
|
| 28 |
+
|
| 29 |
+
# utt2melspec
|
| 30 |
+
if cfg.preprocess.use_melspec:
|
| 31 |
+
self.utt2melspec_path = {}
|
| 32 |
+
for utt_info in self.metadata:
|
| 33 |
+
dataset = utt_info["Dataset"]
|
| 34 |
+
uid = utt_info["Uid"]
|
| 35 |
+
utt = "{}_{}".format(dataset, uid)
|
| 36 |
+
|
| 37 |
+
self.utt2melspec_path[utt] = os.path.join(
|
| 38 |
+
cfg.preprocess.processed_dir,
|
| 39 |
+
dataset,
|
| 40 |
+
cfg.preprocess.melspec_dir,
|
| 41 |
+
uid + ".npy",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# utt2wav
|
| 45 |
+
if cfg.preprocess.use_wav:
|
| 46 |
+
self.utt2wav_path = {}
|
| 47 |
+
for utt_info in self.metadata:
|
| 48 |
+
dataset = utt_info["Dataset"]
|
| 49 |
+
uid = utt_info["Uid"]
|
| 50 |
+
utt = "{}_{}".format(dataset, uid)
|
| 51 |
+
|
| 52 |
+
self.utt2wav_path[utt] = os.path.join(
|
| 53 |
+
cfg.preprocess.processed_dir,
|
| 54 |
+
dataset,
|
| 55 |
+
cfg.preprocess.wav_dir,
|
| 56 |
+
uid + ".wav",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# utt2caption
|
| 60 |
+
if cfg.preprocess.use_caption:
|
| 61 |
+
self.utt2caption = {}
|
| 62 |
+
for utt_info in self.metadata:
|
| 63 |
+
dataset = utt_info["Dataset"]
|
| 64 |
+
uid = utt_info["Uid"]
|
| 65 |
+
utt = "{}_{}".format(dataset, uid)
|
| 66 |
+
|
| 67 |
+
self.utt2caption[utt] = utt_info["Caption"]
|
| 68 |
+
|
| 69 |
+
def __getitem__(self, index):
|
| 70 |
+
# melspec: (n_mels, T)
|
| 71 |
+
# wav: (T,)
|
| 72 |
+
|
| 73 |
+
single_feature = BaseOfflineDataset.__getitem__(self, index)
|
| 74 |
+
|
| 75 |
+
utt_info = self.metadata[index]
|
| 76 |
+
dataset = utt_info["Dataset"]
|
| 77 |
+
uid = utt_info["Uid"]
|
| 78 |
+
utt = "{}_{}".format(dataset, uid)
|
| 79 |
+
|
| 80 |
+
if self.cfg.preprocess.use_melspec:
|
| 81 |
+
single_feature["melspec"] = np.load(self.utt2melspec_path[utt])
|
| 82 |
+
|
| 83 |
+
if self.cfg.preprocess.use_wav:
|
| 84 |
+
wav, sr = librosa.load(
|
| 85 |
+
self.utt2wav_path[utt], sr=16000
|
| 86 |
+
) # hard coding for 16KHz...
|
| 87 |
+
single_feature["wav"] = wav
|
| 88 |
+
|
| 89 |
+
if self.cfg.preprocess.use_caption:
|
| 90 |
+
cond_mask = np.random.choice(
|
| 91 |
+
[1, 0],
|
| 92 |
+
p=[
|
| 93 |
+
self.cfg.preprocess.cond_mask_prob,
|
| 94 |
+
1 - self.cfg.preprocess.cond_mask_prob,
|
| 95 |
+
],
|
| 96 |
+
) # (0.1, 0.9)
|
| 97 |
+
if cond_mask:
|
| 98 |
+
single_feature["caption"] = ""
|
| 99 |
+
else:
|
| 100 |
+
single_feature["caption"] = self.utt2caption[utt]
|
| 101 |
+
|
| 102 |
+
return single_feature
|
| 103 |
+
|
| 104 |
+
def __len__(self):
|
| 105 |
+
return len(self.metadata)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class AudioLDMCollator(BaseOfflineCollator):
|
| 109 |
+
def __init__(self, cfg):
|
| 110 |
+
BaseOfflineCollator.__init__(self, cfg)
|
| 111 |
+
|
| 112 |
+
self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
|
| 113 |
+
|
| 114 |
+
def __call__(self, batch):
|
| 115 |
+
# mel: (B, n_mels, T)
|
| 116 |
+
# wav (option): (B, T)
|
| 117 |
+
# text_input_ids: (B, L)
|
| 118 |
+
# text_attention_mask: (B, L)
|
| 119 |
+
|
| 120 |
+
packed_batch_features = dict()
|
| 121 |
+
|
| 122 |
+
for key in batch[0].keys():
|
| 123 |
+
if key == "melspec":
|
| 124 |
+
packed_batch_features["melspec"] = torch.from_numpy(
|
| 125 |
+
np.array([b["melspec"][:, :624] for b in batch])
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if key == "wav":
|
| 129 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
| 130 |
+
packed_batch_features[key] = pad_sequence(
|
| 131 |
+
values, batch_first=True, padding_value=0
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
if key == "caption":
|
| 135 |
+
captions = [b[key] for b in batch]
|
| 136 |
+
text_input = self.tokenizer(
|
| 137 |
+
captions, return_tensors="pt", truncation=True, padding="longest"
|
| 138 |
+
)
|
| 139 |
+
text_input_ids = text_input["input_ids"]
|
| 140 |
+
text_attention_mask = text_input["attention_mask"]
|
| 141 |
+
|
| 142 |
+
packed_batch_features["text_input_ids"] = text_input_ids
|
| 143 |
+
packed_batch_features["text_attention_mask"] = text_attention_mask
|
| 144 |
+
|
| 145 |
+
return packed_batch_features
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class AudioLDMTestDataset(BaseTestDataset): ...
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class AudioLDMTestCollator(BaseTestCollator): ...
|
Amphion/models/tta/ldm/audioldm_trainer.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from models.base.base_trainer import BaseTrainer
|
| 7 |
+
from diffusers import DDPMScheduler
|
| 8 |
+
from models.tta.ldm.audioldm_dataset import AudioLDMDataset, AudioLDMCollator
|
| 9 |
+
from models.tta.autoencoder.autoencoder import AutoencoderKL
|
| 10 |
+
from models.tta.ldm.audioldm import AudioLDM, UNetModel
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.nn import MSELoss, L1Loss
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
| 16 |
+
|
| 17 |
+
from transformers import T5EncoderModel
|
| 18 |
+
from diffusers import DDPMScheduler
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AudioLDMTrainer(BaseTrainer):
|
| 22 |
+
def __init__(self, args, cfg):
|
| 23 |
+
BaseTrainer.__init__(self, args, cfg)
|
| 24 |
+
self.cfg = cfg
|
| 25 |
+
|
| 26 |
+
self.build_autoencoderkl()
|
| 27 |
+
self.build_textencoder()
|
| 28 |
+
self.nosie_scheduler = self.build_noise_scheduler()
|
| 29 |
+
|
| 30 |
+
self.save_config_file()
|
| 31 |
+
|
| 32 |
+
def build_autoencoderkl(self):
|
| 33 |
+
self.autoencoderkl = AutoencoderKL(self.cfg.model.autoencoderkl)
|
| 34 |
+
self.autoencoder_path = self.cfg.model.autoencoder_path
|
| 35 |
+
checkpoint = torch.load(self.autoencoder_path, map_location="cpu")
|
| 36 |
+
self.autoencoderkl.load_state_dict(checkpoint["model"])
|
| 37 |
+
self.autoencoderkl.cuda(self.args.local_rank)
|
| 38 |
+
self.autoencoderkl.requires_grad_(requires_grad=False)
|
| 39 |
+
self.autoencoderkl.eval()
|
| 40 |
+
|
| 41 |
+
def build_textencoder(self):
|
| 42 |
+
self.text_encoder = T5EncoderModel.from_pretrained("t5-base")
|
| 43 |
+
self.text_encoder.cuda(self.args.local_rank)
|
| 44 |
+
self.text_encoder.requires_grad_(requires_grad=False)
|
| 45 |
+
self.text_encoder.eval()
|
| 46 |
+
|
| 47 |
+
def build_noise_scheduler(self):
|
| 48 |
+
nosie_scheduler = DDPMScheduler(
|
| 49 |
+
num_train_timesteps=self.cfg.model.noise_scheduler.num_train_timesteps,
|
| 50 |
+
beta_start=self.cfg.model.noise_scheduler.beta_start,
|
| 51 |
+
beta_end=self.cfg.model.noise_scheduler.beta_end,
|
| 52 |
+
beta_schedule=self.cfg.model.noise_scheduler.beta_schedule,
|
| 53 |
+
clip_sample=self.cfg.model.noise_scheduler.clip_sample,
|
| 54 |
+
# steps_offset=self.cfg.model.noise_scheduler.steps_offset,
|
| 55 |
+
# set_alpha_to_one=self.cfg.model.noise_scheduler.set_alpha_to_one,
|
| 56 |
+
# skip_prk_steps=self.cfg.model.noise_scheduler.skip_prk_steps,
|
| 57 |
+
prediction_type=self.cfg.model.noise_scheduler.prediction_type,
|
| 58 |
+
)
|
| 59 |
+
return nosie_scheduler
|
| 60 |
+
|
| 61 |
+
def build_dataset(self):
|
| 62 |
+
return AudioLDMDataset, AudioLDMCollator
|
| 63 |
+
|
| 64 |
+
def build_data_loader(self):
|
| 65 |
+
Dataset, Collator = self.build_dataset()
|
| 66 |
+
# build dataset instance for each dataset and combine them by ConcatDataset
|
| 67 |
+
datasets_list = []
|
| 68 |
+
for dataset in self.cfg.dataset:
|
| 69 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=False)
|
| 70 |
+
datasets_list.append(subdataset)
|
| 71 |
+
train_dataset = ConcatDataset(datasets_list)
|
| 72 |
+
|
| 73 |
+
train_collate = Collator(self.cfg)
|
| 74 |
+
|
| 75 |
+
# use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
|
| 76 |
+
train_loader = DataLoader(
|
| 77 |
+
train_dataset,
|
| 78 |
+
collate_fn=train_collate,
|
| 79 |
+
num_workers=self.args.num_workers,
|
| 80 |
+
batch_size=self.cfg.train.batch_size,
|
| 81 |
+
pin_memory=False,
|
| 82 |
+
)
|
| 83 |
+
if not self.cfg.train.ddp or self.args.local_rank == 0:
|
| 84 |
+
datasets_list = []
|
| 85 |
+
for dataset in self.cfg.dataset:
|
| 86 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=True)
|
| 87 |
+
datasets_list.append(subdataset)
|
| 88 |
+
valid_dataset = ConcatDataset(datasets_list)
|
| 89 |
+
valid_collate = Collator(self.cfg)
|
| 90 |
+
|
| 91 |
+
valid_loader = DataLoader(
|
| 92 |
+
valid_dataset,
|
| 93 |
+
collate_fn=valid_collate,
|
| 94 |
+
num_workers=1,
|
| 95 |
+
batch_size=self.cfg.train.batch_size,
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
raise NotImplementedError("DDP is not supported yet.")
|
| 99 |
+
# valid_loader = None
|
| 100 |
+
data_loader = {"train": train_loader, "valid": valid_loader}
|
| 101 |
+
return data_loader
|
| 102 |
+
|
| 103 |
+
def build_optimizer(self):
|
| 104 |
+
optimizer = torch.optim.AdamW(self.model.parameters(), **self.cfg.train.adam)
|
| 105 |
+
return optimizer
|
| 106 |
+
|
| 107 |
+
# TODO: check it...
|
| 108 |
+
def build_scheduler(self):
|
| 109 |
+
return None
|
| 110 |
+
# return ReduceLROnPlateau(self.optimizer["opt_ae"], **self.cfg.train.lronPlateau)
|
| 111 |
+
|
| 112 |
+
def write_summary(self, losses, stats):
|
| 113 |
+
for key, value in losses.items():
|
| 114 |
+
self.sw.add_scalar(key, value, self.step)
|
| 115 |
+
|
| 116 |
+
def write_valid_summary(self, losses, stats):
|
| 117 |
+
for key, value in losses.items():
|
| 118 |
+
self.sw.add_scalar(key, value, self.step)
|
| 119 |
+
|
| 120 |
+
def build_criterion(self):
|
| 121 |
+
criterion = nn.MSELoss(reduction="mean")
|
| 122 |
+
return criterion
|
| 123 |
+
|
| 124 |
+
def get_state_dict(self):
|
| 125 |
+
if self.scheduler != None:
|
| 126 |
+
state_dict = {
|
| 127 |
+
"model": self.model.state_dict(),
|
| 128 |
+
"optimizer": self.optimizer.state_dict(),
|
| 129 |
+
"scheduler": self.scheduler.state_dict(),
|
| 130 |
+
"step": self.step,
|
| 131 |
+
"epoch": self.epoch,
|
| 132 |
+
"batch_size": self.cfg.train.batch_size,
|
| 133 |
+
}
|
| 134 |
+
else:
|
| 135 |
+
state_dict = {
|
| 136 |
+
"model": self.model.state_dict(),
|
| 137 |
+
"optimizer": self.optimizer.state_dict(),
|
| 138 |
+
"step": self.step,
|
| 139 |
+
"epoch": self.epoch,
|
| 140 |
+
"batch_size": self.cfg.train.batch_size,
|
| 141 |
+
}
|
| 142 |
+
return state_dict
|
| 143 |
+
|
| 144 |
+
def load_model(self, checkpoint):
|
| 145 |
+
self.step = checkpoint["step"]
|
| 146 |
+
self.epoch = checkpoint["epoch"]
|
| 147 |
+
|
| 148 |
+
self.model.load_state_dict(checkpoint["model"])
|
| 149 |
+
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
| 150 |
+
if self.scheduler != None:
|
| 151 |
+
self.scheduler.load_state_dict(checkpoint["scheduler"])
|
| 152 |
+
|
| 153 |
+
def build_model(self):
|
| 154 |
+
self.model = AudioLDM(self.cfg.model.audioldm)
|
| 155 |
+
return self.model
|
| 156 |
+
|
| 157 |
+
@torch.no_grad()
|
| 158 |
+
def mel_to_latent(self, melspec):
|
| 159 |
+
posterior = self.autoencoderkl.encode(melspec)
|
| 160 |
+
latent = posterior.sample() # (B, 4, 5, 78)
|
| 161 |
+
return latent
|
| 162 |
+
|
| 163 |
+
@torch.no_grad()
|
| 164 |
+
def get_text_embedding(self, text_input_ids, text_attention_mask):
|
| 165 |
+
text_embedding = self.text_encoder(
|
| 166 |
+
input_ids=text_input_ids, attention_mask=text_attention_mask
|
| 167 |
+
).last_hidden_state
|
| 168 |
+
return text_embedding # (B, T, 768)
|
| 169 |
+
|
| 170 |
+
def train_step(self, data):
|
| 171 |
+
train_losses = {}
|
| 172 |
+
total_loss = 0
|
| 173 |
+
train_stats = {}
|
| 174 |
+
|
| 175 |
+
melspec = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
|
| 176 |
+
latents = self.mel_to_latent(melspec)
|
| 177 |
+
|
| 178 |
+
text_embedding = self.get_text_embedding(
|
| 179 |
+
data["text_input_ids"], data["text_attention_mask"]
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
noise = torch.randn_like(latents).float()
|
| 183 |
+
|
| 184 |
+
bsz = latents.shape[0]
|
| 185 |
+
timesteps = torch.randint(
|
| 186 |
+
0,
|
| 187 |
+
self.cfg.model.noise_scheduler.num_train_timesteps,
|
| 188 |
+
(bsz,),
|
| 189 |
+
device=latents.device,
|
| 190 |
+
)
|
| 191 |
+
timesteps = timesteps.long()
|
| 192 |
+
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
noisy_latents = self.nosie_scheduler.add_noise(latents, noise, timesteps)
|
| 195 |
+
|
| 196 |
+
model_pred = self.model(
|
| 197 |
+
noisy_latents, timesteps=timesteps, context=text_embedding
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
loss = self.criterion(model_pred, noise)
|
| 201 |
+
|
| 202 |
+
train_losses["loss"] = loss
|
| 203 |
+
total_loss += loss
|
| 204 |
+
|
| 205 |
+
self.optimizer.zero_grad()
|
| 206 |
+
total_loss.backward()
|
| 207 |
+
self.optimizer.step()
|
| 208 |
+
|
| 209 |
+
for item in train_losses:
|
| 210 |
+
train_losses[item] = train_losses[item].item()
|
| 211 |
+
|
| 212 |
+
return train_losses, train_stats, total_loss.item()
|
| 213 |
+
|
| 214 |
+
# TODO: eval step
|
| 215 |
+
@torch.no_grad()
|
| 216 |
+
def eval_step(self, data, index):
|
| 217 |
+
valid_loss = {}
|
| 218 |
+
total_valid_loss = 0
|
| 219 |
+
valid_stats = {}
|
| 220 |
+
|
| 221 |
+
melspec = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
|
| 222 |
+
latents = self.mel_to_latent(melspec)
|
| 223 |
+
|
| 224 |
+
text_embedding = self.get_text_embedding(
|
| 225 |
+
data["text_input_ids"], data["text_attention_mask"]
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
noise = torch.randn_like(latents).float()
|
| 229 |
+
|
| 230 |
+
bsz = latents.shape[0]
|
| 231 |
+
timesteps = torch.randint(
|
| 232 |
+
0,
|
| 233 |
+
self.cfg.model.noise_scheduler.num_train_timesteps,
|
| 234 |
+
(bsz,),
|
| 235 |
+
device=latents.device,
|
| 236 |
+
)
|
| 237 |
+
timesteps = timesteps.long()
|
| 238 |
+
|
| 239 |
+
noisy_latents = self.nosie_scheduler.add_noise(latents, noise, timesteps)
|
| 240 |
+
|
| 241 |
+
model_pred = self.model(noisy_latents, timesteps, text_embedding)
|
| 242 |
+
|
| 243 |
+
loss = self.criterion(model_pred, noise)
|
| 244 |
+
valid_loss["loss"] = loss
|
| 245 |
+
|
| 246 |
+
total_valid_loss += loss
|
| 247 |
+
|
| 248 |
+
for item in valid_loss:
|
| 249 |
+
valid_loss[item] = valid_loss[item].item()
|
| 250 |
+
|
| 251 |
+
return valid_loss, valid_stats, total_valid_loss.item()
|
Amphion/models/tta/ldm/inference_utils/vocoder.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
| 10 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
| 11 |
+
from models.tta.ldm.inference_utils.utils import get_padding, init_weights
|
| 12 |
+
|
| 13 |
+
LRELU_SLOPE = 0.1
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ResBlock1(torch.nn.Module):
|
| 17 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
| 18 |
+
super(ResBlock1, self).__init__()
|
| 19 |
+
self.h = h
|
| 20 |
+
self.convs1 = nn.ModuleList(
|
| 21 |
+
[
|
| 22 |
+
weight_norm(
|
| 23 |
+
Conv1d(
|
| 24 |
+
channels,
|
| 25 |
+
channels,
|
| 26 |
+
kernel_size,
|
| 27 |
+
1,
|
| 28 |
+
dilation=dilation[0],
|
| 29 |
+
padding=get_padding(kernel_size, dilation[0]),
|
| 30 |
+
)
|
| 31 |
+
),
|
| 32 |
+
weight_norm(
|
| 33 |
+
Conv1d(
|
| 34 |
+
channels,
|
| 35 |
+
channels,
|
| 36 |
+
kernel_size,
|
| 37 |
+
1,
|
| 38 |
+
dilation=dilation[1],
|
| 39 |
+
padding=get_padding(kernel_size, dilation[1]),
|
| 40 |
+
)
|
| 41 |
+
),
|
| 42 |
+
weight_norm(
|
| 43 |
+
Conv1d(
|
| 44 |
+
channels,
|
| 45 |
+
channels,
|
| 46 |
+
kernel_size,
|
| 47 |
+
1,
|
| 48 |
+
dilation=dilation[2],
|
| 49 |
+
padding=get_padding(kernel_size, dilation[2]),
|
| 50 |
+
)
|
| 51 |
+
),
|
| 52 |
+
]
|
| 53 |
+
)
|
| 54 |
+
self.convs1.apply(init_weights)
|
| 55 |
+
|
| 56 |
+
self.convs2 = nn.ModuleList(
|
| 57 |
+
[
|
| 58 |
+
weight_norm(
|
| 59 |
+
Conv1d(
|
| 60 |
+
channels,
|
| 61 |
+
channels,
|
| 62 |
+
kernel_size,
|
| 63 |
+
1,
|
| 64 |
+
dilation=1,
|
| 65 |
+
padding=get_padding(kernel_size, 1),
|
| 66 |
+
)
|
| 67 |
+
),
|
| 68 |
+
weight_norm(
|
| 69 |
+
Conv1d(
|
| 70 |
+
channels,
|
| 71 |
+
channels,
|
| 72 |
+
kernel_size,
|
| 73 |
+
1,
|
| 74 |
+
dilation=1,
|
| 75 |
+
padding=get_padding(kernel_size, 1),
|
| 76 |
+
)
|
| 77 |
+
),
|
| 78 |
+
weight_norm(
|
| 79 |
+
Conv1d(
|
| 80 |
+
channels,
|
| 81 |
+
channels,
|
| 82 |
+
kernel_size,
|
| 83 |
+
1,
|
| 84 |
+
dilation=1,
|
| 85 |
+
padding=get_padding(kernel_size, 1),
|
| 86 |
+
)
|
| 87 |
+
),
|
| 88 |
+
]
|
| 89 |
+
)
|
| 90 |
+
self.convs2.apply(init_weights)
|
| 91 |
+
|
| 92 |
+
def forward(self, x):
|
| 93 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
| 94 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 95 |
+
xt = c1(xt)
|
| 96 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
| 97 |
+
xt = c2(xt)
|
| 98 |
+
x = xt + x
|
| 99 |
+
return x
|
| 100 |
+
|
| 101 |
+
def remove_weight_norm(self):
|
| 102 |
+
for l in self.convs1:
|
| 103 |
+
remove_weight_norm(l)
|
| 104 |
+
for l in self.convs2:
|
| 105 |
+
remove_weight_norm(l)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class ResBlock2(torch.nn.Module):
|
| 109 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
| 110 |
+
super(ResBlock2, self).__init__()
|
| 111 |
+
self.h = h
|
| 112 |
+
self.convs = nn.ModuleList(
|
| 113 |
+
[
|
| 114 |
+
weight_norm(
|
| 115 |
+
Conv1d(
|
| 116 |
+
channels,
|
| 117 |
+
channels,
|
| 118 |
+
kernel_size,
|
| 119 |
+
1,
|
| 120 |
+
dilation=dilation[0],
|
| 121 |
+
padding=get_padding(kernel_size, dilation[0]),
|
| 122 |
+
)
|
| 123 |
+
),
|
| 124 |
+
weight_norm(
|
| 125 |
+
Conv1d(
|
| 126 |
+
channels,
|
| 127 |
+
channels,
|
| 128 |
+
kernel_size,
|
| 129 |
+
1,
|
| 130 |
+
dilation=dilation[1],
|
| 131 |
+
padding=get_padding(kernel_size, dilation[1]),
|
| 132 |
+
)
|
| 133 |
+
),
|
| 134 |
+
]
|
| 135 |
+
)
|
| 136 |
+
self.convs.apply(init_weights)
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
for c in self.convs:
|
| 140 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 141 |
+
xt = c(xt)
|
| 142 |
+
x = xt + x
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
def remove_weight_norm(self):
|
| 146 |
+
for l in self.convs:
|
| 147 |
+
remove_weight_norm(l)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class Generator(torch.nn.Module):
|
| 151 |
+
def __init__(self, h):
|
| 152 |
+
super(Generator, self).__init__()
|
| 153 |
+
self.h = h
|
| 154 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
| 155 |
+
self.num_upsamples = len(h.upsample_rates)
|
| 156 |
+
self.conv_pre = weight_norm(
|
| 157 |
+
Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
|
| 158 |
+
)
|
| 159 |
+
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
|
| 160 |
+
|
| 161 |
+
self.ups = nn.ModuleList()
|
| 162 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
| 163 |
+
self.ups.append(
|
| 164 |
+
weight_norm(
|
| 165 |
+
ConvTranspose1d(
|
| 166 |
+
h.upsample_initial_channel // (2**i),
|
| 167 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
| 168 |
+
k,
|
| 169 |
+
u,
|
| 170 |
+
padding=(k - u) // 2,
|
| 171 |
+
)
|
| 172 |
+
)
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
self.resblocks = nn.ModuleList()
|
| 176 |
+
for i in range(len(self.ups)):
|
| 177 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
| 178 |
+
for j, (k, d) in enumerate(
|
| 179 |
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
| 180 |
+
):
|
| 181 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
| 182 |
+
|
| 183 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
| 184 |
+
self.ups.apply(init_weights)
|
| 185 |
+
self.conv_post.apply(init_weights)
|
| 186 |
+
|
| 187 |
+
def forward(self, x):
|
| 188 |
+
x = self.conv_pre(x)
|
| 189 |
+
for i in range(self.num_upsamples):
|
| 190 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 191 |
+
x = self.ups[i](x)
|
| 192 |
+
xs = None
|
| 193 |
+
for j in range(self.num_kernels):
|
| 194 |
+
if xs is None:
|
| 195 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 196 |
+
else:
|
| 197 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 198 |
+
x = xs / self.num_kernels
|
| 199 |
+
x = F.leaky_relu(x)
|
| 200 |
+
x = self.conv_post(x)
|
| 201 |
+
x = torch.tanh(x)
|
| 202 |
+
|
| 203 |
+
return x
|
| 204 |
+
|
| 205 |
+
def remove_weight_norm(self):
|
| 206 |
+
print("Removing weight norm...")
|
| 207 |
+
for l in self.ups:
|
| 208 |
+
remove_weight_norm(l)
|
| 209 |
+
for l in self.resblocks:
|
| 210 |
+
l.remove_weight_norm()
|
| 211 |
+
remove_weight_norm(self.conv_pre)
|
| 212 |
+
remove_weight_norm(self.conv_post)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class DiscriminatorP(torch.nn.Module):
|
| 216 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
| 217 |
+
super(DiscriminatorP, self).__init__()
|
| 218 |
+
self.period = period
|
| 219 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 220 |
+
self.convs = nn.ModuleList(
|
| 221 |
+
[
|
| 222 |
+
norm_f(
|
| 223 |
+
Conv2d(
|
| 224 |
+
1,
|
| 225 |
+
32,
|
| 226 |
+
(kernel_size, 1),
|
| 227 |
+
(stride, 1),
|
| 228 |
+
padding=(get_padding(5, 1), 0),
|
| 229 |
+
)
|
| 230 |
+
),
|
| 231 |
+
norm_f(
|
| 232 |
+
Conv2d(
|
| 233 |
+
32,
|
| 234 |
+
128,
|
| 235 |
+
(kernel_size, 1),
|
| 236 |
+
(stride, 1),
|
| 237 |
+
padding=(get_padding(5, 1), 0),
|
| 238 |
+
)
|
| 239 |
+
),
|
| 240 |
+
norm_f(
|
| 241 |
+
Conv2d(
|
| 242 |
+
128,
|
| 243 |
+
512,
|
| 244 |
+
(kernel_size, 1),
|
| 245 |
+
(stride, 1),
|
| 246 |
+
padding=(get_padding(5, 1), 0),
|
| 247 |
+
)
|
| 248 |
+
),
|
| 249 |
+
norm_f(
|
| 250 |
+
Conv2d(
|
| 251 |
+
512,
|
| 252 |
+
1024,
|
| 253 |
+
(kernel_size, 1),
|
| 254 |
+
(stride, 1),
|
| 255 |
+
padding=(get_padding(5, 1), 0),
|
| 256 |
+
)
|
| 257 |
+
),
|
| 258 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
| 259 |
+
]
|
| 260 |
+
)
|
| 261 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 262 |
+
|
| 263 |
+
def forward(self, x):
|
| 264 |
+
fmap = []
|
| 265 |
+
|
| 266 |
+
# 1d to 2d
|
| 267 |
+
b, c, t = x.shape
|
| 268 |
+
if t % self.period != 0: # pad first
|
| 269 |
+
n_pad = self.period - (t % self.period)
|
| 270 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
| 271 |
+
t = t + n_pad
|
| 272 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 273 |
+
|
| 274 |
+
for l in self.convs:
|
| 275 |
+
x = l(x)
|
| 276 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 277 |
+
fmap.append(x)
|
| 278 |
+
x = self.conv_post(x)
|
| 279 |
+
fmap.append(x)
|
| 280 |
+
x = torch.flatten(x, 1, -1)
|
| 281 |
+
|
| 282 |
+
return x, fmap
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 286 |
+
def __init__(self):
|
| 287 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
| 288 |
+
self.discriminators = nn.ModuleList(
|
| 289 |
+
[
|
| 290 |
+
DiscriminatorP(2),
|
| 291 |
+
DiscriminatorP(3),
|
| 292 |
+
DiscriminatorP(5),
|
| 293 |
+
DiscriminatorP(7),
|
| 294 |
+
DiscriminatorP(11),
|
| 295 |
+
]
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
def forward(self, y, y_hat):
|
| 299 |
+
y_d_rs = []
|
| 300 |
+
y_d_gs = []
|
| 301 |
+
fmap_rs = []
|
| 302 |
+
fmap_gs = []
|
| 303 |
+
for i, d in enumerate(self.discriminators):
|
| 304 |
+
y_d_r, fmap_r = d(y)
|
| 305 |
+
y_d_g, fmap_g = d(y_hat)
|
| 306 |
+
y_d_rs.append(y_d_r)
|
| 307 |
+
fmap_rs.append(fmap_r)
|
| 308 |
+
y_d_gs.append(y_d_g)
|
| 309 |
+
fmap_gs.append(fmap_g)
|
| 310 |
+
|
| 311 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class DiscriminatorS(torch.nn.Module):
|
| 315 |
+
def __init__(self, use_spectral_norm=False):
|
| 316 |
+
super(DiscriminatorS, self).__init__()
|
| 317 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 318 |
+
self.convs = nn.ModuleList(
|
| 319 |
+
[
|
| 320 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
| 321 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
| 322 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
| 323 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
| 324 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
| 325 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
| 326 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
| 327 |
+
]
|
| 328 |
+
)
|
| 329 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
| 330 |
+
|
| 331 |
+
def forward(self, x):
|
| 332 |
+
fmap = []
|
| 333 |
+
for l in self.convs:
|
| 334 |
+
x = l(x)
|
| 335 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 336 |
+
fmap.append(x)
|
| 337 |
+
x = self.conv_post(x)
|
| 338 |
+
fmap.append(x)
|
| 339 |
+
x = torch.flatten(x, 1, -1)
|
| 340 |
+
|
| 341 |
+
return x, fmap
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
| 345 |
+
def __init__(self):
|
| 346 |
+
super(MultiScaleDiscriminator, self).__init__()
|
| 347 |
+
self.discriminators = nn.ModuleList(
|
| 348 |
+
[
|
| 349 |
+
DiscriminatorS(use_spectral_norm=True),
|
| 350 |
+
DiscriminatorS(),
|
| 351 |
+
DiscriminatorS(),
|
| 352 |
+
]
|
| 353 |
+
)
|
| 354 |
+
self.meanpools = nn.ModuleList(
|
| 355 |
+
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
def forward(self, y, y_hat):
|
| 359 |
+
y_d_rs = []
|
| 360 |
+
y_d_gs = []
|
| 361 |
+
fmap_rs = []
|
| 362 |
+
fmap_gs = []
|
| 363 |
+
for i, d in enumerate(self.discriminators):
|
| 364 |
+
if i != 0:
|
| 365 |
+
y = self.meanpools[i - 1](y)
|
| 366 |
+
y_hat = self.meanpools[i - 1](y_hat)
|
| 367 |
+
y_d_r, fmap_r = d(y)
|
| 368 |
+
y_d_g, fmap_g = d(y_hat)
|
| 369 |
+
y_d_rs.append(y_d_r)
|
| 370 |
+
fmap_rs.append(fmap_r)
|
| 371 |
+
y_d_gs.append(y_d_g)
|
| 372 |
+
fmap_gs.append(fmap_g)
|
| 373 |
+
|
| 374 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def feature_loss(fmap_r, fmap_g):
|
| 378 |
+
loss = 0
|
| 379 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
| 380 |
+
for rl, gl in zip(dr, dg):
|
| 381 |
+
loss += torch.mean(torch.abs(rl - gl))
|
| 382 |
+
|
| 383 |
+
return loss * 2
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
| 387 |
+
loss = 0
|
| 388 |
+
r_losses = []
|
| 389 |
+
g_losses = []
|
| 390 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 391 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
| 392 |
+
g_loss = torch.mean(dg**2)
|
| 393 |
+
loss += r_loss + g_loss
|
| 394 |
+
r_losses.append(r_loss.item())
|
| 395 |
+
g_losses.append(g_loss.item())
|
| 396 |
+
|
| 397 |
+
return loss, r_losses, g_losses
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def generator_loss(disc_outputs):
|
| 401 |
+
loss = 0
|
| 402 |
+
gen_losses = []
|
| 403 |
+
for dg in disc_outputs:
|
| 404 |
+
l = torch.mean((1 - dg) ** 2)
|
| 405 |
+
gen_losses.append(l)
|
| 406 |
+
loss += l
|
| 407 |
+
|
| 408 |
+
return loss, gen_losses
|
Amphion/models/tts/base/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# from .tts_inferece import TTSInference
|
| 7 |
+
from .tts_trainer import TTSTrainer
|
Amphion/models/tts/base/tts_trainer.py
ADDED
|
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import shutil
|
| 9 |
+
import torch
|
| 10 |
+
import time
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import torch
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import re
|
| 15 |
+
import logging
|
| 16 |
+
import json5
|
| 17 |
+
import accelerate
|
| 18 |
+
from accelerate.logging import get_logger
|
| 19 |
+
from accelerate.utils import ProjectConfiguration
|
| 20 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
| 21 |
+
from accelerate import DistributedDataParallelKwargs
|
| 22 |
+
from schedulers.scheduler import Eden
|
| 23 |
+
from models.base.base_sampler import build_samplers
|
| 24 |
+
from models.base.new_trainer import BaseTrainer
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TTSTrainer(BaseTrainer):
|
| 28 |
+
r"""The base trainer for all TTS models. It inherits from BaseTrainer and implements
|
| 29 |
+
``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this
|
| 30 |
+
class, and implement ``_build_model``, ``_forward_step``.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, args=None, cfg=None):
|
| 34 |
+
self.args = args
|
| 35 |
+
self.cfg = cfg
|
| 36 |
+
|
| 37 |
+
cfg.exp_name = args.exp_name
|
| 38 |
+
|
| 39 |
+
# init with accelerate
|
| 40 |
+
self._init_accelerator()
|
| 41 |
+
self.accelerator.wait_for_everyone()
|
| 42 |
+
|
| 43 |
+
with self.accelerator.main_process_first():
|
| 44 |
+
self.logger = get_logger(args.exp_name, log_level="INFO")
|
| 45 |
+
|
| 46 |
+
# Log some info
|
| 47 |
+
self.logger.info("=" * 56)
|
| 48 |
+
self.logger.info("||\t\t" + "New training process started." + "\t\t||")
|
| 49 |
+
self.logger.info("=" * 56)
|
| 50 |
+
self.logger.info("\n")
|
| 51 |
+
self.logger.debug(f"Using {args.log_level.upper()} logging level.")
|
| 52 |
+
self.logger.info(f"Experiment name: {args.exp_name}")
|
| 53 |
+
self.logger.info(f"Experiment directory: {self.exp_dir}")
|
| 54 |
+
self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
|
| 55 |
+
if self.accelerator.is_main_process:
|
| 56 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
| 57 |
+
self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
|
| 58 |
+
|
| 59 |
+
# init counts
|
| 60 |
+
self.batch_count: int = 0
|
| 61 |
+
self.step: int = 0
|
| 62 |
+
self.epoch: int = 0
|
| 63 |
+
self.max_epoch = (
|
| 64 |
+
self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
|
| 65 |
+
)
|
| 66 |
+
self.logger.info(
|
| 67 |
+
"Max epoch: {}".format(
|
| 68 |
+
self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
|
| 69 |
+
)
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Check values
|
| 73 |
+
if self.accelerator.is_main_process:
|
| 74 |
+
self.__check_basic_configs()
|
| 75 |
+
# Set runtime configs
|
| 76 |
+
self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
|
| 77 |
+
self.checkpoints_path = [
|
| 78 |
+
[] for _ in range(len(self.save_checkpoint_stride))
|
| 79 |
+
]
|
| 80 |
+
self.keep_last = [
|
| 81 |
+
i if i > 0 else float("inf") for i in self.cfg.train.keep_last
|
| 82 |
+
]
|
| 83 |
+
self.run_eval = self.cfg.train.run_eval
|
| 84 |
+
|
| 85 |
+
# set random seed
|
| 86 |
+
with self.accelerator.main_process_first():
|
| 87 |
+
start = time.monotonic_ns()
|
| 88 |
+
self._set_random_seed(self.cfg.train.random_seed)
|
| 89 |
+
end = time.monotonic_ns()
|
| 90 |
+
self.logger.debug(
|
| 91 |
+
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
|
| 92 |
+
)
|
| 93 |
+
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
|
| 94 |
+
|
| 95 |
+
# setup data_loader
|
| 96 |
+
with self.accelerator.main_process_first():
|
| 97 |
+
self.logger.info("Building dataset...")
|
| 98 |
+
start = time.monotonic_ns()
|
| 99 |
+
self.train_dataloader, self.valid_dataloader = self._build_dataloader()
|
| 100 |
+
end = time.monotonic_ns()
|
| 101 |
+
self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
|
| 102 |
+
|
| 103 |
+
# save phone table to exp dir. Should be done before building model due to loading phone table in model
|
| 104 |
+
if cfg.preprocess.use_phone and cfg.preprocess.phone_extractor != "lexicon":
|
| 105 |
+
self._save_phone_symbols_file_to_exp_path()
|
| 106 |
+
|
| 107 |
+
# setup model
|
| 108 |
+
with self.accelerator.main_process_first():
|
| 109 |
+
self.logger.info("Building model...")
|
| 110 |
+
start = time.monotonic_ns()
|
| 111 |
+
self.model = self._build_model()
|
| 112 |
+
end = time.monotonic_ns()
|
| 113 |
+
self.logger.debug(self.model)
|
| 114 |
+
self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
|
| 115 |
+
self.logger.info(
|
| 116 |
+
f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# optimizer & scheduler
|
| 120 |
+
with self.accelerator.main_process_first():
|
| 121 |
+
self.logger.info("Building optimizer and scheduler...")
|
| 122 |
+
start = time.monotonic_ns()
|
| 123 |
+
self.optimizer = self._build_optimizer()
|
| 124 |
+
self.scheduler = self._build_scheduler()
|
| 125 |
+
end = time.monotonic_ns()
|
| 126 |
+
self.logger.info(
|
| 127 |
+
f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# create criterion
|
| 131 |
+
with self.accelerator.main_process_first():
|
| 132 |
+
self.logger.info("Building criterion...")
|
| 133 |
+
start = time.monotonic_ns()
|
| 134 |
+
self.criterion = self._build_criterion()
|
| 135 |
+
end = time.monotonic_ns()
|
| 136 |
+
self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
|
| 137 |
+
|
| 138 |
+
# Resume or Finetune
|
| 139 |
+
with self.accelerator.main_process_first():
|
| 140 |
+
self._check_resume()
|
| 141 |
+
|
| 142 |
+
# accelerate prepare
|
| 143 |
+
self.logger.info("Initializing accelerate...")
|
| 144 |
+
start = time.monotonic_ns()
|
| 145 |
+
self._accelerator_prepare()
|
| 146 |
+
end = time.monotonic_ns()
|
| 147 |
+
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
|
| 148 |
+
|
| 149 |
+
# save config file path
|
| 150 |
+
self.config_save_path = os.path.join(self.exp_dir, "args.json")
|
| 151 |
+
self.device = self.accelerator.device
|
| 152 |
+
|
| 153 |
+
if cfg.preprocess.use_spkid and cfg.train.multi_speaker_training:
|
| 154 |
+
self.speakers = self._build_speaker_lut()
|
| 155 |
+
self.utt2spk_dict = self._build_utt2spk_dict()
|
| 156 |
+
|
| 157 |
+
# Only for TTS tasks
|
| 158 |
+
self.task_type = "TTS"
|
| 159 |
+
self.logger.info("Task type: {}".format(self.task_type))
|
| 160 |
+
|
| 161 |
+
def _check_resume(self):
|
| 162 |
+
# if args.resume:
|
| 163 |
+
if self.args.resume or (
|
| 164 |
+
self.cfg.model_type == "VALLE" and self.args.train_stage == 2
|
| 165 |
+
):
|
| 166 |
+
checkpoint_dir = self.checkpoint_dir
|
| 167 |
+
if self.cfg.model_type == "VALLE" and self.args.train_stage == 2:
|
| 168 |
+
ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
|
| 169 |
+
if (
|
| 170 |
+
self.args.checkpoint_path is None or len(ls) == 0
|
| 171 |
+
): # Train stage 2 from scratch using the checkpoint of stage 1
|
| 172 |
+
assert (
|
| 173 |
+
self.args.ar_model_ckpt_dir is not None
|
| 174 |
+
), "Error: ar_model_ckpt_dir should be set to train nar model."
|
| 175 |
+
self.args.resume_type = "finetune"
|
| 176 |
+
checkpoint_dir = self.args.ar_model_ckpt_dir
|
| 177 |
+
self.logger.info(
|
| 178 |
+
f"Training NAR model at stage 2 using the checkpoint of AR model at stage 1."
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
self.logger.info(f"Resuming from checkpoint: {checkpoint_dir}")
|
| 182 |
+
start = time.monotonic_ns()
|
| 183 |
+
self.ckpt_path = self._load_model(
|
| 184 |
+
checkpoint_dir, self.args.checkpoint_path, self.args.resume_type
|
| 185 |
+
)
|
| 186 |
+
self.logger.info(f"Checkpoint path: {self.ckpt_path}")
|
| 187 |
+
end = time.monotonic_ns()
|
| 188 |
+
self.logger.info(
|
| 189 |
+
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
|
| 190 |
+
)
|
| 191 |
+
self.checkpoints_path = json.load(
|
| 192 |
+
open(os.path.join(self.ckpt_path, "ckpts.json"), "r")
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
def _init_accelerator(self):
|
| 196 |
+
self.exp_dir = os.path.join(
|
| 197 |
+
os.path.abspath(self.cfg.log_dir), self.args.exp_name
|
| 198 |
+
)
|
| 199 |
+
project_config = ProjectConfiguration(
|
| 200 |
+
project_dir=self.exp_dir,
|
| 201 |
+
logging_dir=os.path.join(self.exp_dir, "log"),
|
| 202 |
+
)
|
| 203 |
+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 204 |
+
self.accelerator = accelerate.Accelerator(
|
| 205 |
+
gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
|
| 206 |
+
log_with=self.cfg.train.tracker,
|
| 207 |
+
project_config=project_config,
|
| 208 |
+
kwargs_handlers=[kwargs],
|
| 209 |
+
)
|
| 210 |
+
if self.accelerator.is_main_process:
|
| 211 |
+
os.makedirs(project_config.project_dir, exist_ok=True)
|
| 212 |
+
os.makedirs(project_config.logging_dir, exist_ok=True)
|
| 213 |
+
with self.accelerator.main_process_first():
|
| 214 |
+
self.accelerator.init_trackers(self.args.exp_name)
|
| 215 |
+
|
| 216 |
+
def _accelerator_prepare(self):
|
| 217 |
+
(
|
| 218 |
+
self.train_dataloader,
|
| 219 |
+
self.valid_dataloader,
|
| 220 |
+
) = self.accelerator.prepare(
|
| 221 |
+
self.train_dataloader,
|
| 222 |
+
self.valid_dataloader,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
if isinstance(self.model, dict):
|
| 226 |
+
for key in self.model.keys():
|
| 227 |
+
self.model[key] = self.accelerator.prepare(self.model[key])
|
| 228 |
+
else:
|
| 229 |
+
self.model = self.accelerator.prepare(self.model)
|
| 230 |
+
|
| 231 |
+
if isinstance(self.optimizer, dict):
|
| 232 |
+
for key in self.optimizer.keys():
|
| 233 |
+
self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
|
| 234 |
+
else:
|
| 235 |
+
self.optimizer = self.accelerator.prepare(self.optimizer)
|
| 236 |
+
|
| 237 |
+
if isinstance(self.scheduler, dict):
|
| 238 |
+
for key in self.scheduler.keys():
|
| 239 |
+
self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
|
| 240 |
+
else:
|
| 241 |
+
self.scheduler = self.accelerator.prepare(self.scheduler)
|
| 242 |
+
|
| 243 |
+
### Following are methods only for TTS tasks ###
|
| 244 |
+
def _build_dataset(self):
|
| 245 |
+
pass
|
| 246 |
+
|
| 247 |
+
def _build_criterion(self):
|
| 248 |
+
pass
|
| 249 |
+
|
| 250 |
+
def _build_model(self):
|
| 251 |
+
pass
|
| 252 |
+
|
| 253 |
+
def _build_dataloader(self):
|
| 254 |
+
"""Build dataloader which merges a series of datasets."""
|
| 255 |
+
# Build dataset instance for each dataset and combine them by ConcatDataset
|
| 256 |
+
Dataset, Collator = self._build_dataset()
|
| 257 |
+
|
| 258 |
+
# Build train set
|
| 259 |
+
datasets_list = []
|
| 260 |
+
for dataset in self.cfg.dataset:
|
| 261 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=False)
|
| 262 |
+
datasets_list.append(subdataset)
|
| 263 |
+
train_dataset = ConcatDataset(datasets_list)
|
| 264 |
+
train_collate = Collator(self.cfg)
|
| 265 |
+
_, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
|
| 266 |
+
train_loader = DataLoader(
|
| 267 |
+
train_dataset,
|
| 268 |
+
collate_fn=train_collate,
|
| 269 |
+
batch_sampler=batch_sampler,
|
| 270 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
| 271 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Build test set
|
| 275 |
+
datasets_list = []
|
| 276 |
+
for dataset in self.cfg.dataset:
|
| 277 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=True)
|
| 278 |
+
datasets_list.append(subdataset)
|
| 279 |
+
valid_dataset = ConcatDataset(datasets_list)
|
| 280 |
+
valid_collate = Collator(self.cfg)
|
| 281 |
+
_, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
|
| 282 |
+
valid_loader = DataLoader(
|
| 283 |
+
valid_dataset,
|
| 284 |
+
collate_fn=valid_collate,
|
| 285 |
+
batch_sampler=batch_sampler,
|
| 286 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
| 287 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
| 288 |
+
)
|
| 289 |
+
return train_loader, valid_loader
|
| 290 |
+
|
| 291 |
+
def _build_optimizer(self):
|
| 292 |
+
pass
|
| 293 |
+
|
| 294 |
+
def _build_scheduler(self):
|
| 295 |
+
pass
|
| 296 |
+
|
| 297 |
+
def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
|
| 298 |
+
"""Load model from checkpoint. If a folder is given, it will
|
| 299 |
+
load the latest checkpoint in checkpoint_dir. If a path is given
|
| 300 |
+
it will load the checkpoint specified by checkpoint_path.
|
| 301 |
+
**Only use this method after** ``accelerator.prepare()``.
|
| 302 |
+
"""
|
| 303 |
+
if checkpoint_path is None or checkpoint_path == "":
|
| 304 |
+
ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
|
| 305 |
+
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
|
| 306 |
+
checkpoint_path = ls[0]
|
| 307 |
+
self.logger.info("Load model from {}".format(checkpoint_path))
|
| 308 |
+
print("Load model from {}".format(checkpoint_path))
|
| 309 |
+
if resume_type == "resume":
|
| 310 |
+
self.accelerator.load_state(checkpoint_path)
|
| 311 |
+
self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
|
| 312 |
+
self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
|
| 313 |
+
elif resume_type == "finetune":
|
| 314 |
+
if isinstance(self.model, dict):
|
| 315 |
+
for idx, sub_model in enumerate(self.model.keys()):
|
| 316 |
+
if idx == 0:
|
| 317 |
+
ckpt_name = "pytorch_model.bin"
|
| 318 |
+
else:
|
| 319 |
+
ckpt_name = "pytorch_model_{}.bin".format(idx)
|
| 320 |
+
|
| 321 |
+
self.model[sub_model].load_state_dict(
|
| 322 |
+
torch.load(os.path.join(checkpoint_path, ckpt_name))
|
| 323 |
+
)
|
| 324 |
+
self.model[sub_model].cuda(self.accelerator.device)
|
| 325 |
+
else:
|
| 326 |
+
self.model.load_state_dict(
|
| 327 |
+
torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
|
| 328 |
+
)
|
| 329 |
+
self.model.cuda(self.accelerator.device)
|
| 330 |
+
self.logger.info("Load model weights for finetune SUCCESS!")
|
| 331 |
+
|
| 332 |
+
else:
|
| 333 |
+
raise ValueError("Unsupported resume type: {}".format(resume_type))
|
| 334 |
+
|
| 335 |
+
return checkpoint_path
|
| 336 |
+
|
| 337 |
+
### THIS IS MAIN ENTRY ###
|
| 338 |
+
def train_loop(self):
|
| 339 |
+
r"""Training loop. The public entry of training process."""
|
| 340 |
+
# Wait everyone to prepare before we move on
|
| 341 |
+
self.accelerator.wait_for_everyone()
|
| 342 |
+
# dump config file
|
| 343 |
+
if self.accelerator.is_main_process:
|
| 344 |
+
self.__dump_cfg(self.config_save_path)
|
| 345 |
+
|
| 346 |
+
# self.optimizer.zero_grad()
|
| 347 |
+
# Wait to ensure good to go
|
| 348 |
+
|
| 349 |
+
self.accelerator.wait_for_everyone()
|
| 350 |
+
while self.epoch < self.max_epoch:
|
| 351 |
+
self.logger.info("\n")
|
| 352 |
+
self.logger.info("-" * 32)
|
| 353 |
+
self.logger.info("Epoch {}: ".format(self.epoch))
|
| 354 |
+
|
| 355 |
+
# Do training & validating epoch
|
| 356 |
+
train_total_loss, train_losses = self._train_epoch()
|
| 357 |
+
if isinstance(train_losses, dict):
|
| 358 |
+
for key, loss in train_losses.items():
|
| 359 |
+
self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
|
| 360 |
+
self.accelerator.log(
|
| 361 |
+
{"Epoch/Train {} Loss".format(key): loss},
|
| 362 |
+
step=self.epoch,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
valid_total_loss, valid_losses = self._valid_epoch()
|
| 366 |
+
if isinstance(valid_losses, dict):
|
| 367 |
+
for key, loss in valid_losses.items():
|
| 368 |
+
self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss))
|
| 369 |
+
self.accelerator.log(
|
| 370 |
+
{"Epoch/Train {} Loss".format(key): loss},
|
| 371 |
+
step=self.epoch,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
self.logger.info(" |- Train/Loss: {:.6f}".format(train_total_loss))
|
| 375 |
+
self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_total_loss))
|
| 376 |
+
self.accelerator.log(
|
| 377 |
+
{
|
| 378 |
+
"Epoch/Train Loss": train_total_loss,
|
| 379 |
+
"Epoch/Valid Loss": valid_total_loss,
|
| 380 |
+
},
|
| 381 |
+
step=self.epoch,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
self.accelerator.wait_for_everyone()
|
| 385 |
+
|
| 386 |
+
# Check if hit save_checkpoint_stride and run_eval
|
| 387 |
+
run_eval = False
|
| 388 |
+
if self.accelerator.is_main_process:
|
| 389 |
+
save_checkpoint = False
|
| 390 |
+
hit_dix = []
|
| 391 |
+
for i, num in enumerate(self.save_checkpoint_stride):
|
| 392 |
+
if self.epoch % num == 0:
|
| 393 |
+
save_checkpoint = True
|
| 394 |
+
hit_dix.append(i)
|
| 395 |
+
run_eval |= self.run_eval[i]
|
| 396 |
+
|
| 397 |
+
self.accelerator.wait_for_everyone()
|
| 398 |
+
if self.accelerator.is_main_process and save_checkpoint:
|
| 399 |
+
path = os.path.join(
|
| 400 |
+
self.checkpoint_dir,
|
| 401 |
+
"epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
| 402 |
+
self.epoch, self.step, train_total_loss
|
| 403 |
+
),
|
| 404 |
+
)
|
| 405 |
+
self.accelerator.save_state(path)
|
| 406 |
+
|
| 407 |
+
json.dump(
|
| 408 |
+
self.checkpoints_path,
|
| 409 |
+
open(os.path.join(path, "ckpts.json"), "w"),
|
| 410 |
+
ensure_ascii=False,
|
| 411 |
+
indent=4,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
# Remove old checkpoints
|
| 415 |
+
to_remove = []
|
| 416 |
+
for idx in hit_dix:
|
| 417 |
+
self.checkpoints_path[idx].append(path)
|
| 418 |
+
while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
|
| 419 |
+
to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
|
| 420 |
+
|
| 421 |
+
# Search conflicts
|
| 422 |
+
total = set()
|
| 423 |
+
for i in self.checkpoints_path:
|
| 424 |
+
total |= set(i)
|
| 425 |
+
do_remove = set()
|
| 426 |
+
for idx, path in to_remove[::-1]:
|
| 427 |
+
if path in total:
|
| 428 |
+
self.checkpoints_path[idx].insert(0, path)
|
| 429 |
+
else:
|
| 430 |
+
do_remove.add(path)
|
| 431 |
+
|
| 432 |
+
# Remove old checkpoints
|
| 433 |
+
for path in do_remove:
|
| 434 |
+
shutil.rmtree(path, ignore_errors=True)
|
| 435 |
+
self.logger.debug(f"Remove old checkpoint: {path}")
|
| 436 |
+
|
| 437 |
+
self.accelerator.wait_for_everyone()
|
| 438 |
+
if run_eval:
|
| 439 |
+
# TODO: run evaluation
|
| 440 |
+
pass
|
| 441 |
+
|
| 442 |
+
# Update info for each epoch
|
| 443 |
+
self.epoch += 1
|
| 444 |
+
|
| 445 |
+
# Finish training and save final checkpoint
|
| 446 |
+
self.accelerator.wait_for_everyone()
|
| 447 |
+
if self.accelerator.is_main_process:
|
| 448 |
+
path = os.path.join(
|
| 449 |
+
self.checkpoint_dir,
|
| 450 |
+
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
| 451 |
+
self.epoch, self.step, valid_total_loss
|
| 452 |
+
),
|
| 453 |
+
)
|
| 454 |
+
self.accelerator.save_state(
|
| 455 |
+
os.path.join(
|
| 456 |
+
self.checkpoint_dir,
|
| 457 |
+
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
| 458 |
+
self.epoch, self.step, valid_total_loss
|
| 459 |
+
),
|
| 460 |
+
)
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
json.dump(
|
| 464 |
+
self.checkpoints_path,
|
| 465 |
+
open(os.path.join(path, "ckpts.json"), "w"),
|
| 466 |
+
ensure_ascii=False,
|
| 467 |
+
indent=4,
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
self.accelerator.end_training()
|
| 471 |
+
|
| 472 |
+
### Following are methods that can be used directly in child classes ###
|
| 473 |
+
def _train_epoch(self):
|
| 474 |
+
r"""Training epoch. Should return average loss of a batch (sample) over
|
| 475 |
+
one epoch. See ``train_loop`` for usage.
|
| 476 |
+
"""
|
| 477 |
+
if isinstance(self.model, dict):
|
| 478 |
+
for key in self.model.keys():
|
| 479 |
+
self.model[key].train()
|
| 480 |
+
else:
|
| 481 |
+
self.model.train()
|
| 482 |
+
|
| 483 |
+
epoch_sum_loss: float = 0.0
|
| 484 |
+
epoch_losses: dict = {}
|
| 485 |
+
epoch_step: int = 0
|
| 486 |
+
for batch in tqdm(
|
| 487 |
+
self.train_dataloader,
|
| 488 |
+
desc=f"Training Epoch {self.epoch}",
|
| 489 |
+
unit="batch",
|
| 490 |
+
colour="GREEN",
|
| 491 |
+
leave=False,
|
| 492 |
+
dynamic_ncols=True,
|
| 493 |
+
smoothing=0.04,
|
| 494 |
+
disable=not self.accelerator.is_main_process,
|
| 495 |
+
):
|
| 496 |
+
# Do training step and BP
|
| 497 |
+
with self.accelerator.accumulate(self.model):
|
| 498 |
+
total_loss, train_losses, _ = self._train_step(batch)
|
| 499 |
+
self.batch_count += 1
|
| 500 |
+
|
| 501 |
+
# Update info for each step
|
| 502 |
+
# TODO: step means BP counts or batch counts?
|
| 503 |
+
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
|
| 504 |
+
if isinstance(self.scheduler, dict):
|
| 505 |
+
for key in self.scheduler.keys():
|
| 506 |
+
self.scheduler[key].step()
|
| 507 |
+
else:
|
| 508 |
+
if isinstance(self.scheduler, Eden):
|
| 509 |
+
self.scheduler.step_batch(self.step)
|
| 510 |
+
else:
|
| 511 |
+
self.scheduler.step()
|
| 512 |
+
|
| 513 |
+
epoch_sum_loss += total_loss
|
| 514 |
+
|
| 515 |
+
if isinstance(train_losses, dict):
|
| 516 |
+
for key, value in train_losses.items():
|
| 517 |
+
epoch_losses[key] += value
|
| 518 |
+
|
| 519 |
+
if isinstance(train_losses, dict):
|
| 520 |
+
for key, loss in train_losses.items():
|
| 521 |
+
self.accelerator.log(
|
| 522 |
+
{"Epoch/Train {} Loss".format(key): loss},
|
| 523 |
+
step=self.step,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
self.step += 1
|
| 527 |
+
epoch_step += 1
|
| 528 |
+
|
| 529 |
+
self.accelerator.wait_for_everyone()
|
| 530 |
+
|
| 531 |
+
epoch_sum_loss = (
|
| 532 |
+
epoch_sum_loss
|
| 533 |
+
/ len(self.train_dataloader)
|
| 534 |
+
* self.cfg.train.gradient_accumulation_step
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
for key in epoch_losses.keys():
|
| 538 |
+
epoch_losses[key] = (
|
| 539 |
+
epoch_losses[key]
|
| 540 |
+
/ len(self.train_dataloader)
|
| 541 |
+
* self.cfg.train.gradient_accumulation_step
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
return epoch_sum_loss, epoch_losses
|
| 545 |
+
|
| 546 |
+
@torch.inference_mode()
|
| 547 |
+
def _valid_epoch(self):
|
| 548 |
+
r"""Testing epoch. Should return average loss of a batch (sample) over
|
| 549 |
+
one epoch. See ``train_loop`` for usage.
|
| 550 |
+
"""
|
| 551 |
+
if isinstance(self.model, dict):
|
| 552 |
+
for key in self.model.keys():
|
| 553 |
+
self.model[key].eval()
|
| 554 |
+
else:
|
| 555 |
+
self.model.eval()
|
| 556 |
+
|
| 557 |
+
epoch_sum_loss = 0.0
|
| 558 |
+
epoch_losses = dict()
|
| 559 |
+
for batch in tqdm(
|
| 560 |
+
self.valid_dataloader,
|
| 561 |
+
desc=f"Validating Epoch {self.epoch}",
|
| 562 |
+
unit="batch",
|
| 563 |
+
colour="GREEN",
|
| 564 |
+
leave=False,
|
| 565 |
+
dynamic_ncols=True,
|
| 566 |
+
smoothing=0.04,
|
| 567 |
+
disable=not self.accelerator.is_main_process,
|
| 568 |
+
):
|
| 569 |
+
total_loss, valid_losses, valid_stats = self._valid_step(batch)
|
| 570 |
+
epoch_sum_loss += total_loss
|
| 571 |
+
if isinstance(valid_losses, dict):
|
| 572 |
+
for key, value in valid_losses.items():
|
| 573 |
+
if key not in epoch_losses.keys():
|
| 574 |
+
epoch_losses[key] = value
|
| 575 |
+
else:
|
| 576 |
+
epoch_losses[key] += value
|
| 577 |
+
|
| 578 |
+
epoch_sum_loss = epoch_sum_loss / len(self.valid_dataloader)
|
| 579 |
+
for key in epoch_losses.keys():
|
| 580 |
+
epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader)
|
| 581 |
+
|
| 582 |
+
self.accelerator.wait_for_everyone()
|
| 583 |
+
|
| 584 |
+
return epoch_sum_loss, epoch_losses
|
| 585 |
+
|
| 586 |
+
def _train_step(self):
|
| 587 |
+
pass
|
| 588 |
+
|
| 589 |
+
def _valid_step(self, batch):
|
| 590 |
+
pass
|
| 591 |
+
|
| 592 |
+
def _inference(self):
|
| 593 |
+
pass
|
| 594 |
+
|
| 595 |
+
def _is_valid_pattern(self, directory_name):
|
| 596 |
+
directory_name = str(directory_name)
|
| 597 |
+
pattern = r"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}"
|
| 598 |
+
return re.match(pattern, directory_name) is not None
|
| 599 |
+
|
| 600 |
+
def _check_basic_configs(self):
|
| 601 |
+
if self.cfg.train.gradient_accumulation_step <= 0:
|
| 602 |
+
self.logger.fatal("Invalid gradient_accumulation_step value!")
|
| 603 |
+
self.logger.error(
|
| 604 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
| 605 |
+
)
|
| 606 |
+
self.accelerator.end_training()
|
| 607 |
+
raise ValueError(
|
| 608 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
def __dump_cfg(self, path):
|
| 612 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 613 |
+
json5.dump(
|
| 614 |
+
self.cfg,
|
| 615 |
+
open(path, "w"),
|
| 616 |
+
indent=4,
|
| 617 |
+
sort_keys=True,
|
| 618 |
+
ensure_ascii=False,
|
| 619 |
+
quote_keys=True,
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
def __check_basic_configs(self):
|
| 623 |
+
if self.cfg.train.gradient_accumulation_step <= 0:
|
| 624 |
+
self.logger.fatal("Invalid gradient_accumulation_step value!")
|
| 625 |
+
self.logger.error(
|
| 626 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
| 627 |
+
)
|
| 628 |
+
self.accelerator.end_training()
|
| 629 |
+
raise ValueError(
|
| 630 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
| 631 |
+
)
|
| 632 |
+
# TODO: check other values
|
| 633 |
+
|
| 634 |
+
@staticmethod
|
| 635 |
+
def __count_parameters(model):
|
| 636 |
+
model_param = 0.0
|
| 637 |
+
if isinstance(model, dict):
|
| 638 |
+
for key, value in model.items():
|
| 639 |
+
model_param += sum(p.numel() for p in model[key].parameters())
|
| 640 |
+
else:
|
| 641 |
+
model_param = sum(p.numel() for p in model.parameters())
|
| 642 |
+
return model_param
|
| 643 |
+
|
| 644 |
+
def _build_speaker_lut(self):
|
| 645 |
+
# combine speakers
|
| 646 |
+
if not os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
|
| 647 |
+
speakers = {}
|
| 648 |
+
else:
|
| 649 |
+
with open(
|
| 650 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "r"
|
| 651 |
+
) as speaker_file:
|
| 652 |
+
speakers = json.load(speaker_file)
|
| 653 |
+
for dataset in self.cfg.dataset:
|
| 654 |
+
speaker_lut_path = os.path.join(
|
| 655 |
+
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
|
| 656 |
+
)
|
| 657 |
+
with open(speaker_lut_path, "r") as speaker_lut_path:
|
| 658 |
+
singer_lut = json.load(speaker_lut_path)
|
| 659 |
+
for singer in singer_lut.keys():
|
| 660 |
+
if singer not in speakers:
|
| 661 |
+
speakers[singer] = len(speakers)
|
| 662 |
+
with open(
|
| 663 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w"
|
| 664 |
+
) as speaker_file:
|
| 665 |
+
json.dump(speakers, speaker_file, indent=4, ensure_ascii=False)
|
| 666 |
+
print(
|
| 667 |
+
"speakers have been dumped to {}".format(
|
| 668 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
|
| 669 |
+
)
|
| 670 |
+
)
|
| 671 |
+
return speakers
|
| 672 |
+
|
| 673 |
+
def _build_utt2spk_dict(self):
|
| 674 |
+
# combine speakers
|
| 675 |
+
utt2spk = {}
|
| 676 |
+
if not os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk)):
|
| 677 |
+
utt2spk = {}
|
| 678 |
+
else:
|
| 679 |
+
with open(
|
| 680 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk), "r"
|
| 681 |
+
) as utt2spk_file:
|
| 682 |
+
for line in utt2spk_file.readlines():
|
| 683 |
+
utt, spk = line.strip().split("\t")
|
| 684 |
+
utt2spk[utt] = spk
|
| 685 |
+
for dataset in self.cfg.dataset:
|
| 686 |
+
utt2spk_dict_path = os.path.join(
|
| 687 |
+
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.utt2spk
|
| 688 |
+
)
|
| 689 |
+
with open(utt2spk_dict_path, "r") as utt2spk_dict:
|
| 690 |
+
for line in utt2spk_dict.readlines():
|
| 691 |
+
utt, spk = line.strip().split("\t")
|
| 692 |
+
if utt not in utt2spk.keys():
|
| 693 |
+
utt2spk[utt] = spk
|
| 694 |
+
with open(
|
| 695 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk), "w"
|
| 696 |
+
) as utt2spk_file:
|
| 697 |
+
for utt, spk in utt2spk.items():
|
| 698 |
+
utt2spk_file.write(utt + "\t" + spk + "\n")
|
| 699 |
+
print(
|
| 700 |
+
"utterance and speaker mapper have been dumped to {}".format(
|
| 701 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.utt2spk)
|
| 702 |
+
)
|
| 703 |
+
)
|
| 704 |
+
return utt2spk
|
| 705 |
+
|
| 706 |
+
def _save_phone_symbols_file_to_exp_path(self):
|
| 707 |
+
phone_symbols_file = os.path.join(
|
| 708 |
+
self.cfg.preprocess.processed_dir,
|
| 709 |
+
self.cfg.dataset[0],
|
| 710 |
+
self.cfg.preprocess.symbols_dict,
|
| 711 |
+
)
|
| 712 |
+
phone_symbols_file_to_exp_path = os.path.join(
|
| 713 |
+
self.exp_dir, self.cfg.preprocess.symbols_dict
|
| 714 |
+
)
|
| 715 |
+
shutil.copy(phone_symbols_file, phone_symbols_file_to_exp_path)
|
| 716 |
+
os.chmod(phone_symbols_file_to_exp_path, 0o666)
|
| 717 |
+
print(
|
| 718 |
+
"phone symbols been dumped to {}".format(
|
| 719 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.symbols_dict)
|
| 720 |
+
)
|
| 721 |
+
)
|
Amphion/models/tts/fastspeech2/fs2_trainer.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from models.tts.base import TTSTrainer
|
| 10 |
+
from models.tts.fastspeech2.fs2 import FastSpeech2, FastSpeech2Loss
|
| 11 |
+
from models.tts.fastspeech2.fs2_dataset import FS2Dataset, FS2Collator
|
| 12 |
+
from optimizer.optimizers import NoamLR
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FastSpeech2Trainer(TTSTrainer):
|
| 16 |
+
def __init__(self, args, cfg):
|
| 17 |
+
TTSTrainer.__init__(self, args, cfg)
|
| 18 |
+
self.cfg = cfg
|
| 19 |
+
|
| 20 |
+
def _build_dataset(self):
|
| 21 |
+
return FS2Dataset, FS2Collator
|
| 22 |
+
|
| 23 |
+
def __build_scheduler(self):
|
| 24 |
+
return NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
|
| 25 |
+
|
| 26 |
+
def _write_summary(self, losses, stats):
|
| 27 |
+
for key, value in losses.items():
|
| 28 |
+
self.sw.add_scalar("train/" + key, value, self.step)
|
| 29 |
+
lr = self.optimizer.state_dict()["param_groups"][0]["lr"]
|
| 30 |
+
self.sw.add_scalar("learning_rate", lr, self.step)
|
| 31 |
+
|
| 32 |
+
def _write_valid_summary(self, losses, stats):
|
| 33 |
+
for key, value in losses.items():
|
| 34 |
+
self.sw.add_scalar("val/" + key, value, self.step)
|
| 35 |
+
|
| 36 |
+
def _build_criterion(self):
|
| 37 |
+
return FastSpeech2Loss(self.cfg)
|
| 38 |
+
|
| 39 |
+
def get_state_dict(self):
|
| 40 |
+
state_dict = {
|
| 41 |
+
"model": self.model.state_dict(),
|
| 42 |
+
"optimizer": self.optimizer.state_dict(),
|
| 43 |
+
"scheduler": self.scheduler.state_dict(),
|
| 44 |
+
"step": self.step,
|
| 45 |
+
"epoch": self.epoch,
|
| 46 |
+
"batch_size": self.cfg.train.batch_size,
|
| 47 |
+
}
|
| 48 |
+
return state_dict
|
| 49 |
+
|
| 50 |
+
def _build_optimizer(self):
|
| 51 |
+
optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
|
| 52 |
+
return optimizer
|
| 53 |
+
|
| 54 |
+
def _build_scheduler(self):
|
| 55 |
+
scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
|
| 56 |
+
return scheduler
|
| 57 |
+
|
| 58 |
+
def _build_model(self):
|
| 59 |
+
self.model = FastSpeech2(self.cfg)
|
| 60 |
+
return self.model
|
| 61 |
+
|
| 62 |
+
def _train_epoch(self):
|
| 63 |
+
r"""Training epoch. Should return average loss of a batch (sample) over
|
| 64 |
+
one epoch. See ``train_loop`` for usage.
|
| 65 |
+
"""
|
| 66 |
+
self.model.train()
|
| 67 |
+
epoch_sum_loss: float = 0.0
|
| 68 |
+
epoch_step: int = 0
|
| 69 |
+
epoch_losses: dict = {}
|
| 70 |
+
for batch in tqdm(
|
| 71 |
+
self.train_dataloader,
|
| 72 |
+
desc=f"Training Epoch {self.epoch}",
|
| 73 |
+
unit="batch",
|
| 74 |
+
colour="GREEN",
|
| 75 |
+
leave=False,
|
| 76 |
+
dynamic_ncols=True,
|
| 77 |
+
smoothing=0.04,
|
| 78 |
+
disable=not self.accelerator.is_main_process,
|
| 79 |
+
):
|
| 80 |
+
# Do training step and BP
|
| 81 |
+
with self.accelerator.accumulate(self.model):
|
| 82 |
+
loss, train_losses = self._train_step(batch)
|
| 83 |
+
self.accelerator.backward(loss)
|
| 84 |
+
grad_clip_thresh = self.cfg.train.grad_clip_thresh
|
| 85 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), grad_clip_thresh)
|
| 86 |
+
self.optimizer.step()
|
| 87 |
+
self.scheduler.step()
|
| 88 |
+
self.optimizer.zero_grad()
|
| 89 |
+
self.batch_count += 1
|
| 90 |
+
|
| 91 |
+
# Update info for each step
|
| 92 |
+
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
|
| 93 |
+
epoch_sum_loss += loss
|
| 94 |
+
for key, value in train_losses.items():
|
| 95 |
+
if key not in epoch_losses.keys():
|
| 96 |
+
epoch_losses[key] = value
|
| 97 |
+
else:
|
| 98 |
+
epoch_losses[key] += value
|
| 99 |
+
|
| 100 |
+
self.accelerator.log(
|
| 101 |
+
{
|
| 102 |
+
"Step/Train Loss": loss,
|
| 103 |
+
"Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
|
| 104 |
+
},
|
| 105 |
+
step=self.step,
|
| 106 |
+
)
|
| 107 |
+
self.step += 1
|
| 108 |
+
epoch_step += 1
|
| 109 |
+
|
| 110 |
+
self.accelerator.wait_for_everyone()
|
| 111 |
+
|
| 112 |
+
epoch_sum_loss = (
|
| 113 |
+
epoch_sum_loss
|
| 114 |
+
/ len(self.train_dataloader)
|
| 115 |
+
* self.cfg.train.gradient_accumulation_step
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
for key in epoch_losses.keys():
|
| 119 |
+
epoch_losses[key] = (
|
| 120 |
+
epoch_losses[key]
|
| 121 |
+
/ len(self.train_dataloader)
|
| 122 |
+
* self.cfg.train.gradient_accumulation_step
|
| 123 |
+
)
|
| 124 |
+
return epoch_sum_loss, epoch_losses
|
| 125 |
+
|
| 126 |
+
def _train_step(self, data):
|
| 127 |
+
train_losses = {}
|
| 128 |
+
total_loss = 0
|
| 129 |
+
train_stats = {}
|
| 130 |
+
|
| 131 |
+
preds = self.model(data)
|
| 132 |
+
|
| 133 |
+
train_losses = self.criterion(data, preds)
|
| 134 |
+
|
| 135 |
+
total_loss = train_losses["loss"]
|
| 136 |
+
for key, value in train_losses.items():
|
| 137 |
+
train_losses[key] = value.item()
|
| 138 |
+
|
| 139 |
+
return total_loss, train_losses
|
| 140 |
+
|
| 141 |
+
@torch.no_grad()
|
| 142 |
+
def _valid_step(self, data):
|
| 143 |
+
valid_loss = {}
|
| 144 |
+
total_valid_loss = 0
|
| 145 |
+
valid_stats = {}
|
| 146 |
+
|
| 147 |
+
preds = self.model(data)
|
| 148 |
+
|
| 149 |
+
valid_losses = self.criterion(data, preds)
|
| 150 |
+
|
| 151 |
+
total_valid_loss = valid_losses["loss"]
|
| 152 |
+
for key, value in valid_losses.items():
|
| 153 |
+
valid_losses[key] = value.item()
|
| 154 |
+
|
| 155 |
+
return total_valid_loss, valid_losses, valid_stats
|
Amphion/models/tts/naturalspeech2/ns2.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from models.tts.naturalspeech2.diffusion import Diffusion
|
| 11 |
+
from models.tts.naturalspeech2.diffusion_flow import DiffusionFlow
|
| 12 |
+
from models.tts.naturalspeech2.wavenet import WaveNet
|
| 13 |
+
from models.tts.naturalspeech2.prior_encoder import PriorEncoder
|
| 14 |
+
from modules.naturalpseech2.transformers import TransformerEncoder
|
| 15 |
+
from encodec import EncodecModel
|
| 16 |
+
from einops import rearrange, repeat
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import json
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class NaturalSpeech2(nn.Module):
|
| 23 |
+
def __init__(self, cfg):
|
| 24 |
+
super().__init__()
|
| 25 |
+
self.cfg = cfg
|
| 26 |
+
|
| 27 |
+
self.latent_dim = cfg.latent_dim
|
| 28 |
+
self.query_emb_num = cfg.query_emb.query_token_num
|
| 29 |
+
|
| 30 |
+
self.prior_encoder = PriorEncoder(cfg.prior_encoder)
|
| 31 |
+
if cfg.diffusion.diffusion_type == "diffusion":
|
| 32 |
+
self.diffusion = Diffusion(cfg.diffusion)
|
| 33 |
+
elif cfg.diffusion.diffusion_type == "flow":
|
| 34 |
+
self.diffusion = DiffusionFlow(cfg.diffusion)
|
| 35 |
+
|
| 36 |
+
self.prompt_encoder = TransformerEncoder(cfg=cfg.prompt_encoder)
|
| 37 |
+
if self.latent_dim != cfg.prompt_encoder.encoder_hidden:
|
| 38 |
+
self.prompt_lin = nn.Linear(
|
| 39 |
+
self.latent_dim, cfg.prompt_encoder.encoder_hidden
|
| 40 |
+
)
|
| 41 |
+
self.prompt_lin.weight.data.normal_(0.0, 0.02)
|
| 42 |
+
else:
|
| 43 |
+
self.prompt_lin = None
|
| 44 |
+
|
| 45 |
+
self.query_emb = nn.Embedding(self.query_emb_num, cfg.query_emb.hidden_size)
|
| 46 |
+
self.query_attn = nn.MultiheadAttention(
|
| 47 |
+
cfg.query_emb.hidden_size, cfg.query_emb.head_num, batch_first=True
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
codec_model = EncodecModel.encodec_model_24khz()
|
| 51 |
+
codec_model.set_target_bandwidth(12.0)
|
| 52 |
+
codec_model.requires_grad_(False)
|
| 53 |
+
self.quantizer = codec_model.quantizer
|
| 54 |
+
|
| 55 |
+
@torch.no_grad()
|
| 56 |
+
def code_to_latent(self, code):
|
| 57 |
+
latent = self.quantizer.decode(code.transpose(0, 1))
|
| 58 |
+
return latent
|
| 59 |
+
|
| 60 |
+
def latent_to_code(self, latent, nq=16):
|
| 61 |
+
residual = latent
|
| 62 |
+
all_indices = []
|
| 63 |
+
all_dist = []
|
| 64 |
+
for i in range(nq):
|
| 65 |
+
layer = self.quantizer.vq.layers[i]
|
| 66 |
+
x = rearrange(residual, "b d n -> b n d")
|
| 67 |
+
x = layer.project_in(x)
|
| 68 |
+
shape = x.shape
|
| 69 |
+
x = layer._codebook.preprocess(x)
|
| 70 |
+
embed = layer._codebook.embed.t()
|
| 71 |
+
dist = -(
|
| 72 |
+
x.pow(2).sum(1, keepdim=True)
|
| 73 |
+
- 2 * x @ embed
|
| 74 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
| 75 |
+
)
|
| 76 |
+
indices = dist.max(dim=-1).indices
|
| 77 |
+
indices = layer._codebook.postprocess_emb(indices, shape)
|
| 78 |
+
dist = dist.reshape(*shape[:-1], dist.shape[-1])
|
| 79 |
+
quantized = layer.decode(indices)
|
| 80 |
+
residual = residual - quantized
|
| 81 |
+
all_indices.append(indices)
|
| 82 |
+
all_dist.append(dist)
|
| 83 |
+
|
| 84 |
+
out_indices = torch.stack(all_indices)
|
| 85 |
+
out_dist = torch.stack(all_dist)
|
| 86 |
+
|
| 87 |
+
return out_indices, out_dist # (nq, B, T); (nq, B, T, 1024)
|
| 88 |
+
|
| 89 |
+
@torch.no_grad()
|
| 90 |
+
def latent_to_latent(self, latent, nq=16):
|
| 91 |
+
codes, _ = self.latent_to_code(latent, nq)
|
| 92 |
+
latent = self.quantizer.vq.decode(codes)
|
| 93 |
+
return latent
|
| 94 |
+
|
| 95 |
+
def forward(
|
| 96 |
+
self,
|
| 97 |
+
code=None,
|
| 98 |
+
pitch=None,
|
| 99 |
+
duration=None,
|
| 100 |
+
phone_id=None,
|
| 101 |
+
phone_id_frame=None,
|
| 102 |
+
frame_nums=None,
|
| 103 |
+
ref_code=None,
|
| 104 |
+
ref_frame_nums=None,
|
| 105 |
+
phone_mask=None,
|
| 106 |
+
mask=None,
|
| 107 |
+
ref_mask=None,
|
| 108 |
+
):
|
| 109 |
+
ref_latent = self.code_to_latent(ref_code)
|
| 110 |
+
latent = self.code_to_latent(code)
|
| 111 |
+
|
| 112 |
+
if self.latent_dim is not None:
|
| 113 |
+
ref_latent = self.prompt_lin(ref_latent.transpose(1, 2))
|
| 114 |
+
|
| 115 |
+
ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None)
|
| 116 |
+
spk_emb = ref_latent.transpose(1, 2) # (B, d, T')
|
| 117 |
+
|
| 118 |
+
spk_query_emb = self.query_emb(
|
| 119 |
+
torch.arange(self.query_emb_num).to(latent.device)
|
| 120 |
+
).repeat(
|
| 121 |
+
latent.shape[0], 1, 1
|
| 122 |
+
) # (B, query_emb_num, d)
|
| 123 |
+
spk_query_emb, _ = self.query_attn(
|
| 124 |
+
spk_query_emb,
|
| 125 |
+
spk_emb.transpose(1, 2),
|
| 126 |
+
spk_emb.transpose(1, 2),
|
| 127 |
+
key_padding_mask=~(ref_mask.bool()),
|
| 128 |
+
) # (B, query_emb_num, d)
|
| 129 |
+
|
| 130 |
+
prior_out = self.prior_encoder(
|
| 131 |
+
phone_id=phone_id,
|
| 132 |
+
duration=duration,
|
| 133 |
+
pitch=pitch,
|
| 134 |
+
phone_mask=phone_mask,
|
| 135 |
+
mask=mask,
|
| 136 |
+
ref_emb=spk_emb,
|
| 137 |
+
ref_mask=ref_mask,
|
| 138 |
+
is_inference=False,
|
| 139 |
+
)
|
| 140 |
+
prior_condition = prior_out["prior_out"] # (B, T, d)
|
| 141 |
+
|
| 142 |
+
diff_out = self.diffusion(latent, mask, prior_condition, spk_query_emb)
|
| 143 |
+
|
| 144 |
+
return diff_out, prior_out
|
| 145 |
+
|
| 146 |
+
@torch.no_grad()
|
| 147 |
+
def inference(
|
| 148 |
+
self, ref_code=None, phone_id=None, ref_mask=None, inference_steps=1000
|
| 149 |
+
):
|
| 150 |
+
ref_latent = self.code_to_latent(ref_code)
|
| 151 |
+
|
| 152 |
+
if self.latent_dim is not None:
|
| 153 |
+
ref_latent = self.prompt_lin(ref_latent.transpose(1, 2))
|
| 154 |
+
|
| 155 |
+
ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None)
|
| 156 |
+
spk_emb = ref_latent.transpose(1, 2) # (B, d, T')
|
| 157 |
+
|
| 158 |
+
spk_query_emb = self.query_emb(
|
| 159 |
+
torch.arange(self.query_emb_num).to(ref_latent.device)
|
| 160 |
+
).repeat(
|
| 161 |
+
ref_latent.shape[0], 1, 1
|
| 162 |
+
) # (B, query_emb_num, d)
|
| 163 |
+
spk_query_emb, _ = self.query_attn(
|
| 164 |
+
spk_query_emb,
|
| 165 |
+
spk_emb.transpose(1, 2),
|
| 166 |
+
spk_emb.transpose(1, 2),
|
| 167 |
+
key_padding_mask=~(ref_mask.bool()),
|
| 168 |
+
) # (B, query_emb_num, d)
|
| 169 |
+
|
| 170 |
+
prior_out = self.prior_encoder(
|
| 171 |
+
phone_id=phone_id,
|
| 172 |
+
duration=None,
|
| 173 |
+
pitch=None,
|
| 174 |
+
phone_mask=None,
|
| 175 |
+
mask=None,
|
| 176 |
+
ref_emb=spk_emb,
|
| 177 |
+
ref_mask=ref_mask,
|
| 178 |
+
is_inference=True,
|
| 179 |
+
)
|
| 180 |
+
prior_condition = prior_out["prior_out"] # (B, T, d)
|
| 181 |
+
|
| 182 |
+
z = torch.randn(
|
| 183 |
+
prior_condition.shape[0], self.latent_dim, prior_condition.shape[1]
|
| 184 |
+
).to(ref_latent.device) / (1.20)
|
| 185 |
+
x0 = self.diffusion.reverse_diffusion(
|
| 186 |
+
z, None, prior_condition, inference_steps, spk_query_emb
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
return x0, prior_out
|
| 190 |
+
|
| 191 |
+
@torch.no_grad()
|
| 192 |
+
def reverse_diffusion_from_t(
|
| 193 |
+
self,
|
| 194 |
+
code=None,
|
| 195 |
+
pitch=None,
|
| 196 |
+
duration=None,
|
| 197 |
+
phone_id=None,
|
| 198 |
+
ref_code=None,
|
| 199 |
+
phone_mask=None,
|
| 200 |
+
mask=None,
|
| 201 |
+
ref_mask=None,
|
| 202 |
+
n_timesteps=None,
|
| 203 |
+
t=None,
|
| 204 |
+
):
|
| 205 |
+
# o Only for debug
|
| 206 |
+
|
| 207 |
+
ref_latent = self.code_to_latent(ref_code)
|
| 208 |
+
latent = self.code_to_latent(code)
|
| 209 |
+
|
| 210 |
+
if self.latent_dim is not None:
|
| 211 |
+
ref_latent = self.prompt_lin(ref_latent.transpose(1, 2))
|
| 212 |
+
|
| 213 |
+
ref_latent = self.prompt_encoder(ref_latent, ref_mask, condition=None)
|
| 214 |
+
spk_emb = ref_latent.transpose(1, 2) # (B, d, T')
|
| 215 |
+
|
| 216 |
+
spk_query_emb = self.query_emb(
|
| 217 |
+
torch.arange(self.query_emb_num).to(latent.device)
|
| 218 |
+
).repeat(
|
| 219 |
+
latent.shape[0], 1, 1
|
| 220 |
+
) # (B, query_emb_num, d)
|
| 221 |
+
spk_query_emb, _ = self.query_attn(
|
| 222 |
+
spk_query_emb,
|
| 223 |
+
spk_emb.transpose(1, 2),
|
| 224 |
+
spk_emb.transpose(1, 2),
|
| 225 |
+
key_padding_mask=~(ref_mask.bool()),
|
| 226 |
+
) # (B, query_emb_num, d)
|
| 227 |
+
|
| 228 |
+
prior_out = self.prior_encoder(
|
| 229 |
+
phone_id=phone_id,
|
| 230 |
+
duration=duration,
|
| 231 |
+
pitch=pitch,
|
| 232 |
+
phone_mask=phone_mask,
|
| 233 |
+
mask=mask,
|
| 234 |
+
ref_emb=spk_emb,
|
| 235 |
+
ref_mask=ref_mask,
|
| 236 |
+
is_inference=False,
|
| 237 |
+
)
|
| 238 |
+
prior_condition = prior_out["prior_out"] # (B, T, d)
|
| 239 |
+
|
| 240 |
+
diffusion_step = (
|
| 241 |
+
torch.ones(
|
| 242 |
+
latent.shape[0],
|
| 243 |
+
dtype=latent.dtype,
|
| 244 |
+
device=latent.device,
|
| 245 |
+
requires_grad=False,
|
| 246 |
+
)
|
| 247 |
+
* t
|
| 248 |
+
)
|
| 249 |
+
diffusion_step = torch.clamp(diffusion_step, 1e-5, 1.0 - 1e-5)
|
| 250 |
+
xt, _ = self.diffusion.forward_diffusion(
|
| 251 |
+
x0=latent, diffusion_step=diffusion_step
|
| 252 |
+
)
|
| 253 |
+
# print(torch.abs(xt-latent).max(), torch.abs(xt-latent).mean(), torch.abs(xt-latent).std())
|
| 254 |
+
|
| 255 |
+
x0 = self.diffusion.reverse_diffusion_from_t(
|
| 256 |
+
xt, mask, prior_condition, n_timesteps, spk_query_emb, t_start=t
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return x0, prior_out, xt
|
Amphion/models/tts/naturalspeech2/ns2_dataset.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
import torch
|
| 8 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 9 |
+
from utils.data_utils import *
|
| 10 |
+
from processors.acoustic_extractor import cal_normalized_mel
|
| 11 |
+
from processors.acoustic_extractor import load_normalized
|
| 12 |
+
from models.base.base_dataset import (
|
| 13 |
+
BaseOfflineCollator,
|
| 14 |
+
BaseOfflineDataset,
|
| 15 |
+
BaseTestDataset,
|
| 16 |
+
BaseTestCollator,
|
| 17 |
+
)
|
| 18 |
+
from text import text_to_sequence
|
| 19 |
+
from text.cmudict import valid_symbols
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
import pickle
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class NS2Dataset(torch.utils.data.Dataset):
|
| 25 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
| 26 |
+
assert isinstance(dataset, str)
|
| 27 |
+
|
| 28 |
+
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
|
| 29 |
+
|
| 30 |
+
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
|
| 31 |
+
# train.json
|
| 32 |
+
|
| 33 |
+
self.metafile_path = os.path.join(processed_data_dir, meta_file)
|
| 34 |
+
|
| 35 |
+
self.metadata = self.get_metadata()
|
| 36 |
+
|
| 37 |
+
self.cfg = cfg
|
| 38 |
+
|
| 39 |
+
assert cfg.preprocess.use_mel == False
|
| 40 |
+
if cfg.preprocess.use_mel:
|
| 41 |
+
self.utt2melspec_path = {}
|
| 42 |
+
for utt_info in self.metadata:
|
| 43 |
+
dataset = utt_info["Dataset"]
|
| 44 |
+
uid = utt_info["Uid"]
|
| 45 |
+
utt = "{}_{}".format(dataset, uid)
|
| 46 |
+
|
| 47 |
+
self.utt2melspec_path[utt] = os.path.join(
|
| 48 |
+
cfg.preprocess.processed_dir,
|
| 49 |
+
dataset,
|
| 50 |
+
cfg.preprocess.melspec_dir, # mel
|
| 51 |
+
utt_info["speaker"],
|
| 52 |
+
uid + ".npy",
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
assert cfg.preprocess.use_code == True
|
| 56 |
+
if cfg.preprocess.use_code:
|
| 57 |
+
self.utt2code_path = {}
|
| 58 |
+
for utt_info in self.metadata:
|
| 59 |
+
dataset = utt_info["Dataset"]
|
| 60 |
+
uid = utt_info["Uid"]
|
| 61 |
+
utt = "{}_{}".format(dataset, uid)
|
| 62 |
+
|
| 63 |
+
self.utt2code_path[utt] = os.path.join(
|
| 64 |
+
cfg.preprocess.processed_dir,
|
| 65 |
+
dataset,
|
| 66 |
+
cfg.preprocess.code_dir, # code
|
| 67 |
+
utt_info["speaker"],
|
| 68 |
+
uid + ".npy",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
assert cfg.preprocess.use_spkid == True
|
| 72 |
+
if cfg.preprocess.use_spkid:
|
| 73 |
+
self.utt2spkid = {}
|
| 74 |
+
for utt_info in self.metadata:
|
| 75 |
+
dataset = utt_info["Dataset"]
|
| 76 |
+
uid = utt_info["Uid"]
|
| 77 |
+
utt = "{}_{}".format(dataset, uid)
|
| 78 |
+
|
| 79 |
+
self.utt2spkid[utt] = utt_info["speaker"]
|
| 80 |
+
|
| 81 |
+
assert cfg.preprocess.use_pitch == True
|
| 82 |
+
if cfg.preprocess.use_pitch:
|
| 83 |
+
self.utt2pitch_path = {}
|
| 84 |
+
for utt_info in self.metadata:
|
| 85 |
+
dataset = utt_info["Dataset"]
|
| 86 |
+
uid = utt_info["Uid"]
|
| 87 |
+
utt = "{}_{}".format(dataset, uid)
|
| 88 |
+
|
| 89 |
+
self.utt2pitch_path[utt] = os.path.join(
|
| 90 |
+
cfg.preprocess.processed_dir,
|
| 91 |
+
dataset,
|
| 92 |
+
cfg.preprocess.pitch_dir, # pitch
|
| 93 |
+
utt_info["speaker"],
|
| 94 |
+
uid + ".npy",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
assert cfg.preprocess.use_duration == True
|
| 98 |
+
if cfg.preprocess.use_duration:
|
| 99 |
+
self.utt2duration_path = {}
|
| 100 |
+
for utt_info in self.metadata:
|
| 101 |
+
dataset = utt_info["Dataset"]
|
| 102 |
+
uid = utt_info["Uid"]
|
| 103 |
+
utt = "{}_{}".format(dataset, uid)
|
| 104 |
+
|
| 105 |
+
self.utt2duration_path[utt] = os.path.join(
|
| 106 |
+
cfg.preprocess.processed_dir,
|
| 107 |
+
dataset,
|
| 108 |
+
cfg.preprocess.duration_dir, # duration
|
| 109 |
+
utt_info["speaker"],
|
| 110 |
+
uid + ".npy",
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
assert cfg.preprocess.use_phone == True
|
| 114 |
+
if cfg.preprocess.use_phone:
|
| 115 |
+
self.utt2phone = {}
|
| 116 |
+
for utt_info in self.metadata:
|
| 117 |
+
dataset = utt_info["Dataset"]
|
| 118 |
+
uid = utt_info["Uid"]
|
| 119 |
+
utt = "{}_{}".format(dataset, uid)
|
| 120 |
+
|
| 121 |
+
self.utt2phone[utt] = utt_info["phones"]
|
| 122 |
+
|
| 123 |
+
assert cfg.preprocess.use_len == True
|
| 124 |
+
if cfg.preprocess.use_len:
|
| 125 |
+
self.utt2len = {}
|
| 126 |
+
for utt_info in self.metadata:
|
| 127 |
+
dataset = utt_info["Dataset"]
|
| 128 |
+
uid = utt_info["Uid"]
|
| 129 |
+
utt = "{}_{}".format(dataset, uid)
|
| 130 |
+
|
| 131 |
+
self.utt2len[utt] = utt_info["num_frames"]
|
| 132 |
+
|
| 133 |
+
# for cross reference
|
| 134 |
+
if cfg.preprocess.use_cross_reference:
|
| 135 |
+
self.spkid2utt = {}
|
| 136 |
+
for utt_info in self.metadata:
|
| 137 |
+
dataset = utt_info["Dataset"]
|
| 138 |
+
uid = utt_info["Uid"]
|
| 139 |
+
utt = "{}_{}".format(dataset, uid)
|
| 140 |
+
spkid = utt_info["speaker"]
|
| 141 |
+
if spkid not in self.spkid2utt:
|
| 142 |
+
self.spkid2utt[spkid] = []
|
| 143 |
+
self.spkid2utt[spkid].append(utt)
|
| 144 |
+
|
| 145 |
+
# get phone to id / id to phone map
|
| 146 |
+
self.phone2id, self.id2phone = self.get_phone_map()
|
| 147 |
+
|
| 148 |
+
self.all_num_frames = []
|
| 149 |
+
for i in range(len(self.metadata)):
|
| 150 |
+
self.all_num_frames.append(self.metadata[i]["num_frames"])
|
| 151 |
+
self.num_frame_sorted = np.array(sorted(self.all_num_frames))
|
| 152 |
+
self.num_frame_indices = np.array(
|
| 153 |
+
sorted(
|
| 154 |
+
range(len(self.all_num_frames)), key=lambda k: self.all_num_frames[k]
|
| 155 |
+
)
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def __len__(self):
|
| 159 |
+
return len(self.metadata)
|
| 160 |
+
|
| 161 |
+
def get_dataset_name(self):
|
| 162 |
+
return self.metadata[0]["Dataset"]
|
| 163 |
+
|
| 164 |
+
def get_metadata(self):
|
| 165 |
+
with open(self.metafile_path, "r", encoding="utf-8") as f:
|
| 166 |
+
metadata = json.load(f)
|
| 167 |
+
|
| 168 |
+
print("metadata len: ", len(metadata))
|
| 169 |
+
|
| 170 |
+
return metadata
|
| 171 |
+
|
| 172 |
+
def get_phone_map(self):
|
| 173 |
+
symbols = valid_symbols + ["sp", "spn", "sil"] + ["<s>", "</s>"]
|
| 174 |
+
phone2id = {s: i for i, s in enumerate(symbols)}
|
| 175 |
+
id2phone = {i: s for s, i in phone2id.items()}
|
| 176 |
+
return phone2id, id2phone
|
| 177 |
+
|
| 178 |
+
def __getitem__(self, index):
|
| 179 |
+
utt_info = self.metadata[index]
|
| 180 |
+
|
| 181 |
+
dataset = utt_info["Dataset"]
|
| 182 |
+
uid = utt_info["Uid"]
|
| 183 |
+
utt = "{}_{}".format(dataset, uid)
|
| 184 |
+
|
| 185 |
+
single_feature = dict()
|
| 186 |
+
|
| 187 |
+
if self.cfg.preprocess.read_metadata:
|
| 188 |
+
metadata_uid_path = os.path.join(
|
| 189 |
+
self.cfg.preprocess.processed_dir,
|
| 190 |
+
self.cfg.preprocess.metadata_dir,
|
| 191 |
+
dataset,
|
| 192 |
+
# utt_info["speaker"],
|
| 193 |
+
uid + ".pkl",
|
| 194 |
+
)
|
| 195 |
+
with open(metadata_uid_path, "rb") as f:
|
| 196 |
+
metadata_uid = pickle.load(f)
|
| 197 |
+
# code
|
| 198 |
+
code = metadata_uid["code"]
|
| 199 |
+
# frame_nums
|
| 200 |
+
frame_nums = code.shape[1]
|
| 201 |
+
# pitch
|
| 202 |
+
pitch = metadata_uid["pitch"]
|
| 203 |
+
# duration
|
| 204 |
+
duration = metadata_uid["duration"]
|
| 205 |
+
# phone_id
|
| 206 |
+
phone_id = np.array(
|
| 207 |
+
[
|
| 208 |
+
*map(
|
| 209 |
+
self.phone2id.get,
|
| 210 |
+
self.utt2phone[utt].replace("{", "").replace("}", "").split(),
|
| 211 |
+
)
|
| 212 |
+
]
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
else:
|
| 216 |
+
# code
|
| 217 |
+
code = np.load(self.utt2code_path[utt])
|
| 218 |
+
# frame_nums
|
| 219 |
+
frame_nums = code.shape[1]
|
| 220 |
+
# pitch
|
| 221 |
+
pitch = np.load(self.utt2pitch_path[utt])
|
| 222 |
+
# duration
|
| 223 |
+
duration = np.load(self.utt2duration_path[utt])
|
| 224 |
+
# phone_id
|
| 225 |
+
phone_id = np.array(
|
| 226 |
+
[
|
| 227 |
+
*map(
|
| 228 |
+
self.phone2id.get,
|
| 229 |
+
self.utt2phone[utt].replace("{", "").replace("}", "").split(),
|
| 230 |
+
)
|
| 231 |
+
]
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# align length
|
| 235 |
+
code, pitch, duration, phone_id, frame_nums = self.align_length(
|
| 236 |
+
code, pitch, duration, phone_id, frame_nums
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# spkid
|
| 240 |
+
spkid = self.utt2spkid[utt]
|
| 241 |
+
|
| 242 |
+
# get target and reference
|
| 243 |
+
out = self.get_target_and_reference(code, pitch, duration, phone_id, frame_nums)
|
| 244 |
+
code, ref_code = out["code"], out["ref_code"]
|
| 245 |
+
pitch, ref_pitch = out["pitch"], out["ref_pitch"]
|
| 246 |
+
duration, ref_duration = out["duration"], out["ref_duration"]
|
| 247 |
+
phone_id, ref_phone_id = out["phone_id"], out["ref_phone_id"]
|
| 248 |
+
frame_nums, ref_frame_nums = out["frame_nums"], out["ref_frame_nums"]
|
| 249 |
+
|
| 250 |
+
# phone_id_frame
|
| 251 |
+
assert len(phone_id) == len(duration)
|
| 252 |
+
phone_id_frame = []
|
| 253 |
+
for i in range(len(phone_id)):
|
| 254 |
+
phone_id_frame.extend([phone_id[i] for _ in range(duration[i])])
|
| 255 |
+
phone_id_frame = np.array(phone_id_frame)
|
| 256 |
+
|
| 257 |
+
# ref_phone_id_frame
|
| 258 |
+
assert len(ref_phone_id) == len(ref_duration)
|
| 259 |
+
ref_phone_id_frame = []
|
| 260 |
+
for i in range(len(ref_phone_id)):
|
| 261 |
+
ref_phone_id_frame.extend([ref_phone_id[i] for _ in range(ref_duration[i])])
|
| 262 |
+
ref_phone_id_frame = np.array(ref_phone_id_frame)
|
| 263 |
+
|
| 264 |
+
single_feature.update(
|
| 265 |
+
{
|
| 266 |
+
"code": code,
|
| 267 |
+
"frame_nums": frame_nums,
|
| 268 |
+
"pitch": pitch,
|
| 269 |
+
"duration": duration,
|
| 270 |
+
"phone_id": phone_id,
|
| 271 |
+
"phone_id_frame": phone_id_frame,
|
| 272 |
+
"ref_code": ref_code,
|
| 273 |
+
"ref_frame_nums": ref_frame_nums,
|
| 274 |
+
"ref_pitch": ref_pitch,
|
| 275 |
+
"ref_duration": ref_duration,
|
| 276 |
+
"ref_phone_id": ref_phone_id,
|
| 277 |
+
"ref_phone_id_frame": ref_phone_id_frame,
|
| 278 |
+
"spkid": spkid,
|
| 279 |
+
}
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
return single_feature
|
| 283 |
+
|
| 284 |
+
def get_num_frames(self, index):
|
| 285 |
+
utt_info = self.metadata[index]
|
| 286 |
+
return utt_info["num_frames"]
|
| 287 |
+
|
| 288 |
+
def align_length(self, code, pitch, duration, phone_id, frame_nums):
|
| 289 |
+
# aligh lenght of code, pitch, duration, phone_id, and frame nums
|
| 290 |
+
code_len = code.shape[1]
|
| 291 |
+
pitch_len = len(pitch)
|
| 292 |
+
dur_sum = sum(duration)
|
| 293 |
+
min_len = min(code_len, dur_sum)
|
| 294 |
+
code = code[:, :min_len]
|
| 295 |
+
if pitch_len >= min_len:
|
| 296 |
+
pitch = pitch[:min_len]
|
| 297 |
+
else:
|
| 298 |
+
pitch = np.pad(pitch, (0, min_len - pitch_len), mode="edge")
|
| 299 |
+
frame_nums = min_len
|
| 300 |
+
if dur_sum > min_len:
|
| 301 |
+
assert (duration[-1] - (dur_sum - min_len)) >= 0
|
| 302 |
+
duration[-1] = duration[-1] - (dur_sum - min_len)
|
| 303 |
+
assert duration[-1] >= 0
|
| 304 |
+
|
| 305 |
+
return code, pitch, duration, phone_id, frame_nums
|
| 306 |
+
|
| 307 |
+
def get_target_and_reference(self, code, pitch, duration, phone_id, frame_nums):
|
| 308 |
+
phone_nums = len(phone_id)
|
| 309 |
+
clip_phone_nums = np.random.randint(
|
| 310 |
+
int(phone_nums * 0.1), int(phone_nums * 0.5) + 1
|
| 311 |
+
)
|
| 312 |
+
clip_phone_nums = max(clip_phone_nums, 1)
|
| 313 |
+
assert clip_phone_nums < phone_nums and clip_phone_nums >= 1
|
| 314 |
+
if self.cfg.preprocess.clip_mode == "mid":
|
| 315 |
+
start_idx = np.random.randint(0, phone_nums - clip_phone_nums)
|
| 316 |
+
elif self.cfg.preprocess.clip_mode == "start":
|
| 317 |
+
if duration[0] == 0 and clip_phone_nums == 1:
|
| 318 |
+
start_idx = 1
|
| 319 |
+
else:
|
| 320 |
+
start_idx = 0
|
| 321 |
+
else:
|
| 322 |
+
assert self.cfg.preprocess.clip_mode in ["mid", "start"]
|
| 323 |
+
end_idx = start_idx + clip_phone_nums
|
| 324 |
+
start_frames = sum(duration[:start_idx])
|
| 325 |
+
end_frames = sum(duration[:end_idx])
|
| 326 |
+
|
| 327 |
+
new_code = np.concatenate(
|
| 328 |
+
(code[:, :start_frames], code[:, end_frames:]), axis=1
|
| 329 |
+
)
|
| 330 |
+
ref_code = code[:, start_frames:end_frames]
|
| 331 |
+
|
| 332 |
+
new_pitch = np.append(pitch[:start_frames], pitch[end_frames:])
|
| 333 |
+
ref_pitch = pitch[start_frames:end_frames]
|
| 334 |
+
|
| 335 |
+
new_duration = np.append(duration[:start_idx], duration[end_idx:])
|
| 336 |
+
ref_duration = duration[start_idx:end_idx]
|
| 337 |
+
|
| 338 |
+
new_phone_id = np.append(phone_id[:start_idx], phone_id[end_idx:])
|
| 339 |
+
ref_phone_id = phone_id[start_idx:end_idx]
|
| 340 |
+
|
| 341 |
+
new_frame_nums = frame_nums - (end_frames - start_frames)
|
| 342 |
+
ref_frame_nums = end_frames - start_frames
|
| 343 |
+
|
| 344 |
+
return {
|
| 345 |
+
"code": new_code,
|
| 346 |
+
"ref_code": ref_code,
|
| 347 |
+
"pitch": new_pitch,
|
| 348 |
+
"ref_pitch": ref_pitch,
|
| 349 |
+
"duration": new_duration,
|
| 350 |
+
"ref_duration": ref_duration,
|
| 351 |
+
"phone_id": new_phone_id,
|
| 352 |
+
"ref_phone_id": ref_phone_id,
|
| 353 |
+
"frame_nums": new_frame_nums,
|
| 354 |
+
"ref_frame_nums": ref_frame_nums,
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
class NS2Collator(BaseOfflineCollator):
|
| 359 |
+
def __init__(self, cfg):
|
| 360 |
+
BaseOfflineCollator.__init__(self, cfg)
|
| 361 |
+
|
| 362 |
+
def __call__(self, batch):
|
| 363 |
+
packed_batch_features = dict()
|
| 364 |
+
|
| 365 |
+
# code: (B, 16, T)
|
| 366 |
+
# frame_nums: (B,) not used
|
| 367 |
+
# pitch: (B, T)
|
| 368 |
+
# duration: (B, N)
|
| 369 |
+
# phone_id: (B, N)
|
| 370 |
+
# phone_id_frame: (B, T)
|
| 371 |
+
# ref_code: (B, 16, T')
|
| 372 |
+
# ref_frame_nums: (B,) not used
|
| 373 |
+
# ref_pitch: (B, T) not used
|
| 374 |
+
# ref_duration: (B, N') not used
|
| 375 |
+
# ref_phone_id: (B, N') not used
|
| 376 |
+
# ref_phone_frame: (B, T') not used
|
| 377 |
+
# spkid: (B,) not used
|
| 378 |
+
# phone_mask: (B, N)
|
| 379 |
+
# mask: (B, T)
|
| 380 |
+
# ref_mask: (B, T')
|
| 381 |
+
|
| 382 |
+
for key in batch[0].keys():
|
| 383 |
+
if key == "phone_id":
|
| 384 |
+
phone_ids = [torch.LongTensor(b["phone_id"]) for b in batch]
|
| 385 |
+
phone_masks = [torch.ones(len(b["phone_id"])) for b in batch]
|
| 386 |
+
packed_batch_features["phone_id"] = pad_sequence(
|
| 387 |
+
phone_ids,
|
| 388 |
+
batch_first=True,
|
| 389 |
+
padding_value=0,
|
| 390 |
+
)
|
| 391 |
+
packed_batch_features["phone_mask"] = pad_sequence(
|
| 392 |
+
phone_masks,
|
| 393 |
+
batch_first=True,
|
| 394 |
+
padding_value=0,
|
| 395 |
+
)
|
| 396 |
+
elif key == "phone_id_frame":
|
| 397 |
+
phone_id_frames = [torch.LongTensor(b["phone_id_frame"]) for b in batch]
|
| 398 |
+
masks = [torch.ones(len(b["phone_id_frame"])) for b in batch]
|
| 399 |
+
packed_batch_features["phone_id_frame"] = pad_sequence(
|
| 400 |
+
phone_id_frames,
|
| 401 |
+
batch_first=True,
|
| 402 |
+
padding_value=0,
|
| 403 |
+
)
|
| 404 |
+
packed_batch_features["mask"] = pad_sequence(
|
| 405 |
+
masks,
|
| 406 |
+
batch_first=True,
|
| 407 |
+
padding_value=0,
|
| 408 |
+
)
|
| 409 |
+
elif key == "ref_code":
|
| 410 |
+
ref_codes = [
|
| 411 |
+
torch.from_numpy(b["ref_code"]).transpose(0, 1) for b in batch
|
| 412 |
+
]
|
| 413 |
+
ref_masks = [torch.ones(max(b["ref_code"].shape[1], 1)) for b in batch]
|
| 414 |
+
packed_batch_features["ref_code"] = pad_sequence(
|
| 415 |
+
ref_codes,
|
| 416 |
+
batch_first=True,
|
| 417 |
+
padding_value=0,
|
| 418 |
+
).transpose(1, 2)
|
| 419 |
+
packed_batch_features["ref_mask"] = pad_sequence(
|
| 420 |
+
ref_masks,
|
| 421 |
+
batch_first=True,
|
| 422 |
+
padding_value=0,
|
| 423 |
+
)
|
| 424 |
+
elif key == "code":
|
| 425 |
+
codes = [torch.from_numpy(b["code"]).transpose(0, 1) for b in batch]
|
| 426 |
+
masks = [torch.ones(max(b["code"].shape[1], 1)) for b in batch]
|
| 427 |
+
packed_batch_features["code"] = pad_sequence(
|
| 428 |
+
codes,
|
| 429 |
+
batch_first=True,
|
| 430 |
+
padding_value=0,
|
| 431 |
+
).transpose(1, 2)
|
| 432 |
+
packed_batch_features["mask"] = pad_sequence(
|
| 433 |
+
masks,
|
| 434 |
+
batch_first=True,
|
| 435 |
+
padding_value=0,
|
| 436 |
+
)
|
| 437 |
+
elif key == "pitch":
|
| 438 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
| 439 |
+
packed_batch_features[key] = pad_sequence(
|
| 440 |
+
values, batch_first=True, padding_value=50.0
|
| 441 |
+
)
|
| 442 |
+
elif key == "duration":
|
| 443 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
| 444 |
+
packed_batch_features[key] = pad_sequence(
|
| 445 |
+
values, batch_first=True, padding_value=0
|
| 446 |
+
)
|
| 447 |
+
elif key == "frame_nums":
|
| 448 |
+
packed_batch_features["frame_nums"] = torch.LongTensor(
|
| 449 |
+
[b["frame_nums"] for b in batch]
|
| 450 |
+
)
|
| 451 |
+
elif key == "ref_frame_nums":
|
| 452 |
+
packed_batch_features["ref_frame_nums"] = torch.LongTensor(
|
| 453 |
+
[b["ref_frame_nums"] for b in batch]
|
| 454 |
+
)
|
| 455 |
+
else:
|
| 456 |
+
pass
|
| 457 |
+
|
| 458 |
+
return packed_batch_features
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
|
| 462 |
+
if len(batch) == 0:
|
| 463 |
+
return 0
|
| 464 |
+
if len(batch) == max_sentences:
|
| 465 |
+
return 1
|
| 466 |
+
if num_tokens > max_tokens:
|
| 467 |
+
return 1
|
| 468 |
+
return 0
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def batch_by_size(
|
| 472 |
+
indices,
|
| 473 |
+
num_tokens_fn,
|
| 474 |
+
max_tokens=None,
|
| 475 |
+
max_sentences=None,
|
| 476 |
+
required_batch_size_multiple=1,
|
| 477 |
+
):
|
| 478 |
+
"""
|
| 479 |
+
Yield mini-batches of indices bucketed by size. Batches may contain
|
| 480 |
+
sequences of different lengths.
|
| 481 |
+
|
| 482 |
+
Args:
|
| 483 |
+
indices (List[int]): ordered list of dataset indices
|
| 484 |
+
num_tokens_fn (callable): function that returns the number of tokens at
|
| 485 |
+
a given index
|
| 486 |
+
max_tokens (int, optional): max number of tokens in each batch
|
| 487 |
+
(default: None).
|
| 488 |
+
max_sentences (int, optional): max number of sentences in each
|
| 489 |
+
batch (default: None).
|
| 490 |
+
required_batch_size_multiple (int, optional): require batch size to
|
| 491 |
+
be a multiple of N (default: 1).
|
| 492 |
+
"""
|
| 493 |
+
bsz_mult = required_batch_size_multiple
|
| 494 |
+
|
| 495 |
+
sample_len = 0
|
| 496 |
+
sample_lens = []
|
| 497 |
+
batch = []
|
| 498 |
+
batches = []
|
| 499 |
+
for i in range(len(indices)):
|
| 500 |
+
idx = indices[i]
|
| 501 |
+
num_tokens = num_tokens_fn(idx)
|
| 502 |
+
sample_lens.append(num_tokens)
|
| 503 |
+
sample_len = max(sample_len, num_tokens)
|
| 504 |
+
|
| 505 |
+
assert (
|
| 506 |
+
sample_len <= max_tokens
|
| 507 |
+
), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format(
|
| 508 |
+
idx, sample_len, max_tokens
|
| 509 |
+
)
|
| 510 |
+
num_tokens = (len(batch) + 1) * sample_len
|
| 511 |
+
|
| 512 |
+
if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
|
| 513 |
+
mod_len = max(
|
| 514 |
+
bsz_mult * (len(batch) // bsz_mult),
|
| 515 |
+
len(batch) % bsz_mult,
|
| 516 |
+
)
|
| 517 |
+
batches.append(batch[:mod_len])
|
| 518 |
+
batch = batch[mod_len:]
|
| 519 |
+
sample_lens = sample_lens[mod_len:]
|
| 520 |
+
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
|
| 521 |
+
batch.append(idx)
|
| 522 |
+
if len(batch) > 0:
|
| 523 |
+
batches.append(batch)
|
| 524 |
+
return batches
|
Amphion/models/tts/naturalspeech2/ns2_inference.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import os
|
| 8 |
+
import torch
|
| 9 |
+
import soundfile as sf
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from models.tts.naturalspeech2.ns2 import NaturalSpeech2
|
| 13 |
+
from encodec import EncodecModel
|
| 14 |
+
from encodec.utils import convert_audio
|
| 15 |
+
from utils.util import load_config
|
| 16 |
+
|
| 17 |
+
from text import text_to_sequence
|
| 18 |
+
from text.cmudict import valid_symbols
|
| 19 |
+
from text.g2p import preprocess_english, read_lexicon
|
| 20 |
+
|
| 21 |
+
import torchaudio
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class NS2Inference:
|
| 25 |
+
def __init__(self, args, cfg):
|
| 26 |
+
self.cfg = cfg
|
| 27 |
+
self.args = args
|
| 28 |
+
|
| 29 |
+
self.model = self.build_model()
|
| 30 |
+
self.codec = self.build_codec()
|
| 31 |
+
|
| 32 |
+
self.symbols = valid_symbols + ["sp", "spn", "sil"] + ["<s>", "</s>"]
|
| 33 |
+
self.phone2id = {s: i for i, s in enumerate(self.symbols)}
|
| 34 |
+
self.id2phone = {i: s for s, i in self.phone2id.items()}
|
| 35 |
+
|
| 36 |
+
def build_model(self):
|
| 37 |
+
model = NaturalSpeech2(self.cfg.model)
|
| 38 |
+
model.load_state_dict(
|
| 39 |
+
torch.load(
|
| 40 |
+
os.path.join(self.args.checkpoint_path, "pytorch_model.bin"),
|
| 41 |
+
map_location="cpu",
|
| 42 |
+
)
|
| 43 |
+
)
|
| 44 |
+
model = model.to(self.args.device)
|
| 45 |
+
return model
|
| 46 |
+
|
| 47 |
+
def build_codec(self):
|
| 48 |
+
encodec_model = EncodecModel.encodec_model_24khz()
|
| 49 |
+
encodec_model = encodec_model.to(device=self.args.device)
|
| 50 |
+
encodec_model.set_target_bandwidth(12.0)
|
| 51 |
+
return encodec_model
|
| 52 |
+
|
| 53 |
+
def get_ref_code(self):
|
| 54 |
+
ref_wav_path = self.args.ref_audio
|
| 55 |
+
ref_wav, sr = torchaudio.load(ref_wav_path)
|
| 56 |
+
ref_wav = convert_audio(
|
| 57 |
+
ref_wav, sr, self.codec.sample_rate, self.codec.channels
|
| 58 |
+
)
|
| 59 |
+
ref_wav = ref_wav.unsqueeze(0).to(device=self.args.device)
|
| 60 |
+
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
encoded_frames = self.codec.encode(ref_wav)
|
| 63 |
+
ref_code = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
|
| 64 |
+
# print(ref_code.shape)
|
| 65 |
+
|
| 66 |
+
ref_mask = torch.ones(ref_code.shape[0], ref_code.shape[-1]).to(ref_code.device)
|
| 67 |
+
# print(ref_mask.shape)
|
| 68 |
+
|
| 69 |
+
return ref_code, ref_mask
|
| 70 |
+
|
| 71 |
+
def inference(self):
|
| 72 |
+
ref_code, ref_mask = self.get_ref_code()
|
| 73 |
+
|
| 74 |
+
lexicon = read_lexicon(self.cfg.preprocess.lexicon_path)
|
| 75 |
+
phone_seq = preprocess_english(self.args.text, lexicon)
|
| 76 |
+
print(phone_seq)
|
| 77 |
+
|
| 78 |
+
phone_id = np.array(
|
| 79 |
+
[
|
| 80 |
+
*map(
|
| 81 |
+
self.phone2id.get,
|
| 82 |
+
phone_seq.replace("{", "").replace("}", "").split(),
|
| 83 |
+
)
|
| 84 |
+
]
|
| 85 |
+
)
|
| 86 |
+
phone_id = torch.from_numpy(phone_id).unsqueeze(0).to(device=self.args.device)
|
| 87 |
+
print(phone_id)
|
| 88 |
+
|
| 89 |
+
x0, prior_out = self.model.inference(
|
| 90 |
+
ref_code, phone_id, ref_mask, self.args.inference_step
|
| 91 |
+
)
|
| 92 |
+
print(prior_out["dur_pred"])
|
| 93 |
+
print(prior_out["dur_pred_round"])
|
| 94 |
+
print(torch.sum(prior_out["dur_pred_round"]))
|
| 95 |
+
|
| 96 |
+
latent_ref = self.codec.quantizer.vq.decode(ref_code.transpose(0, 1))
|
| 97 |
+
|
| 98 |
+
rec_wav = self.codec.decoder(x0)
|
| 99 |
+
# ref_wav = self.codec.decoder(latent_ref)
|
| 100 |
+
|
| 101 |
+
os.makedirs(self.args.output_dir, exist_ok=True)
|
| 102 |
+
|
| 103 |
+
sf.write(
|
| 104 |
+
"{}/{}.wav".format(
|
| 105 |
+
self.args.output_dir, self.args.text.replace(" ", "_", 100)
|
| 106 |
+
),
|
| 107 |
+
rec_wav[0, 0].detach().cpu().numpy(),
|
| 108 |
+
samplerate=24000,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def add_arguments(parser: argparse.ArgumentParser):
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--ref_audio",
|
| 114 |
+
type=str,
|
| 115 |
+
default="",
|
| 116 |
+
help="Reference audio path",
|
| 117 |
+
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--device",
|
| 120 |
+
type=str,
|
| 121 |
+
default="cuda",
|
| 122 |
+
)
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--inference_step",
|
| 125 |
+
type=int,
|
| 126 |
+
default=200,
|
| 127 |
+
help="Total inference steps for the diffusion model",
|
| 128 |
+
)
|
Amphion/models/tts/naturalspeech2/ns2_trainer.py
ADDED
|
@@ -0,0 +1,798 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import shutil
|
| 8 |
+
import json
|
| 9 |
+
import time
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
from utils.util import Logger, ValueWindow
|
| 13 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
| 14 |
+
from models.tts.base.tts_trainer import TTSTrainer
|
| 15 |
+
from models.base.base_trainer import BaseTrainer
|
| 16 |
+
from models.base.base_sampler import VariableSampler
|
| 17 |
+
from models.tts.naturalspeech2.ns2_dataset import NS2Dataset, NS2Collator, batch_by_size
|
| 18 |
+
from models.tts.naturalspeech2.ns2_loss import (
|
| 19 |
+
log_pitch_loss,
|
| 20 |
+
log_dur_loss,
|
| 21 |
+
diff_loss,
|
| 22 |
+
diff_ce_loss,
|
| 23 |
+
)
|
| 24 |
+
from torch.utils.data.sampler import BatchSampler, SequentialSampler
|
| 25 |
+
from models.tts.naturalspeech2.ns2 import NaturalSpeech2
|
| 26 |
+
from torch.optim import Adam, AdamW
|
| 27 |
+
from torch.nn import MSELoss, L1Loss
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
from diffusers import get_scheduler
|
| 30 |
+
|
| 31 |
+
import accelerate
|
| 32 |
+
from accelerate.logging import get_logger
|
| 33 |
+
from accelerate.utils import ProjectConfiguration
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class NS2Trainer(TTSTrainer):
|
| 37 |
+
def __init__(self, args, cfg):
|
| 38 |
+
self.args = args
|
| 39 |
+
self.cfg = cfg
|
| 40 |
+
|
| 41 |
+
cfg.exp_name = args.exp_name
|
| 42 |
+
|
| 43 |
+
self._init_accelerator()
|
| 44 |
+
self.accelerator.wait_for_everyone()
|
| 45 |
+
|
| 46 |
+
# Init logger
|
| 47 |
+
with self.accelerator.main_process_first():
|
| 48 |
+
if self.accelerator.is_main_process:
|
| 49 |
+
os.makedirs(os.path.join(self.exp_dir, "checkpoint"), exist_ok=True)
|
| 50 |
+
self.log_file = os.path.join(
|
| 51 |
+
os.path.join(self.exp_dir, "checkpoint"), "train.log"
|
| 52 |
+
)
|
| 53 |
+
self.logger = Logger(self.log_file, level=self.args.log_level).logger
|
| 54 |
+
|
| 55 |
+
self.time_window = ValueWindow(50)
|
| 56 |
+
|
| 57 |
+
if self.accelerator.is_main_process:
|
| 58 |
+
# Log some info
|
| 59 |
+
self.logger.info("=" * 56)
|
| 60 |
+
self.logger.info("||\t\t" + "New training process started." + "\t\t||")
|
| 61 |
+
self.logger.info("=" * 56)
|
| 62 |
+
self.logger.info("\n")
|
| 63 |
+
self.logger.debug(f"Using {args.log_level.upper()} logging level.")
|
| 64 |
+
self.logger.info(f"Experiment name: {args.exp_name}")
|
| 65 |
+
self.logger.info(f"Experiment directory: {self.exp_dir}")
|
| 66 |
+
|
| 67 |
+
self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
|
| 68 |
+
if self.accelerator.is_main_process:
|
| 69 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
| 70 |
+
|
| 71 |
+
if self.accelerator.is_main_process:
|
| 72 |
+
self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
|
| 73 |
+
|
| 74 |
+
# init counts
|
| 75 |
+
self.batch_count: int = 0
|
| 76 |
+
self.step: int = 0
|
| 77 |
+
self.epoch: int = 0
|
| 78 |
+
self.max_epoch = (
|
| 79 |
+
self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
|
| 80 |
+
)
|
| 81 |
+
if self.accelerator.is_main_process:
|
| 82 |
+
self.logger.info(
|
| 83 |
+
"Max epoch: {}".format(
|
| 84 |
+
self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
|
| 85 |
+
)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Check values
|
| 89 |
+
if self.accelerator.is_main_process:
|
| 90 |
+
self._check_basic_configs()
|
| 91 |
+
# Set runtime configs
|
| 92 |
+
self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
|
| 93 |
+
self.checkpoints_path = [
|
| 94 |
+
[] for _ in range(len(self.save_checkpoint_stride))
|
| 95 |
+
]
|
| 96 |
+
self.keep_last = [
|
| 97 |
+
i if i > 0 else float("inf") for i in self.cfg.train.keep_last
|
| 98 |
+
]
|
| 99 |
+
self.run_eval = self.cfg.train.run_eval
|
| 100 |
+
|
| 101 |
+
# set random seed
|
| 102 |
+
with self.accelerator.main_process_first():
|
| 103 |
+
start = time.monotonic_ns()
|
| 104 |
+
self._set_random_seed(self.cfg.train.random_seed)
|
| 105 |
+
end = time.monotonic_ns()
|
| 106 |
+
if self.accelerator.is_main_process:
|
| 107 |
+
self.logger.debug(
|
| 108 |
+
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
|
| 109 |
+
)
|
| 110 |
+
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
|
| 111 |
+
|
| 112 |
+
# setup data_loader
|
| 113 |
+
with self.accelerator.main_process_first():
|
| 114 |
+
if self.accelerator.is_main_process:
|
| 115 |
+
self.logger.info("Building dataset...")
|
| 116 |
+
start = time.monotonic_ns()
|
| 117 |
+
self.train_dataloader, self.valid_dataloader = self._build_dataloader()
|
| 118 |
+
end = time.monotonic_ns()
|
| 119 |
+
if self.accelerator.is_main_process:
|
| 120 |
+
self.logger.info(
|
| 121 |
+
f"Building dataset done in {(end - start) / 1e6:.2f}ms"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# setup model
|
| 125 |
+
with self.accelerator.main_process_first():
|
| 126 |
+
if self.accelerator.is_main_process:
|
| 127 |
+
self.logger.info("Building model...")
|
| 128 |
+
start = time.monotonic_ns()
|
| 129 |
+
self.model = self._build_model()
|
| 130 |
+
end = time.monotonic_ns()
|
| 131 |
+
if self.accelerator.is_main_process:
|
| 132 |
+
self.logger.debug(self.model)
|
| 133 |
+
self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
|
| 134 |
+
self.logger.info(
|
| 135 |
+
f"Model parameters: {self._count_parameters(self.model)/1e6:.2f}M"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# optimizer & scheduler
|
| 139 |
+
with self.accelerator.main_process_first():
|
| 140 |
+
if self.accelerator.is_main_process:
|
| 141 |
+
self.logger.info("Building optimizer and scheduler...")
|
| 142 |
+
start = time.monotonic_ns()
|
| 143 |
+
self.optimizer = self._build_optimizer()
|
| 144 |
+
self.scheduler = self._build_scheduler()
|
| 145 |
+
end = time.monotonic_ns()
|
| 146 |
+
if self.accelerator.is_main_process:
|
| 147 |
+
self.logger.info(
|
| 148 |
+
f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# accelerate prepare
|
| 152 |
+
if not self.cfg.train.use_dynamic_batchsize:
|
| 153 |
+
if self.accelerator.is_main_process:
|
| 154 |
+
self.logger.info("Initializing accelerate...")
|
| 155 |
+
start = time.monotonic_ns()
|
| 156 |
+
(
|
| 157 |
+
self.train_dataloader,
|
| 158 |
+
self.valid_dataloader,
|
| 159 |
+
) = self.accelerator.prepare(
|
| 160 |
+
self.train_dataloader,
|
| 161 |
+
self.valid_dataloader,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
if isinstance(self.model, dict):
|
| 165 |
+
for key in self.model.keys():
|
| 166 |
+
self.model[key] = self.accelerator.prepare(self.model[key])
|
| 167 |
+
else:
|
| 168 |
+
self.model = self.accelerator.prepare(self.model)
|
| 169 |
+
|
| 170 |
+
if isinstance(self.optimizer, dict):
|
| 171 |
+
for key in self.optimizer.keys():
|
| 172 |
+
self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
|
| 173 |
+
else:
|
| 174 |
+
self.optimizer = self.accelerator.prepare(self.optimizer)
|
| 175 |
+
|
| 176 |
+
if isinstance(self.scheduler, dict):
|
| 177 |
+
for key in self.scheduler.keys():
|
| 178 |
+
self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
|
| 179 |
+
else:
|
| 180 |
+
self.scheduler = self.accelerator.prepare(self.scheduler)
|
| 181 |
+
|
| 182 |
+
end = time.monotonic_ns()
|
| 183 |
+
if self.accelerator.is_main_process:
|
| 184 |
+
self.logger.info(
|
| 185 |
+
f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms"
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# create criterion
|
| 189 |
+
with self.accelerator.main_process_first():
|
| 190 |
+
if self.accelerator.is_main_process:
|
| 191 |
+
self.logger.info("Building criterion...")
|
| 192 |
+
start = time.monotonic_ns()
|
| 193 |
+
self.criterion = self._build_criterion()
|
| 194 |
+
end = time.monotonic_ns()
|
| 195 |
+
if self.accelerator.is_main_process:
|
| 196 |
+
self.logger.info(
|
| 197 |
+
f"Building criterion done in {(end - start) / 1e6:.2f}ms"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# TODO: Resume from ckpt need test/debug
|
| 201 |
+
with self.accelerator.main_process_first():
|
| 202 |
+
if args.resume:
|
| 203 |
+
if self.accelerator.is_main_process:
|
| 204 |
+
self.logger.info("Resuming from checkpoint...")
|
| 205 |
+
start = time.monotonic_ns()
|
| 206 |
+
ckpt_path = self._load_model(
|
| 207 |
+
self.checkpoint_dir,
|
| 208 |
+
args.checkpoint_path,
|
| 209 |
+
resume_type=args.resume_type,
|
| 210 |
+
)
|
| 211 |
+
end = time.monotonic_ns()
|
| 212 |
+
if self.accelerator.is_main_process:
|
| 213 |
+
self.logger.info(
|
| 214 |
+
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
|
| 215 |
+
)
|
| 216 |
+
self.checkpoints_path = json.load(
|
| 217 |
+
open(os.path.join(ckpt_path, "ckpts.json"), "r")
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
|
| 221 |
+
if self.accelerator.is_main_process:
|
| 222 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
| 223 |
+
if self.accelerator.is_main_process:
|
| 224 |
+
self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
|
| 225 |
+
|
| 226 |
+
# save config file path
|
| 227 |
+
self.config_save_path = os.path.join(self.exp_dir, "args.json")
|
| 228 |
+
|
| 229 |
+
# Only for TTS tasks
|
| 230 |
+
self.task_type = "TTS"
|
| 231 |
+
if self.accelerator.is_main_process:
|
| 232 |
+
self.logger.info("Task type: {}".format(self.task_type))
|
| 233 |
+
|
| 234 |
+
def _init_accelerator(self):
|
| 235 |
+
self.exp_dir = os.path.join(
|
| 236 |
+
os.path.abspath(self.cfg.log_dir), self.args.exp_name
|
| 237 |
+
)
|
| 238 |
+
project_config = ProjectConfiguration(
|
| 239 |
+
project_dir=self.exp_dir,
|
| 240 |
+
logging_dir=os.path.join(self.exp_dir, "log"),
|
| 241 |
+
)
|
| 242 |
+
# ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
| 243 |
+
self.accelerator = accelerate.Accelerator(
|
| 244 |
+
gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
|
| 245 |
+
log_with=self.cfg.train.tracker,
|
| 246 |
+
project_config=project_config,
|
| 247 |
+
# kwargs_handlers=[ddp_kwargs]
|
| 248 |
+
)
|
| 249 |
+
if self.accelerator.is_main_process:
|
| 250 |
+
os.makedirs(project_config.project_dir, exist_ok=True)
|
| 251 |
+
os.makedirs(project_config.logging_dir, exist_ok=True)
|
| 252 |
+
with self.accelerator.main_process_first():
|
| 253 |
+
self.accelerator.init_trackers(self.args.exp_name)
|
| 254 |
+
|
| 255 |
+
def _build_model(self):
|
| 256 |
+
model = NaturalSpeech2(cfg=self.cfg.model)
|
| 257 |
+
return model
|
| 258 |
+
|
| 259 |
+
def _build_dataset(self):
|
| 260 |
+
return NS2Dataset, NS2Collator
|
| 261 |
+
|
| 262 |
+
def _build_dataloader(self):
|
| 263 |
+
if self.cfg.train.use_dynamic_batchsize:
|
| 264 |
+
print("Use Dynamic Batchsize......")
|
| 265 |
+
Dataset, Collator = self._build_dataset()
|
| 266 |
+
train_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=False)
|
| 267 |
+
train_collate = Collator(self.cfg)
|
| 268 |
+
batch_sampler = batch_by_size(
|
| 269 |
+
train_dataset.num_frame_indices,
|
| 270 |
+
train_dataset.get_num_frames,
|
| 271 |
+
max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
|
| 272 |
+
max_sentences=self.cfg.train.max_sentences
|
| 273 |
+
* self.accelerator.num_processes,
|
| 274 |
+
required_batch_size_multiple=self.accelerator.num_processes,
|
| 275 |
+
)
|
| 276 |
+
np.random.seed(980205)
|
| 277 |
+
np.random.shuffle(batch_sampler)
|
| 278 |
+
print(batch_sampler[:1])
|
| 279 |
+
batches = [
|
| 280 |
+
x[
|
| 281 |
+
self.accelerator.local_process_index :: self.accelerator.num_processes
|
| 282 |
+
]
|
| 283 |
+
for x in batch_sampler
|
| 284 |
+
if len(x) % self.accelerator.num_processes == 0
|
| 285 |
+
]
|
| 286 |
+
|
| 287 |
+
train_loader = DataLoader(
|
| 288 |
+
train_dataset,
|
| 289 |
+
collate_fn=train_collate,
|
| 290 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
| 291 |
+
batch_sampler=VariableSampler(
|
| 292 |
+
batches, drop_last=False, use_random_sampler=True
|
| 293 |
+
),
|
| 294 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
| 295 |
+
)
|
| 296 |
+
self.accelerator.wait_for_everyone()
|
| 297 |
+
|
| 298 |
+
valid_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=True)
|
| 299 |
+
valid_collate = Collator(self.cfg)
|
| 300 |
+
batch_sampler = batch_by_size(
|
| 301 |
+
valid_dataset.num_frame_indices,
|
| 302 |
+
valid_dataset.get_num_frames,
|
| 303 |
+
max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
|
| 304 |
+
max_sentences=self.cfg.train.max_sentences
|
| 305 |
+
* self.accelerator.num_processes,
|
| 306 |
+
required_batch_size_multiple=self.accelerator.num_processes,
|
| 307 |
+
)
|
| 308 |
+
batches = [
|
| 309 |
+
x[
|
| 310 |
+
self.accelerator.local_process_index :: self.accelerator.num_processes
|
| 311 |
+
]
|
| 312 |
+
for x in batch_sampler
|
| 313 |
+
if len(x) % self.accelerator.num_processes == 0
|
| 314 |
+
]
|
| 315 |
+
valid_loader = DataLoader(
|
| 316 |
+
valid_dataset,
|
| 317 |
+
collate_fn=valid_collate,
|
| 318 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
| 319 |
+
batch_sampler=VariableSampler(batches, drop_last=False),
|
| 320 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
| 321 |
+
)
|
| 322 |
+
self.accelerator.wait_for_everyone()
|
| 323 |
+
|
| 324 |
+
else:
|
| 325 |
+
print("Use Normal Batchsize......")
|
| 326 |
+
Dataset, Collator = self._build_dataset()
|
| 327 |
+
train_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=False)
|
| 328 |
+
train_collate = Collator(self.cfg)
|
| 329 |
+
|
| 330 |
+
train_loader = DataLoader(
|
| 331 |
+
train_dataset,
|
| 332 |
+
shuffle=True,
|
| 333 |
+
collate_fn=train_collate,
|
| 334 |
+
batch_size=self.cfg.train.batch_size,
|
| 335 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
| 336 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
valid_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=True)
|
| 340 |
+
valid_collate = Collator(self.cfg)
|
| 341 |
+
|
| 342 |
+
valid_loader = DataLoader(
|
| 343 |
+
valid_dataset,
|
| 344 |
+
shuffle=True,
|
| 345 |
+
collate_fn=valid_collate,
|
| 346 |
+
batch_size=self.cfg.train.batch_size,
|
| 347 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
| 348 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
| 349 |
+
)
|
| 350 |
+
self.accelerator.wait_for_everyone()
|
| 351 |
+
|
| 352 |
+
return train_loader, valid_loader
|
| 353 |
+
|
| 354 |
+
def _build_optimizer(self):
|
| 355 |
+
optimizer = torch.optim.AdamW(
|
| 356 |
+
filter(lambda p: p.requires_grad, self.model.parameters()),
|
| 357 |
+
**self.cfg.train.adam,
|
| 358 |
+
)
|
| 359 |
+
return optimizer
|
| 360 |
+
|
| 361 |
+
def _build_scheduler(self):
|
| 362 |
+
lr_scheduler = get_scheduler(
|
| 363 |
+
self.cfg.train.lr_scheduler,
|
| 364 |
+
optimizer=self.optimizer,
|
| 365 |
+
num_warmup_steps=self.cfg.train.lr_warmup_steps,
|
| 366 |
+
num_training_steps=self.cfg.train.num_train_steps,
|
| 367 |
+
)
|
| 368 |
+
return lr_scheduler
|
| 369 |
+
|
| 370 |
+
def _build_criterion(self):
|
| 371 |
+
criterion = torch.nn.L1Loss(reduction="mean")
|
| 372 |
+
return criterion
|
| 373 |
+
|
| 374 |
+
def write_summary(self, losses, stats):
|
| 375 |
+
for key, value in losses.items():
|
| 376 |
+
self.sw.add_scalar(key, value, self.step)
|
| 377 |
+
|
| 378 |
+
def write_valid_summary(self, losses, stats):
|
| 379 |
+
for key, value in losses.items():
|
| 380 |
+
self.sw.add_scalar(key, value, self.step)
|
| 381 |
+
|
| 382 |
+
def get_state_dict(self):
|
| 383 |
+
state_dict = {
|
| 384 |
+
"model": self.model.state_dict(),
|
| 385 |
+
"optimizer": self.optimizer.state_dict(),
|
| 386 |
+
"scheduler": self.scheduler.state_dict(),
|
| 387 |
+
"step": self.step,
|
| 388 |
+
"epoch": self.epoch,
|
| 389 |
+
"batch_size": self.cfg.train.batch_size,
|
| 390 |
+
}
|
| 391 |
+
return state_dict
|
| 392 |
+
|
| 393 |
+
def load_model(self, checkpoint):
|
| 394 |
+
self.step = checkpoint["step"]
|
| 395 |
+
self.epoch = checkpoint["epoch"]
|
| 396 |
+
|
| 397 |
+
self.model.load_state_dict(checkpoint["model"])
|
| 398 |
+
self.optimizer.load_state_dict(checkpoint["optimizer"])
|
| 399 |
+
self.scheduler.load_state_dict(checkpoint["scheduler"])
|
| 400 |
+
|
| 401 |
+
def _train_step(self, batch):
|
| 402 |
+
train_losses = {}
|
| 403 |
+
total_loss = 0
|
| 404 |
+
train_stats = {}
|
| 405 |
+
|
| 406 |
+
code = batch["code"] # (B, 16, T)
|
| 407 |
+
pitch = batch["pitch"] # (B, T)
|
| 408 |
+
duration = batch["duration"] # (B, N)
|
| 409 |
+
phone_id = batch["phone_id"] # (B, N)
|
| 410 |
+
ref_code = batch["ref_code"] # (B, 16, T')
|
| 411 |
+
phone_mask = batch["phone_mask"] # (B, N)
|
| 412 |
+
mask = batch["mask"] # (B, T)
|
| 413 |
+
ref_mask = batch["ref_mask"] # (B, T')
|
| 414 |
+
|
| 415 |
+
diff_out, prior_out = self.model(
|
| 416 |
+
code=code,
|
| 417 |
+
pitch=pitch,
|
| 418 |
+
duration=duration,
|
| 419 |
+
phone_id=phone_id,
|
| 420 |
+
ref_code=ref_code,
|
| 421 |
+
phone_mask=phone_mask,
|
| 422 |
+
mask=mask,
|
| 423 |
+
ref_mask=ref_mask,
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# pitch loss
|
| 427 |
+
pitch_loss = log_pitch_loss(prior_out["pitch_pred_log"], pitch, mask=mask)
|
| 428 |
+
total_loss += pitch_loss
|
| 429 |
+
train_losses["pitch_loss"] = pitch_loss
|
| 430 |
+
|
| 431 |
+
# duration loss
|
| 432 |
+
dur_loss = log_dur_loss(prior_out["dur_pred_log"], duration, mask=phone_mask)
|
| 433 |
+
total_loss += dur_loss
|
| 434 |
+
train_losses["dur_loss"] = dur_loss
|
| 435 |
+
|
| 436 |
+
x0 = self.model.module.code_to_latent(code)
|
| 437 |
+
if self.cfg.model.diffusion.diffusion_type == "diffusion":
|
| 438 |
+
# diff loss x0
|
| 439 |
+
diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask)
|
| 440 |
+
total_loss += diff_loss_x0
|
| 441 |
+
train_losses["diff_loss_x0"] = diff_loss_x0
|
| 442 |
+
|
| 443 |
+
# diff loss noise
|
| 444 |
+
diff_loss_noise = diff_loss(
|
| 445 |
+
diff_out["noise_pred"], diff_out["noise"], mask=mask
|
| 446 |
+
)
|
| 447 |
+
total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda
|
| 448 |
+
train_losses["diff_loss_noise"] = diff_loss_noise
|
| 449 |
+
|
| 450 |
+
elif self.cfg.model.diffusion.diffusion_type == "flow":
|
| 451 |
+
# diff flow matching loss
|
| 452 |
+
flow_gt = diff_out["noise"] - x0
|
| 453 |
+
diff_loss_flow = diff_loss(diff_out["flow_pred"], flow_gt, mask=mask)
|
| 454 |
+
total_loss += diff_loss_flow
|
| 455 |
+
train_losses["diff_loss_flow"] = diff_loss_flow
|
| 456 |
+
|
| 457 |
+
# diff loss ce
|
| 458 |
+
|
| 459 |
+
# (nq, B, T); (nq, B, T, 1024)
|
| 460 |
+
if self.cfg.train.diff_ce_loss_lambda > 0:
|
| 461 |
+
pred_indices, pred_dist = self.model.module.latent_to_code(
|
| 462 |
+
diff_out["x0_pred"], nq=code.shape[1]
|
| 463 |
+
)
|
| 464 |
+
gt_indices, _ = self.model.module.latent_to_code(x0, nq=code.shape[1])
|
| 465 |
+
diff_loss_ce = diff_ce_loss(pred_dist, gt_indices, mask=mask)
|
| 466 |
+
total_loss += diff_loss_ce * self.cfg.train.diff_ce_loss_lambda
|
| 467 |
+
train_losses["diff_loss_ce"] = diff_loss_ce
|
| 468 |
+
|
| 469 |
+
self.optimizer.zero_grad()
|
| 470 |
+
# total_loss.backward()
|
| 471 |
+
self.accelerator.backward(total_loss)
|
| 472 |
+
if self.accelerator.sync_gradients:
|
| 473 |
+
self.accelerator.clip_grad_norm_(
|
| 474 |
+
filter(lambda p: p.requires_grad, self.model.parameters()), 0.5
|
| 475 |
+
)
|
| 476 |
+
self.optimizer.step()
|
| 477 |
+
self.scheduler.step()
|
| 478 |
+
|
| 479 |
+
for item in train_losses:
|
| 480 |
+
train_losses[item] = train_losses[item].item()
|
| 481 |
+
|
| 482 |
+
if self.cfg.train.diff_ce_loss_lambda > 0:
|
| 483 |
+
pred_indices_list = pred_indices.long().detach().cpu().numpy()
|
| 484 |
+
gt_indices_list = gt_indices.long().detach().cpu().numpy()
|
| 485 |
+
mask_list = batch["mask"].detach().cpu().numpy()
|
| 486 |
+
|
| 487 |
+
for i in range(pred_indices_list.shape[0]):
|
| 488 |
+
pred_acc = np.sum(
|
| 489 |
+
(pred_indices_list[i] == gt_indices_list[i]) * mask_list
|
| 490 |
+
) / np.sum(mask_list)
|
| 491 |
+
train_losses["pred_acc_{}".format(str(i))] = pred_acc
|
| 492 |
+
|
| 493 |
+
train_losses["batch_size"] = code.shape[0]
|
| 494 |
+
train_losses["max_frame_nums"] = np.max(
|
| 495 |
+
batch["frame_nums"].detach().cpu().numpy()
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
return (total_loss.item(), train_losses, train_stats)
|
| 499 |
+
|
| 500 |
+
@torch.inference_mode()
|
| 501 |
+
def _valid_step(self, batch):
|
| 502 |
+
valid_losses = {}
|
| 503 |
+
total_loss = 0
|
| 504 |
+
valid_stats = {}
|
| 505 |
+
|
| 506 |
+
code = batch["code"] # (B, 16, T)
|
| 507 |
+
pitch = batch["pitch"] # (B, T)
|
| 508 |
+
duration = batch["duration"] # (B, N)
|
| 509 |
+
phone_id = batch["phone_id"] # (B, N)
|
| 510 |
+
ref_code = batch["ref_code"] # (B, 16, T')
|
| 511 |
+
phone_mask = batch["phone_mask"] # (B, N)
|
| 512 |
+
mask = batch["mask"] # (B, T)
|
| 513 |
+
ref_mask = batch["ref_mask"] # (B, T')
|
| 514 |
+
|
| 515 |
+
diff_out, prior_out = self.model(
|
| 516 |
+
code=code,
|
| 517 |
+
pitch=pitch,
|
| 518 |
+
duration=duration,
|
| 519 |
+
phone_id=phone_id,
|
| 520 |
+
ref_code=ref_code,
|
| 521 |
+
phone_mask=phone_mask,
|
| 522 |
+
mask=mask,
|
| 523 |
+
ref_mask=ref_mask,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
# pitch loss
|
| 527 |
+
pitch_loss = log_pitch_loss(prior_out["pitch_pred_log"], pitch, mask=mask)
|
| 528 |
+
total_loss += pitch_loss
|
| 529 |
+
valid_losses["pitch_loss"] = pitch_loss
|
| 530 |
+
|
| 531 |
+
# duration loss
|
| 532 |
+
dur_loss = log_dur_loss(prior_out["dur_pred_log"], duration, mask=phone_mask)
|
| 533 |
+
total_loss += dur_loss
|
| 534 |
+
valid_losses["dur_loss"] = dur_loss
|
| 535 |
+
|
| 536 |
+
x0 = self.model.module.code_to_latent(code)
|
| 537 |
+
if self.cfg.model.diffusion.diffusion_type == "diffusion":
|
| 538 |
+
# diff loss x0
|
| 539 |
+
diff_loss_x0 = diff_loss(diff_out["x0_pred"], x0, mask=mask)
|
| 540 |
+
total_loss += diff_loss_x0
|
| 541 |
+
valid_losses["diff_loss_x0"] = diff_loss_x0
|
| 542 |
+
|
| 543 |
+
# diff loss noise
|
| 544 |
+
diff_loss_noise = diff_loss(
|
| 545 |
+
diff_out["noise_pred"], diff_out["noise"], mask=mask
|
| 546 |
+
)
|
| 547 |
+
total_loss += diff_loss_noise * self.cfg.train.diff_noise_loss_lambda
|
| 548 |
+
valid_losses["diff_loss_noise"] = diff_loss_noise
|
| 549 |
+
|
| 550 |
+
elif self.cfg.model.diffusion.diffusion_type == "flow":
|
| 551 |
+
# diff flow matching loss
|
| 552 |
+
flow_gt = diff_out["noise"] - x0
|
| 553 |
+
diff_loss_flow = diff_loss(diff_out["flow_pred"], flow_gt, mask=mask)
|
| 554 |
+
total_loss += diff_loss_flow
|
| 555 |
+
valid_losses["diff_loss_flow"] = diff_loss_flow
|
| 556 |
+
|
| 557 |
+
# diff loss ce
|
| 558 |
+
|
| 559 |
+
# (nq, B, T); (nq, B, T, 1024)
|
| 560 |
+
if self.cfg.train.diff_ce_loss_lambda > 0:
|
| 561 |
+
pred_indices, pred_dist = self.model.module.latent_to_code(
|
| 562 |
+
diff_out["x0_pred"], nq=code.shape[1]
|
| 563 |
+
)
|
| 564 |
+
gt_indices, _ = self.model.module.latent_to_code(x0, nq=code.shape[1])
|
| 565 |
+
diff_loss_ce = diff_ce_loss(pred_dist, gt_indices, mask=mask)
|
| 566 |
+
total_loss += diff_loss_ce * self.cfg.train.diff_ce_loss_lambda
|
| 567 |
+
valid_losses["diff_loss_ce"] = diff_loss_ce
|
| 568 |
+
|
| 569 |
+
for item in valid_losses:
|
| 570 |
+
valid_losses[item] = valid_losses[item].item()
|
| 571 |
+
|
| 572 |
+
if self.cfg.train.diff_ce_loss_lambda > 0:
|
| 573 |
+
pred_indices_list = pred_indices.long().detach().cpu().numpy()
|
| 574 |
+
gt_indices_list = gt_indices.long().detach().cpu().numpy()
|
| 575 |
+
mask_list = batch["mask"].detach().cpu().numpy()
|
| 576 |
+
|
| 577 |
+
for i in range(pred_indices_list.shape[0]):
|
| 578 |
+
pred_acc = np.sum(
|
| 579 |
+
(pred_indices_list[i] == gt_indices_list[i]) * mask_list
|
| 580 |
+
) / np.sum(mask_list)
|
| 581 |
+
valid_losses["pred_acc_{}".format(str(i))] = pred_acc
|
| 582 |
+
|
| 583 |
+
return (total_loss.item(), valid_losses, valid_stats)
|
| 584 |
+
|
| 585 |
+
@torch.inference_mode()
|
| 586 |
+
def _valid_epoch(self):
|
| 587 |
+
r"""Testing epoch. Should return average loss of a batch (sample) over
|
| 588 |
+
one epoch. See ``train_loop`` for usage.
|
| 589 |
+
"""
|
| 590 |
+
if isinstance(self.model, dict):
|
| 591 |
+
for key in self.model.keys():
|
| 592 |
+
self.model[key].eval()
|
| 593 |
+
else:
|
| 594 |
+
self.model.eval()
|
| 595 |
+
|
| 596 |
+
epoch_sum_loss = 0.0
|
| 597 |
+
epoch_losses = dict()
|
| 598 |
+
|
| 599 |
+
for batch in self.valid_dataloader:
|
| 600 |
+
# Put the data to cuda device
|
| 601 |
+
device = self.accelerator.device
|
| 602 |
+
for k, v in batch.items():
|
| 603 |
+
if isinstance(v, torch.Tensor):
|
| 604 |
+
batch[k] = v.to(device)
|
| 605 |
+
|
| 606 |
+
total_loss, valid_losses, valid_stats = self._valid_step(batch)
|
| 607 |
+
epoch_sum_loss = total_loss
|
| 608 |
+
for key, value in valid_losses.items():
|
| 609 |
+
epoch_losses[key] = value
|
| 610 |
+
|
| 611 |
+
self.accelerator.wait_for_everyone()
|
| 612 |
+
|
| 613 |
+
return epoch_sum_loss, epoch_losses
|
| 614 |
+
|
| 615 |
+
def _train_epoch(self):
|
| 616 |
+
r"""Training epoch. Should return average loss of a batch (sample) over
|
| 617 |
+
one epoch. See ``train_loop`` for usage.
|
| 618 |
+
"""
|
| 619 |
+
if isinstance(self.model, dict):
|
| 620 |
+
for key in self.model.keys():
|
| 621 |
+
self.model[key].train()
|
| 622 |
+
else:
|
| 623 |
+
self.model.train()
|
| 624 |
+
|
| 625 |
+
epoch_sum_loss: float = 0.0
|
| 626 |
+
epoch_losses: dict = {}
|
| 627 |
+
epoch_step: int = 0
|
| 628 |
+
|
| 629 |
+
for batch in self.train_dataloader:
|
| 630 |
+
# Put the data to cuda device
|
| 631 |
+
device = self.accelerator.device
|
| 632 |
+
for k, v in batch.items():
|
| 633 |
+
if isinstance(v, torch.Tensor):
|
| 634 |
+
batch[k] = v.to(device)
|
| 635 |
+
|
| 636 |
+
# Do training step and BP
|
| 637 |
+
with self.accelerator.accumulate(self.model):
|
| 638 |
+
total_loss, train_losses, training_stats = self._train_step(batch)
|
| 639 |
+
self.batch_count += 1
|
| 640 |
+
|
| 641 |
+
# Update info for each step
|
| 642 |
+
# TODO: step means BP counts or batch counts?
|
| 643 |
+
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
|
| 644 |
+
epoch_sum_loss = total_loss
|
| 645 |
+
for key, value in train_losses.items():
|
| 646 |
+
epoch_losses[key] = value
|
| 647 |
+
|
| 648 |
+
if isinstance(train_losses, dict):
|
| 649 |
+
for key, loss in train_losses.items():
|
| 650 |
+
self.accelerator.log(
|
| 651 |
+
{"Epoch/Train {} Loss".format(key): loss},
|
| 652 |
+
step=self.step,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
if (
|
| 656 |
+
self.accelerator.is_main_process
|
| 657 |
+
and self.batch_count
|
| 658 |
+
% (1 * self.cfg.train.gradient_accumulation_step)
|
| 659 |
+
== 0
|
| 660 |
+
):
|
| 661 |
+
self.echo_log(train_losses, mode="Training")
|
| 662 |
+
|
| 663 |
+
self.step += 1
|
| 664 |
+
epoch_step += 1
|
| 665 |
+
|
| 666 |
+
self.accelerator.wait_for_everyone()
|
| 667 |
+
|
| 668 |
+
return epoch_sum_loss, epoch_losses
|
| 669 |
+
|
| 670 |
+
def train_loop(self):
|
| 671 |
+
r"""Training loop. The public entry of training process."""
|
| 672 |
+
# Wait everyone to prepare before we move on
|
| 673 |
+
self.accelerator.wait_for_everyone()
|
| 674 |
+
# dump config file
|
| 675 |
+
if self.accelerator.is_main_process:
|
| 676 |
+
self._dump_cfg(self.config_save_path)
|
| 677 |
+
|
| 678 |
+
# self.optimizer.zero_grad()
|
| 679 |
+
|
| 680 |
+
# Wait to ensure good to go
|
| 681 |
+
self.accelerator.wait_for_everyone()
|
| 682 |
+
while self.epoch < self.max_epoch:
|
| 683 |
+
if self.accelerator.is_main_process:
|
| 684 |
+
self.logger.info("\n")
|
| 685 |
+
self.logger.info("-" * 32)
|
| 686 |
+
self.logger.info("Epoch {}: ".format(self.epoch))
|
| 687 |
+
|
| 688 |
+
# Do training & validating epoch
|
| 689 |
+
train_total_loss, train_losses = self._train_epoch()
|
| 690 |
+
if isinstance(train_losses, dict):
|
| 691 |
+
for key, loss in train_losses.items():
|
| 692 |
+
if self.accelerator.is_main_process:
|
| 693 |
+
self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss))
|
| 694 |
+
self.accelerator.log(
|
| 695 |
+
{"Epoch/Train {} Loss".format(key): loss},
|
| 696 |
+
step=self.epoch,
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
valid_total_loss, valid_losses = self._valid_epoch()
|
| 700 |
+
if isinstance(valid_losses, dict):
|
| 701 |
+
for key, loss in valid_losses.items():
|
| 702 |
+
if self.accelerator.is_main_process:
|
| 703 |
+
self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss))
|
| 704 |
+
self.accelerator.log(
|
| 705 |
+
{"Epoch/Train {} Loss".format(key): loss},
|
| 706 |
+
step=self.epoch,
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
if self.accelerator.is_main_process:
|
| 710 |
+
self.logger.info(" |- Train/Loss: {:.6f}".format(train_total_loss))
|
| 711 |
+
self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_total_loss))
|
| 712 |
+
self.accelerator.log(
|
| 713 |
+
{
|
| 714 |
+
"Epoch/Train Loss": train_total_loss,
|
| 715 |
+
"Epoch/Valid Loss": valid_total_loss,
|
| 716 |
+
},
|
| 717 |
+
step=self.epoch,
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
self.accelerator.wait_for_everyone()
|
| 721 |
+
if isinstance(self.scheduler, dict):
|
| 722 |
+
for key in self.scheduler.keys():
|
| 723 |
+
self.scheduler[key].step()
|
| 724 |
+
else:
|
| 725 |
+
self.scheduler.step()
|
| 726 |
+
|
| 727 |
+
# Check if hit save_checkpoint_stride and run_eval
|
| 728 |
+
run_eval = False
|
| 729 |
+
if self.accelerator.is_main_process:
|
| 730 |
+
save_checkpoint = False
|
| 731 |
+
hit_dix = []
|
| 732 |
+
for i, num in enumerate(self.save_checkpoint_stride):
|
| 733 |
+
if self.epoch % num == 0:
|
| 734 |
+
save_checkpoint = True
|
| 735 |
+
hit_dix.append(i)
|
| 736 |
+
run_eval |= self.run_eval[i]
|
| 737 |
+
|
| 738 |
+
self.accelerator.wait_for_everyone()
|
| 739 |
+
if self.accelerator.is_main_process and save_checkpoint:
|
| 740 |
+
path = os.path.join(
|
| 741 |
+
self.checkpoint_dir,
|
| 742 |
+
"epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
| 743 |
+
self.epoch, self.step, train_total_loss
|
| 744 |
+
),
|
| 745 |
+
)
|
| 746 |
+
print("save state......")
|
| 747 |
+
self.accelerator.save_state(path)
|
| 748 |
+
print("finish saving state......")
|
| 749 |
+
json.dump(
|
| 750 |
+
self.checkpoints_path,
|
| 751 |
+
open(os.path.join(path, "ckpts.json"), "w"),
|
| 752 |
+
ensure_ascii=False,
|
| 753 |
+
indent=4,
|
| 754 |
+
)
|
| 755 |
+
# Remove old checkpoints
|
| 756 |
+
to_remove = []
|
| 757 |
+
for idx in hit_dix:
|
| 758 |
+
self.checkpoints_path[idx].append(path)
|
| 759 |
+
while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
|
| 760 |
+
to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
|
| 761 |
+
|
| 762 |
+
# Search conflicts
|
| 763 |
+
total = set()
|
| 764 |
+
for i in self.checkpoints_path:
|
| 765 |
+
total |= set(i)
|
| 766 |
+
do_remove = set()
|
| 767 |
+
for idx, path in to_remove[::-1]:
|
| 768 |
+
if path in total:
|
| 769 |
+
self.checkpoints_path[idx].insert(0, path)
|
| 770 |
+
else:
|
| 771 |
+
do_remove.add(path)
|
| 772 |
+
|
| 773 |
+
# Remove old checkpoints
|
| 774 |
+
for path in do_remove:
|
| 775 |
+
shutil.rmtree(path, ignore_errors=True)
|
| 776 |
+
if self.accelerator.is_main_process:
|
| 777 |
+
self.logger.debug(f"Remove old checkpoint: {path}")
|
| 778 |
+
|
| 779 |
+
self.accelerator.wait_for_everyone()
|
| 780 |
+
if run_eval:
|
| 781 |
+
# TODO: run evaluation
|
| 782 |
+
pass
|
| 783 |
+
|
| 784 |
+
# Update info for each epoch
|
| 785 |
+
self.epoch += 1
|
| 786 |
+
|
| 787 |
+
# Finish training and save final checkpoint
|
| 788 |
+
self.accelerator.wait_for_everyone()
|
| 789 |
+
if self.accelerator.is_main_process:
|
| 790 |
+
self.accelerator.save_state(
|
| 791 |
+
os.path.join(
|
| 792 |
+
self.checkpoint_dir,
|
| 793 |
+
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
| 794 |
+
self.epoch, self.step, valid_total_loss
|
| 795 |
+
),
|
| 796 |
+
)
|
| 797 |
+
)
|
| 798 |
+
self.accelerator.end_training()
|
Amphion/models/tts/valle/__init__.py
ADDED
|
File without changes
|
Amphion/models/vocoders/autoregressive/autoregressive_vocoder_inference.py
ADDED
|
File without changes
|
Amphion/models/vocoders/autoregressive/wavenet/conv.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Conv1d(nn.Conv1d):
|
| 11 |
+
"""Extended nn.Conv1d for incremental dilated convolutions"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, *args, **kwargs):
|
| 14 |
+
super().__init__(*args, **kwargs)
|
| 15 |
+
self.clear_buffer()
|
| 16 |
+
self._linearized_weight = None
|
| 17 |
+
self.register_backward_hook(self._clear_linearized_weight)
|
| 18 |
+
|
| 19 |
+
def incremental_forward(self, input):
|
| 20 |
+
# input (B, T, C)
|
| 21 |
+
# run forward pre hooks
|
| 22 |
+
for hook in self._forward_pre_hooks.values():
|
| 23 |
+
hook(self, input)
|
| 24 |
+
|
| 25 |
+
# reshape weight
|
| 26 |
+
weight = self._get_linearized_weight()
|
| 27 |
+
kw = self.kernel_size[0]
|
| 28 |
+
dilation = self.dilation[0]
|
| 29 |
+
|
| 30 |
+
bsz = input.size(0)
|
| 31 |
+
if kw > 1:
|
| 32 |
+
input = input.data
|
| 33 |
+
if self.input_buffer is None:
|
| 34 |
+
self.input_buffer = input.new(
|
| 35 |
+
bsz, kw + (kw - 1) * (dilation - 1), input.size(2)
|
| 36 |
+
)
|
| 37 |
+
self.input_buffer.zero_()
|
| 38 |
+
else:
|
| 39 |
+
# shift buffer
|
| 40 |
+
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone()
|
| 41 |
+
# append next input
|
| 42 |
+
self.input_buffer[:, -1, :] = input[:, -1, :]
|
| 43 |
+
input = self.input_buffer
|
| 44 |
+
if dilation > 1:
|
| 45 |
+
input = input[:, 0::dilation, :].contiguous()
|
| 46 |
+
output = F.linear(input.view(bsz, -1), weight, self.bias)
|
| 47 |
+
return output.view(bsz, 1, -1)
|
| 48 |
+
|
| 49 |
+
def clear_buffer(self):
|
| 50 |
+
self.input_buffer = None
|
| 51 |
+
|
| 52 |
+
def _get_linearized_weight(self):
|
| 53 |
+
if self._linearized_weight is None:
|
| 54 |
+
kw = self.kernel_size[0]
|
| 55 |
+
# nn.Conv1d
|
| 56 |
+
if self.weight.size() == (self.out_channels, self.in_channels, kw):
|
| 57 |
+
weight = self.weight.transpose(1, 2).contiguous()
|
| 58 |
+
else:
|
| 59 |
+
# fairseq.modules.conv_tbc.ConvTBC
|
| 60 |
+
weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
|
| 61 |
+
assert weight.size() == (self.out_channels, kw, self.in_channels)
|
| 62 |
+
self._linearized_weight = weight.view(self.out_channels, -1)
|
| 63 |
+
return self._linearized_weight
|
| 64 |
+
|
| 65 |
+
def _clear_linearized_weight(self, *args):
|
| 66 |
+
self._linearized_weight = None
|
Amphion/models/vocoders/autoregressive/wavenet/wavenet.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
from .modules import Conv1d1x1, ResidualConv1dGLU
|
| 12 |
+
from .upsample import ConvInUpsampleNetwork
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def receptive_field_size(
|
| 16 |
+
total_layers, num_cycles, kernel_size, dilation=lambda x: 2**x
|
| 17 |
+
):
|
| 18 |
+
"""Compute receptive field size
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
total_layers (int): total layers
|
| 22 |
+
num_cycles (int): cycles
|
| 23 |
+
kernel_size (int): kernel size
|
| 24 |
+
dilation (lambda): lambda to compute dilation factor. ``lambda x : 1``
|
| 25 |
+
to disable dilated convolution.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
int: receptive field size in sample
|
| 29 |
+
|
| 30 |
+
"""
|
| 31 |
+
assert total_layers % num_cycles == 0
|
| 32 |
+
|
| 33 |
+
layers_per_cycle = total_layers // num_cycles
|
| 34 |
+
dilations = [dilation(i % layers_per_cycle) for i in range(total_layers)]
|
| 35 |
+
return (kernel_size - 1) * sum(dilations) + 1
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class WaveNet(nn.Module):
|
| 39 |
+
"""The WaveNet model that supports local and global conditioning.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
out_channels (int): Output channels. If input_type is mu-law quantized
|
| 43 |
+
one-hot vecror. this must equal to the quantize channels. Other wise
|
| 44 |
+
num_mixtures x 3 (pi, mu, log_scale).
|
| 45 |
+
layers (int): Number of total layers
|
| 46 |
+
stacks (int): Number of dilation cycles
|
| 47 |
+
residual_channels (int): Residual input / output channels
|
| 48 |
+
gate_channels (int): Gated activation channels.
|
| 49 |
+
skip_out_channels (int): Skip connection channels.
|
| 50 |
+
kernel_size (int): Kernel size of convolution layers.
|
| 51 |
+
dropout (float): Dropout probability.
|
| 52 |
+
input_dim (int): Number of mel-spec dimension.
|
| 53 |
+
upsample_scales (list): List of upsample scale.
|
| 54 |
+
``np.prod(upsample_scales)`` must equal to hop size. Used only if
|
| 55 |
+
upsample_conditional_features is enabled.
|
| 56 |
+
freq_axis_kernel_size (int): Freq-axis kernel_size for transposed
|
| 57 |
+
convolution layers for upsampling. If you only care about time-axis
|
| 58 |
+
upsampling, set this to 1.
|
| 59 |
+
scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise
|
| 60 |
+
quantized one-hot vector is expected..
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self, cfg):
|
| 64 |
+
super(WaveNet, self).__init__()
|
| 65 |
+
self.cfg = cfg
|
| 66 |
+
self.scalar_input = self.cfg.VOCODER.SCALAR_INPUT
|
| 67 |
+
self.out_channels = self.cfg.VOCODER.OUT_CHANNELS
|
| 68 |
+
self.cin_channels = self.cfg.VOCODER.INPUT_DIM
|
| 69 |
+
self.residual_channels = self.cfg.VOCODER.RESIDUAL_CHANNELS
|
| 70 |
+
self.layers = self.cfg.VOCODER.LAYERS
|
| 71 |
+
self.stacks = self.cfg.VOCODER.STACKS
|
| 72 |
+
self.gate_channels = self.cfg.VOCODER.GATE_CHANNELS
|
| 73 |
+
self.kernel_size = self.cfg.VOCODER.KERNEL_SIZE
|
| 74 |
+
self.skip_out_channels = self.cfg.VOCODER.SKIP_OUT_CHANNELS
|
| 75 |
+
self.dropout = self.cfg.VOCODER.DROPOUT
|
| 76 |
+
self.upsample_scales = self.cfg.VOCODER.UPSAMPLE_SCALES
|
| 77 |
+
self.mel_frame_pad = self.cfg.VOCODER.MEL_FRAME_PAD
|
| 78 |
+
|
| 79 |
+
assert self.layers % self.stacks == 0
|
| 80 |
+
|
| 81 |
+
layers_per_stack = self.layers // self.stacks
|
| 82 |
+
if self.scalar_input:
|
| 83 |
+
self.first_conv = Conv1d1x1(1, self.residual_channels)
|
| 84 |
+
else:
|
| 85 |
+
self.first_conv = Conv1d1x1(self.out_channels, self.residual_channels)
|
| 86 |
+
|
| 87 |
+
self.conv_layers = nn.ModuleList()
|
| 88 |
+
for layer in range(self.layers):
|
| 89 |
+
dilation = 2 ** (layer % layers_per_stack)
|
| 90 |
+
conv = ResidualConv1dGLU(
|
| 91 |
+
self.residual_channels,
|
| 92 |
+
self.gate_channels,
|
| 93 |
+
kernel_size=self.kernel_size,
|
| 94 |
+
skip_out_channels=self.skip_out_channels,
|
| 95 |
+
bias=True,
|
| 96 |
+
dilation=dilation,
|
| 97 |
+
dropout=self.dropout,
|
| 98 |
+
cin_channels=self.cin_channels,
|
| 99 |
+
)
|
| 100 |
+
self.conv_layers.append(conv)
|
| 101 |
+
|
| 102 |
+
self.last_conv_layers = nn.ModuleList(
|
| 103 |
+
[
|
| 104 |
+
nn.ReLU(inplace=True),
|
| 105 |
+
Conv1d1x1(self.skip_out_channels, self.skip_out_channels),
|
| 106 |
+
nn.ReLU(inplace=True),
|
| 107 |
+
Conv1d1x1(self.skip_out_channels, self.out_channels),
|
| 108 |
+
]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self.upsample_net = ConvInUpsampleNetwork(
|
| 112 |
+
upsample_scales=self.upsample_scales,
|
| 113 |
+
cin_pad=self.mel_frame_pad,
|
| 114 |
+
cin_channels=self.cin_channels,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
self.receptive_field = receptive_field_size(
|
| 118 |
+
self.layers, self.stacks, self.kernel_size
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def forward(self, x, mel, softmax=False):
|
| 122 |
+
"""Forward step
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
x (Tensor): One-hot encoded audio signal, shape (B x C x T)
|
| 126 |
+
mel (Tensor): Local conditioning features,
|
| 127 |
+
shape (B x cin_channels x T)
|
| 128 |
+
softmax (bool): Whether applies softmax or not.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Tensor: output, shape B x out_channels x T
|
| 132 |
+
"""
|
| 133 |
+
B, _, T = x.size()
|
| 134 |
+
|
| 135 |
+
mel = self.upsample_net(mel)
|
| 136 |
+
assert mel.shape[-1] == x.shape[-1]
|
| 137 |
+
|
| 138 |
+
x = self.first_conv(x)
|
| 139 |
+
skips = 0
|
| 140 |
+
for f in self.conv_layers:
|
| 141 |
+
x, h = f(x, mel)
|
| 142 |
+
skips += h
|
| 143 |
+
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
| 144 |
+
|
| 145 |
+
x = skips
|
| 146 |
+
for f in self.last_conv_layers:
|
| 147 |
+
x = f(x)
|
| 148 |
+
|
| 149 |
+
x = F.softmax(x, dim=1) if softmax else x
|
| 150 |
+
|
| 151 |
+
return x
|
| 152 |
+
|
| 153 |
+
def clear_buffer(self):
|
| 154 |
+
self.first_conv.clear_buffer()
|
| 155 |
+
for f in self.conv_layers:
|
| 156 |
+
f.clear_buffer()
|
| 157 |
+
for f in self.last_conv_layers:
|
| 158 |
+
try:
|
| 159 |
+
f.clear_buffer()
|
| 160 |
+
except AttributeError:
|
| 161 |
+
pass
|
| 162 |
+
|
| 163 |
+
def make_generation_fast_(self):
|
| 164 |
+
def remove_weight_norm(m):
|
| 165 |
+
try:
|
| 166 |
+
nn.utils.remove_weight_norm(m)
|
| 167 |
+
except ValueError: # this module didn't have weight norm
|
| 168 |
+
return
|
| 169 |
+
|
| 170 |
+
self.apply(remove_weight_norm)
|
Amphion/models/vocoders/diffusion/diffusion_vocoder_inference.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from utils.util import pad_mels_to_tensors, pad_f0_to_tensors
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def vocoder_inference(cfg, model, mels, f0s=None, device=None, fast_inference=False):
|
| 14 |
+
"""Inference the vocoder
|
| 15 |
+
Args:
|
| 16 |
+
mels: A tensor of mel-specs with the shape (batch_size, num_mels, frames)
|
| 17 |
+
Returns:
|
| 18 |
+
audios: A tensor of audios with the shape (batch_size, seq_len)
|
| 19 |
+
"""
|
| 20 |
+
model.eval()
|
| 21 |
+
|
| 22 |
+
with torch.no_grad():
|
| 23 |
+
training_noise_schedule = np.array(cfg.model.diffwave.noise_schedule)
|
| 24 |
+
inference_noise_schedule = (
|
| 25 |
+
np.array(cfg.model.diffwave.inference_noise_schedule)
|
| 26 |
+
if fast_inference
|
| 27 |
+
else np.array(cfg.model.diffwave.noise_schedule)
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
talpha = 1 - training_noise_schedule
|
| 31 |
+
talpha_cum = np.cumprod(talpha)
|
| 32 |
+
|
| 33 |
+
beta = inference_noise_schedule
|
| 34 |
+
alpha = 1 - beta
|
| 35 |
+
alpha_cum = np.cumprod(alpha)
|
| 36 |
+
|
| 37 |
+
T = []
|
| 38 |
+
for s in range(len(inference_noise_schedule)):
|
| 39 |
+
for t in range(len(training_noise_schedule) - 1):
|
| 40 |
+
if talpha_cum[t + 1] <= alpha_cum[s] <= talpha_cum[t]:
|
| 41 |
+
twiddle = (talpha_cum[t] ** 0.5 - alpha_cum[s] ** 0.5) / (
|
| 42 |
+
talpha_cum[t] ** 0.5 - talpha_cum[t + 1] ** 0.5
|
| 43 |
+
)
|
| 44 |
+
T.append(t + twiddle)
|
| 45 |
+
break
|
| 46 |
+
T = np.array(T, dtype=np.float32)
|
| 47 |
+
|
| 48 |
+
mels = mels.to(device)
|
| 49 |
+
audio = torch.randn(
|
| 50 |
+
mels.shape[0],
|
| 51 |
+
cfg.preprocess.hop_size * mels.shape[-1],
|
| 52 |
+
device=device,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
for n in tqdm(range(len(alpha) - 1, -1, -1)):
|
| 56 |
+
c1 = 1 / alpha[n] ** 0.5
|
| 57 |
+
c2 = beta[n] / (1 - alpha_cum[n]) ** 0.5
|
| 58 |
+
audio = c1 * (
|
| 59 |
+
audio
|
| 60 |
+
- c2
|
| 61 |
+
* model(audio, torch.tensor([T[n]], device=audio.device), mels).squeeze(
|
| 62 |
+
1
|
| 63 |
+
)
|
| 64 |
+
)
|
| 65 |
+
if n > 0:
|
| 66 |
+
noise = torch.randn_like(audio)
|
| 67 |
+
sigma = (
|
| 68 |
+
(1.0 - alpha_cum[n - 1]) / (1.0 - alpha_cum[n]) * beta[n]
|
| 69 |
+
) ** 0.5
|
| 70 |
+
audio += sigma * noise
|
| 71 |
+
audio = torch.clamp(audio, -1.0, 1.0)
|
| 72 |
+
|
| 73 |
+
return audio.detach().cpu()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def synthesis_audios(cfg, model, mels, f0s=None, batch_size=None, fast_inference=False):
|
| 77 |
+
"""Inference the vocoder
|
| 78 |
+
Args:
|
| 79 |
+
mels: A list of mel-specs
|
| 80 |
+
Returns:
|
| 81 |
+
audios: A list of audios
|
| 82 |
+
"""
|
| 83 |
+
# Get the device
|
| 84 |
+
device = next(model.parameters()).device
|
| 85 |
+
|
| 86 |
+
audios = []
|
| 87 |
+
|
| 88 |
+
# Pad the given list into tensors
|
| 89 |
+
mel_batches, mel_frames = pad_mels_to_tensors(mels, batch_size)
|
| 90 |
+
if f0s != None:
|
| 91 |
+
f0_batches = pad_f0_to_tensors(f0s, batch_size)
|
| 92 |
+
|
| 93 |
+
if f0s == None:
|
| 94 |
+
for mel_batch, mel_frame in zip(mel_batches, mel_frames):
|
| 95 |
+
for i in range(mel_batch.shape[0]):
|
| 96 |
+
mel = mel_batch[i]
|
| 97 |
+
frame = mel_frame[i]
|
| 98 |
+
audio = vocoder_inference(
|
| 99 |
+
cfg,
|
| 100 |
+
model,
|
| 101 |
+
mel.unsqueeze(0),
|
| 102 |
+
device=device,
|
| 103 |
+
fast_inference=fast_inference,
|
| 104 |
+
).squeeze(0)
|
| 105 |
+
|
| 106 |
+
# calculate the audio length
|
| 107 |
+
audio_length = frame * cfg.preprocess.hop_size
|
| 108 |
+
audio = audio[:audio_length]
|
| 109 |
+
|
| 110 |
+
audios.append(audio)
|
| 111 |
+
else:
|
| 112 |
+
for mel_batch, f0_batch, mel_frame in zip(mel_batches, f0_batches, mel_frames):
|
| 113 |
+
for i in range(mel_batch.shape[0]):
|
| 114 |
+
mel = mel_batch[i]
|
| 115 |
+
f0 = f0_batch[i]
|
| 116 |
+
frame = mel_frame[i]
|
| 117 |
+
audio = vocoder_inference(
|
| 118 |
+
cfg,
|
| 119 |
+
model,
|
| 120 |
+
mel.unsqueeze(0),
|
| 121 |
+
f0s=f0.unsqueeze(0),
|
| 122 |
+
device=device,
|
| 123 |
+
fast_inference=fast_inference,
|
| 124 |
+
).squeeze(0)
|
| 125 |
+
|
| 126 |
+
# calculate the audio length
|
| 127 |
+
audio_length = frame * cfg.preprocess.hop_size
|
| 128 |
+
audio = audio[:audio_length]
|
| 129 |
+
|
| 130 |
+
audios.append(audio)
|
| 131 |
+
return audios
|
Amphion/models/vocoders/flow/flow_vocoder_dataset.py
ADDED
|
File without changes
|
Amphion/models/vocoders/flow/flow_vocoder_inference.py
ADDED
|
File without changes
|
Amphion/models/vocoders/gan/discriminator/msd.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.nn import Conv1d, AvgPool1d
|
| 10 |
+
from torch.nn.utils import weight_norm, spectral_norm
|
| 11 |
+
from torch import nn
|
| 12 |
+
from modules.vocoder_blocks import *
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
LRELU_SLOPE = 0.1
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DiscriminatorS(nn.Module):
|
| 19 |
+
def __init__(self, use_spectral_norm=False):
|
| 20 |
+
super(DiscriminatorS, self).__init__()
|
| 21 |
+
|
| 22 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 23 |
+
|
| 24 |
+
self.convs = nn.ModuleList(
|
| 25 |
+
[
|
| 26 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
| 27 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
| 28 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
| 29 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
| 30 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
| 31 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
| 32 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
| 33 |
+
]
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
fmap = []
|
| 40 |
+
|
| 41 |
+
for l in self.convs:
|
| 42 |
+
x = l(x)
|
| 43 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 44 |
+
fmap.append(x)
|
| 45 |
+
|
| 46 |
+
x = self.conv_post(x)
|
| 47 |
+
fmap.append(x)
|
| 48 |
+
x = torch.flatten(x, 1, -1)
|
| 49 |
+
|
| 50 |
+
return x, fmap
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class MultiScaleDiscriminator(nn.Module):
|
| 54 |
+
def __init__(self, cfg):
|
| 55 |
+
super(MultiScaleDiscriminator, self).__init__()
|
| 56 |
+
|
| 57 |
+
self.cfg = cfg
|
| 58 |
+
|
| 59 |
+
self.discriminators = nn.ModuleList(
|
| 60 |
+
[
|
| 61 |
+
DiscriminatorS(use_spectral_norm=True),
|
| 62 |
+
DiscriminatorS(),
|
| 63 |
+
DiscriminatorS(),
|
| 64 |
+
]
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.meanpools = nn.ModuleList(
|
| 68 |
+
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def forward(self, y, y_hat):
|
| 72 |
+
y_d_rs = []
|
| 73 |
+
y_d_gs = []
|
| 74 |
+
fmap_rs = []
|
| 75 |
+
fmap_gs = []
|
| 76 |
+
|
| 77 |
+
for i, d in enumerate(self.discriminators):
|
| 78 |
+
if i != 0:
|
| 79 |
+
y = self.meanpools[i - 1](y)
|
| 80 |
+
y_hat = self.meanpools[i - 1](y_hat)
|
| 81 |
+
y_d_r, fmap_r = d(y)
|
| 82 |
+
y_d_g, fmap_g = d(y_hat)
|
| 83 |
+
y_d_rs.append(y_d_r)
|
| 84 |
+
fmap_rs.append(fmap_r)
|
| 85 |
+
y_d_gs.append(y_d_g)
|
| 86 |
+
fmap_gs.append(fmap_g)
|
| 87 |
+
|
| 88 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
Amphion/models/vocoders/gan/gan_vocoder_inference.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from utils.util import pad_mels_to_tensors, pad_f0_to_tensors
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def vocoder_inference(cfg, model, mels, f0s=None, device=None, fast_inference=False):
|
| 12 |
+
"""Inference the vocoder
|
| 13 |
+
Args:
|
| 14 |
+
mels: A tensor of mel-specs with the shape (batch_size, num_mels, frames)
|
| 15 |
+
Returns:
|
| 16 |
+
audios: A tensor of audios with the shape (batch_size, seq_len)
|
| 17 |
+
"""
|
| 18 |
+
model.eval()
|
| 19 |
+
|
| 20 |
+
with torch.no_grad():
|
| 21 |
+
mels = mels.to(device)
|
| 22 |
+
if f0s != None:
|
| 23 |
+
f0s = f0s.to(device)
|
| 24 |
+
|
| 25 |
+
if f0s == None and not cfg.preprocess.extract_amplitude_phase:
|
| 26 |
+
output = model.forward(mels)
|
| 27 |
+
elif cfg.preprocess.extract_amplitude_phase:
|
| 28 |
+
(
|
| 29 |
+
_,
|
| 30 |
+
_,
|
| 31 |
+
_,
|
| 32 |
+
_,
|
| 33 |
+
output,
|
| 34 |
+
) = model.forward(mels)
|
| 35 |
+
else:
|
| 36 |
+
output = model.forward(mels, f0s)
|
| 37 |
+
|
| 38 |
+
return output.squeeze(1).detach().cpu()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def synthesis_audios(cfg, model, mels, f0s=None, batch_size=None, fast_inference=False):
|
| 42 |
+
"""Inference the vocoder
|
| 43 |
+
Args:
|
| 44 |
+
mels: A list of mel-specs
|
| 45 |
+
Returns:
|
| 46 |
+
audios: A list of audios
|
| 47 |
+
"""
|
| 48 |
+
# Get the device
|
| 49 |
+
device = next(model.parameters()).device
|
| 50 |
+
|
| 51 |
+
audios = []
|
| 52 |
+
|
| 53 |
+
# Pad the given list into tensors
|
| 54 |
+
mel_batches, mel_frames = pad_mels_to_tensors(mels, batch_size)
|
| 55 |
+
if f0s != None:
|
| 56 |
+
f0_batches = pad_f0_to_tensors(f0s, batch_size)
|
| 57 |
+
|
| 58 |
+
if f0s == None:
|
| 59 |
+
for mel_batch, mel_frame in zip(mel_batches, mel_frames):
|
| 60 |
+
for i in range(mel_batch.shape[0]):
|
| 61 |
+
mel = mel_batch[i]
|
| 62 |
+
frame = mel_frame[i]
|
| 63 |
+
audio = vocoder_inference(
|
| 64 |
+
cfg,
|
| 65 |
+
model,
|
| 66 |
+
mel.unsqueeze(0),
|
| 67 |
+
device=device,
|
| 68 |
+
fast_inference=fast_inference,
|
| 69 |
+
).squeeze(0)
|
| 70 |
+
|
| 71 |
+
# calculate the audio length
|
| 72 |
+
audio_length = frame * model.cfg.preprocess.hop_size
|
| 73 |
+
audio = audio[:audio_length]
|
| 74 |
+
|
| 75 |
+
audios.append(audio)
|
| 76 |
+
else:
|
| 77 |
+
for mel_batch, f0_batch, mel_frame in zip(mel_batches, f0_batches, mel_frames):
|
| 78 |
+
for i in range(mel_batch.shape[0]):
|
| 79 |
+
mel = mel_batch[i]
|
| 80 |
+
f0 = f0_batch[i]
|
| 81 |
+
frame = mel_frame[i]
|
| 82 |
+
audio = vocoder_inference(
|
| 83 |
+
cfg,
|
| 84 |
+
model,
|
| 85 |
+
mel.unsqueeze(0),
|
| 86 |
+
f0s=f0.unsqueeze(0),
|
| 87 |
+
device=device,
|
| 88 |
+
fast_inference=fast_inference,
|
| 89 |
+
).squeeze(0)
|
| 90 |
+
|
| 91 |
+
# calculate the audio length
|
| 92 |
+
audio_length = frame * model.cfg.preprocess.hop_size
|
| 93 |
+
audio = audio[:audio_length]
|
| 94 |
+
|
| 95 |
+
audios.append(audio)
|
| 96 |
+
return audios
|
Amphion/models/vocoders/vocoder_dataset.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Iterable
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch.utils.data
|
| 10 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 11 |
+
from utils.data_utils import *
|
| 12 |
+
from torch.utils.data import ConcatDataset, Dataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class VocoderDataset(torch.utils.data.Dataset):
|
| 16 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
| 17 |
+
"""
|
| 18 |
+
Args:
|
| 19 |
+
cfg: config
|
| 20 |
+
dataset: dataset name
|
| 21 |
+
is_valid: whether to use train or valid dataset
|
| 22 |
+
"""
|
| 23 |
+
assert isinstance(dataset, str)
|
| 24 |
+
|
| 25 |
+
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
|
| 26 |
+
|
| 27 |
+
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
|
| 28 |
+
self.metafile_path = os.path.join(processed_data_dir, meta_file)
|
| 29 |
+
self.metadata = self.get_metadata()
|
| 30 |
+
|
| 31 |
+
self.data_root = processed_data_dir
|
| 32 |
+
self.cfg = cfg
|
| 33 |
+
|
| 34 |
+
if cfg.preprocess.use_audio:
|
| 35 |
+
self.utt2audio_path = {}
|
| 36 |
+
for utt_info in self.metadata:
|
| 37 |
+
dataset = utt_info["Dataset"]
|
| 38 |
+
uid = utt_info["Uid"]
|
| 39 |
+
utt = "{}_{}".format(dataset, uid)
|
| 40 |
+
|
| 41 |
+
self.utt2audio_path[utt] = os.path.join(
|
| 42 |
+
cfg.preprocess.processed_dir,
|
| 43 |
+
dataset,
|
| 44 |
+
cfg.preprocess.audio_dir,
|
| 45 |
+
uid + ".npy",
|
| 46 |
+
)
|
| 47 |
+
elif cfg.preprocess.use_label:
|
| 48 |
+
self.utt2label_path = {}
|
| 49 |
+
for utt_info in self.metadata:
|
| 50 |
+
dataset = utt_info["Dataset"]
|
| 51 |
+
uid = utt_info["Uid"]
|
| 52 |
+
utt = "{}_{}".format(dataset, uid)
|
| 53 |
+
|
| 54 |
+
self.utt2label_path[utt] = os.path.join(
|
| 55 |
+
cfg.preprocess.processed_dir,
|
| 56 |
+
dataset,
|
| 57 |
+
cfg.preprocess.label_dir,
|
| 58 |
+
uid + ".npy",
|
| 59 |
+
)
|
| 60 |
+
elif cfg.preprocess.use_one_hot:
|
| 61 |
+
self.utt2one_hot_path = {}
|
| 62 |
+
for utt_info in self.metadata:
|
| 63 |
+
dataset = utt_info["Dataset"]
|
| 64 |
+
uid = utt_info["Uid"]
|
| 65 |
+
utt = "{}_{}".format(dataset, uid)
|
| 66 |
+
|
| 67 |
+
self.utt2one_hot_path[utt] = os.path.join(
|
| 68 |
+
cfg.preprocess.processed_dir,
|
| 69 |
+
dataset,
|
| 70 |
+
cfg.preprocess.one_hot_dir,
|
| 71 |
+
uid + ".npy",
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
if cfg.preprocess.use_mel:
|
| 75 |
+
self.utt2mel_path = {}
|
| 76 |
+
for utt_info in self.metadata:
|
| 77 |
+
dataset = utt_info["Dataset"]
|
| 78 |
+
uid = utt_info["Uid"]
|
| 79 |
+
utt = "{}_{}".format(dataset, uid)
|
| 80 |
+
|
| 81 |
+
self.utt2mel_path[utt] = os.path.join(
|
| 82 |
+
cfg.preprocess.processed_dir,
|
| 83 |
+
dataset,
|
| 84 |
+
cfg.preprocess.mel_dir,
|
| 85 |
+
uid + ".npy",
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if cfg.preprocess.use_frame_pitch:
|
| 89 |
+
self.utt2frame_pitch_path = {}
|
| 90 |
+
for utt_info in self.metadata:
|
| 91 |
+
dataset = utt_info["Dataset"]
|
| 92 |
+
uid = utt_info["Uid"]
|
| 93 |
+
utt = "{}_{}".format(dataset, uid)
|
| 94 |
+
|
| 95 |
+
self.utt2frame_pitch_path[utt] = os.path.join(
|
| 96 |
+
cfg.preprocess.processed_dir,
|
| 97 |
+
dataset,
|
| 98 |
+
cfg.preprocess.pitch_dir,
|
| 99 |
+
uid + ".npy",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
if cfg.preprocess.use_uv:
|
| 103 |
+
self.utt2uv_path = {}
|
| 104 |
+
for utt_info in self.metadata:
|
| 105 |
+
dataset = utt_info["Dataset"]
|
| 106 |
+
uid = utt_info["Uid"]
|
| 107 |
+
utt = "{}_{}".format(dataset, uid)
|
| 108 |
+
self.utt2uv_path[utt] = os.path.join(
|
| 109 |
+
cfg.preprocess.processed_dir,
|
| 110 |
+
dataset,
|
| 111 |
+
cfg.preprocess.uv_dir,
|
| 112 |
+
uid + ".npy",
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if cfg.preprocess.use_amplitude_phase:
|
| 116 |
+
self.utt2logamp_path = {}
|
| 117 |
+
self.utt2pha_path = {}
|
| 118 |
+
self.utt2rea_path = {}
|
| 119 |
+
self.utt2imag_path = {}
|
| 120 |
+
for utt_info in self.metadata:
|
| 121 |
+
dataset = utt_info["Dataset"]
|
| 122 |
+
uid = utt_info["Uid"]
|
| 123 |
+
utt = "{}_{}".format(dataset, uid)
|
| 124 |
+
self.utt2logamp_path[utt] = os.path.join(
|
| 125 |
+
cfg.preprocess.processed_dir,
|
| 126 |
+
dataset,
|
| 127 |
+
cfg.preprocess.log_amplitude_dir,
|
| 128 |
+
uid + ".npy",
|
| 129 |
+
)
|
| 130 |
+
self.utt2pha_path[utt] = os.path.join(
|
| 131 |
+
cfg.preprocess.processed_dir,
|
| 132 |
+
dataset,
|
| 133 |
+
cfg.preprocess.phase_dir,
|
| 134 |
+
uid + ".npy",
|
| 135 |
+
)
|
| 136 |
+
self.utt2rea_path[utt] = os.path.join(
|
| 137 |
+
cfg.preprocess.processed_dir,
|
| 138 |
+
dataset,
|
| 139 |
+
cfg.preprocess.real_dir,
|
| 140 |
+
uid + ".npy",
|
| 141 |
+
)
|
| 142 |
+
self.utt2imag_path[utt] = os.path.join(
|
| 143 |
+
cfg.preprocess.processed_dir,
|
| 144 |
+
dataset,
|
| 145 |
+
cfg.preprocess.imaginary_dir,
|
| 146 |
+
uid + ".npy",
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def __getitem__(self, index):
|
| 150 |
+
utt_info = self.metadata[index]
|
| 151 |
+
|
| 152 |
+
dataset = utt_info["Dataset"]
|
| 153 |
+
uid = utt_info["Uid"]
|
| 154 |
+
utt = "{}_{}".format(dataset, uid)
|
| 155 |
+
|
| 156 |
+
single_feature = dict()
|
| 157 |
+
|
| 158 |
+
if self.cfg.preprocess.use_mel:
|
| 159 |
+
mel = np.load(self.utt2mel_path[utt])
|
| 160 |
+
assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
|
| 161 |
+
|
| 162 |
+
if "target_len" not in single_feature.keys():
|
| 163 |
+
single_feature["target_len"] = mel.shape[1]
|
| 164 |
+
|
| 165 |
+
single_feature["mel"] = mel
|
| 166 |
+
|
| 167 |
+
if self.cfg.preprocess.use_frame_pitch:
|
| 168 |
+
frame_pitch = np.load(self.utt2frame_pitch_path[utt])
|
| 169 |
+
|
| 170 |
+
if "target_len" not in single_feature.keys():
|
| 171 |
+
single_feature["target_len"] = len(frame_pitch)
|
| 172 |
+
|
| 173 |
+
aligned_frame_pitch = align_length(
|
| 174 |
+
frame_pitch, single_feature["target_len"]
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
single_feature["frame_pitch"] = aligned_frame_pitch
|
| 178 |
+
|
| 179 |
+
if self.cfg.preprocess.use_audio:
|
| 180 |
+
audio = np.load(self.utt2audio_path[utt])
|
| 181 |
+
|
| 182 |
+
single_feature["audio"] = audio
|
| 183 |
+
|
| 184 |
+
return single_feature
|
| 185 |
+
|
| 186 |
+
def get_metadata(self):
|
| 187 |
+
with open(self.metafile_path, "r", encoding="utf-8") as f:
|
| 188 |
+
metadata = json.load(f)
|
| 189 |
+
|
| 190 |
+
return metadata
|
| 191 |
+
|
| 192 |
+
def get_dataset_name(self):
|
| 193 |
+
return self.metadata[0]["Dataset"]
|
| 194 |
+
|
| 195 |
+
def __len__(self):
|
| 196 |
+
return len(self.metadata)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class VocoderConcatDataset(ConcatDataset):
|
| 200 |
+
def __init__(self, datasets: Iterable[Dataset], full_audio_inference=False):
|
| 201 |
+
"""Concatenate a series of datasets with their random inference audio merged."""
|
| 202 |
+
super().__init__(datasets)
|
| 203 |
+
|
| 204 |
+
self.cfg = self.datasets[0].cfg
|
| 205 |
+
|
| 206 |
+
self.metadata = []
|
| 207 |
+
|
| 208 |
+
# Merge metadata
|
| 209 |
+
for dataset in self.datasets:
|
| 210 |
+
self.metadata += dataset.metadata
|
| 211 |
+
|
| 212 |
+
# Merge random inference features
|
| 213 |
+
if full_audio_inference:
|
| 214 |
+
self.eval_audios = []
|
| 215 |
+
self.eval_dataset_names = []
|
| 216 |
+
if self.cfg.preprocess.use_mel:
|
| 217 |
+
self.eval_mels = []
|
| 218 |
+
if self.cfg.preprocess.use_frame_pitch:
|
| 219 |
+
self.eval_pitchs = []
|
| 220 |
+
for dataset in self.datasets:
|
| 221 |
+
self.eval_audios.append(dataset.eval_audio)
|
| 222 |
+
self.eval_dataset_names.append(dataset.get_dataset_name())
|
| 223 |
+
if self.cfg.preprocess.use_mel:
|
| 224 |
+
self.eval_mels.append(dataset.eval_mel)
|
| 225 |
+
if self.cfg.preprocess.use_frame_pitch:
|
| 226 |
+
self.eval_pitchs.append(dataset.eval_pitch)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class VocoderCollator(object):
|
| 230 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
| 231 |
+
|
| 232 |
+
def __init__(self, cfg):
|
| 233 |
+
self.cfg = cfg
|
| 234 |
+
|
| 235 |
+
def __call__(self, batch):
|
| 236 |
+
packed_batch_features = dict()
|
| 237 |
+
|
| 238 |
+
# mel: [b, n_mels, frame]
|
| 239 |
+
# frame_pitch: [b, frame]
|
| 240 |
+
# audios: [b, frame * hop_size]
|
| 241 |
+
|
| 242 |
+
for key in batch[0].keys():
|
| 243 |
+
if key == "target_len":
|
| 244 |
+
packed_batch_features["target_len"] = torch.LongTensor(
|
| 245 |
+
[b["target_len"] for b in batch]
|
| 246 |
+
)
|
| 247 |
+
masks = [
|
| 248 |
+
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
|
| 249 |
+
]
|
| 250 |
+
packed_batch_features["mask"] = pad_sequence(
|
| 251 |
+
masks, batch_first=True, padding_value=0
|
| 252 |
+
)
|
| 253 |
+
elif key == "mel":
|
| 254 |
+
values = [torch.from_numpy(b[key]).T for b in batch]
|
| 255 |
+
packed_batch_features[key] = pad_sequence(
|
| 256 |
+
values, batch_first=True, padding_value=0
|
| 257 |
+
)
|
| 258 |
+
else:
|
| 259 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
| 260 |
+
packed_batch_features[key] = pad_sequence(
|
| 261 |
+
values, batch_first=True, padding_value=0
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
return packed_batch_features
|
Amphion/models/vocoders/vocoder_sampler.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
from torch.utils.data import ConcatDataset, Dataset
|
| 10 |
+
from torch.utils.data.sampler import (
|
| 11 |
+
BatchSampler,
|
| 12 |
+
RandomSampler,
|
| 13 |
+
Sampler,
|
| 14 |
+
SequentialSampler,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ScheduledSampler(Sampler):
|
| 19 |
+
"""A sampler that samples data from a given concat-dataset.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
|
| 23 |
+
batch_size (int): batch size
|
| 24 |
+
holistic_shuffle (bool): whether to shuffle the whole dataset or not
|
| 25 |
+
logger (logging.Logger): logger to print warning message
|
| 26 |
+
|
| 27 |
+
Usage:
|
| 28 |
+
For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
|
| 29 |
+
>>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
|
| 30 |
+
[3, 4, 5, 0, 1, 2, 6, 7, 8]
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self, concat_dataset, batch_size, holistic_shuffle, logger=None, type="train"
|
| 35 |
+
):
|
| 36 |
+
if not isinstance(concat_dataset, ConcatDataset):
|
| 37 |
+
raise ValueError(
|
| 38 |
+
"concat_dataset must be an instance of ConcatDataset, but got {}".format(
|
| 39 |
+
type(concat_dataset)
|
| 40 |
+
)
|
| 41 |
+
)
|
| 42 |
+
if not isinstance(batch_size, int):
|
| 43 |
+
raise ValueError(
|
| 44 |
+
"batch_size must be an integer, but got {}".format(type(batch_size))
|
| 45 |
+
)
|
| 46 |
+
if not isinstance(holistic_shuffle, bool):
|
| 47 |
+
raise ValueError(
|
| 48 |
+
"holistic_shuffle must be a boolean, but got {}".format(
|
| 49 |
+
type(holistic_shuffle)
|
| 50 |
+
)
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
self.concat_dataset = concat_dataset
|
| 54 |
+
self.batch_size = batch_size
|
| 55 |
+
self.holistic_shuffle = holistic_shuffle
|
| 56 |
+
|
| 57 |
+
affected_dataset_name = []
|
| 58 |
+
affected_dataset_len = []
|
| 59 |
+
for dataset in concat_dataset.datasets:
|
| 60 |
+
dataset_len = len(dataset)
|
| 61 |
+
dataset_name = dataset.get_dataset_name()
|
| 62 |
+
if dataset_len < batch_size:
|
| 63 |
+
affected_dataset_name.append(dataset_name)
|
| 64 |
+
affected_dataset_len.append(dataset_len)
|
| 65 |
+
|
| 66 |
+
self.type = type
|
| 67 |
+
for dataset_name, dataset_len in zip(
|
| 68 |
+
affected_dataset_name, affected_dataset_len
|
| 69 |
+
):
|
| 70 |
+
if not type == "valid":
|
| 71 |
+
logger.warning(
|
| 72 |
+
"The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
|
| 73 |
+
type, dataset_name, dataset_len, batch_size
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def __len__(self):
|
| 78 |
+
# the number of batches with drop last
|
| 79 |
+
num_of_batches = sum(
|
| 80 |
+
[
|
| 81 |
+
math.floor(len(dataset) / self.batch_size)
|
| 82 |
+
for dataset in self.concat_dataset.datasets
|
| 83 |
+
]
|
| 84 |
+
)
|
| 85 |
+
return num_of_batches * self.batch_size
|
| 86 |
+
|
| 87 |
+
def __iter__(self):
|
| 88 |
+
iters = []
|
| 89 |
+
for dataset in self.concat_dataset.datasets:
|
| 90 |
+
iters.append(
|
| 91 |
+
SequentialSampler(dataset).__iter__()
|
| 92 |
+
if self.holistic_shuffle
|
| 93 |
+
else RandomSampler(dataset).__iter__()
|
| 94 |
+
)
|
| 95 |
+
init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
|
| 96 |
+
output_batches = []
|
| 97 |
+
for dataset_idx in range(len(self.concat_dataset.datasets)):
|
| 98 |
+
cur_batch = []
|
| 99 |
+
for idx in iters[dataset_idx]:
|
| 100 |
+
cur_batch.append(idx + init_indices[dataset_idx])
|
| 101 |
+
if len(cur_batch) == self.batch_size:
|
| 102 |
+
output_batches.append(cur_batch)
|
| 103 |
+
cur_batch = []
|
| 104 |
+
if self.type == "valid" and len(cur_batch) > 0:
|
| 105 |
+
output_batches.append(cur_batch)
|
| 106 |
+
cur_batch = []
|
| 107 |
+
# force drop last in training
|
| 108 |
+
random.shuffle(output_batches)
|
| 109 |
+
output_indices = [item for sublist in output_batches for item in sublist]
|
| 110 |
+
return iter(output_indices)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def build_samplers(concat_dataset: Dataset, cfg, logger, type):
|
| 114 |
+
sampler = ScheduledSampler(
|
| 115 |
+
concat_dataset,
|
| 116 |
+
cfg.train.batch_size,
|
| 117 |
+
cfg.train.sampler.holistic_shuffle,
|
| 118 |
+
logger,
|
| 119 |
+
type,
|
| 120 |
+
)
|
| 121 |
+
batch_sampler = BatchSampler(
|
| 122 |
+
sampler,
|
| 123 |
+
cfg.train.batch_size,
|
| 124 |
+
cfg.train.sampler.drop_last if not type == "valid" else False,
|
| 125 |
+
)
|
| 126 |
+
return sampler, batch_sampler
|
Amphion/modules/activation_functions/snake.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn, pow, sin
|
| 8 |
+
from torch.nn import Parameter
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Snake(nn.Module):
|
| 12 |
+
r"""Implementation of a sine-based periodic activation function.
|
| 13 |
+
Alpha is initialized to 1 by default, higher values means higher frequency.
|
| 14 |
+
It will be trained along with the rest of your model.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
in_features: shape of the input
|
| 18 |
+
alpha: trainable parameter
|
| 19 |
+
|
| 20 |
+
Shape:
|
| 21 |
+
- Input: (B, C, T)
|
| 22 |
+
- Output: (B, C, T), same shape as the input
|
| 23 |
+
|
| 24 |
+
References:
|
| 25 |
+
This activation function is from this paper by Liu Ziyin, Tilman Hartwig,
|
| 26 |
+
Masahito Ueda: https://arxiv.org/abs/2006.08195
|
| 27 |
+
|
| 28 |
+
Examples:
|
| 29 |
+
>>> a1 = Snake(256)
|
| 30 |
+
>>> x = torch.randn(256)
|
| 31 |
+
>>> x = a1(x)
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
| 36 |
+
):
|
| 37 |
+
super(Snake, self).__init__()
|
| 38 |
+
self.in_features = in_features
|
| 39 |
+
|
| 40 |
+
# initialize alpha
|
| 41 |
+
self.alpha_logscale = alpha_logscale
|
| 42 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 43 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 44 |
+
else: # linear scale alphas initialized to ones
|
| 45 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 46 |
+
|
| 47 |
+
self.alpha.requires_grad = alpha_trainable
|
| 48 |
+
|
| 49 |
+
self.no_div_by_zero = 0.000000001
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
r"""Forward pass of the function. Applies the function to the input elementwise.
|
| 53 |
+
Snake ∶= x + 1/a * sin^2 (ax)
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 57 |
+
if self.alpha_logscale:
|
| 58 |
+
alpha = torch.exp(alpha)
|
| 59 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 60 |
+
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class SnakeBeta(nn.Module):
|
| 65 |
+
r"""A modified Snake function which uses separate parameters for the magnitude
|
| 66 |
+
of the periodic components. Alpha is initialized to 1 by default,
|
| 67 |
+
higher values means higher frequency. Beta is initialized to 1 by default,
|
| 68 |
+
higher values means higher magnitude. Both will be trained along with the
|
| 69 |
+
rest of your model.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
in_features: shape of the input
|
| 73 |
+
alpha: trainable parameter that controls frequency
|
| 74 |
+
beta: trainable parameter that controls magnitude
|
| 75 |
+
|
| 76 |
+
Shape:
|
| 77 |
+
- Input: (B, C, T)
|
| 78 |
+
- Output: (B, C, T), same shape as the input
|
| 79 |
+
|
| 80 |
+
References:
|
| 81 |
+
This activation function is a modified version based on this paper by Liu Ziyin,
|
| 82 |
+
Tilman Hartwig, Masahito Ueda: https://arxiv.org/abs/2006.08195
|
| 83 |
+
|
| 84 |
+
Examples:
|
| 85 |
+
>>> a1 = SnakeBeta(256)
|
| 86 |
+
>>> x = torch.randn(256)
|
| 87 |
+
>>> x = a1(x)
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
| 92 |
+
):
|
| 93 |
+
super(SnakeBeta, self).__init__()
|
| 94 |
+
self.in_features = in_features
|
| 95 |
+
|
| 96 |
+
# initialize alpha
|
| 97 |
+
self.alpha_logscale = alpha_logscale
|
| 98 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 99 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 100 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
| 101 |
+
else: # linear scale alphas initialized to ones
|
| 102 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 103 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
| 104 |
+
|
| 105 |
+
self.alpha.requires_grad = alpha_trainable
|
| 106 |
+
self.beta.requires_grad = alpha_trainable
|
| 107 |
+
|
| 108 |
+
self.no_div_by_zero = 0.000000001
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
r"""Forward pass of the function. Applies the function to the input elementwise.
|
| 112 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 116 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 117 |
+
if self.alpha_logscale:
|
| 118 |
+
alpha = torch.exp(alpha)
|
| 119 |
+
beta = torch.exp(beta)
|
| 120 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 121 |
+
|
| 122 |
+
return x
|
Amphion/modules/diffusion/bidilconv/bidilated_conv.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from modules.general.utils import Conv1d, zero_module
|
| 11 |
+
from .residual_block import ResidualBlock
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class BiDilConv(nn.Module):
|
| 15 |
+
r"""Dilated CNN architecture with residual connections, default diffusion decoder.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
input_channel: The number of input channels.
|
| 19 |
+
base_channel: The number of base channels.
|
| 20 |
+
n_res_block: The number of residual blocks.
|
| 21 |
+
conv_kernel_size: The kernel size of convolutional layers.
|
| 22 |
+
dilation_cycle_length: The cycle length of dilation.
|
| 23 |
+
conditioner_size: The size of conditioner.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
input_channel,
|
| 29 |
+
base_channel,
|
| 30 |
+
n_res_block,
|
| 31 |
+
conv_kernel_size,
|
| 32 |
+
dilation_cycle_length,
|
| 33 |
+
conditioner_size,
|
| 34 |
+
output_channel: int = -1,
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
self.input_channel = input_channel
|
| 39 |
+
self.base_channel = base_channel
|
| 40 |
+
self.n_res_block = n_res_block
|
| 41 |
+
self.conv_kernel_size = conv_kernel_size
|
| 42 |
+
self.dilation_cycle_length = dilation_cycle_length
|
| 43 |
+
self.conditioner_size = conditioner_size
|
| 44 |
+
self.output_channel = output_channel if output_channel > 0 else input_channel
|
| 45 |
+
|
| 46 |
+
self.input = nn.Sequential(
|
| 47 |
+
Conv1d(
|
| 48 |
+
input_channel,
|
| 49 |
+
base_channel,
|
| 50 |
+
1,
|
| 51 |
+
),
|
| 52 |
+
nn.ReLU(),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.residual_blocks = nn.ModuleList(
|
| 56 |
+
[
|
| 57 |
+
ResidualBlock(
|
| 58 |
+
channels=base_channel,
|
| 59 |
+
kernel_size=conv_kernel_size,
|
| 60 |
+
dilation=2 ** (i % dilation_cycle_length),
|
| 61 |
+
d_context=conditioner_size,
|
| 62 |
+
)
|
| 63 |
+
for i in range(n_res_block)
|
| 64 |
+
]
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.out_proj = nn.Sequential(
|
| 68 |
+
Conv1d(
|
| 69 |
+
base_channel,
|
| 70 |
+
base_channel,
|
| 71 |
+
1,
|
| 72 |
+
),
|
| 73 |
+
nn.ReLU(),
|
| 74 |
+
zero_module(
|
| 75 |
+
Conv1d(
|
| 76 |
+
base_channel,
|
| 77 |
+
self.output_channel,
|
| 78 |
+
1,
|
| 79 |
+
),
|
| 80 |
+
),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def forward(self, x, y, context=None):
|
| 84 |
+
"""
|
| 85 |
+
Args:
|
| 86 |
+
x: Noisy mel-spectrogram [B x ``n_mel`` x L]
|
| 87 |
+
y: FILM embeddings with the shape of (B, ``base_channel``)
|
| 88 |
+
context: Context with the shape of [B x ``d_context`` x L], default to None.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
h = self.input(x)
|
| 92 |
+
|
| 93 |
+
skip = None
|
| 94 |
+
for i in range(self.n_res_block):
|
| 95 |
+
h, skip_connection = self.residual_blocks[i](h, y, context)
|
| 96 |
+
skip = skip_connection if skip is None else skip_connection + skip
|
| 97 |
+
|
| 98 |
+
out = skip / math.sqrt(self.n_res_block)
|
| 99 |
+
|
| 100 |
+
out = self.out_proj(out)
|
| 101 |
+
|
| 102 |
+
return out
|
Amphion/modules/diffusion/karras/sample.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from abc import ABC, abstractmethod
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch as th
|
| 10 |
+
from scipy.stats import norm
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def create_named_schedule_sampler(name, diffusion):
|
| 15 |
+
"""
|
| 16 |
+
Create a ScheduleSampler from a library of pre-defined samplers.
|
| 17 |
+
|
| 18 |
+
:param name: the name of the sampler.
|
| 19 |
+
:param diffusion: the diffusion object to sample for.
|
| 20 |
+
"""
|
| 21 |
+
if name == "uniform":
|
| 22 |
+
return UniformSampler(diffusion)
|
| 23 |
+
elif name == "loss-second-moment":
|
| 24 |
+
return LossSecondMomentResampler(diffusion)
|
| 25 |
+
elif name == "lognormal":
|
| 26 |
+
return LogNormalSampler()
|
| 27 |
+
else:
|
| 28 |
+
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ScheduleSampler(ABC):
|
| 32 |
+
"""
|
| 33 |
+
A distribution over timesteps in the diffusion process, intended to reduce
|
| 34 |
+
variance of the objective.
|
| 35 |
+
|
| 36 |
+
By default, samplers perform unbiased importance sampling, in which the
|
| 37 |
+
objective's mean is unchanged.
|
| 38 |
+
However, subclasses may override sample() to change how the resampled
|
| 39 |
+
terms are reweighted, allowing for actual changes in the objective.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
@abstractmethod
|
| 43 |
+
def weights(self):
|
| 44 |
+
"""
|
| 45 |
+
Get a numpy array of weights, one per diffusion step.
|
| 46 |
+
|
| 47 |
+
The weights needn't be normalized, but must be positive.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def sample(self, batch_size, device):
|
| 51 |
+
"""
|
| 52 |
+
Importance-sample timesteps for a batch.
|
| 53 |
+
|
| 54 |
+
:param batch_size: the number of timesteps.
|
| 55 |
+
:param device: the torch device to save to.
|
| 56 |
+
:return: a tuple (timesteps, weights):
|
| 57 |
+
- timesteps: a tensor of timestep indices.
|
| 58 |
+
- weights: a tensor of weights to scale the resulting losses.
|
| 59 |
+
"""
|
| 60 |
+
w = self.weights()
|
| 61 |
+
p = w / np.sum(w)
|
| 62 |
+
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
|
| 63 |
+
indices = th.from_numpy(indices_np).long().to(device)
|
| 64 |
+
weights_np = 1 / (len(p) * p[indices_np])
|
| 65 |
+
weights = th.from_numpy(weights_np).float().to(device)
|
| 66 |
+
return indices, weights
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class UniformSampler(ScheduleSampler):
|
| 70 |
+
def __init__(self, diffusion):
|
| 71 |
+
self.diffusion = diffusion
|
| 72 |
+
self._weights = np.ones([diffusion.num_timesteps])
|
| 73 |
+
|
| 74 |
+
def weights(self):
|
| 75 |
+
return self._weights
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class LossAwareSampler(ScheduleSampler):
|
| 79 |
+
def update_with_local_losses(self, local_ts, local_losses):
|
| 80 |
+
"""
|
| 81 |
+
Update the reweighting using losses from a model.
|
| 82 |
+
|
| 83 |
+
Call this method from each rank with a batch of timesteps and the
|
| 84 |
+
corresponding losses for each of those timesteps.
|
| 85 |
+
This method will perform synchronization to make sure all of the ranks
|
| 86 |
+
maintain the exact same reweighting.
|
| 87 |
+
|
| 88 |
+
:param local_ts: an integer Tensor of timesteps.
|
| 89 |
+
:param local_losses: a 1D Tensor of losses.
|
| 90 |
+
"""
|
| 91 |
+
batch_sizes = [
|
| 92 |
+
th.tensor([0], dtype=th.int32, device=local_ts.device)
|
| 93 |
+
for _ in range(dist.get_world_size())
|
| 94 |
+
]
|
| 95 |
+
dist.all_gather(
|
| 96 |
+
batch_sizes,
|
| 97 |
+
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Pad all_gather batches to be the maximum batch size.
|
| 101 |
+
batch_sizes = [x.item() for x in batch_sizes]
|
| 102 |
+
max_bs = max(batch_sizes)
|
| 103 |
+
|
| 104 |
+
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
|
| 105 |
+
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
|
| 106 |
+
dist.all_gather(timestep_batches, local_ts)
|
| 107 |
+
dist.all_gather(loss_batches, local_losses)
|
| 108 |
+
timesteps = [
|
| 109 |
+
x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
|
| 110 |
+
]
|
| 111 |
+
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
|
| 112 |
+
self.update_with_all_losses(timesteps, losses)
|
| 113 |
+
|
| 114 |
+
@abstractmethod
|
| 115 |
+
def update_with_all_losses(self, ts, losses):
|
| 116 |
+
"""
|
| 117 |
+
Update the reweighting using losses from a model.
|
| 118 |
+
|
| 119 |
+
Sub-classes should override this method to update the reweighting
|
| 120 |
+
using losses from the model.
|
| 121 |
+
|
| 122 |
+
This method directly updates the reweighting without synchronizing
|
| 123 |
+
between workers. It is called by update_with_local_losses from all
|
| 124 |
+
ranks with identical arguments. Thus, it should have deterministic
|
| 125 |
+
behavior to maintain state across workers.
|
| 126 |
+
|
| 127 |
+
:param ts: a list of int timesteps.
|
| 128 |
+
:param losses: a list of float losses, one per timestep.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class LossSecondMomentResampler(LossAwareSampler):
|
| 133 |
+
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
|
| 134 |
+
self.diffusion = diffusion
|
| 135 |
+
self.history_per_term = history_per_term
|
| 136 |
+
self.uniform_prob = uniform_prob
|
| 137 |
+
self._loss_history = np.zeros(
|
| 138 |
+
[diffusion.num_timesteps, history_per_term], dtype=np.float64
|
| 139 |
+
)
|
| 140 |
+
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
|
| 141 |
+
|
| 142 |
+
def weights(self):
|
| 143 |
+
if not self._warmed_up():
|
| 144 |
+
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
|
| 145 |
+
weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
|
| 146 |
+
weights /= np.sum(weights)
|
| 147 |
+
weights *= 1 - self.uniform_prob
|
| 148 |
+
weights += self.uniform_prob / len(weights)
|
| 149 |
+
return weights
|
| 150 |
+
|
| 151 |
+
def update_with_all_losses(self, ts, losses):
|
| 152 |
+
for t, loss in zip(ts, losses):
|
| 153 |
+
if self._loss_counts[t] == self.history_per_term:
|
| 154 |
+
# Shift out the oldest loss term.
|
| 155 |
+
self._loss_history[t, :-1] = self._loss_history[t, 1:]
|
| 156 |
+
self._loss_history[t, -1] = loss
|
| 157 |
+
else:
|
| 158 |
+
self._loss_history[t, self._loss_counts[t]] = loss
|
| 159 |
+
self._loss_counts[t] += 1
|
| 160 |
+
|
| 161 |
+
def _warmed_up(self):
|
| 162 |
+
return (self._loss_counts == self.history_per_term).all()
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class LogNormalSampler:
|
| 166 |
+
def __init__(self, p_mean=-1.2, p_std=1.2, even=False):
|
| 167 |
+
self.p_mean = p_mean
|
| 168 |
+
self.p_std = p_std
|
| 169 |
+
self.even = even
|
| 170 |
+
if self.even:
|
| 171 |
+
self.inv_cdf = lambda x: norm.ppf(x, loc=p_mean, scale=p_std)
|
| 172 |
+
self.rank, self.size = dist.get_rank(), dist.get_world_size()
|
| 173 |
+
|
| 174 |
+
def sample(self, bs, device):
|
| 175 |
+
if self.even:
|
| 176 |
+
# buckets = [1/G]
|
| 177 |
+
start_i, end_i = self.rank * bs, (self.rank + 1) * bs
|
| 178 |
+
global_batch_size = self.size * bs
|
| 179 |
+
locs = (th.arange(start_i, end_i) + th.rand(bs)) / global_batch_size
|
| 180 |
+
log_sigmas = th.tensor(self.inv_cdf(locs), dtype=th.float32, device=device)
|
| 181 |
+
else:
|
| 182 |
+
log_sigmas = self.p_mean + self.p_std * th.randn(bs, device=device)
|
| 183 |
+
sigmas = th.exp(log_sigmas)
|
| 184 |
+
weights = th.ones_like(sigmas)
|
| 185 |
+
return sigmas, weights
|
Amphion/modules/diffusion/unet/attention.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from modules.general.utils import Conv1d, normalization, zero_module
|
| 11 |
+
from .basic import UNetBlock
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AttentionBlock(UNetBlock):
|
| 15 |
+
r"""A spatial transformer encoder block that allows spatial positions to attend
|
| 16 |
+
to each other. Reference from `latent diffusion repo
|
| 17 |
+
<https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/attention.py#L531>`_.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
channels: Number of channels in the input.
|
| 21 |
+
num_head_channels: Number of channels per attention head.
|
| 22 |
+
num_heads: Number of attention heads. Overrides ``num_head_channels`` if set.
|
| 23 |
+
encoder_channels: Number of channels in the encoder output for cross-attention.
|
| 24 |
+
If ``None``, then self-attention is performed.
|
| 25 |
+
use_self_attention: Whether to use self-attention before cross-attention, only applicable if encoder_channels is set.
|
| 26 |
+
dims: Number of spatial dimensions, i.e. 1 for temporal signals, 2 for images.
|
| 27 |
+
h_dim: The dimension of the height, would be applied if ``dims`` is 2.
|
| 28 |
+
encoder_hdim: The dimension of the height of the encoder output, would be applied if ``dims`` is 2.
|
| 29 |
+
p_dropout: Dropout probability.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
channels: int,
|
| 35 |
+
num_head_channels: int = 32,
|
| 36 |
+
num_heads: int = -1,
|
| 37 |
+
encoder_channels: int = None,
|
| 38 |
+
use_self_attention: bool = False,
|
| 39 |
+
dims: int = 1,
|
| 40 |
+
h_dim: int = 100,
|
| 41 |
+
encoder_hdim: int = 384,
|
| 42 |
+
p_dropout: float = 0.0,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
|
| 46 |
+
self.channels = channels
|
| 47 |
+
self.p_dropout = p_dropout
|
| 48 |
+
self.dims = dims
|
| 49 |
+
|
| 50 |
+
if dims == 1:
|
| 51 |
+
self.channels = channels
|
| 52 |
+
elif dims == 2:
|
| 53 |
+
# We consider the channel as product of channel and height, i.e. C x H
|
| 54 |
+
# This is because we want to apply attention on the audio signal, which is 1D
|
| 55 |
+
self.channels = channels * h_dim
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(f"invalid number of dimensions: {dims}")
|
| 58 |
+
|
| 59 |
+
if num_head_channels == -1:
|
| 60 |
+
assert (
|
| 61 |
+
self.channels % num_heads == 0
|
| 62 |
+
), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}"
|
| 63 |
+
self.num_heads = num_heads
|
| 64 |
+
self.num_head_channels = self.channels // num_heads
|
| 65 |
+
else:
|
| 66 |
+
assert (
|
| 67 |
+
self.channels % num_head_channels == 0
|
| 68 |
+
), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}"
|
| 69 |
+
self.num_heads = self.channels // num_head_channels
|
| 70 |
+
self.num_head_channels = num_head_channels
|
| 71 |
+
|
| 72 |
+
if encoder_channels is not None:
|
| 73 |
+
self.use_self_attention = use_self_attention
|
| 74 |
+
|
| 75 |
+
if dims == 1:
|
| 76 |
+
self.encoder_channels = encoder_channels
|
| 77 |
+
elif dims == 2:
|
| 78 |
+
self.encoder_channels = encoder_channels * encoder_hdim
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f"invalid number of dimensions: {dims}")
|
| 81 |
+
|
| 82 |
+
if use_self_attention:
|
| 83 |
+
self.self_attention = BasicAttentionBlock(
|
| 84 |
+
self.channels,
|
| 85 |
+
self.num_head_channels,
|
| 86 |
+
self.num_heads,
|
| 87 |
+
p_dropout=self.p_dropout,
|
| 88 |
+
)
|
| 89 |
+
self.cross_attention = BasicAttentionBlock(
|
| 90 |
+
self.channels,
|
| 91 |
+
self.num_head_channels,
|
| 92 |
+
self.num_heads,
|
| 93 |
+
self.encoder_channels,
|
| 94 |
+
p_dropout=self.p_dropout,
|
| 95 |
+
)
|
| 96 |
+
else:
|
| 97 |
+
self.encoder_channels = None
|
| 98 |
+
self.self_attention = BasicAttentionBlock(
|
| 99 |
+
self.channels,
|
| 100 |
+
self.num_head_channels,
|
| 101 |
+
self.num_heads,
|
| 102 |
+
p_dropout=self.p_dropout,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def forward(self, x: torch.Tensor, encoder_output: torch.Tensor = None):
|
| 106 |
+
r"""
|
| 107 |
+
Args:
|
| 108 |
+
x: input tensor with shape [B x ``channels`` x ...]
|
| 109 |
+
encoder_output: feature tensor with shape [B x ``encoder_channels`` x ...], if ``None``, then self-attention is performed.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
output tensor with shape [B x ``channels`` x ...]
|
| 113 |
+
"""
|
| 114 |
+
shape = x.size()
|
| 115 |
+
x = x.reshape(shape[0], self.channels, -1).contiguous()
|
| 116 |
+
|
| 117 |
+
if self.encoder_channels is None:
|
| 118 |
+
assert (
|
| 119 |
+
encoder_output is None
|
| 120 |
+
), "encoder_output must be None for self-attention."
|
| 121 |
+
h = self.self_attention(x)
|
| 122 |
+
|
| 123 |
+
else:
|
| 124 |
+
assert (
|
| 125 |
+
encoder_output is not None
|
| 126 |
+
), "encoder_output must be given for cross-attention."
|
| 127 |
+
encoder_output = encoder_output.reshape(
|
| 128 |
+
shape[0], self.encoder_channels, -1
|
| 129 |
+
).contiguous()
|
| 130 |
+
|
| 131 |
+
if self.use_self_attention:
|
| 132 |
+
x = self.self_attention(x)
|
| 133 |
+
h = self.cross_attention(x, encoder_output)
|
| 134 |
+
|
| 135 |
+
return h.reshape(*shape).contiguous()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class BasicAttentionBlock(nn.Module):
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
channels: int,
|
| 142 |
+
num_head_channels: int = 32,
|
| 143 |
+
num_heads: int = -1,
|
| 144 |
+
context_channels: int = None,
|
| 145 |
+
p_dropout: float = 0.0,
|
| 146 |
+
):
|
| 147 |
+
super().__init__()
|
| 148 |
+
|
| 149 |
+
self.channels = channels
|
| 150 |
+
self.p_dropout = p_dropout
|
| 151 |
+
self.context_channels = context_channels
|
| 152 |
+
|
| 153 |
+
if num_head_channels == -1:
|
| 154 |
+
assert (
|
| 155 |
+
self.channels % num_heads == 0
|
| 156 |
+
), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}"
|
| 157 |
+
self.num_heads = num_heads
|
| 158 |
+
self.num_head_channels = self.channels // num_heads
|
| 159 |
+
else:
|
| 160 |
+
assert (
|
| 161 |
+
self.channels % num_head_channels == 0
|
| 162 |
+
), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}"
|
| 163 |
+
self.num_heads = self.channels // num_head_channels
|
| 164 |
+
self.num_head_channels = num_head_channels
|
| 165 |
+
|
| 166 |
+
if context_channels is not None:
|
| 167 |
+
self.to_q = nn.Sequential(
|
| 168 |
+
normalization(self.channels),
|
| 169 |
+
Conv1d(self.channels, self.channels, 1),
|
| 170 |
+
)
|
| 171 |
+
self.to_kv = Conv1d(context_channels, 2 * self.channels, 1)
|
| 172 |
+
else:
|
| 173 |
+
self.to_qkv = nn.Sequential(
|
| 174 |
+
normalization(self.channels),
|
| 175 |
+
Conv1d(self.channels, 3 * self.channels, 1),
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
self.linear = Conv1d(self.channels, self.channels)
|
| 179 |
+
|
| 180 |
+
self.proj_out = nn.Sequential(
|
| 181 |
+
normalization(self.channels),
|
| 182 |
+
Conv1d(self.channels, self.channels, 1),
|
| 183 |
+
nn.GELU(),
|
| 184 |
+
nn.Dropout(p=self.p_dropout),
|
| 185 |
+
zero_module(Conv1d(self.channels, self.channels, 1)),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def forward(self, q: torch.Tensor, kv: torch.Tensor = None):
|
| 189 |
+
r"""
|
| 190 |
+
Args:
|
| 191 |
+
q: input tensor with shape [B, ``channels``, L]
|
| 192 |
+
kv: feature tensor with shape [B, ``context_channels``, T], if ``None``, then self-attention is performed.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
output tensor with shape [B, ``channels``, L]
|
| 196 |
+
"""
|
| 197 |
+
N, C, L = q.size()
|
| 198 |
+
|
| 199 |
+
if self.context_channels is not None:
|
| 200 |
+
assert kv is not None, "kv must be given for cross-attention."
|
| 201 |
+
|
| 202 |
+
q = (
|
| 203 |
+
self.to_q(q)
|
| 204 |
+
.reshape(self.num_heads, self.num_head_channels, -1)
|
| 205 |
+
.transpose(-1, -2)
|
| 206 |
+
.contiguous()
|
| 207 |
+
)
|
| 208 |
+
kv = (
|
| 209 |
+
self.to_kv(kv)
|
| 210 |
+
.reshape(2, self.num_heads, self.num_head_channels, -1)
|
| 211 |
+
.transpose(-1, -2)
|
| 212 |
+
.chunk(2)
|
| 213 |
+
)
|
| 214 |
+
k, v = (
|
| 215 |
+
kv[0].squeeze(0).contiguous(),
|
| 216 |
+
kv[1].squeeze(0).contiguous(),
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
else:
|
| 220 |
+
qkv = (
|
| 221 |
+
self.to_qkv(q)
|
| 222 |
+
.reshape(3, self.num_heads, self.num_head_channels, -1)
|
| 223 |
+
.transpose(-1, -2)
|
| 224 |
+
.chunk(3)
|
| 225 |
+
)
|
| 226 |
+
q, k, v = (
|
| 227 |
+
qkv[0].squeeze(0).contiguous(),
|
| 228 |
+
qkv[1].squeeze(0).contiguous(),
|
| 229 |
+
qkv[2].squeeze(0).contiguous(),
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
h = F.scaled_dot_product_attention(q, k, v, dropout_p=self.p_dropout).transpose(
|
| 233 |
+
-1, -2
|
| 234 |
+
)
|
| 235 |
+
h = h.reshape(N, -1, L).contiguous()
|
| 236 |
+
h = self.linear(h)
|
| 237 |
+
|
| 238 |
+
x = q + h
|
| 239 |
+
h = self.proj_out(x)
|
| 240 |
+
|
| 241 |
+
return x + h
|
Amphion/modules/diffusion/unet/resblock.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from .basic import UNetBlock
|
| 10 |
+
from modules.general.utils import (
|
| 11 |
+
append_dims,
|
| 12 |
+
ConvNd,
|
| 13 |
+
normalization,
|
| 14 |
+
zero_module,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ResBlock(UNetBlock):
|
| 19 |
+
r"""A residual block that can optionally change the number of channels.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
channels: the number of input channels.
|
| 23 |
+
emb_channels: the number of timestep embedding channels.
|
| 24 |
+
dropout: the rate of dropout.
|
| 25 |
+
out_channels: if specified, the number of out channels.
|
| 26 |
+
use_conv: if True and out_channels is specified, use a spatial
|
| 27 |
+
convolution instead of a smaller 1x1 convolution to change the
|
| 28 |
+
channels in the skip connection.
|
| 29 |
+
dims: determines if the signal is 1D, 2D, or 3D.
|
| 30 |
+
up: if True, use this block for upsampling.
|
| 31 |
+
down: if True, use this block for downsampling.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
channels,
|
| 37 |
+
emb_channels,
|
| 38 |
+
dropout: float = 0.0,
|
| 39 |
+
out_channels=None,
|
| 40 |
+
use_conv=False,
|
| 41 |
+
use_scale_shift_norm=False,
|
| 42 |
+
dims=2,
|
| 43 |
+
up=False,
|
| 44 |
+
down=False,
|
| 45 |
+
):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.channels = channels
|
| 48 |
+
self.emb_channels = emb_channels
|
| 49 |
+
self.dropout = dropout
|
| 50 |
+
self.out_channels = out_channels or channels
|
| 51 |
+
self.use_conv = use_conv
|
| 52 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
| 53 |
+
|
| 54 |
+
self.in_layers = nn.Sequential(
|
| 55 |
+
normalization(channels),
|
| 56 |
+
nn.SiLU(),
|
| 57 |
+
ConvNd(dims, channels, self.out_channels, 3, padding=1),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
self.updown = up or down
|
| 61 |
+
|
| 62 |
+
if up:
|
| 63 |
+
self.h_upd = Upsample(channels, False, dims)
|
| 64 |
+
self.x_upd = Upsample(channels, False, dims)
|
| 65 |
+
elif down:
|
| 66 |
+
self.h_upd = Downsample(channels, False, dims)
|
| 67 |
+
self.x_upd = Downsample(channels, False, dims)
|
| 68 |
+
else:
|
| 69 |
+
self.h_upd = self.x_upd = nn.Identity()
|
| 70 |
+
|
| 71 |
+
self.emb_layers = nn.Sequential(
|
| 72 |
+
nn.SiLU(),
|
| 73 |
+
ConvNd(
|
| 74 |
+
dims,
|
| 75 |
+
emb_channels,
|
| 76 |
+
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
| 77 |
+
1,
|
| 78 |
+
),
|
| 79 |
+
)
|
| 80 |
+
self.out_layers = nn.Sequential(
|
| 81 |
+
normalization(self.out_channels),
|
| 82 |
+
nn.SiLU(),
|
| 83 |
+
nn.Dropout(p=dropout),
|
| 84 |
+
zero_module(
|
| 85 |
+
ConvNd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
| 86 |
+
),
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
if self.out_channels == channels:
|
| 90 |
+
self.skip_connection = nn.Identity()
|
| 91 |
+
elif use_conv:
|
| 92 |
+
self.skip_connection = ConvNd(
|
| 93 |
+
dims, channels, self.out_channels, 3, padding=1
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
self.skip_connection = ConvNd(dims, channels, self.out_channels, 1)
|
| 97 |
+
|
| 98 |
+
def forward(self, x, emb):
|
| 99 |
+
"""
|
| 100 |
+
Apply the block to a Tensor, conditioned on a timestep embedding.
|
| 101 |
+
|
| 102 |
+
x: an [N x C x ...] Tensor of features.
|
| 103 |
+
emb: an [N x emb_channels x ...] Tensor of timestep embeddings.
|
| 104 |
+
:return: an [N x C x ...] Tensor of outputs.
|
| 105 |
+
"""
|
| 106 |
+
if self.updown:
|
| 107 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
| 108 |
+
h = in_rest(x)
|
| 109 |
+
h = self.h_upd(h)
|
| 110 |
+
x = self.x_upd(x)
|
| 111 |
+
h = in_conv(h)
|
| 112 |
+
else:
|
| 113 |
+
h = self.in_layers(x)
|
| 114 |
+
emb_out = self.emb_layers(emb)
|
| 115 |
+
emb_out = append_dims(emb_out, h.dim())
|
| 116 |
+
if self.use_scale_shift_norm:
|
| 117 |
+
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
| 118 |
+
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
| 119 |
+
h = out_norm(h) * (1 + scale) + shift
|
| 120 |
+
h = out_rest(h)
|
| 121 |
+
else:
|
| 122 |
+
h = h + emb_out
|
| 123 |
+
h = self.out_layers(h)
|
| 124 |
+
return self.skip_connection(x) + h
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class Upsample(nn.Module):
|
| 128 |
+
r"""An upsampling layer with an optional convolution.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
channels: channels in the inputs and outputs.
|
| 132 |
+
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| 133 |
+
upsampling occurs in the inner-two dimensions.
|
| 134 |
+
out_channels: if specified, the number of out channels.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(self, channels, dims=2, out_channels=None):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.channels = channels
|
| 140 |
+
self.out_channels = out_channels or channels
|
| 141 |
+
self.dims = dims
|
| 142 |
+
self.conv = ConvNd(dims, self.channels, self.out_channels, 3, padding=1)
|
| 143 |
+
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
assert x.shape[1] == self.channels
|
| 146 |
+
if self.dims == 3:
|
| 147 |
+
x = F.interpolate(
|
| 148 |
+
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
| 149 |
+
)
|
| 150 |
+
else:
|
| 151 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
| 152 |
+
x = self.conv(x)
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class Downsample(nn.Module):
|
| 157 |
+
r"""A downsampling layer with an optional convolution.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
channels: channels in the inputs and outputs.
|
| 161 |
+
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
| 162 |
+
downsampling occurs in the inner-two dimensions.
|
| 163 |
+
out_channels: if specified, the number of output channels.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
def __init__(self, channels, dims=2, out_channels=None):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.channels = channels
|
| 169 |
+
self.out_channels = out_channels or channels
|
| 170 |
+
self.dims = dims
|
| 171 |
+
stride = 2 if dims != 3 else (1, 2, 2)
|
| 172 |
+
self.op = ConvNd(
|
| 173 |
+
dims, self.channels, self.out_channels, 3, stride=stride, padding=1
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
def forward(self, x):
|
| 177 |
+
assert x.shape[1] == self.channels
|
| 178 |
+
return self.op(x)
|
Amphion/modules/diffusion/unet/unet.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from modules.encoder.position_encoder import PositionEncoder
|
| 10 |
+
from modules.general.utils import append_dims, ConvNd, normalization, zero_module
|
| 11 |
+
from .attention import AttentionBlock
|
| 12 |
+
from .resblock import Downsample, ResBlock, Upsample
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class UNet(nn.Module):
|
| 16 |
+
r"""The full UNet model with attention and timestep embedding.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
dims: determines if the signal is 1D (temporal), 2D(spatial).
|
| 20 |
+
in_channels: channels in the input Tensor.
|
| 21 |
+
model_channels: base channel count for the model.
|
| 22 |
+
out_channels: channels in the output Tensor.
|
| 23 |
+
num_res_blocks: number of residual blocks per downsample.
|
| 24 |
+
channel_mult: channel multiplier for each level of the UNet.
|
| 25 |
+
num_attn_blocks: number of attention blocks at place.
|
| 26 |
+
attention_resolutions: a collection of downsample rates at which attention will
|
| 27 |
+
take place. May be a set, list, or tuple. For example, if this contains 4,
|
| 28 |
+
then at 4x downsampling, attention will be used.
|
| 29 |
+
num_heads: the number of attention heads in each attention layer.
|
| 30 |
+
num_head_channels: if specified, ignore num_heads and instead use a fixed
|
| 31 |
+
channel width per attention head.
|
| 32 |
+
d_context: if specified, use for cross-attention channel project.
|
| 33 |
+
p_dropout: the dropout probability.
|
| 34 |
+
use_self_attention: Apply self attention before cross attention.
|
| 35 |
+
num_classes: if specified (as an int), then this model will be class-conditional
|
| 36 |
+
with ``num_classes`` classes.
|
| 37 |
+
use_extra_film: if specified, use an extra FiLM-like conditioning mechanism.
|
| 38 |
+
d_emb: if specified, use for FiLM-like conditioning.
|
| 39 |
+
use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
| 40 |
+
resblock_updown: use residual blocks for up/downsampling.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
dims: int = 1,
|
| 46 |
+
in_channels: int = 100,
|
| 47 |
+
model_channels: int = 128,
|
| 48 |
+
out_channels: int = 100,
|
| 49 |
+
h_dim: int = 128,
|
| 50 |
+
num_res_blocks: int = 1,
|
| 51 |
+
channel_mult: tuple = (1, 2, 4),
|
| 52 |
+
num_attn_blocks: int = 1,
|
| 53 |
+
attention_resolutions: tuple = (1, 2, 4),
|
| 54 |
+
num_heads: int = 1,
|
| 55 |
+
num_head_channels: int = -1,
|
| 56 |
+
d_context: int = None,
|
| 57 |
+
context_hdim: int = 128,
|
| 58 |
+
p_dropout: float = 0.0,
|
| 59 |
+
num_classes: int = -1,
|
| 60 |
+
use_extra_film: str = None,
|
| 61 |
+
d_emb: int = None,
|
| 62 |
+
use_scale_shift_norm: bool = True,
|
| 63 |
+
resblock_updown: bool = False,
|
| 64 |
+
):
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
self.dims = dims
|
| 68 |
+
self.in_channels = in_channels
|
| 69 |
+
self.model_channels = model_channels
|
| 70 |
+
self.out_channels = out_channels
|
| 71 |
+
self.num_res_blocks = num_res_blocks
|
| 72 |
+
self.channel_mult = channel_mult
|
| 73 |
+
self.num_attn_blocks = num_attn_blocks
|
| 74 |
+
self.attention_resolutions = attention_resolutions
|
| 75 |
+
self.num_heads = num_heads
|
| 76 |
+
self.num_head_channels = num_head_channels
|
| 77 |
+
self.d_context = d_context
|
| 78 |
+
self.p_dropout = p_dropout
|
| 79 |
+
self.num_classes = num_classes
|
| 80 |
+
self.use_extra_film = use_extra_film
|
| 81 |
+
self.d_emb = d_emb
|
| 82 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
| 83 |
+
self.resblock_updown = resblock_updown
|
| 84 |
+
|
| 85 |
+
time_embed_dim = model_channels * 4
|
| 86 |
+
self.pos_enc = PositionEncoder(model_channels, time_embed_dim)
|
| 87 |
+
|
| 88 |
+
assert (
|
| 89 |
+
num_classes == -1 or use_extra_film is None
|
| 90 |
+
), "You cannot set both num_classes and use_extra_film."
|
| 91 |
+
|
| 92 |
+
if self.num_classes > 0:
|
| 93 |
+
# TODO: if used for singer, norm should be 1, correct?
|
| 94 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim, max_norm=1.0)
|
| 95 |
+
elif use_extra_film is not None:
|
| 96 |
+
assert (
|
| 97 |
+
d_emb is not None
|
| 98 |
+
), "d_emb must be specified if use_extra_film is not None"
|
| 99 |
+
assert use_extra_film in [
|
| 100 |
+
"add",
|
| 101 |
+
"concat",
|
| 102 |
+
], f"use_extra_film only supported by add or concat. Your input is {use_extra_film}"
|
| 103 |
+
self.use_extra_film = use_extra_film
|
| 104 |
+
self.film_emb = ConvNd(dims, d_emb, time_embed_dim, 1)
|
| 105 |
+
if use_extra_film == "concat":
|
| 106 |
+
time_embed_dim *= 2
|
| 107 |
+
|
| 108 |
+
# Input blocks
|
| 109 |
+
ch = input_ch = int(channel_mult[0] * model_channels)
|
| 110 |
+
self.input_blocks = nn.ModuleList(
|
| 111 |
+
[UNetSequential(ConvNd(dims, in_channels, ch, 3, padding=1))]
|
| 112 |
+
)
|
| 113 |
+
self._feature_size = ch
|
| 114 |
+
input_block_chans = [ch]
|
| 115 |
+
ds = 1
|
| 116 |
+
for level, mult in enumerate(channel_mult):
|
| 117 |
+
for _ in range(num_res_blocks):
|
| 118 |
+
layers = [
|
| 119 |
+
ResBlock(
|
| 120 |
+
ch,
|
| 121 |
+
time_embed_dim,
|
| 122 |
+
p_dropout,
|
| 123 |
+
out_channels=int(mult * model_channels),
|
| 124 |
+
dims=dims,
|
| 125 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 126 |
+
)
|
| 127 |
+
]
|
| 128 |
+
ch = int(mult * model_channels)
|
| 129 |
+
if ds in attention_resolutions:
|
| 130 |
+
for _ in range(num_attn_blocks):
|
| 131 |
+
layers.append(
|
| 132 |
+
AttentionBlock(
|
| 133 |
+
ch,
|
| 134 |
+
num_heads=num_heads,
|
| 135 |
+
num_head_channels=num_head_channels,
|
| 136 |
+
encoder_channels=d_context,
|
| 137 |
+
dims=dims,
|
| 138 |
+
h_dim=h_dim // (level + 1),
|
| 139 |
+
encoder_hdim=context_hdim,
|
| 140 |
+
p_dropout=p_dropout,
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
self.input_blocks.append(UNetSequential(*layers))
|
| 144 |
+
self._feature_size += ch
|
| 145 |
+
input_block_chans.append(ch)
|
| 146 |
+
if level != len(channel_mult) - 1:
|
| 147 |
+
out_ch = ch
|
| 148 |
+
self.input_blocks.append(
|
| 149 |
+
UNetSequential(
|
| 150 |
+
ResBlock(
|
| 151 |
+
ch,
|
| 152 |
+
time_embed_dim,
|
| 153 |
+
p_dropout,
|
| 154 |
+
out_channels=out_ch,
|
| 155 |
+
dims=dims,
|
| 156 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 157 |
+
down=True,
|
| 158 |
+
)
|
| 159 |
+
if resblock_updown
|
| 160 |
+
else Downsample(ch, dims=dims, out_channels=out_ch)
|
| 161 |
+
)
|
| 162 |
+
)
|
| 163 |
+
ch = out_ch
|
| 164 |
+
input_block_chans.append(ch)
|
| 165 |
+
ds *= 2
|
| 166 |
+
self._feature_size += ch
|
| 167 |
+
|
| 168 |
+
# Middle blocks
|
| 169 |
+
self.middle_block = UNetSequential(
|
| 170 |
+
ResBlock(
|
| 171 |
+
ch,
|
| 172 |
+
time_embed_dim,
|
| 173 |
+
p_dropout,
|
| 174 |
+
dims=dims,
|
| 175 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 176 |
+
),
|
| 177 |
+
AttentionBlock(
|
| 178 |
+
ch,
|
| 179 |
+
num_heads=num_heads,
|
| 180 |
+
num_head_channels=num_head_channels,
|
| 181 |
+
encoder_channels=d_context,
|
| 182 |
+
dims=dims,
|
| 183 |
+
h_dim=h_dim // (level + 1),
|
| 184 |
+
encoder_hdim=context_hdim,
|
| 185 |
+
p_dropout=p_dropout,
|
| 186 |
+
),
|
| 187 |
+
ResBlock(
|
| 188 |
+
ch,
|
| 189 |
+
time_embed_dim,
|
| 190 |
+
p_dropout,
|
| 191 |
+
dims=dims,
|
| 192 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 193 |
+
),
|
| 194 |
+
)
|
| 195 |
+
self._feature_size += ch
|
| 196 |
+
|
| 197 |
+
# Output blocks
|
| 198 |
+
self.output_blocks = nn.ModuleList([])
|
| 199 |
+
for level, mult in tuple(enumerate(channel_mult))[::-1]:
|
| 200 |
+
for i in range(num_res_blocks + 1):
|
| 201 |
+
ich = input_block_chans.pop()
|
| 202 |
+
layers = [
|
| 203 |
+
ResBlock(
|
| 204 |
+
ch + ich,
|
| 205 |
+
time_embed_dim,
|
| 206 |
+
p_dropout,
|
| 207 |
+
out_channels=int(model_channels * mult),
|
| 208 |
+
dims=dims,
|
| 209 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 210 |
+
)
|
| 211 |
+
]
|
| 212 |
+
ch = int(model_channels * mult)
|
| 213 |
+
if ds in attention_resolutions:
|
| 214 |
+
for _ in range(num_attn_blocks):
|
| 215 |
+
layers.append(
|
| 216 |
+
AttentionBlock(
|
| 217 |
+
ch,
|
| 218 |
+
num_heads=num_heads,
|
| 219 |
+
num_head_channels=num_head_channels,
|
| 220 |
+
encoder_channels=d_context,
|
| 221 |
+
dims=dims,
|
| 222 |
+
h_dim=h_dim // (level + 1),
|
| 223 |
+
encoder_hdim=context_hdim,
|
| 224 |
+
p_dropout=p_dropout,
|
| 225 |
+
)
|
| 226 |
+
)
|
| 227 |
+
if level and i == num_res_blocks:
|
| 228 |
+
out_ch = ch
|
| 229 |
+
layers.append(
|
| 230 |
+
ResBlock(
|
| 231 |
+
ch,
|
| 232 |
+
time_embed_dim,
|
| 233 |
+
p_dropout,
|
| 234 |
+
out_channels=out_ch,
|
| 235 |
+
dims=dims,
|
| 236 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 237 |
+
up=True,
|
| 238 |
+
)
|
| 239 |
+
if resblock_updown
|
| 240 |
+
else Upsample(ch, dims=dims, out_channels=out_ch)
|
| 241 |
+
)
|
| 242 |
+
ds //= 2
|
| 243 |
+
self.output_blocks.append(UNetSequential(*layers))
|
| 244 |
+
self._feature_size += ch
|
| 245 |
+
|
| 246 |
+
# Final proj out
|
| 247 |
+
self.out = nn.Sequential(
|
| 248 |
+
normalization(ch),
|
| 249 |
+
nn.SiLU(),
|
| 250 |
+
zero_module(ConvNd(dims, input_ch, out_channels, 3, padding=1)),
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
| 254 |
+
r"""Apply the model to an input batch.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
x: an [N x C x ...] Tensor of inputs.
|
| 258 |
+
timesteps: a 1-D batch of timesteps, i.e. [N].
|
| 259 |
+
context: conditioning Tensor with shape of [N x ``d_context`` x ...] plugged
|
| 260 |
+
in via cross attention.
|
| 261 |
+
y: an [N] Tensor of labels, if **class-conditional**.
|
| 262 |
+
an [N x ``d_emb`` x ...] Tensor if **film-embed conditional**.
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
an [N x C x ...] Tensor of outputs.
|
| 266 |
+
"""
|
| 267 |
+
assert (y is None) or (
|
| 268 |
+
(y is not None)
|
| 269 |
+
and ((self.num_classes > 0) or (self.use_extra_film is not None))
|
| 270 |
+
), f"y must be specified if num_classes or use_extra_film is not None. \nGot num_classes: {self.num_classes}\t\nuse_extra_film: {self.use_extra_film}\t\n"
|
| 271 |
+
|
| 272 |
+
hs = []
|
| 273 |
+
emb = self.pos_enc(timesteps)
|
| 274 |
+
emb = append_dims(emb, x.dim())
|
| 275 |
+
|
| 276 |
+
if self.num_classes > 0:
|
| 277 |
+
assert y.size() == (x.size(0),)
|
| 278 |
+
emb = emb + self.label_emb(y)
|
| 279 |
+
elif self.use_extra_film is not None:
|
| 280 |
+
assert y.size() == (x.size(0), self.d_emb, *x.size()[2:])
|
| 281 |
+
y = self.film_emb(y)
|
| 282 |
+
if self.use_extra_film == "add":
|
| 283 |
+
emb = emb + y
|
| 284 |
+
elif self.use_extra_film == "concat":
|
| 285 |
+
emb = torch.cat([emb, y], dim=1)
|
| 286 |
+
|
| 287 |
+
h = x
|
| 288 |
+
for module in self.input_blocks:
|
| 289 |
+
h = module(h, emb, context)
|
| 290 |
+
hs.append(h)
|
| 291 |
+
h = self.middle_block(h, emb, context)
|
| 292 |
+
for module in self.output_blocks:
|
| 293 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
| 294 |
+
h = module(h, emb, context)
|
| 295 |
+
|
| 296 |
+
return self.out(h)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class UNetSequential(nn.Sequential):
|
| 300 |
+
r"""A sequential module that passes embeddings to the children that support it."""
|
| 301 |
+
|
| 302 |
+
def forward(self, x, emb=None, context=None):
|
| 303 |
+
for layer in self:
|
| 304 |
+
if isinstance(layer, ResBlock):
|
| 305 |
+
x = layer(x, emb)
|
| 306 |
+
elif isinstance(layer, AttentionBlock):
|
| 307 |
+
x = layer(x, context)
|
| 308 |
+
else:
|
| 309 |
+
x = layer(x)
|
| 310 |
+
return x
|
Amphion/modules/encoder/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .token_encoder import TokenEmbedding
|
Amphion/modules/general/utils.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def normalization(channels: int, groups: int = 32):
|
| 11 |
+
r"""Make a standard normalization layer, i.e. GroupNorm.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
channels: number of input channels.
|
| 15 |
+
groups: number of groups for group normalization.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
a ``nn.Module`` for normalization.
|
| 19 |
+
"""
|
| 20 |
+
assert groups > 0, f"invalid number of groups: {groups}"
|
| 21 |
+
return nn.GroupNorm(groups, channels)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def Linear(*args, **kwargs):
|
| 25 |
+
r"""Wrapper of ``nn.Linear`` with kaiming_normal_ initialization."""
|
| 26 |
+
layer = nn.Linear(*args, **kwargs)
|
| 27 |
+
nn.init.kaiming_normal_(layer.weight)
|
| 28 |
+
return layer
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def Conv1d(*args, **kwargs):
|
| 32 |
+
r"""Wrapper of ``nn.Conv1d`` with kaiming_normal_ initialization."""
|
| 33 |
+
layer = nn.Conv1d(*args, **kwargs)
|
| 34 |
+
nn.init.kaiming_normal_(layer.weight)
|
| 35 |
+
return layer
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def Conv2d(*args, **kwargs):
|
| 39 |
+
r"""Wrapper of ``nn.Conv2d`` with kaiming_normal_ initialization."""
|
| 40 |
+
layer = nn.Conv2d(*args, **kwargs)
|
| 41 |
+
nn.init.kaiming_normal_(layer.weight)
|
| 42 |
+
return layer
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def ConvNd(dims: int = 1, *args, **kwargs):
|
| 46 |
+
r"""Wrapper of N-dimension convolution with kaiming_normal_ initialization.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
dims: number of dimensions of the convolution.
|
| 50 |
+
"""
|
| 51 |
+
if dims == 1:
|
| 52 |
+
return Conv1d(*args, **kwargs)
|
| 53 |
+
elif dims == 2:
|
| 54 |
+
return Conv2d(*args, **kwargs)
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError(f"invalid number of dimensions: {dims}")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def zero_module(module: nn.Module):
|
| 60 |
+
r"""Zero out the parameters of a module and return it."""
|
| 61 |
+
nn.init.zeros_(module.weight)
|
| 62 |
+
nn.init.zeros_(module.bias)
|
| 63 |
+
return module
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def scale_module(module: nn.Module, scale):
|
| 67 |
+
r"""Scale the parameters of a module and return it."""
|
| 68 |
+
for p in module.parameters():
|
| 69 |
+
p.detach().mul_(scale)
|
| 70 |
+
return module
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def mean_flat(tensor: torch.Tensor):
|
| 74 |
+
r"""Take the mean over all non-batch dimensions."""
|
| 75 |
+
return tensor.mean(dim=tuple(range(1, tensor.dim())))
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def append_dims(x, target_dims):
|
| 79 |
+
r"""Appends dimensions to the end of a tensor until
|
| 80 |
+
it has target_dims dimensions.
|
| 81 |
+
"""
|
| 82 |
+
dims_to_append = target_dims - x.dim()
|
| 83 |
+
if dims_to_append < 0:
|
| 84 |
+
raise ValueError(
|
| 85 |
+
f"input has {x.dim()} dims but target_dims is {target_dims}, which is less"
|
| 86 |
+
)
|
| 87 |
+
return x[(...,) + (None,) * dims_to_append]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def append_zero(x, count=1):
|
| 91 |
+
r"""Appends ``count`` zeros to the end of a tensor along the last dimension."""
|
| 92 |
+
assert count > 0, f"invalid count: {count}"
|
| 93 |
+
return torch.cat([x, x.new_zeros((*x.size()[:-1], count))], dim=-1)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class Transpose(nn.Identity):
|
| 97 |
+
"""(N, T, D) -> (N, D, T)"""
|
| 98 |
+
|
| 99 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
return input.transpose(1, 2)
|
Amphion/modules/norms/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .norm import AdaptiveLayerNorm, LayerNorm, BalancedBasicNorm, IdentityNorm
|