haoxiangsnr commited on
Commit
53cf2e4
·
verified ·
1 Parent(s): 9c38bf7

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Amphion/egs/metrics/README.md +174 -0
  2. Amphion/egs/metrics/run.sh +132 -0
  3. Amphion/egs/svc/TransformerSVC/exp_config.json +108 -0
  4. Amphion/egs/svc/VitsSVC/README.md +125 -0
  5. Amphion/egs/tta/README.md +19 -0
  6. Amphion/egs/tta/audioldm/exp_config.json +90 -0
  7. Amphion/egs/tta/audioldm/run_train.sh +26 -0
  8. Amphion/egs/tta/audioldm/run_train_latent_4_10_78.sh +26 -0
  9. Amphion/egs/tta/autoencoderkl/run_train_latent_4_10_78.sh +26 -0
  10. Amphion/egs/tts/FastSpeech2/prepare_mfa.sh +29 -0
  11. Amphion/egs/tts/FastSpeech2/run.sh +155 -0
  12. Amphion/egs/tts/NaturalSpeech2/run_inference.sh +49 -0
  13. Amphion/egs/tts/VALLE/README.md +207 -0
  14. Amphion/egs/tts/VALLE/prompt_examples/5142_33396_000002_000004.wav +0 -0
  15. Amphion/egs/tts/VALLE/prompt_examples/7176_92135_000004_000000.normalized.txt +1 -0
  16. Amphion/egs/tts/VITS/exp_config.json +34 -0
  17. Amphion/egs/vocoder/gan/bigvgan_large/exp_config.json +70 -0
  18. Amphion/egs/vocoder/gan/hifigan/exp_config.json +59 -0
  19. Amphion/egs/vocoder/gan/hifigan/run.sh +141 -0
  20. Amphion/egs/vocoder/gan/nsfhifigan/run.sh +141 -0
  21. Amphion/models/__pycache__/__init__.cpython-310.pyc +0 -0
  22. Amphion/models/base/__init__.py +7 -0
  23. Amphion/models/base/new_dataset.py +50 -0
  24. Amphion/models/base/new_trainer.py +727 -0
  25. Amphion/models/codec/ns3_codec/__pycache__/facodec.cpython-310.pyc +0 -0
  26. Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/__init__.cpython-310.pyc +0 -0
  27. Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/resample.cpython-310.pyc +0 -0
  28. Amphion/models/codec/ns3_codec/quantize/__pycache__/rvq.cpython-310.pyc +0 -0
  29. Amphion/models/codec/ns3_codec/quantize/rvq.py +87 -0
  30. Amphion/models/svc/transformer/transformer.py +82 -0
  31. Amphion/models/tts/fastspeech2/fs2.py +548 -0
  32. Amphion/models/tts/fastspeech2/fs2_inference.py +193 -0
  33. Amphion/models/tts/naturalspeech2/__init__.py +0 -0
  34. Amphion/models/tts/naturalspeech2/wavenet.py +206 -0
  35. Amphion/models/tts/valle/valle.py +794 -0
  36. Amphion/models/tts/valle/valle_inference.py +237 -0
  37. Amphion/models/tts/valle/valle_trainer.py +367 -0
  38. Amphion/models/tts/vits/__init__.py +0 -0
  39. Amphion/models/tts/vits/vits.py +379 -0
  40. Amphion/models/tts/vits/vits_dataset.py +140 -0
  41. Amphion/models/tts/vits/vits_trainer.py +439 -0
  42. Amphion/models/vocoders/autoregressive/autoregressive_vocoder_trainer.py +0 -0
  43. Amphion/modules/activation_functions/gated_activation_unit.py +61 -0
  44. Amphion/modules/base/base_module.py +75 -0
  45. Amphion/modules/diffusion/__init__.py +7 -0
  46. Amphion/modules/duration_predictor/__init__.py +0 -0
  47. Amphion/modules/duration_predictor/standard_duration_predictor.py +53 -0
  48. Amphion/modules/duration_predictor/stochastic_duration_predictor.py +120 -0
  49. Amphion/modules/general/scaling.py +1349 -0
  50. Amphion/modules/norms/norm.py +173 -0
Amphion/egs/metrics/README.md ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Amphion Evaluation Recipe
2
+
3
+ ## Supported Evaluation Metrics
4
+
5
+ Until now, Amphion Evaluation has supported the following objective metrics:
6
+
7
+ - **F0 Modeling**:
8
+ - F0 Pearson Coefficients (FPC)
9
+ - F0 Periodicity Root Mean Square Error (PeriodicityRMSE)
10
+ - F0 Root Mean Square Error (F0RMSE)
11
+ - Voiced/Unvoiced F1 Score (V/UV F1)
12
+ - **Energy Modeling**:
13
+ - Energy Root Mean Square Error (EnergyRMSE)
14
+ - Energy Pearson Coefficients (EnergyPC)
15
+ - **Intelligibility**:
16
+ - Character Error Rate (CER) based on [Whipser](https://github.com/openai/whisper)
17
+ - Word Error Rate (WER) based on [Whipser](https://github.com/openai/whisper)
18
+ - **Spectrogram Distortion**:
19
+ - Frechet Audio Distance (FAD)
20
+ - Mel Cepstral Distortion (MCD)
21
+ - Multi-Resolution STFT Distance (MSTFT)
22
+ - Perceptual Evaluation of Speech Quality (PESQ)
23
+ - Short Time Objective Intelligibility (STOI)
24
+ - Scale Invariant Signal to Distortion Ratio (SISDR)
25
+ - Scale Invariant Signal to Noise Ratio (SISNR)
26
+ - **Speaker Similarity**:
27
+ - Cosine similarity based on:
28
+ - [Rawnet3](https://github.com/Jungjee/RawNet)
29
+ - [Resemblyzer](https://github.com/resemble-ai/Resemblyzer)
30
+ - [WavLM](https://huggingface.co/microsoft/wavlm-base-plus-sv)
31
+
32
+ We provide a recipe to demonstrate how to objectively evaluate your generated audios. There are three steps in total:
33
+
34
+ 1. Pretrained Models Preparation
35
+ 2. Audio Data Preparation
36
+ 3. Evaluation
37
+
38
+ ## 1. Pretrained Models Preparation
39
+
40
+ If you want to calculate `RawNet3` based speaker similarity, you need to download the pretrained model first, as illustrated [here](../../pretrained/README.md).
41
+
42
+ ## 2. Audio Data Preparation
43
+
44
+ Prepare reference audios and generated audios in two folders, the `ref_dir` contains the reference audio and the `gen_dir` contains the generated audio. Here is an example.
45
+
46
+ ```plaintext
47
+ ┣ {ref_dir}
48
+ ┃ ┣ sample1.wav
49
+ ┃ ┣ sample2.wav
50
+ ┣ {gen_dir}
51
+ ┃ ┣ sample1.wav
52
+ ┃ ┣ sample2.wav
53
+ ```
54
+
55
+ You have to make sure that the pairwise **reference audio and generated audio are named the same**, as illustrated above (sample1 to sample1, sample2 to sample2).
56
+
57
+ ## 3. Evaluation
58
+
59
+ Run the `run.sh` with specified refenrece folder, generated folder, dump folder and metrics.
60
+
61
+ ```bash
62
+ cd Amphion
63
+ sh egs/metrics/run.sh \
64
+ --reference_folder [Your path to the reference audios] \
65
+ --generated_folder [Your path to the generated audios] \
66
+ --dump_folder [Your path to dump the objective results] \
67
+ --metrics [The metrics you need] \
68
+ --fs [Optional. To calculate all metrics in the specified sampling rate] \
69
+ --similarity_model [Optional. To choose the model for calculating the speaker similarity. Currently "rawnet", "wavlm" and "resemblyzer" are available. Default to "wavlm"] \
70
+ --similarity_mode [Optional. To choose the mode for calculating the speaker similarity. "pairwith" for calculating a series of ground truth / prediction audio pairs to obtain the speaker similarity, and "overall" for computing the average score with all possible pairs between the refernece folder and generated folder. Default to "pairwith"] \
71
+ --intelligibility_mode [Optionoal. To choose the mode for computing CER and WER. "gt_audio" means selecting the recognition content of the reference audio as the target, "gt_content" means using transcription as the target. Default to "gt_audio"] \
72
+ --ltr_path [Optional. Path to the transcription file] \
73
+ --language [Optional. Language for computing CER and WER. Default to "english"]
74
+ ```
75
+
76
+ As for the metrics, an example is provided below:
77
+
78
+ ```bash
79
+ --metrics "mcd pesq fad"
80
+ ```
81
+
82
+ All currently available metrics keywords are listed below:
83
+
84
+ | Keys | Description |
85
+ | ------------------------- | ------------------------------------------ |
86
+ | `fpc` | F0 Pearson Coefficients |
87
+ | `f0_periodicity_rmse` | F0 Periodicity Root Mean Square Error |
88
+ | `f0rmse` | F0 Root Mean Square Error |
89
+ | `v_uv_f1` | Voiced/Unvoiced F1 Score |
90
+ | `energy_rmse` | Energy Root Mean Square Error |
91
+ | `energy_pc` | Energy Pearson Coefficients |
92
+ | `cer` | Character Error Rate |
93
+ | `wer` | Word Error Rate |
94
+ | `similarity` | Speaker Similarity
95
+ | `fad` | Frechet Audio Distance |
96
+ | `mcd` | Mel Cepstral Distortion |
97
+ | `mstft` | Multi-Resolution STFT Distance |
98
+ | `pesq` | Perceptual Evaluation of Speech Quality |
99
+ | `si_sdr` | Scale Invariant Signal to Distortion Ratio |
100
+ | `si_snr` | Scale Invariant Signal to Noise Ratio |
101
+ | `stoi` | Short Time Objective Intelligibility |
102
+
103
+ For example, if want to calculate the speaker similarity between the synthesized audio and the reference audio with the same content, run:
104
+
105
+ ```bash
106
+ sh egs/metrics/run.sh \
107
+ --reference_folder [Your path to the reference audios] \
108
+ --generated_folder [Your path to the generated audios] \
109
+ --dump_folder [Your path to dump the objective results] \
110
+ --metrics "similarity" \
111
+ --similarity_model [Optional. To choose the model for calculating the speaker similarity. Currently "rawnet", "wavlm" and "resemblyzer" are available. Default to "wavlm"] \
112
+ --similarity_mode "pairwith" \
113
+ ```
114
+
115
+ If you don't have the reference audio with the same content, run the following to get the conteng-free similarity score:
116
+
117
+ ```bash
118
+ sh egs/metrics/run.sh \
119
+ --reference_folder [Your path to the reference audios] \
120
+ --generated_folder [Your path to the generated audios] \
121
+ --dump_folder [Your path to dump the objective results] \
122
+ --metrics "similarity" \
123
+ --similarity_model [Optional. To choose the model for calculating the speaker similarity. Currently "rawnet", "wavlm" and "resemblyzer" are available. Default to "wavlm"] \
124
+ --similarity_mode "overall" \
125
+ ```
126
+
127
+ ## Troubleshooting
128
+ ### FAD (Using Offline Models)
129
+ If your system is unable to access huggingface.co from the terminal, you might run into an error like "OSError: Can't load tokenizer for ...". To work around this, follow these steps to use local models:
130
+
131
+ 1. Download the [bert-base-uncased](https://huggingface.co/bert-base-uncased), [roberta-base](https://huggingface.co/roberta-base), and [facebook/bart-base](https://huggingface.co/facebook/bart-base) models from `huggingface.co`. Ensure that the models are complete and uncorrupted. Place these directories within `Amphion/pretrained`. For a detailed file structure reference, see [This README](../../pretrained/README.md#optional-model-dependencies-for-evaluation) under `Amphion/pretrained`.
132
+ 2. Inside the `Amphion/pretrained` directory, create a bash script with the content outlined below. This script will automatically update the tokenizer paths used by your system:
133
+ ```bash
134
+ #!/bin/bash
135
+
136
+ BERT_DIR="bert-base-uncased"
137
+ ROBERTA_DIR="roberta-base"
138
+ BART_DIR="facebook/bart-base"
139
+ PYTHON_SCRIPT="[YOUR ENV PATH]/lib/python3.9/site-packages/laion_clap/training/data.py"
140
+
141
+ update_tokenizer_path() {
142
+ local dir_name=$1
143
+ local tokenizer_variable=$2
144
+ local full_path
145
+
146
+ if [ -d "$dir_name" ]; then
147
+ full_path=$(realpath "$dir_name")
148
+ if [ -f "$PYTHON_SCRIPT" ]; then
149
+ sed -i "s|${tokenizer_variable}.from_pretrained(\".*\")|${tokenizer_variable}.from_pretrained(\"$full_path\")|" "$PYTHON_SCRIPT"
150
+ echo "Updated ${tokenizer_variable} path to $full_path."
151
+ else
152
+ echo "Error: The specified Python script does not exist."
153
+ exit 1
154
+ fi
155
+ else
156
+ echo "Error: The directory $dir_name does not exist in the current directory."
157
+ exit 1
158
+ fi
159
+ }
160
+
161
+ update_tokenizer_path "$BERT_DIR" "BertTokenizer"
162
+ update_tokenizer_path "$ROBERTA_DIR" "RobertaTokenizer"
163
+ update_tokenizer_path "$BART_DIR" "BartTokenizer"
164
+
165
+ echo "BERT, BART and RoBERTa Python script paths have been updated."
166
+
167
+ ```
168
+
169
+ 3. The script provided is intended to adjust the tokenizer paths in the `data.py` file, found under `/lib/python3.9/site-packages/laion_clap/training/`, within your specific environment. For those utilizing conda, you can determine your environment path by running `conda info --envs`. Then, substitute `[YOUR ENV PATH]` in the script with this path. If your environment is configured differently, you'll need to update the `PYTHON_SCRIPT` variable to correctly point to the `data.py` file.
170
+ 4. Run the script. If it executes successfully, the tokenizer paths will be updated, allowing them to be loaded locally.
171
+
172
+ ### WavLM-based Speaker Similarity (Using Offline Models)
173
+
174
+ If your system is unable to access huggingface.co from the terminal and you want to calculate `WavLM` based speaker similarity, you need to download the pretrained model first, as illustrated [here](../../pretrained/README.md).
Amphion/egs/metrics/run.sh ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ ######## Build Experiment Environment ###########
7
+ exp_dir=$(cd `dirname $0`; pwd)
8
+ work_dir=$(dirname $(dirname $exp_dir))
9
+
10
+ export WORK_DIR=$work_dir
11
+ export PYTHONPATH=$work_dir
12
+ export PYTHONIOENCODING=UTF-8
13
+
14
+ ######## Parse the Given Parameters from the Commond ###########
15
+ options=$(getopt -o c:n:s --long gpu:,reference_folder:,generated_folder:,dump_folder:,metrics:,fs:,align_method:,energy_db_scale:,f0_subtract_mean:,similarity_model:,similarity_mode:,ltr_path:,intelligibility_mode:,language: -- "$@")
16
+ eval set -- "$options"
17
+
18
+ while true; do
19
+ case $1 in
20
+ # Visible GPU machines. The default value is "0".
21
+ --gpu) shift; gpu=$1 ; shift ;;
22
+ # Reference Audio Folder
23
+ --reference_folder) shift; ref_dir=$1 ; shift ;;
24
+ # Generated Audio Folder
25
+ --generated_folder) shift; deg_dir=$1 ; shift ;;
26
+ # Result Dumping Folder
27
+ --dump_folder) shift; dump_dir=$1 ; shift ;;
28
+ # Metrics to Compute
29
+ --metrics) shift; metrics=$1 ; shift ;;
30
+ # Sampling Rate
31
+ --fs) shift; fs=$1 ; shift ;;
32
+
33
+ # Method for aligning F0. The default value is "cut"
34
+ --align_method) shift; align_method=$1 ; shift ;;
35
+ # Method for normalizing F0. The default value is "True"
36
+ --f0_subtract_mean) shift; f0_subtract_mean=$1 ; shift ;;
37
+ # Method for normalizing Energy. The default value is "True"
38
+ --energy_db_scale) shift; energy_db_scale=$1 ; shift ;;
39
+
40
+ # Model for computing speaker similarity. The default value is "wavlm"
41
+ --similarity_model) shift; similarity_model=$1 ; shift ;;
42
+ # Mode for computing speaker similarity. The default value is "pairwith"
43
+ --similarity_mode) shift; similarity_mode=$1 ; shift ;;
44
+
45
+ # Path for the transcript.
46
+ --ltr_path) shift; ltr_path=$1 ; shift ;;
47
+ # Mode for computing CER and WER. The default value is "gt_audio"
48
+ --intelligibility_mode) shift; intelligibility_mode=$1 ; shift ;;
49
+ # Language for computing CER and WER. The default value is "english"
50
+ --language) shift; language=$1 ; shift ;;
51
+
52
+ --) shift ; break ;;
53
+ *) echo "Invalid option: $1" exit 1 ;;
54
+ esac
55
+ done
56
+
57
+ ### Value check ###
58
+ if [ -z "$ref_dir" ]; then
59
+ echo "[Error] Please specify the reference_folder"
60
+ exit 1
61
+ fi
62
+
63
+ if [ -z "$deg_dir" ]; then
64
+ echo "[Error] Please specify the generated_folder"
65
+ exit 1
66
+ fi
67
+
68
+ if [ -z "$dump_dir" ]; then
69
+ echo "[Error] Please specify the dump_folder"
70
+ exit 1
71
+ fi
72
+
73
+ if [ -z "$metrics" ]; then
74
+ echo "[Error] Please specify the metrics"
75
+ exit 1
76
+ fi
77
+
78
+ if [ -z "$gpu" ]; then
79
+ gpu="0"
80
+ fi
81
+
82
+ if [ -z "$fs" ]; then
83
+ fs="None"
84
+ fi
85
+
86
+ if [ -z "$align_method" ]; then
87
+ align_method="dtw"
88
+ fi
89
+
90
+ if [ -z "$energy_db_scale" ]; then
91
+ energy_db_scale="True"
92
+ fi
93
+
94
+ if [ -z "$f0_subtract_mean" ]; then
95
+ f0_subtract_mean="True"
96
+ fi
97
+
98
+ if [ -z "$similarity_model" ]; then
99
+ similarity_model="wavlm"
100
+ fi
101
+
102
+ if [ -z "$similarity_mode" ]; then
103
+ similarity_mode="pairwith"
104
+ fi
105
+
106
+ if [ -z "$ltr_path" ]; then
107
+ ltr_path="None"
108
+ fi
109
+
110
+ if [ -z "$intelligibility_mode" ]; then
111
+ intelligibility_mode="gt_audio"
112
+ fi
113
+
114
+ if [ -z "$language" ]; then
115
+ language="english"
116
+ fi
117
+
118
+ ######## Calculate Objective Metrics ###########
119
+ CUDA_VISIBLE_DEVICES=$gpu python "$work_dir"/bins/calc_metrics.py \
120
+ --ref_dir $ref_dir \
121
+ --deg_dir $deg_dir \
122
+ --dump_dir $dump_dir \
123
+ --metrics $metrics \
124
+ --fs $fs \
125
+ --align_method $align_method \
126
+ --db_scale $energy_db_scale \
127
+ --f0_subtract_mean $f0_subtract_mean \
128
+ --similarity_model $similarity_model \
129
+ --similarity_mode $similarity_mode \
130
+ --ltr_path $ltr_path \
131
+ --intelligibility_mode $intelligibility_mode \
132
+ --language $language
Amphion/egs/svc/TransformerSVC/exp_config.json ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/transformer.json",
3
+ "model_type": "TransformerSVC",
4
+ "dataset": [
5
+ "m4singer",
6
+ "opencpop",
7
+ "opensinger",
8
+ "svcc",
9
+ "vctk"
10
+ ],
11
+ "dataset_path": {
12
+ // TODO: Fill in your dataset path
13
+ "m4singer": "[M4Singer dataset path]",
14
+ "opencpop": "[Opencpop dataset path]",
15
+ "opensinger": "[OpenSinger dataset path]",
16
+ "svcc": "[SVCC dataset path]",
17
+ "vctk": "[VCTK dataset path]"
18
+ },
19
+ // TODO: Fill in the output log path. The default value is "Amphion/ckpts/svc"
20
+ "log_dir": "ckpts/svc",
21
+ "preprocess": {
22
+ // TODO: Fill in the output data path. The default value is "Amphion/data"
23
+ "processed_dir": "data",
24
+ // Config for features extraction
25
+ "extract_mel": true,
26
+ "extract_pitch": true,
27
+ "extract_energy": true,
28
+ "extract_whisper_feature": true,
29
+ "extract_contentvec_feature": true,
30
+ "extract_wenet_feature": false,
31
+ "whisper_batch_size": 30, // decrease it if your GPU is out of memory
32
+ "contentvec_batch_size": 1,
33
+ // Fill in the content-based pretrained model's path
34
+ "contentvec_file": "pretrained/contentvec/checkpoint_best_legacy_500.pt",
35
+ "wenet_model_path": "pretrained/wenet/20220506_u2pp_conformer_exp/final.pt",
36
+ "wenet_config": "pretrained/wenet/20220506_u2pp_conformer_exp/train.yaml",
37
+ "whisper_model": "medium",
38
+ "whisper_model_path": "pretrained/whisper/medium.pt",
39
+ // Config for features usage
40
+ "use_mel": true,
41
+ "use_min_max_norm_mel": true,
42
+ "use_frame_pitch": true,
43
+ "use_frame_energy": true,
44
+ "use_spkid": true,
45
+ "use_whisper": true,
46
+ "use_contentvec": true,
47
+ "use_wenet": false,
48
+ "n_mel": 100,
49
+ "sample_rate": 24000
50
+ },
51
+ "model": {
52
+ "condition_encoder": {
53
+ // Config for features usage
54
+ "use_whisper": true,
55
+ "use_contentvec": true,
56
+ "use_wenet": false,
57
+ "whisper_dim": 1024,
58
+ "contentvec_dim": 256,
59
+ "wenet_dim": 512,
60
+ "use_singer_encoder": false,
61
+ "pitch_min": 50,
62
+ "pitch_max": 1100
63
+ },
64
+ "transformer": {
65
+ // 'conformer' or 'transformer'
66
+ "type": "conformer",
67
+ "input_dim": 384,
68
+ "output_dim": 100,
69
+ "n_heads": 2,
70
+ "n_layers": 6,
71
+ "filter_channels": 512,
72
+ "dropout": 0.1,
73
+ }
74
+ },
75
+ "train": {
76
+ "batch_size": 64,
77
+ "gradient_accumulation_step": 1,
78
+ "max_epoch": -1, // -1 means no limit
79
+ "save_checkpoint_stride": [
80
+ 50,
81
+ 50
82
+ ],
83
+ "keep_last": [
84
+ 5,
85
+ -1
86
+ ],
87
+ "run_eval": [
88
+ false,
89
+ true
90
+ ],
91
+ "adamw": {
92
+ "lr": 4.0e-4
93
+ },
94
+ "reducelronplateau": {
95
+ "factor": 0.8,
96
+ "patience": 10,
97
+ "min_lr": 1.0e-4
98
+ },
99
+ "dataloader": {
100
+ "num_worker": 8,
101
+ "pin_memory": true
102
+ },
103
+ "sampler": {
104
+ "holistic_shuffle": false,
105
+ "drop_last": true
106
+ }
107
+ }
108
+ }
Amphion/egs/svc/VitsSVC/README.md ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VITS for Singing Voice Conversion
2
+
3
+ This is an implementation of VITS as acoustic model for end-to-end singing voice conversion. Adapted from [so-vits-svc](https://github.com/svc-develop-team/so-vits-svc), SoftVC content encoder is used to extract content features from the source audio. These feature vectors are directly fed into VITS without the need for conversion to a text-based intermediate representation.
4
+
5
+ There are four stages in total:
6
+
7
+ 1. Data preparation
8
+ 2. Features extraction
9
+ 3. Training
10
+ 4. Inference/conversion
11
+
12
+ > **NOTE:** You need to run every command of this recipe in the `Amphion` root path:
13
+ > ```bash
14
+ > cd Amphion
15
+ > ```
16
+
17
+ ## 1. Data Preparation
18
+
19
+ ### Dataset Download
20
+
21
+ By default, we utilize the five datasets for training: M4Singer, Opencpop, OpenSinger, SVCC, and VCTK. How to download them is detailed [here](../../datasets/README.md).
22
+
23
+ ### Configuration
24
+
25
+ Specify the dataset paths in `exp_config.json`. Note that you can change the `dataset` list to use your preferred datasets.
26
+
27
+ ```json
28
+ "dataset": [
29
+ "m4singer",
30
+ "opencpop",
31
+ "opensinger",
32
+ "svcc",
33
+ "vctk"
34
+ ],
35
+ "dataset_path": {
36
+ // TODO: Fill in your dataset path
37
+ "m4singer": "[M4Singer dataset path]",
38
+ "opencpop": "[Opencpop dataset path]",
39
+ "opensinger": "[OpenSinger dataset path]",
40
+ "svcc": "[SVCC dataset path]",
41
+ "vctk": "[VCTK dataset path]"
42
+ },
43
+ ```
44
+
45
+ ## 2. Features Extraction
46
+
47
+ ### Content-based Pretrained Models Download
48
+
49
+ By default, we utilize ContentVec and Whisper to extract content features. How to download them is detailed [here](../../../pretrained/README.md).
50
+
51
+ ### Configuration
52
+
53
+ Specify the dataset path and the output path for saving the processed data and the training model in `exp_config.json`:
54
+
55
+ ```json
56
+ // TODO: Fill in the output log path. The default value is "Amphion/ckpts/svc"
57
+ "log_dir": "ckpts/svc",
58
+ "preprocess": {
59
+ // TODO: Fill in the output data path. The default value is "Amphion/data"
60
+ "processed_dir": "data",
61
+ ...
62
+ },
63
+ ```
64
+
65
+ ### Run
66
+
67
+ Run the `run.sh` as the preproces stage (set `--stage 1`).
68
+
69
+ ```bash
70
+ sh egs/svc/VitsSVC/run.sh --stage 1
71
+ ```
72
+
73
+ > **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "1"`.
74
+
75
+ ## 3. Training
76
+
77
+ ### Configuration
78
+
79
+ We provide the default hyparameters in the `exp_config.json`. They can work on single NVIDIA-24g GPU. You can adjust them based on you GPU machines.
80
+
81
+ ```json
82
+ "train": {
83
+ "batch_size": 32,
84
+ ...
85
+ "adamw": {
86
+ "lr": 2.0e-4
87
+ },
88
+ ...
89
+ }
90
+ ```
91
+
92
+ ### Run
93
+
94
+ Run the `run.sh` as the training stage (set `--stage 2`). Specify a experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/svc/[YourExptName]`.
95
+
96
+ ```bash
97
+ sh egs/svc/VitsSVC/run.sh --stage 2 --name [YourExptName]
98
+ ```
99
+
100
+ > **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`.
101
+
102
+ ## 4. Inference/Conversion
103
+
104
+ ### Run
105
+
106
+ For inference/conversion, you need to specify the following configurations when running `run.sh`:
107
+
108
+ | Parameters | Description | Example |
109
+ | --------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
110
+ | `--infer_expt_dir` | The experimental directory which contains `checkpoint` | `[Your path to save logs and checkpoints]/[YourExptName]` |
111
+ | `--infer_output_dir` | The output directory to save inferred audios. | `[Your path to save logs and checkpoints]/[YourExptName]/result` |
112
+ | `--infer_source_file` or `--infer_source_audio_dir` | The inference source (can be a json file or a dir). | The `infer_source_file` could be `[Your path to save processed data]/[YourDataset]/test.json`, and the `infer_source_audio_dir` is a folder which includes several audio files (*.wav, *.mp3 or *.flac). |
113
+ | `--infer_target_speaker` | The target speaker you want to convert into. You can refer to `[Your path to save logs and checkpoints]/[YourExptName]/singers.json` to choose a trained speaker. | For opencpop dataset, the speaker name would be `opencpop_female1`. |
114
+ | `--infer_key_shift` | How many semitones you want to transpose. | `"autoshfit"` (by default), `3`, `-3`, etc. |
115
+
116
+ For example, if you want to make `opencpop_female1` sing the songs in the `[Your Audios Folder]`, just run:
117
+
118
+ ```bash
119
+ sh egs/svc/VitsSVC/run.sh --stage 3 --gpu "0" \
120
+ --infer_expt_dir Amphion/ckpts/svc/[YourExptName] \
121
+ --infer_output_dir Amphion/ckpts/svc/[YourExptName]/result \
122
+ --infer_source_audio_dir [Your Audios Folder] \
123
+ --infer_target_speaker "opencpop_female1" \
124
+ --infer_key_shift "autoshift"
125
+ ```
Amphion/egs/tta/README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Amphion Text-to-Audio (TTA) Recipe
2
+
3
+ ## Quick Start
4
+
5
+ We provide a **[beginner recipe](RECIPE.md)** to demonstrate how to train a cutting edge TTA model. Specifically, it is designed as a latent diffusion model like [AudioLDM](https://arxiv.org/abs/2301.12503), [Make-an-Audio](https://arxiv.org/abs/2301.12661), and [AUDIT](https://arxiv.org/abs/2304.00830).
6
+
7
+ ## Supported Model Architectures
8
+
9
+ Until now, Amphion has supported a latent diffusion based text-to-audio model:
10
+
11
+ <br>
12
+ <div align="center">
13
+ <img src="../../imgs/tta/DiffusionTTA.png" width="65%">
14
+ </div>
15
+ <br>
16
+
17
+ Similar to [AUDIT](https://arxiv.org/abs/2304.00830), we implement it in two-stage training:
18
+ 1. Training the VAE which is called `AutoencoderKL` in Amphion.
19
+ 2. Training the conditional latent diffusion model which is called `AudioLDM` in Amphion.
Amphion/egs/tta/audioldm/exp_config.json ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "egs/tta/audioldm/exp_config_base.json",
3
+ "dataset": [
4
+ "AudioCaps"
5
+ ],
6
+ "preprocess": {
7
+ // Specify the output root path to save the processed data
8
+ "processed_dir": "data",
9
+ // For example: "/home/TTADataset/processed_data"
10
+
11
+ // feature
12
+ "use_spkid": false,
13
+ "use_uv": false,
14
+ "use_frame_pitch": false,
15
+ "use_phone_pitch": false,
16
+ "use_frame_energy": false,
17
+ "use_phone_energy": false,
18
+ "use_mel": false,
19
+ "use_audio": false,
20
+ "use_label": false,
21
+ "use_one_hot": false,
22
+ // feature for text to audio
23
+ "use_caption": true,
24
+ "use_melspec": true,
25
+ "use_wav": false,
26
+ // feature dir
27
+ "melspec_dir": "mel",
28
+ "wav_dir": "wav"
29
+ },
30
+ // Specify the output root path to save model ckpts and logs
31
+ "log_dir": "ckpts/tta",
32
+ // For example: "/home/TTADataset/processed_data/logs"
33
+
34
+ // model
35
+ "model": {
36
+ "audioldm": {
37
+ "image_size": 32,
38
+ "in_channels": 4,
39
+ "out_channels": 4,
40
+ "model_channels": 256,
41
+ "attention_resolutions": [4, 2, 1],
42
+ "num_res_blocks": 2,
43
+ "channel_mult": [1, 2, 4],
44
+ "num_heads": 8,
45
+ "use_spatial_transformer": true,
46
+ "transformer_depth": 1,
47
+ "context_dim": 768,
48
+ "use_checkpoint": true,
49
+ "legacy": false
50
+ },
51
+ "autoencoderkl": {
52
+ "ch": 128,
53
+ "ch_mult": [1,1,2,2,4],
54
+ "num_res_blocks": 2,
55
+ "in_channels": 1,
56
+ "z_channels": 4,
57
+ "out_ch": 1,
58
+ "double_z": true
59
+ },
60
+ "noise_scheduler": {
61
+ "num_train_timesteps": 1000,
62
+ "beta_start": 0.00085,
63
+ "beta_end": 0.012,
64
+ "beta_schedule": "scaled_linear",
65
+ "clip_sample": false,
66
+ "steps_offset": 1,
67
+ "set_alpha_to_one": false,
68
+ "skip_prk_steps": true,
69
+ "prediction_type": "epsilon"
70
+ },
71
+ "autoencoder_path": "ckpts/tta/autoencoder_kl_debug/checkpoints/step-0445000_loss-0.3306.pt"
72
+ },
73
+
74
+ // train
75
+ "train": {
76
+ "adam": {
77
+ "lr": 5.0e-5
78
+ },
79
+ "ddp": false,
80
+ "random_seed": 12345,
81
+ "batch_size": 12,
82
+ "epochs": 50000,
83
+ "max_steps": 1000000,
84
+ "total_training_steps": 800000,
85
+ "save_summary_steps": 1000,
86
+ "save_checkpoints_steps": 5000,
87
+ "valid_interval": 5000,
88
+ "keep_checkpoint_max": 100
89
+ }
90
+ }
Amphion/egs/tta/audioldm/run_train.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ ######## Build Experiment Environment ###########
7
+ exp_dir=$(cd `dirname $0`; pwd)
8
+ work_dir=$(dirname $(dirname $(dirname $exp_dir)))
9
+
10
+ export WORK_DIR=$work_dir
11
+ export PYTHONPATH=$work_dir
12
+ export PYTHONIOENCODING=UTF-8
13
+
14
+ ######## Set Experiment Configuration ###########
15
+ exp_config="$exp_dir/exp_config.json"
16
+ exp_name="audioldm_debug_latent_size_4_5_39"
17
+
18
+ num_workers=8
19
+ export CUDA_VISIBLE_DEVICES="0"
20
+
21
+ ######## Train Model ###########
22
+ python "${work_dir}"/bins/tta/train_tta.py \
23
+ --config=$exp_config \
24
+ --num_workers=$num_workers \
25
+ --exp_name=$exp_name \
26
+ --stdout_interval=25 \
Amphion/egs/tta/audioldm/run_train_latent_4_10_78.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ ######## Build Experiment Environment ###########
7
+ exp_dir=$(cd `dirname $0`; pwd)
8
+ work_dir=$(dirname $(dirname $(dirname $exp_dir)))
9
+
10
+ export WORK_DIR=$work_dir
11
+ export PYTHONPATH=$work_dir
12
+ export PYTHONIOENCODING=UTF-8
13
+
14
+ ######## Set Experiment Configuration ###########
15
+ exp_config="$exp_dir/exp_config_latent_4_10_78.json"
16
+ exp_name="audioldm_debug_latent_size_4_10_78"
17
+
18
+ num_workers=8
19
+ export CUDA_VISIBLE_DEVICES="0"
20
+
21
+ ######## Train Model ###########
22
+ python "${work_dir}"/bins/tta/train_tta.py \
23
+ --config=$exp_config \
24
+ --num_workers=$num_workers \
25
+ --exp_name=$exp_name \
26
+ --stdout_interval=25 \
Amphion/egs/tta/autoencoderkl/run_train_latent_4_10_78.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ ######## Build Experiment Environment ###########
7
+ exp_dir=$(cd `dirname $0`; pwd)
8
+ work_dir=$(dirname $(dirname $(dirname $exp_dir)))
9
+
10
+ export WORK_DIR=$work_dir
11
+ export PYTHONPATH=$work_dir
12
+ export PYTHONIOENCODING=UTF-8
13
+
14
+ ######## Set Experiment Configuration ###########
15
+ exp_config="$exp_dir/exp_config_latent_4_10_78.json"
16
+ exp_name="autoencoder_kl_debug_latent_size_4_10_78"
17
+
18
+ num_workers=8
19
+ export CUDA_VISIBLE_DEVICES="0"
20
+
21
+ ######## Train Model ###########
22
+ python "${work_dir}"/bins/tta/train_tta.py \
23
+ --config=$exp_config \
24
+ --num_workers=$num_workers \
25
+ --exp_name=$exp_name \
26
+ --stdout_interval=25 \
Amphion/egs/tts/FastSpeech2/prepare_mfa.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ #!/bin/bash
7
+
8
+ # Navigate to the 'pretrained' directory
9
+ cd pretrained || { echo "Failed to change directory to 'pretrained'"; exit 1; }
10
+
11
+ # Create and navigate to the 'mfa' directory
12
+ mkdir -p mfa && cd mfa || { echo "Failed to create or change directory to 'mfa'"; exit 1; }
13
+
14
+ # Define the MFA file URL and the file name
15
+ mfa_url="https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.1.0-beta.2/montreal-forced-aligner_linux.tar.gz"
16
+ mfa_file="montreal-forced-aligner_linux.tar.gz"
17
+
18
+ # Download MFA if it doesn't exist
19
+ if [ ! -f "$mfa_file" ]; then
20
+ wget "$mfa_url" || { echo "Failed to download MFA"; exit 1; }
21
+ fi
22
+
23
+ # Extract MFA
24
+ tar -zxvf "$mfa_file" || { echo "Failed to extract MFA"; exit 1; }
25
+
26
+ # Optionally, remove the tar.gz file after extraction
27
+ rm "$mfa_file"
28
+
29
+ echo "MFA setup completed successfully."
Amphion/egs/tts/FastSpeech2/run.sh ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ ######## Build Experiment Environment ###########
7
+ exp_dir=$(cd `dirname $0`; pwd)
8
+ work_dir=$(dirname $(dirname $(dirname $exp_dir)))
9
+
10
+ export WORK_DIR=$work_dir
11
+ export PYTHONPATH=$work_dir
12
+ export PYTHONIOENCODING=UTF-8
13
+
14
+ cd $work_dir/modules/monotonic_align
15
+ mkdir -p monotonic_align
16
+ python setup.py build_ext --inplace
17
+ cd $work_dir
18
+
19
+ mfa_dir=$work_dir/pretrained/mfa
20
+ echo $mfa_dir
21
+
22
+ ######## Parse the Given Parameters from the Commond ###########
23
+ # options=$(getopt -o c:n:s --long gpu:,config:,infer_expt_dir:,infer_output_dir:,infer_source_file:,infer_source_audio_dir:,infer_target_speaker:,infer_key_shift:,infer_vocoder_dir:,name:,stage: -- "$@")
24
+ options=$(getopt -o c:n:s --long gpu:,config:,infer_expt_dir:,infer_output_dir:,infer_mode:,infer_dataset:,infer_testing_set:,infer_text:,name:,stage:,vocoder_dir: -- "$@")
25
+ eval set -- "$options"
26
+
27
+ while true; do
28
+ case $1 in
29
+ # Experimental Configuration File
30
+ -c | --config) shift; exp_config=$1 ; shift ;;
31
+ # Experimental Name
32
+ -n | --name) shift; exp_name=$1 ; shift ;;
33
+ # Running Stage
34
+ -s | --stage) shift; running_stage=$1 ; shift ;;
35
+ # Visible GPU machines. The default value is "0".
36
+ --gpu) shift; gpu=$1 ; shift ;;
37
+
38
+ # [Only for Inference] The experiment dir. The value is like "[Your path to save logs and checkpoints]/[YourExptName]"
39
+ --infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;;
40
+ # [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result"
41
+ --infer_output_dir) shift; infer_output_dir=$1 ; shift ;;
42
+ # [Only for Inference] The inference mode. It can be "batch" to generate speech by batch, or "single" to generage a single clip of speech.
43
+ --infer_mode) shift; infer_mode=$1 ; shift ;;
44
+ # [Only for Inference] The inference dataset. It is only used when the inference model is "batch".
45
+ --infer_dataset) shift; infer_dataset=$1 ; shift ;;
46
+ # [Only for Inference] The inference testing set. It is only used when the inference model is "batch". It can be "test" set split from the dataset, or "golden_test" carefully selected from the testing set.
47
+ --infer_testing_set) shift; infer_testing_set=$1 ; shift ;;
48
+ # [Only for Inference] The text to be synthesized from. It is only used when the inference model is "single".
49
+ --infer_text) shift; infer_text=$1 ; shift ;;
50
+ # [Only for Inference] The output dir to the vocoder.
51
+ --vocoder_dir) shift; vocoder_dir=$1 ; shift ;;
52
+
53
+ --) shift ; break ;;
54
+ *) echo "Invalid option: $1" exit 1 ;;
55
+ esac
56
+ done
57
+
58
+
59
+ ### Value check ###
60
+ if [ -z "$running_stage" ]; then
61
+ echo "[Error] Please specify the running stage"
62
+ exit 1
63
+ fi
64
+
65
+ if [ -z "$exp_config" ]; then
66
+ exp_config="${exp_dir}"/exp_config.json
67
+ fi
68
+ echo "Exprimental Configuration File: $exp_config"
69
+
70
+ if [ -z "$gpu" ]; then
71
+ gpu="0"
72
+ fi
73
+
74
+ ######## Features Extraction ###########
75
+ if [ $running_stage -eq 1 ]; then
76
+ if [ ! -d "$mfa_dir/montreal-forced-aligner" ]; then
77
+ bash ${exp_dir}/prepare_mfa.sh
78
+ fi
79
+ CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/tts/preprocess.py \
80
+ --config=$exp_config \
81
+ --num_workers=4 \
82
+ --prepare_alignment=true
83
+ fi
84
+
85
+ ######## Training ###########
86
+ if [ $running_stage -eq 2 ]; then
87
+ if [ -z "$exp_name" ]; then
88
+ echo "[Error] Please specify the experiments name"
89
+ exit 1
90
+ fi
91
+ echo "Exprimental Name: $exp_name"
92
+
93
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch "${work_dir}"/bins/tts/train.py \
94
+ --config $exp_config \
95
+ --exp_name $exp_name \
96
+ --log_level debug
97
+ fi
98
+
99
+ ######## Inference ###########
100
+ if [ $running_stage -eq 3 ]; then
101
+ if [ -z "$infer_expt_dir" ]; then
102
+ echo "[Error] Please specify the experimental directionary. The value is like [Your path to save logs and checkpoints]/[YourExptName]"
103
+ exit 1
104
+ fi
105
+
106
+ if [ -z "$infer_output_dir" ]; then
107
+ infer_output_dir="$expt_dir/result"
108
+ fi
109
+
110
+ if [ -z "$vocoder_dir" ]; then
111
+ echo "[Error] Please specify the vocoder directory to reconstruct waveform from mel spectrogram."
112
+ exit 1
113
+ fi
114
+
115
+ if [ -z "$infer_mode" ]; then
116
+ echo "[Error] Please specify the inference mode, e.g., "batch", "single""
117
+ exit 1
118
+ fi
119
+
120
+ if [ "$infer_mode" = "batch" ] && [ -z "$infer_dataset" ]; then
121
+ echo "[Error] Please specify the dataset used in inference when the inference mode is batch"
122
+ exit 1
123
+ fi
124
+
125
+ if [ "$infer_mode" = "batch" ] && [ -z "$infer_testing_set" ]; then
126
+ echo "[Error] Please specify the testing set used in inference when the inference mode is batch"
127
+ exit 1
128
+ fi
129
+
130
+ if [ "$infer_mode" = "single" ] && [ -z "$infer_text" ]; then
131
+ echo "[Error] Please specify the text to be synthesized when the inference mode is single"
132
+ exit 1
133
+ fi
134
+
135
+ if [ "$infer_mode" = "single" ]; then
136
+ echo 'Text: ' ${infer_text}
137
+ infer_dataset=None
138
+ infer_testing_set=None
139
+ elif [ "$infer_mode" = "batch" ]; then
140
+ infer_text=''
141
+ fi
142
+
143
+
144
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/tts/inference.py \
145
+ --config $exp_config \
146
+ --acoustics_dir $infer_expt_dir \
147
+ --output_dir $infer_output_dir \
148
+ --mode $infer_mode \
149
+ --dataset $infer_dataset \
150
+ --testing_set $infer_testing_set \
151
+ --text "$infer_text" \
152
+ --log_level debug \
153
+ --vocoder_dir $vocoder_dir
154
+
155
+ fi
Amphion/egs/tts/NaturalSpeech2/run_inference.sh ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ ######## Build Experiment Environment ###########
8
+ exp_dir=$(cd `dirname $0`; pwd)
9
+ work_dir=$(dirname $(dirname $(dirname $exp_dir)))
10
+
11
+ export WORK_DIR=$work_dir
12
+ export PYTHONPATH=$work_dir
13
+ export PYTHONIOENCODING=UTF-8
14
+
15
+ ######## Set Experiment Configuration ###########
16
+ exp_config="$exp_dir/exp_config.json"
17
+ exp_name="ns2_libritts"
18
+ ref_audio="$work_dir/egs/tts/NaturalSpeech2/prompt_example/ref_audio.wav"
19
+ checkpoint_path="$work_dir/ckpts/tts/naturalspeech2_libritts/checkpoint/epoch-0089_step-0512912_loss-6.367693"
20
+ output_dir="$work_dir/output"
21
+ mode="single"
22
+
23
+ export CUDA_VISIBLE_DEVICES="0"
24
+
25
+ ######## Parse Command Line Arguments ###########
26
+ while [[ $# -gt 0 ]]
27
+ do
28
+ key="$1"
29
+
30
+ case $key in
31
+ --text)
32
+ text="$2"
33
+ shift # past argument
34
+ shift # past value
35
+ ;;
36
+ *) # unknown option
37
+ shift # past argument
38
+ ;;
39
+ esac
40
+ done
41
+
42
+ ######## Train Model ###########
43
+ python "${work_dir}"/bins/tts/inference.py \
44
+ --config=$exp_config \
45
+ --text="$text" \
46
+ --mode=$mode \
47
+ --checkpoint_path=$checkpoint_path \
48
+ --ref_audio=$ref_audio \
49
+ --output_dir=$output_dir \
Amphion/egs/tts/VALLE/README.md ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VALL-E Recipe
2
+
3
+ In this recipe, we will show how to train [VALL-E](https://arxiv.org/abs/2301.02111) using Amphion's infrastructure. VALL-E is a zero-shot TTS architecture that uses a neural codec language model with discrete codes.
4
+
5
+ There are four stages in total:
6
+
7
+ 1. Data preparation
8
+ 2. Features extraction
9
+ 3. Training
10
+ 4. Inference
11
+
12
+ > **NOTE:** You need to run every command of this recipe in the `Amphion` root path:
13
+ > ```bash
14
+ > cd Amphion
15
+ > ```
16
+
17
+ ## 1. Data Preparation
18
+
19
+ ### Dataset Download
20
+ You can use the commonly used TTS dataset to train the VALL-E model, e.g., LibriTTS, etc. We strongly recommend you use LibriTTS to train the VALL-E model for the first time. How to download the dataset is detailed [here](../../datasets/README.md).
21
+
22
+ ### Configuration
23
+
24
+ After downloading the dataset, you can set the dataset paths in `exp_config.json`. Note that you can change the `dataset` list to use your preferred datasets.
25
+
26
+ ```json
27
+ "dataset": [
28
+ "libritts",
29
+ ],
30
+ "dataset_path": {
31
+ // TODO: Fill in your dataset path
32
+ "libritts": "[LibriTTS dataset path]",
33
+ },
34
+ ```
35
+
36
+ ## 2. Features Extraction
37
+
38
+ ### Configuration
39
+
40
+ Specify the `processed_dir` and the `log_dir` and for saving the processed data and the checkpoints in `exp_config.json`:
41
+
42
+ ```json
43
+ // TODO: Fill in the output log path. The default value is "Amphion/ckpts/tts"
44
+ "log_dir": "ckpts/tts",
45
+ "preprocess": {
46
+ // TODO: Fill in the output data path. The default value is "Amphion/data"
47
+ "processed_dir": "data",
48
+ ...
49
+ },
50
+ ```
51
+
52
+ ### Run
53
+
54
+ Run the `run.sh` as the preprocess stage (set `--stage 1`):
55
+
56
+ ```bash
57
+ sh egs/tts/VALLE/run.sh --stage 1
58
+ ```
59
+
60
+ > **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "1"`.
61
+
62
+
63
+ ## 3. Training
64
+
65
+ ### Configuration
66
+
67
+ We provide the default hyperparameters in the `exp_config.json`. They can work on a single NVIDIA-24g GPU. You can adjust them based on your GPU machines.
68
+
69
+ ```json
70
+ "train": {
71
+ "batch_size": 4,
72
+ }
73
+ ```
74
+
75
+ ### Train From Scratch
76
+
77
+ Run the `run.sh` as the training stage (set `--stage 2`). Specify an experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/tts/[YourExptName]`.
78
+
79
+ Specifically, VALL-E needs to train an autoregressive (AR) model and then a non-autoregressive (NAR) model. So, you can set `--model_train_stage 1` to train AR model, and set `--model_train_stage 2` to train NAR model, where `--ar_model_ckpt_dir` should be set as the checkpoint path to the trained AR model.
80
+
81
+
82
+ Train an AR model, just run:
83
+
84
+ ```bash
85
+ sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 1 --name [YourExptName]
86
+ ```
87
+
88
+ Train a NAR model, just run:
89
+ ```bash
90
+ sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 2 --ar_model_ckpt_dir [ARModelPath] --name [YourExptName]
91
+ ```
92
+ <!-- > **NOTE:** To train a NAR model, `--checkpoint_path` should be set as the checkpoint path to the trained AR model. -->
93
+
94
+
95
+ ### Train From Existing Source
96
+
97
+ We support training from existing sources for various purposes. You can resume training the model from a checkpoint or fine-tune a model from another checkpoint.
98
+
99
+ By setting `--resume true`, the training will resume from the **latest checkpoint** from the current `[YourExptName]` by default. For example, if you want to resume training from the latest checkpoint in `Amphion/ckpts/tts/[YourExptName]/checkpoint`,
100
+
101
+ Train an AR model, just run:
102
+
103
+ ```bash
104
+ sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 1 --name [YourExptName] \
105
+ --resume true
106
+ ```
107
+
108
+ Train a NAR model, just run:
109
+ ```bash
110
+ sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 2 --ar_model_ckpt_dir [ARModelPath] --name [YourExptName] \
111
+ --resume true
112
+ ```
113
+
114
+
115
+
116
+ You can also choose a **specific checkpoint** for retraining by `--resume_from_ckpt_path` argument. For example, if you want to resume training from the checkpoint `Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificCheckpoint]`,
117
+
118
+ Train an AR model, just run:
119
+
120
+ ```bash
121
+ sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 1 --name [YourExptName] \
122
+ --resume true \
123
+ --resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificARCheckpoint]"
124
+ ```
125
+
126
+ Train a NAR model, just run:
127
+ ```bash
128
+ sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 2 --ar_model_ckpt_dir [ARModelPath] --name [YourExptName] \
129
+ --resume true \
130
+ --resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificNARCheckpoint]"
131
+ ```
132
+
133
+
134
+ If you want to **fine-tune from another checkpoint**, just use `--resume_type` and set it to `"finetune"`. For example, If you want to fine-tune the model from the checkpoint `Amphion/ckpts/tts/[AnotherExperiment]/checkpoint/[SpecificCheckpoint]`,
135
+
136
+
137
+ Train an AR model, just run:
138
+
139
+ ```bash
140
+ sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 1 --name [YourExptName] \
141
+ --resume true \
142
+ --resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificARCheckpoint]" \
143
+ --resume_type "finetune"
144
+ ```
145
+
146
+ Train a NAR model, just run:
147
+ ```bash
148
+ sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 2 --ar_model_ckpt_dir [ARModelPath] --name [YourExptName] \
149
+ --resume true \
150
+ --resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificNARCheckpoint]" \
151
+ --resume_type "finetune"
152
+ ```
153
+
154
+ > **NOTE:** The `--resume_type` is set as `"resume"` in default. It's not necessary to specify it when resuming training.
155
+ >
156
+ > The difference between `"resume"` and `"finetune"` is that the `"finetune"` will **only** load the pretrained model weights from the checkpoint, while the `"resume"` will load all the training states (including optimizer, scheduler, etc.) from the checkpoint.
157
+
158
+
159
+
160
+
161
+ > **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`.
162
+
163
+
164
+ ## 4. Inference
165
+
166
+ ### Configuration
167
+
168
+ For inference, you need to specify the following configurations when running `run.sh`:
169
+
170
+
171
+
172
+ | Parameters | Description | Example |
173
+ | --------------------- | -------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
174
+ | `--infer_expt_dir` | The experimental directory of NAR model which contains `checkpoint` | `Amphion/ckpts/tts/[YourExptName]` |
175
+ | `--infer_output_dir` | The output directory to save inferred audios. | `Amphion/ckpts/tts/[YourExptName]/result` |
176
+ | `--infer_mode` | The inference mode, e.g., "`single`", "`batch`". | "`single`" to generate a clip of speech, "`batch`" to generate a batch of speech at a time. |
177
+ | `--infer_text` | The text to be synthesized. | "`This is a clip of generated speech with the given text from a TTS model.`" |
178
+ | `--infer_text_prompt` | The text prompt for inference. | The text prompt should be aligned with the audio prompt. |
179
+ | `--infer_audio_prompt` | The audio prompt for inference. | The audio prompt should be aligned with text prompt.|
180
+ | `--test_list_file` | The test list file used for batch inference. | The format of test list file is `text\|text_prompt\|audio_prompt`.|
181
+
182
+
183
+ ### Run
184
+ For example, if you want to generate a single clip of speech, just run:
185
+
186
+ ```bash
187
+ sh egs/tts/VALLE/run.sh --stage 3 --gpu "0" \
188
+ --infer_expt_dir Amphion/ckpts/tts/[YourExptName] \
189
+ --infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \
190
+ --infer_mode "single" \
191
+ --infer_text "This is a clip of generated speech with the given text from a TTS model." \
192
+ --infer_text_prompt "But even the unsuccessful dramatist has his moments." \
193
+ --infer_audio_prompt egs/tts/VALLE/prompt_examples/7176_92135_000004_000000.wav
194
+ ```
195
+
196
+ We have released pre-trained VALL-E models, so you can download the pre-trained model and then generate speech following the above inference instruction. Specifically,
197
+ 1. The pre-trained VALL-E trained on [LibriTTS](https://github.com/open-mmlab/Amphion/tree/main/egs/datasets#libritts) can be downloaded [here](https://huggingface.co/amphion/valle-libritts).
198
+ 2. The pre-trained VALL-E trained on the part of [Libri-light](https://ai.meta.com/tools/libri-light/) (about 6k hours) can be downloaded [here](https://huggingface.co/amphion/valle_librilight_6k).
199
+
200
+ ```bibtex
201
+ @article{wang2023neural,
202
+ title={Neural codec language models are zero-shot text to speech synthesizers},
203
+ author={Wang, Chengyi and Chen, Sanyuan and Wu, Yu and Zhang, Ziqiang and Zhou, Long and Liu, Shujie and Chen, Zhuo and Liu, Yanqing and Wang, Huaming and Li, Jinyu and others},
204
+ journal={arXiv preprint arXiv:2301.02111},
205
+ year={2023}
206
+ }
207
+ ```
Amphion/egs/tts/VALLE/prompt_examples/5142_33396_000002_000004.wav ADDED
Binary file (144 kB). View file
 
Amphion/egs/tts/VALLE/prompt_examples/7176_92135_000004_000000.normalized.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ But even the unsuccessful dramatist has his moments.
Amphion/egs/tts/VITS/exp_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/vits.json",
3
+ "model_type": "VITS",
4
+ "dataset": [
5
+ "LJSpeech",
6
+ //"hifitts"
7
+ ],
8
+ "dataset_path": {
9
+ // TODO: Fill in your dataset path
10
+ "LJSpeech": "[LJSpeech dataset path]",
11
+ //"hifitts": "[Hi-Fi TTS dataset path]
12
+ },
13
+ // TODO: Fill in the output log path. The default value is "Amphion/ckpts/tts"
14
+ "log_dir": "ckpts/tts",
15
+ "preprocess": {
16
+ //"extract_audio":true,
17
+ "use_phone": true,
18
+ // linguistic features
19
+ "extract_phone": true,
20
+ "phone_extractor": "espeak", // "espeak, pypinyin, pypinyin_initials_finals, lexicon (only for language=en-us right now)"
21
+ // TODO: Fill in the output data path. The default value is "Amphion/data"
22
+ "processed_dir": "data",
23
+ "sample_rate": 22050, // target sampling rate
24
+ "valid_file": "valid.json", // validation set
25
+ //"use_spkid": true // use speaker ID to train multi-speaker TTS model
26
+ },
27
+ "model":{
28
+ //"n_speakers": 10 // number of speakers, greater than or equal to the number of speakers in the dataset(s) used. The default value is 0 if not specified.
29
+ },
30
+ "train": {
31
+ "batch_size": 16,
32
+ //"multi_speaker_training": true
33
+ }
34
+ }
Amphion/egs/vocoder/gan/bigvgan_large/exp_config.json ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "egs/vocoder/gan/exp_config_base.json",
3
+ "preprocess": {
4
+ // acoustic features
5
+ "extract_mel": true,
6
+ "extract_audio": true,
7
+
8
+ // Features used for model training
9
+ "use_mel": true,
10
+ "use_audio": true
11
+ },
12
+ "model": {
13
+ "generator": "bigvgan",
14
+ "bigvgan": {
15
+ "resblock": "1",
16
+ "activation": "snakebeta",
17
+ "snake_logscale": true,
18
+ "upsample_rates": [
19
+ 4,
20
+ 4,
21
+ 2,
22
+ 2,
23
+ 2,
24
+ 2
25
+ ],
26
+ "upsample_kernel_sizes": [
27
+ 8,
28
+ 8,
29
+ 4,
30
+ 4,
31
+ 4,
32
+ 4
33
+ ],
34
+ "upsample_initial_channel": 1536,
35
+ "resblock_kernel_sizes": [
36
+ 3,
37
+ 7,
38
+ 11
39
+ ],
40
+ "resblock_dilation_sizes": [
41
+ [
42
+ 1,
43
+ 3,
44
+ 5
45
+ ],
46
+ [
47
+ 1,
48
+ 3,
49
+ 5
50
+ ],
51
+ [
52
+ 1,
53
+ 3,
54
+ 5
55
+ ]
56
+ ]
57
+ },
58
+ },
59
+ "train": {
60
+ "criterions": [
61
+ "feature",
62
+ "discriminator",
63
+ "generator",
64
+ "mel",
65
+ ]
66
+ },
67
+ "inference": {
68
+ "batch_size": 1,
69
+ }
70
+ }
Amphion/egs/vocoder/gan/hifigan/exp_config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "egs/vocoder/gan/exp_config_base.json",
3
+ "preprocess": {
4
+ // acoustic features
5
+ "extract_mel": true,
6
+ "extract_audio": true,
7
+
8
+ // Features used for model training
9
+ "use_mel": true,
10
+ "use_audio": true
11
+ },
12
+ "model": {
13
+ "generator": "hifigan",
14
+ "hifigan": {
15
+ "resblock": "2",
16
+ "upsample_rates": [
17
+ 8,
18
+ 8,
19
+ 4
20
+ ],
21
+ "upsample_kernel_sizes": [
22
+ 16,
23
+ 16,
24
+ 8
25
+ ],
26
+ "upsample_initial_channel": 256,
27
+ "resblock_kernel_sizes": [
28
+ 3,
29
+ 5,
30
+ 7
31
+ ],
32
+ "resblock_dilation_sizes": [
33
+ [
34
+ 1,
35
+ 2
36
+ ],
37
+ [
38
+ 2,
39
+ 6
40
+ ],
41
+ [
42
+ 3,
43
+ 12
44
+ ]
45
+ ]
46
+ }
47
+ },
48
+ "train": {
49
+ "criterions": [
50
+ "feature",
51
+ "discriminator",
52
+ "generator",
53
+ "mel",
54
+ ]
55
+ },
56
+ "inference": {
57
+ "batch_size": 1,
58
+ }
59
+ }
Amphion/egs/vocoder/gan/hifigan/run.sh ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ ######## Build Experiment Environment ###########
7
+ exp_dir=$(cd `dirname $0`; pwd)
8
+ work_dir=$(dirname $(dirname $(dirname $(dirname $exp_dir))))
9
+
10
+ export WORK_DIR=$work_dir
11
+ export PYTHONPATH=$work_dir
12
+ export PYTHONIOENCODING=UTF-8
13
+
14
+ ######## Parse the Given Parameters from the Commond ###########
15
+ options=$(getopt -o c:n:s --long gpu:,config:,name:,stage:,checkpoint:,resume_type:,main_process_port:,infer_mode:,infer_datasets:,infer_feature_dir:,infer_audio_dir:,infer_expt_dir:,infer_output_dir: -- "$@")
16
+ eval set -- "$options"
17
+
18
+ while true; do
19
+ case $1 in
20
+ # Experimental Configuration File
21
+ -c | --config) shift; exp_config=$1 ; shift ;;
22
+ # Experimental Name
23
+ -n | --name) shift; exp_name=$1 ; shift ;;
24
+ # Running Stage
25
+ -s | --stage) shift; running_stage=$1 ; shift ;;
26
+ # Visible GPU machines. The default value is "0".
27
+ --gpu) shift; gpu=$1 ; shift ;;
28
+
29
+ # [Only for Training] The specific checkpoint path that you want to resume from.
30
+ --checkpoint) shift; checkpoint=$1 ; shift ;;
31
+ # [Only for Training] `resume` for loading all the things (including model weights, optimizer, scheduler, and random states). `finetune` for loading only the model weights.
32
+ --resume_type) shift; resume_type=$1 ; shift ;;
33
+ # [Only for Traiing] `main_process_port` for multi gpu training
34
+ --main_process_port) shift; main_process_port=$1 ; shift ;;
35
+
36
+ # [Only for Inference] The inference mode
37
+ --infer_mode) shift; infer_mode=$1 ; shift ;;
38
+ # [Only for Inference] The inferenced datasets
39
+ --infer_datasets) shift; infer_datasets=$1 ; shift ;;
40
+ # [Only for Inference] The feature dir for inference
41
+ --infer_feature_dir) shift; infer_feature_dir=$1 ; shift ;;
42
+ # [Only for Inference] The audio dir for inference
43
+ --infer_audio_dir) shift; infer_audio_dir=$1 ; shift ;;
44
+ # [Only for Inference] The experiment dir. The value is like "[Your path to save logs and checkpoints]/[YourExptName]"
45
+ --infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;;
46
+ # [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result"
47
+ --infer_output_dir) shift; infer_output_dir=$1 ; shift ;;
48
+
49
+ --) shift ; break ;;
50
+ *) echo "Invalid option: $1" exit 1 ;;
51
+ esac
52
+ done
53
+
54
+
55
+ ### Value check ###
56
+ if [ -z "$running_stage" ]; then
57
+ echo "[Error] Please specify the running stage"
58
+ exit 1
59
+ fi
60
+
61
+ if [ -z "$exp_config" ]; then
62
+ exp_config="${exp_dir}"/exp_config.json
63
+ fi
64
+ echo "Exprimental Configuration File: $exp_config"
65
+
66
+ if [ -z "$gpu" ]; then
67
+ gpu="0"
68
+ fi
69
+
70
+ if [ -z "$main_process_port" ]; then
71
+ main_process_port=29500
72
+ fi
73
+ echo "Main Process Port: $main_process_port"
74
+
75
+ ######## Features Extraction ###########
76
+ if [ $running_stage -eq 1 ]; then
77
+ CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/vocoder/preprocess.py \
78
+ --config $exp_config \
79
+ --num_workers 8
80
+ fi
81
+
82
+ ######## Training ###########
83
+ if [ $running_stage -eq 2 ]; then
84
+ if [ -z "$exp_name" ]; then
85
+ echo "[Error] Please specify the experiments name"
86
+ exit 1
87
+ fi
88
+ echo "Exprimental Name: $exp_name"
89
+
90
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch \
91
+ --main_process_port "$main_process_port" \
92
+ "${work_dir}"/bins/vocoder/train.py \
93
+ --config "$exp_config" \
94
+ --exp_name "$exp_name" \
95
+ --log_level info \
96
+ --checkpoint "$checkpoint" \
97
+ --resume_type "$resume_type"
98
+ fi
99
+
100
+ ######## Inference/Conversion ###########
101
+ if [ $running_stage -eq 3 ]; then
102
+ if [ -z "$infer_expt_dir" ]; then
103
+ echo "[Error] Please specify the experimental directionary. The value is like [Your path to save logs and checkpoints]/[YourExptName]"
104
+ exit 1
105
+ fi
106
+
107
+ if [ -z "$infer_output_dir" ]; then
108
+ infer_output_dir="$infer_expt_dir/result"
109
+ fi
110
+
111
+ if [ $infer_mode = "infer_from_dataset" ]; then
112
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
113
+ --config $exp_config \
114
+ --infer_mode $infer_mode \
115
+ --infer_datasets $infer_datasets \
116
+ --vocoder_dir $infer_expt_dir \
117
+ --output_dir $infer_output_dir \
118
+ --log_level debug
119
+ fi
120
+
121
+ if [ $infer_mode = "infer_from_feature" ]; then
122
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
123
+ --config $exp_config \
124
+ --infer_mode $infer_mode \
125
+ --feature_folder $infer_feature_dir \
126
+ --vocoder_dir $infer_expt_dir \
127
+ --output_dir $infer_output_dir \
128
+ --log_level debug
129
+ fi
130
+
131
+ if [ $infer_mode = "infer_from_audio" ]; then
132
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
133
+ --config $exp_config \
134
+ --infer_mode $infer_mode \
135
+ --audio_folder $infer_audio_dir \
136
+ --vocoder_dir $infer_expt_dir \
137
+ --output_dir $infer_output_dir \
138
+ --log_level debug
139
+ fi
140
+
141
+ fi
Amphion/egs/vocoder/gan/nsfhifigan/run.sh ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ ######## Build Experiment Environment ###########
7
+ exp_dir=$(cd `dirname $0`; pwd)
8
+ work_dir=$(dirname $(dirname $(dirname $(dirname $exp_dir))))
9
+
10
+ export WORK_DIR=$work_dir
11
+ export PYTHONPATH=$work_dir
12
+ export PYTHONIOENCODING=UTF-8
13
+
14
+ ######## Parse the Given Parameters from the Commond ###########
15
+ options=$(getopt -o c:n:s --long gpu:,config:,name:,stage:,checkpoint:,resume_type:,main_process_port:,infer_mode:,infer_datasets:,infer_feature_dir:,infer_audio_dir:,infer_expt_dir:,infer_output_dir: -- "$@")
16
+ eval set -- "$options"
17
+
18
+ while true; do
19
+ case $1 in
20
+ # Experimental Configuration File
21
+ -c | --config) shift; exp_config=$1 ; shift ;;
22
+ # Experimental Name
23
+ -n | --name) shift; exp_name=$1 ; shift ;;
24
+ # Running Stage
25
+ -s | --stage) shift; running_stage=$1 ; shift ;;
26
+ # Visible GPU machines. The default value is "0".
27
+ --gpu) shift; gpu=$1 ; shift ;;
28
+
29
+ # [Only for Training] The specific checkpoint path that you want to resume from.
30
+ --checkpoint) shift; checkpoint=$1 ; shift ;;
31
+ # [Only for Training] `resume` for loading all the things (including model weights, optimizer, scheduler, and random states). `finetune` for loading only the model weights.
32
+ --resume_type) shift; resume_type=$1 ; shift ;;
33
+ # [Only for Traiing] `main_process_port` for multi gpu training
34
+ --main_process_port) shift; main_process_port=$1 ; shift ;;
35
+
36
+ # [Only for Inference] The inference mode
37
+ --infer_mode) shift; infer_mode=$1 ; shift ;;
38
+ # [Only for Inference] The inferenced datasets
39
+ --infer_datasets) shift; infer_datasets=$1 ; shift ;;
40
+ # [Only for Inference] The feature dir for inference
41
+ --infer_feature_dir) shift; infer_feature_dir=$1 ; shift ;;
42
+ # [Only for Inference] The audio dir for inference
43
+ --infer_audio_dir) shift; infer_audio_dir=$1 ; shift ;;
44
+ # [Only for Inference] The experiment dir. The value is like "[Your path to save logs and checkpoints]/[YourExptName]"
45
+ --infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;;
46
+ # [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result"
47
+ --infer_output_dir) shift; infer_output_dir=$1 ; shift ;;
48
+
49
+ --) shift ; break ;;
50
+ *) echo "Invalid option: $1" exit 1 ;;
51
+ esac
52
+ done
53
+
54
+
55
+ ### Value check ###
56
+ if [ -z "$running_stage" ]; then
57
+ echo "[Error] Please specify the running stage"
58
+ exit 1
59
+ fi
60
+
61
+ if [ -z "$exp_config" ]; then
62
+ exp_config="${exp_dir}"/exp_config.json
63
+ fi
64
+ echo "Exprimental Configuration File: $exp_config"
65
+
66
+ if [ -z "$gpu" ]; then
67
+ gpu="0"
68
+ fi
69
+
70
+ if [ -z "$main_process_port" ]; then
71
+ main_process_port=29500
72
+ fi
73
+ echo "Main Process Port: $main_process_port"
74
+
75
+ ######## Features Extraction ###########
76
+ if [ $running_stage -eq 1 ]; then
77
+ CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/vocoder/preprocess.py \
78
+ --config $exp_config \
79
+ --num_workers 8
80
+ fi
81
+
82
+ ######## Training ###########
83
+ if [ $running_stage -eq 2 ]; then
84
+ if [ -z "$exp_name" ]; then
85
+ echo "[Error] Please specify the experiments name"
86
+ exit 1
87
+ fi
88
+ echo "Exprimental Name: $exp_name"
89
+
90
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch \
91
+ --main_process_port "$main_process_port" \
92
+ "${work_dir}"/bins/vocoder/train.py \
93
+ --config "$exp_config" \
94
+ --exp_name "$exp_name" \
95
+ --log_level info \
96
+ --checkpoint "$checkpoint" \
97
+ --resume_type "$resume_type"
98
+ fi
99
+
100
+ ######## Inference/Conversion ###########
101
+ if [ $running_stage -eq 3 ]; then
102
+ if [ -z "$infer_expt_dir" ]; then
103
+ echo "[Error] Please specify the experimental directionary. The value is like [Your path to save logs and checkpoints]/[YourExptName]"
104
+ exit 1
105
+ fi
106
+
107
+ if [ -z "$infer_output_dir" ]; then
108
+ infer_output_dir="$infer_expt_dir/result"
109
+ fi
110
+
111
+ if [ $infer_mode = "infer_from_dataset" ]; then
112
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
113
+ --config $exp_config \
114
+ --infer_mode $infer_mode \
115
+ --infer_datasets $infer_datasets \
116
+ --vocoder_dir $infer_expt_dir \
117
+ --output_dir $infer_output_dir \
118
+ --log_level debug
119
+ fi
120
+
121
+ if [ $infer_mode = "infer_from_feature" ]; then
122
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
123
+ --config $exp_config \
124
+ --infer_mode $infer_mode \
125
+ --feature_folder $infer_feature_dir \
126
+ --vocoder_dir $infer_expt_dir \
127
+ --output_dir $infer_output_dir \
128
+ --log_level debug
129
+ fi
130
+
131
+ if [ $infer_mode = "infer_from_audio" ]; then
132
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
133
+ --config $exp_config \
134
+ --infer_mode $infer_mode \
135
+ --audio_folder $infer_audio_dir \
136
+ --vocoder_dir $infer_expt_dir \
137
+ --output_dir $infer_output_dir \
138
+ --log_level debug
139
+ fi
140
+
141
+ fi
Amphion/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (186 Bytes). View file
 
Amphion/models/base/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .new_trainer import BaseTrainer
7
+ from .new_inference import BaseInference
Amphion/models/base/new_dataset.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ from abc import abstractmethod
9
+ from pathlib import Path
10
+
11
+ import json5
12
+ import torch
13
+ import yaml
14
+
15
+
16
+ # TODO: for training and validating
17
+ class BaseDataset(torch.utils.data.Dataset):
18
+ r"""Base dataset for training and validating."""
19
+
20
+ def __init__(self, args, cfg, is_valid=False):
21
+ pass
22
+
23
+
24
+ class BaseTestDataset(torch.utils.data.Dataset):
25
+ r"""Test dataset for inference."""
26
+
27
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
28
+ assert infer_type in ["from_dataset", "from_file"]
29
+
30
+ self.args = args
31
+ self.cfg = cfg
32
+ self.infer_type = infer_type
33
+
34
+ @abstractmethod
35
+ def __getitem__(self, index):
36
+ pass
37
+
38
+ def __len__(self):
39
+ return len(self.metadata)
40
+
41
+ def get_metadata(self):
42
+ path = Path(self.args.source)
43
+ if path.suffix == ".json" or path.suffix == ".jsonc":
44
+ metadata = json5.load(open(self.args.source, "r"))
45
+ elif path.suffix == ".yaml" or path.suffix == ".yml":
46
+ metadata = yaml.full_load(open(self.args.source, "r"))
47
+ else:
48
+ raise ValueError(f"Unsupported file type: {path.suffix}")
49
+
50
+ return metadata
Amphion/models/base/new_trainer.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ import random
9
+ import shutil
10
+ import time
11
+ from abc import abstractmethod
12
+ from pathlib import Path
13
+
14
+ import accelerate
15
+ import json5
16
+ import numpy as np
17
+ import torch
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import ProjectConfiguration
20
+ from torch.utils.data import ConcatDataset, DataLoader
21
+ from tqdm import tqdm
22
+
23
+ from models.base.base_sampler import build_samplers
24
+ from optimizer.optimizers import NoamLR
25
+
26
+
27
+ class BaseTrainer(object):
28
+ r"""The base trainer for all tasks. Any trainer should inherit from this class."""
29
+
30
+ def __init__(self, args=None, cfg=None):
31
+ super().__init__()
32
+
33
+ self.args = args
34
+ self.cfg = cfg
35
+
36
+ cfg.exp_name = args.exp_name
37
+
38
+ # init with accelerate
39
+ self._init_accelerator()
40
+ self.accelerator.wait_for_everyone()
41
+
42
+ # Use accelerate logger for distributed training
43
+ with self.accelerator.main_process_first():
44
+ self.logger = get_logger(args.exp_name, log_level=args.log_level)
45
+
46
+ # Log some info
47
+ self.logger.info("=" * 56)
48
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
49
+ self.logger.info("=" * 56)
50
+ self.logger.info("\n")
51
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
52
+ self.logger.info(f"Experiment name: {args.exp_name}")
53
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
54
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
55
+ if self.accelerator.is_main_process:
56
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
57
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
58
+
59
+ # init counts
60
+ self.batch_count: int = 0
61
+ self.step: int = 0
62
+ self.epoch: int = 0
63
+ self.max_epoch = (
64
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
65
+ )
66
+ self.logger.info(
67
+ "Max epoch: {}".format(
68
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
69
+ )
70
+ )
71
+
72
+ # Check values
73
+ if self.accelerator.is_main_process:
74
+ self.__check_basic_configs()
75
+ # Set runtime configs
76
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
77
+ self.checkpoints_path = [
78
+ [] for _ in range(len(self.save_checkpoint_stride))
79
+ ]
80
+ self.keep_last = [
81
+ i if i > 0 else float("inf") for i in self.cfg.train.keep_last
82
+ ]
83
+ self.run_eval = self.cfg.train.run_eval
84
+
85
+ # set random seed
86
+ with self.accelerator.main_process_first():
87
+ start = time.monotonic_ns()
88
+ self._set_random_seed(self.cfg.train.random_seed)
89
+ end = time.monotonic_ns()
90
+ self.logger.debug(
91
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
92
+ )
93
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
94
+
95
+ # setup data_loader
96
+ with self.accelerator.main_process_first():
97
+ self.logger.info("Building dataset...")
98
+ start = time.monotonic_ns()
99
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
100
+ end = time.monotonic_ns()
101
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
102
+
103
+ # setup model
104
+ with self.accelerator.main_process_first():
105
+ self.logger.info("Building model...")
106
+ start = time.monotonic_ns()
107
+ self.model = self._build_model()
108
+ end = time.monotonic_ns()
109
+ self.logger.debug(self.model)
110
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
111
+ self.logger.info(
112
+ f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
113
+ )
114
+ # optimizer & scheduler
115
+ with self.accelerator.main_process_first():
116
+ self.logger.info("Building optimizer and scheduler...")
117
+ start = time.monotonic_ns()
118
+ self.optimizer = self._build_optimizer()
119
+ self.scheduler = self._build_scheduler()
120
+ end = time.monotonic_ns()
121
+ self.logger.info(
122
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
123
+ )
124
+
125
+ # accelerate prepare
126
+ self.logger.info("Initializing accelerate...")
127
+ start = time.monotonic_ns()
128
+ self._accelerator_prepare()
129
+ end = time.monotonic_ns()
130
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
131
+
132
+ # create criterion
133
+ with self.accelerator.main_process_first():
134
+ self.logger.info("Building criterion...")
135
+ start = time.monotonic_ns()
136
+ self.criterion = self._build_criterion()
137
+ end = time.monotonic_ns()
138
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
139
+
140
+ # Resume or Finetune
141
+ with self.accelerator.main_process_first():
142
+ if args.resume:
143
+ if args.resume_from_ckpt_path == "":
144
+ ## Automatically resume according to the current exprimental name
145
+ self.logger.info(
146
+ "Automatically resuming from latest checkpoint in {}...".format(
147
+ self.checkpoint_dir
148
+ )
149
+ )
150
+ start = time.monotonic_ns()
151
+ ckpt_path = self._load_model(
152
+ checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type
153
+ )
154
+ end = time.monotonic_ns()
155
+ self.logger.info(
156
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
157
+ )
158
+ self.checkpoints_path = json.load(
159
+ open(os.path.join(ckpt_path, "ckpts.json"), "r")
160
+ )
161
+ else:
162
+ ## Resume from the given checkpoint path
163
+ if not os.path.exists(args.resume_from_ckpt_path):
164
+ raise ValueError(
165
+ "[Error] The resumed checkpoint path {} don't exist.".format(
166
+ args.resume_from_ckpt_path
167
+ )
168
+ )
169
+ self.logger.info(
170
+ "Resuming from {}...".format(args.resume_from_ckpt_path)
171
+ )
172
+ start = time.monotonic_ns()
173
+ ckpt_path = self._load_model(
174
+ checkpoint_path=args.resume_from_ckpt_path,
175
+ resume_type=args.resume_type,
176
+ )
177
+ end = time.monotonic_ns()
178
+ self.logger.info(
179
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
180
+ )
181
+
182
+ # save config file path
183
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
184
+
185
+ def _accelerator_prepare(self):
186
+ (
187
+ self.train_dataloader,
188
+ self.valid_dataloader,
189
+ self.model,
190
+ self.optimizer,
191
+ self.scheduler,
192
+ ) = self.accelerator.prepare(
193
+ self.train_dataloader,
194
+ self.valid_dataloader,
195
+ self.model,
196
+ self.optimizer,
197
+ self.scheduler,
198
+ )
199
+
200
+ ### Following are abstract methods that should be implemented in child classes ###
201
+ @abstractmethod
202
+ def _build_dataset(self):
203
+ r"""Build dataset for model training/validating/evaluating."""
204
+ pass
205
+
206
+ @staticmethod
207
+ @abstractmethod
208
+ def _build_criterion():
209
+ r"""Build criterion function for model loss calculation."""
210
+ pass
211
+
212
+ @abstractmethod
213
+ def _build_model(self):
214
+ r"""Build model for training/validating/evaluating."""
215
+ pass
216
+
217
+ @abstractmethod
218
+ def _forward_step(self, batch):
219
+ r"""One forward step of the neural network. This abstract method is trying to
220
+ unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation.
221
+ However, for special case that using different forward step pattern for
222
+ training and validating, you could just override this method with ``pass`` and
223
+ implement ``_train_step`` and ``_valid_step`` separately.
224
+ """
225
+ pass
226
+
227
+ @abstractmethod
228
+ def _save_auxiliary_states(self):
229
+ r"""To save some auxiliary states when saving model's ckpt"""
230
+ pass
231
+
232
+ ### Abstract methods end ###
233
+
234
+ ### THIS IS MAIN ENTRY ###
235
+ def train_loop(self):
236
+ r"""Training loop. The public entry of training process."""
237
+ # Wait everyone to prepare before we move on
238
+ self.accelerator.wait_for_everyone()
239
+ # dump config file
240
+ if self.accelerator.is_main_process:
241
+ self.__dump_cfg(self.config_save_path)
242
+ self.model.train()
243
+ self.optimizer.zero_grad()
244
+ # Wait to ensure good to go
245
+ self.accelerator.wait_for_everyone()
246
+ while self.epoch < self.max_epoch:
247
+ self.logger.info("\n")
248
+ self.logger.info("-" * 32)
249
+ self.logger.info("Epoch {}: ".format(self.epoch))
250
+
251
+ ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
252
+ ### It's inconvenient for the model with multiple losses
253
+ # Do training & validating epoch
254
+ train_loss = self._train_epoch()
255
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
256
+ valid_loss = self._valid_epoch()
257
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
258
+ self.accelerator.log(
259
+ {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
260
+ step=self.epoch,
261
+ )
262
+
263
+ self.accelerator.wait_for_everyone()
264
+ # TODO: what is scheduler?
265
+ self.scheduler.step(valid_loss) # FIXME: use epoch track correct?
266
+
267
+ # Check if hit save_checkpoint_stride and run_eval
268
+ run_eval = False
269
+ if self.accelerator.is_main_process:
270
+ save_checkpoint = False
271
+ hit_dix = []
272
+ for i, num in enumerate(self.save_checkpoint_stride):
273
+ if self.epoch % num == 0:
274
+ save_checkpoint = True
275
+ hit_dix.append(i)
276
+ run_eval |= self.run_eval[i]
277
+
278
+ self.accelerator.wait_for_everyone()
279
+ if self.accelerator.is_main_process and save_checkpoint:
280
+ path = os.path.join(
281
+ self.checkpoint_dir,
282
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
283
+ self.epoch, self.step, train_loss
284
+ ),
285
+ )
286
+ self.tmp_checkpoint_save_path = path
287
+ self.accelerator.save_state(path)
288
+ print(f"save checkpoint in {path}")
289
+ json.dump(
290
+ self.checkpoints_path,
291
+ open(os.path.join(path, "ckpts.json"), "w"),
292
+ ensure_ascii=False,
293
+ indent=4,
294
+ )
295
+ self._save_auxiliary_states()
296
+
297
+ # Remove old checkpoints
298
+ to_remove = []
299
+ for idx in hit_dix:
300
+ self.checkpoints_path[idx].append(path)
301
+ while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
302
+ to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
303
+
304
+ # Search conflicts
305
+ total = set()
306
+ for i in self.checkpoints_path:
307
+ total |= set(i)
308
+ do_remove = set()
309
+ for idx, path in to_remove[::-1]:
310
+ if path in total:
311
+ self.checkpoints_path[idx].insert(0, path)
312
+ else:
313
+ do_remove.add(path)
314
+
315
+ # Remove old checkpoints
316
+ for path in do_remove:
317
+ shutil.rmtree(path, ignore_errors=True)
318
+ self.logger.debug(f"Remove old checkpoint: {path}")
319
+
320
+ self.accelerator.wait_for_everyone()
321
+ if run_eval:
322
+ # TODO: run evaluation
323
+ pass
324
+
325
+ # Update info for each epoch
326
+ self.epoch += 1
327
+
328
+ # Finish training and save final checkpoint
329
+ self.accelerator.wait_for_everyone()
330
+ if self.accelerator.is_main_process:
331
+ self.accelerator.save_state(
332
+ os.path.join(
333
+ self.checkpoint_dir,
334
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
335
+ self.epoch, self.step, valid_loss
336
+ ),
337
+ )
338
+ )
339
+ self._save_auxiliary_states()
340
+
341
+ self.accelerator.end_training()
342
+
343
+ ### Following are methods that can be used directly in child classes ###
344
+ def _train_epoch(self):
345
+ r"""Training epoch. Should return average loss of a batch (sample) over
346
+ one epoch. See ``train_loop`` for usage.
347
+ """
348
+ self.model.train()
349
+ epoch_sum_loss: float = 0.0
350
+ epoch_step: int = 0
351
+ for batch in tqdm(
352
+ self.train_dataloader,
353
+ desc=f"Training Epoch {self.epoch}",
354
+ unit="batch",
355
+ colour="GREEN",
356
+ leave=False,
357
+ dynamic_ncols=True,
358
+ smoothing=0.04,
359
+ disable=not self.accelerator.is_main_process,
360
+ ):
361
+ # Do training step and BP
362
+ with self.accelerator.accumulate(self.model):
363
+ loss = self._train_step(batch)
364
+ self.accelerator.backward(loss)
365
+ self.optimizer.step()
366
+ self.optimizer.zero_grad()
367
+ self.batch_count += 1
368
+
369
+ # Update info for each step
370
+ # TODO: step means BP counts or batch counts?
371
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
372
+ epoch_sum_loss += loss
373
+ self.accelerator.log(
374
+ {
375
+ "Step/Train Loss": loss,
376
+ "Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
377
+ },
378
+ step=self.step,
379
+ )
380
+ self.step += 1
381
+ epoch_step += 1
382
+
383
+ self.accelerator.wait_for_everyone()
384
+ return (
385
+ epoch_sum_loss
386
+ / len(self.train_dataloader)
387
+ * self.cfg.train.gradient_accumulation_step
388
+ )
389
+
390
+ @torch.inference_mode()
391
+ def _valid_epoch(self):
392
+ r"""Testing epoch. Should return average loss of a batch (sample) over
393
+ one epoch. See ``train_loop`` for usage.
394
+ """
395
+ self.model.eval()
396
+ epoch_sum_loss = 0.0
397
+ for batch in tqdm(
398
+ self.valid_dataloader,
399
+ desc=f"Validating Epoch {self.epoch}",
400
+ unit="batch",
401
+ colour="GREEN",
402
+ leave=False,
403
+ dynamic_ncols=True,
404
+ smoothing=0.04,
405
+ disable=not self.accelerator.is_main_process,
406
+ ):
407
+ batch_loss = self._valid_step(batch)
408
+ epoch_sum_loss += batch_loss.item()
409
+
410
+ self.accelerator.wait_for_everyone()
411
+ return epoch_sum_loss / len(self.valid_dataloader)
412
+
413
+ def _train_step(self, batch):
414
+ r"""Training forward step. Should return average loss of a sample over
415
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
416
+ See ``_train_epoch`` for usage.
417
+ """
418
+ return self._forward_step(batch)
419
+
420
+ @torch.inference_mode()
421
+ def _valid_step(self, batch):
422
+ r"""Testing forward step. Should return average loss of a sample over
423
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
424
+ See ``_test_epoch`` for usage.
425
+ """
426
+ return self._forward_step(batch)
427
+
428
+ def _load_model(
429
+ self,
430
+ checkpoint_dir: str = None,
431
+ checkpoint_path: str = None,
432
+ resume_type: str = "",
433
+ ):
434
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
435
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
436
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
437
+ method after** ``accelerator.prepare()``.
438
+ """
439
+ if checkpoint_path is None:
440
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
441
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
442
+ checkpoint_path = ls[0]
443
+ self.logger.info("Resume from {}...".format(checkpoint_path))
444
+
445
+ if resume_type in ["resume", ""]:
446
+ # Load all the things, including model weights, optimizer, scheduler, and random states.
447
+ self.accelerator.load_state(input_dir=checkpoint_path)
448
+
449
+ # set epoch and step
450
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
451
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
452
+
453
+ elif resume_type == "finetune":
454
+ # Load only the model weights
455
+ accelerate.load_checkpoint_and_dispatch(
456
+ self.accelerator.unwrap_model(self.model),
457
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
458
+ )
459
+ self.logger.info("Load model weights for finetune...")
460
+
461
+ else:
462
+ raise ValueError("Resume_type must be `resume` or `finetune`.")
463
+
464
+ return checkpoint_path
465
+
466
+ def _build_dataloader(self):
467
+ Dataset, Collator = self._build_dataset()
468
+
469
+ # build dataset instance for each dataset and combine them by ConcatDataset
470
+ datasets_list = []
471
+ for dataset in self.cfg.dataset:
472
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
473
+ datasets_list.append(subdataset)
474
+ train_dataset = ConcatDataset(datasets_list)
475
+ train_collate = Collator(self.cfg)
476
+ _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
477
+ self.logger.debug(f"train batch_sampler: {list(batch_sampler)}")
478
+ self.logger.debug(f"length: {train_dataset.cumulative_sizes}")
479
+ # TODO: use config instead of (sampler, shuffle, drop_last, batch_size)
480
+ train_loader = DataLoader(
481
+ train_dataset,
482
+ # shuffle=True,
483
+ collate_fn=train_collate,
484
+ batch_sampler=batch_sampler,
485
+ num_workers=self.cfg.train.dataloader.num_worker,
486
+ pin_memory=self.cfg.train.dataloader.pin_memory,
487
+ )
488
+
489
+ # Build valid dataloader
490
+ datasets_list = []
491
+ for dataset in self.cfg.dataset:
492
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
493
+ datasets_list.append(subdataset)
494
+ valid_dataset = ConcatDataset(datasets_list)
495
+ valid_collate = Collator(self.cfg)
496
+ _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
497
+ self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}")
498
+ self.logger.debug(f"length: {valid_dataset.cumulative_sizes}")
499
+ valid_loader = DataLoader(
500
+ valid_dataset,
501
+ collate_fn=valid_collate,
502
+ batch_sampler=batch_sampler,
503
+ num_workers=self.cfg.train.dataloader.num_worker,
504
+ pin_memory=self.cfg.train.dataloader.pin_memory,
505
+ )
506
+ return train_loader, valid_loader
507
+
508
+ @staticmethod
509
+ def _set_random_seed(seed):
510
+ r"""Set random seed for all possible random modules."""
511
+ random.seed(seed)
512
+ np.random.seed(seed)
513
+ torch.random.manual_seed(seed)
514
+
515
+ def _check_nan(self, loss, y_pred, y_gt):
516
+ if torch.any(torch.isnan(loss)):
517
+ self.logger.error("Fatal Error: Training is down since loss has Nan!")
518
+ self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
519
+
520
+ ### y_pred ###
521
+ if torch.any(torch.isnan(y_pred)):
522
+ self.logger.error(
523
+ f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
524
+ )
525
+ self.logger.error(f"y_pred: {y_pred}", in_order=True)
526
+ else:
527
+ self.logger.debug(
528
+ f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
529
+ )
530
+ self.logger.debug(f"y_pred: {y_pred}", in_order=True)
531
+
532
+ ### y_gt ###
533
+ if torch.any(torch.isnan(y_gt)):
534
+ self.logger.error(
535
+ f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True
536
+ )
537
+ self.logger.error(f"y_gt: {y_gt}", in_order=True)
538
+ else:
539
+ self.logger.debug(
540
+ f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True
541
+ )
542
+ self.logger.debug(f"y_gt: {y_gt}", in_order=True)
543
+
544
+ self.accelerator.end_training()
545
+ raise RuntimeError("Loss has Nan! See log for more info.")
546
+
547
+ ### Protected methods end ###
548
+
549
+ ## Following are private methods ##
550
+ def _build_optimizer(self):
551
+ r"""Build optimizer for model."""
552
+ # Make case-insensitive matching
553
+ if self.cfg.train.optimizer.lower() == "adadelta":
554
+ optimizer = torch.optim.Adadelta(
555
+ self.model.parameters(), **self.cfg.train.adadelta
556
+ )
557
+ self.logger.info("Using Adadelta optimizer.")
558
+ elif self.cfg.train.optimizer.lower() == "adagrad":
559
+ optimizer = torch.optim.Adagrad(
560
+ self.model.parameters(), **self.cfg.train.adagrad
561
+ )
562
+ self.logger.info("Using Adagrad optimizer.")
563
+ elif self.cfg.train.optimizer.lower() == "adam":
564
+ optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
565
+ self.logger.info("Using Adam optimizer.")
566
+ elif self.cfg.train.optimizer.lower() == "adamw":
567
+ optimizer = torch.optim.AdamW(
568
+ self.model.parameters(), **self.cfg.train.adamw
569
+ )
570
+ elif self.cfg.train.optimizer.lower() == "sparseadam":
571
+ optimizer = torch.optim.SparseAdam(
572
+ self.model.parameters(), **self.cfg.train.sparseadam
573
+ )
574
+ elif self.cfg.train.optimizer.lower() == "adamax":
575
+ optimizer = torch.optim.Adamax(
576
+ self.model.parameters(), **self.cfg.train.adamax
577
+ )
578
+ elif self.cfg.train.optimizer.lower() == "asgd":
579
+ optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd)
580
+ elif self.cfg.train.optimizer.lower() == "lbfgs":
581
+ optimizer = torch.optim.LBFGS(
582
+ self.model.parameters(), **self.cfg.train.lbfgs
583
+ )
584
+ elif self.cfg.train.optimizer.lower() == "nadam":
585
+ optimizer = torch.optim.NAdam(
586
+ self.model.parameters(), **self.cfg.train.nadam
587
+ )
588
+ elif self.cfg.train.optimizer.lower() == "radam":
589
+ optimizer = torch.optim.RAdam(
590
+ self.model.parameters(), **self.cfg.train.radam
591
+ )
592
+ elif self.cfg.train.optimizer.lower() == "rmsprop":
593
+ optimizer = torch.optim.RMSprop(
594
+ self.model.parameters(), **self.cfg.train.rmsprop
595
+ )
596
+ elif self.cfg.train.optimizer.lower() == "rprop":
597
+ optimizer = torch.optim.Rprop(
598
+ self.model.parameters(), **self.cfg.train.rprop
599
+ )
600
+ elif self.cfg.train.optimizer.lower() == "sgd":
601
+ optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd)
602
+ else:
603
+ raise NotImplementedError(
604
+ f"Optimizer {self.cfg.train.optimizer} not supported yet!"
605
+ )
606
+ return optimizer
607
+
608
+ def _build_scheduler(self):
609
+ r"""Build scheduler for optimizer."""
610
+ # Make case-insensitive matching
611
+ if self.cfg.train.scheduler.lower() == "lambdalr":
612
+ scheduler = torch.optim.lr_scheduler.LambdaLR(
613
+ self.optimizer, **self.cfg.train.lambdalr
614
+ )
615
+ elif self.cfg.train.scheduler.lower() == "multiplicativelr":
616
+ scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
617
+ self.optimizer, **self.cfg.train.multiplicativelr
618
+ )
619
+ elif self.cfg.train.scheduler.lower() == "steplr":
620
+ scheduler = torch.optim.lr_scheduler.StepLR(
621
+ self.optimizer, **self.cfg.train.steplr
622
+ )
623
+ elif self.cfg.train.scheduler.lower() == "multisteplr":
624
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(
625
+ self.optimizer, **self.cfg.train.multisteplr
626
+ )
627
+ elif self.cfg.train.scheduler.lower() == "constantlr":
628
+ scheduler = torch.optim.lr_scheduler.ConstantLR(
629
+ self.optimizer, **self.cfg.train.constantlr
630
+ )
631
+ elif self.cfg.train.scheduler.lower() == "linearlr":
632
+ scheduler = torch.optim.lr_scheduler.LinearLR(
633
+ self.optimizer, **self.cfg.train.linearlr
634
+ )
635
+ elif self.cfg.train.scheduler.lower() == "exponentiallr":
636
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(
637
+ self.optimizer, **self.cfg.train.exponentiallr
638
+ )
639
+ elif self.cfg.train.scheduler.lower() == "polynomiallr":
640
+ scheduler = torch.optim.lr_scheduler.PolynomialLR(
641
+ self.optimizer, **self.cfg.train.polynomiallr
642
+ )
643
+ elif self.cfg.train.scheduler.lower() == "cosineannealinglr":
644
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
645
+ self.optimizer, **self.cfg.train.cosineannealinglr
646
+ )
647
+ elif self.cfg.train.scheduler.lower() == "sequentiallr":
648
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
649
+ self.optimizer, **self.cfg.train.sequentiallr
650
+ )
651
+ elif self.cfg.train.scheduler.lower() == "reducelronplateau":
652
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
653
+ self.optimizer, **self.cfg.train.reducelronplateau
654
+ )
655
+ elif self.cfg.train.scheduler.lower() == "cycliclr":
656
+ scheduler = torch.optim.lr_scheduler.CyclicLR(
657
+ self.optimizer, **self.cfg.train.cycliclr
658
+ )
659
+ elif self.cfg.train.scheduler.lower() == "onecyclelr":
660
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
661
+ self.optimizer, **self.cfg.train.onecyclelr
662
+ )
663
+ elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts":
664
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
665
+ self.optimizer, **self.cfg.train.cosineannearingwarmrestarts
666
+ )
667
+ elif self.cfg.train.scheduler.lower() == "noamlr":
668
+ scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
669
+ else:
670
+ raise NotImplementedError(
671
+ f"Scheduler {self.cfg.train.scheduler} not supported yet!"
672
+ )
673
+ return scheduler
674
+
675
+ def _init_accelerator(self):
676
+ self.exp_dir = os.path.join(
677
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
678
+ )
679
+ project_config = ProjectConfiguration(
680
+ project_dir=self.exp_dir,
681
+ logging_dir=os.path.join(self.exp_dir, "log"),
682
+ )
683
+ self.accelerator = accelerate.Accelerator(
684
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
685
+ log_with=self.cfg.train.tracker,
686
+ project_config=project_config,
687
+ )
688
+ if self.accelerator.is_main_process:
689
+ os.makedirs(project_config.project_dir, exist_ok=True)
690
+ os.makedirs(project_config.logging_dir, exist_ok=True)
691
+ with self.accelerator.main_process_first():
692
+ self.accelerator.init_trackers(self.args.exp_name)
693
+
694
+ def __check_basic_configs(self):
695
+ if self.cfg.train.gradient_accumulation_step <= 0:
696
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
697
+ self.logger.error(
698
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
699
+ )
700
+ self.accelerator.end_training()
701
+ raise ValueError(
702
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
703
+ )
704
+ # TODO: check other values
705
+
706
+ @staticmethod
707
+ def __count_parameters(model):
708
+ model_param = 0.0
709
+ if isinstance(model, dict):
710
+ for key, value in model.items():
711
+ model_param += sum(p.numel() for p in model[key].parameters())
712
+ else:
713
+ model_param = sum(p.numel() for p in model.parameters())
714
+ return model_param
715
+
716
+ def __dump_cfg(self, path):
717
+ os.makedirs(os.path.dirname(path), exist_ok=True)
718
+ json5.dump(
719
+ self.cfg,
720
+ open(path, "w"),
721
+ indent=4,
722
+ sort_keys=True,
723
+ ensure_ascii=False,
724
+ quote_keys=True,
725
+ )
726
+
727
+ ### Private methods end ###
Amphion/models/codec/ns3_codec/__pycache__/facodec.cpython-310.pyc ADDED
Binary file (24.2 kB). View file
 
Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (277 Bytes). View file
 
Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/resample.cpython-310.pyc ADDED
Binary file (1.97 kB). View file
 
Amphion/models/codec/ns3_codec/quantize/__pycache__/rvq.cpython-310.pyc ADDED
Binary file (2.51 kB). View file
 
Amphion/models/codec/ns3_codec/quantize/rvq.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import torch
8
+ from torch import nn
9
+ from .fvq import FactorizedVectorQuantize
10
+
11
+
12
+ class ResidualVQ(nn.Module):
13
+ """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
14
+
15
+ def __init__(self, *, num_quantizers, codebook_size, **kwargs):
16
+ super().__init__()
17
+ VQ = FactorizedVectorQuantize
18
+ if type(codebook_size) == int:
19
+ codebook_size = [codebook_size] * num_quantizers
20
+ self.layers = nn.ModuleList(
21
+ [VQ(codebook_size=2**size, **kwargs) for size in codebook_size]
22
+ )
23
+ self.num_quantizers = num_quantizers
24
+ self.quantizer_dropout = kwargs.get("quantizer_dropout", 0.0)
25
+ self.dropout_type = kwargs.get("dropout_type", None)
26
+
27
+ def forward(self, x, n_quantizers=None):
28
+ quantized_out = 0.0
29
+ residual = x
30
+
31
+ all_losses = []
32
+ all_indices = []
33
+ all_quantized = []
34
+
35
+ if n_quantizers is None:
36
+ n_quantizers = self.num_quantizers
37
+ if self.training:
38
+ n_quantizers = torch.ones((x.shape[0],)) * self.num_quantizers + 1
39
+ if self.dropout_type == "linear":
40
+ dropout = torch.randint(1, self.num_quantizers + 1, (x.shape[0],))
41
+ elif self.dropout_type == "exp":
42
+ dropout = torch.randint(
43
+ 1, int(math.log2(self.num_quantizers)), (x.shape[0],)
44
+ )
45
+ dropout = torch.pow(2, dropout)
46
+ n_dropout = int(x.shape[0] * self.quantizer_dropout)
47
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
48
+ n_quantizers = n_quantizers.to(x.device)
49
+
50
+ for idx, layer in enumerate(self.layers):
51
+ if not self.training and idx >= n_quantizers:
52
+ break
53
+ quantized, indices, loss = layer(residual)
54
+
55
+ mask = (
56
+ torch.full((x.shape[0],), fill_value=idx, device=x.device)
57
+ < n_quantizers
58
+ )
59
+
60
+ residual = residual - quantized
61
+
62
+ quantized_out = quantized_out + quantized * mask[:, None, None]
63
+
64
+ # loss
65
+ loss = (loss * mask).mean()
66
+
67
+ all_indices.append(indices)
68
+ all_losses.append(loss)
69
+ all_quantized.append(quantized)
70
+ all_losses, all_indices, all_quantized = map(
71
+ torch.stack, (all_losses, all_indices, all_quantized)
72
+ )
73
+ return quantized_out, all_indices, all_losses, all_quantized
74
+
75
+ def vq2emb(self, vq):
76
+ # vq: [n_quantizers, B, T]
77
+ quantized_out = 0.0
78
+ for idx, layer in enumerate(self.layers):
79
+ quantized = layer.vq2emb(vq[idx])
80
+ quantized_out += quantized
81
+ return quantized_out
82
+
83
+ def get_emb(self):
84
+ embs = []
85
+ for idx, layer in enumerate(self.layers):
86
+ embs.append(layer.get_emb())
87
+ return embs
Amphion/models/svc/transformer/transformer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
10
+
11
+
12
+ class Transformer(nn.Module):
13
+ def __init__(self, cfg):
14
+ super().__init__()
15
+ self.cfg = cfg
16
+
17
+ dropout = self.cfg.dropout
18
+ nhead = self.cfg.n_heads
19
+ nlayers = self.cfg.n_layers
20
+ input_dim = self.cfg.input_dim
21
+ output_dim = self.cfg.output_dim
22
+
23
+ d_model = input_dim
24
+ self.pos_encoder = PositionalEncoding(d_model, dropout)
25
+ encoder_layers = TransformerEncoderLayer(
26
+ d_model, nhead, dropout=dropout, batch_first=True
27
+ )
28
+ self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
29
+
30
+ self.output_mlp = nn.Linear(d_model, output_dim)
31
+
32
+ def forward(self, x, mask=None):
33
+ """
34
+ Args:
35
+ x: (N, seq_len, input_dim)
36
+ Returns:
37
+ output: (N, seq_len, output_dim)
38
+ """
39
+ # (N, seq_len, d_model)
40
+ src = self.pos_encoder(x)
41
+ # model_stats["pos_embedding"] = x
42
+ # (N, seq_len, d_model)
43
+ output = self.transformer_encoder(src)
44
+ # (N, seq_len, output_dim)
45
+ output = self.output_mlp(output)
46
+ return output
47
+
48
+
49
+ class PositionalEncoding(nn.Module):
50
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
51
+ super().__init__()
52
+ self.dropout = nn.Dropout(p=dropout)
53
+
54
+ position = torch.arange(max_len).unsqueeze(1)
55
+ div_term = torch.exp(
56
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
57
+ )
58
+
59
+ # Assume that x is (seq_len, N, d)
60
+ # pe = torch.zeros(max_len, 1, d_model)
61
+ # pe[:, 0, 0::2] = torch.sin(position * div_term)
62
+ # pe[:, 0, 1::2] = torch.cos(position * div_term)
63
+
64
+ # Assume that x in (N, seq_len, d)
65
+ pe = torch.zeros(1, max_len, d_model)
66
+ pe[0, :, 0::2] = torch.sin(position * div_term)
67
+ pe[0, :, 1::2] = torch.cos(position * div_term)
68
+
69
+ self.register_buffer("pe", pe)
70
+
71
+ def forward(self, x):
72
+ """
73
+ Args:
74
+ x: Tensor, shape [N, seq_len, d]
75
+ """
76
+ # Old: Assume that x is (seq_len, N, d), and self.pe is (max_len, 1, d_model)
77
+ # x = x + self.pe[: x.size(0)]
78
+
79
+ # Now: self.pe is (1, max_len, d)
80
+ x = x + self.pe[:, : x.size(1), :]
81
+
82
+ return self.dropout(x)
Amphion/models/tts/fastspeech2/fs2.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is modified from https://github.com/ming024/FastSpeech2/blob/master/model/fastspeech2.py
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+
12
+ from modules.transformer.Models import Encoder, Decoder
13
+ from modules.transformer.Layers import PostNet
14
+ from collections import OrderedDict
15
+
16
+ import os
17
+ import json
18
+
19
+
20
+ def get_mask_from_lengths(lengths, max_len=None):
21
+ device = lengths.device
22
+ batch_size = lengths.shape[0]
23
+ if max_len is None:
24
+ max_len = torch.max(lengths).item()
25
+
26
+ ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
27
+ mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
28
+
29
+ return mask
30
+
31
+
32
+ def pad(input_ele, mel_max_length=None):
33
+ if mel_max_length:
34
+ max_len = mel_max_length
35
+ else:
36
+ max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
37
+
38
+ out_list = list()
39
+ for i, batch in enumerate(input_ele):
40
+ if len(batch.shape) == 1:
41
+ one_batch_padded = F.pad(
42
+ batch, (0, max_len - batch.size(0)), "constant", 0.0
43
+ )
44
+ elif len(batch.shape) == 2:
45
+ one_batch_padded = F.pad(
46
+ batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
47
+ )
48
+ out_list.append(one_batch_padded)
49
+ out_padded = torch.stack(out_list)
50
+ return out_padded
51
+
52
+
53
+ class VarianceAdaptor(nn.Module):
54
+ """Variance Adaptor"""
55
+
56
+ def __init__(self, cfg):
57
+ super(VarianceAdaptor, self).__init__()
58
+ self.duration_predictor = VariancePredictor(cfg)
59
+ self.length_regulator = LengthRegulator()
60
+ self.pitch_predictor = VariancePredictor(cfg)
61
+ self.energy_predictor = VariancePredictor(cfg)
62
+
63
+ # assign the pitch/energy feature level
64
+ if cfg.preprocess.use_frame_pitch:
65
+ self.pitch_feature_level = "frame_level"
66
+ self.pitch_dir = cfg.preprocess.pitch_dir
67
+ else:
68
+ self.pitch_feature_level = "phoneme_level"
69
+ self.pitch_dir = cfg.preprocess.phone_pitch_dir
70
+
71
+ if cfg.preprocess.use_frame_energy:
72
+ self.energy_feature_level = "frame_level"
73
+ self.energy_dir = cfg.preprocess.energy_dir
74
+ else:
75
+ self.energy_feature_level = "phoneme_level"
76
+ self.energy_dir = cfg.preprocess.phone_energy_dir
77
+
78
+ assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
79
+ assert self.energy_feature_level in ["phoneme_level", "frame_level"]
80
+
81
+ pitch_quantization = cfg.model.variance_embedding.pitch_quantization
82
+ energy_quantization = cfg.model.variance_embedding.energy_quantization
83
+ n_bins = cfg.model.variance_embedding.n_bins
84
+ assert pitch_quantization in ["linear", "log"]
85
+ assert energy_quantization in ["linear", "log"]
86
+
87
+ with open(
88
+ os.path.join(
89
+ cfg.preprocess.processed_dir,
90
+ cfg.dataset[0],
91
+ self.energy_dir,
92
+ "statistics.json",
93
+ )
94
+ ) as f:
95
+ stats = json.load(f)
96
+ stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]]
97
+ mean, std = (
98
+ stats["voiced_positions"]["mean"],
99
+ stats["voiced_positions"]["std"],
100
+ )
101
+ energy_min = (stats["total_positions"]["min"] - mean) / std
102
+ energy_max = (stats["total_positions"]["max"] - mean) / std
103
+
104
+ with open(
105
+ os.path.join(
106
+ cfg.preprocess.processed_dir,
107
+ cfg.dataset[0],
108
+ self.pitch_dir,
109
+ "statistics.json",
110
+ )
111
+ ) as f:
112
+ stats = json.load(f)
113
+ stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]]
114
+ mean, std = (
115
+ stats["voiced_positions"]["mean"],
116
+ stats["voiced_positions"]["std"],
117
+ )
118
+ pitch_min = (stats["total_positions"]["min"] - mean) / std
119
+ pitch_max = (stats["total_positions"]["max"] - mean) / std
120
+
121
+ if pitch_quantization == "log":
122
+ self.pitch_bins = nn.Parameter(
123
+ torch.exp(
124
+ torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1)
125
+ ),
126
+ requires_grad=False,
127
+ )
128
+ else:
129
+ self.pitch_bins = nn.Parameter(
130
+ torch.linspace(pitch_min, pitch_max, n_bins - 1),
131
+ requires_grad=False,
132
+ )
133
+ if energy_quantization == "log":
134
+ self.energy_bins = nn.Parameter(
135
+ torch.exp(
136
+ torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1)
137
+ ),
138
+ requires_grad=False,
139
+ )
140
+ else:
141
+ self.energy_bins = nn.Parameter(
142
+ torch.linspace(energy_min, energy_max, n_bins - 1),
143
+ requires_grad=False,
144
+ )
145
+
146
+ self.pitch_embedding = nn.Embedding(
147
+ n_bins, cfg.model.transformer.encoder_hidden
148
+ )
149
+ self.energy_embedding = nn.Embedding(
150
+ n_bins, cfg.model.transformer.encoder_hidden
151
+ )
152
+
153
+ def get_pitch_embedding(self, x, target, mask, control):
154
+ prediction = self.pitch_predictor(x, mask)
155
+ if target is not None:
156
+ embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins))
157
+ else:
158
+ prediction = prediction * control
159
+ embedding = self.pitch_embedding(
160
+ torch.bucketize(prediction, self.pitch_bins)
161
+ )
162
+ return prediction, embedding
163
+
164
+ def get_energy_embedding(self, x, target, mask, control):
165
+ prediction = self.energy_predictor(x, mask)
166
+ if target is not None:
167
+ embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins))
168
+ else:
169
+ prediction = prediction * control
170
+ embedding = self.energy_embedding(
171
+ torch.bucketize(prediction, self.energy_bins)
172
+ )
173
+ return prediction, embedding
174
+
175
+ def forward(
176
+ self,
177
+ x,
178
+ src_mask,
179
+ mel_mask=None,
180
+ max_len=None,
181
+ pitch_target=None,
182
+ energy_target=None,
183
+ duration_target=None,
184
+ p_control=1.0,
185
+ e_control=1.0,
186
+ d_control=1.0,
187
+ ):
188
+ log_duration_prediction = self.duration_predictor(x, src_mask)
189
+ if self.pitch_feature_level == "phoneme_level":
190
+ pitch_prediction, pitch_embedding = self.get_pitch_embedding(
191
+ x, pitch_target, src_mask, p_control
192
+ )
193
+ x = x + pitch_embedding
194
+ if self.energy_feature_level == "phoneme_level":
195
+ energy_prediction, energy_embedding = self.get_energy_embedding(
196
+ x, energy_target, src_mask, e_control
197
+ )
198
+ x = x + energy_embedding
199
+
200
+ if duration_target is not None:
201
+ x, mel_len = self.length_regulator(x, duration_target, max_len)
202
+ duration_rounded = duration_target
203
+ else:
204
+ duration_rounded = torch.clamp(
205
+ (torch.round(torch.exp(log_duration_prediction) - 1) * d_control),
206
+ min=0,
207
+ )
208
+ x, mel_len = self.length_regulator(x, duration_rounded, max_len)
209
+ mel_mask = get_mask_from_lengths(mel_len)
210
+
211
+ if self.pitch_feature_level == "frame_level":
212
+ pitch_prediction, pitch_embedding = self.get_pitch_embedding(
213
+ x, pitch_target, mel_mask, p_control
214
+ )
215
+ x = x + pitch_embedding
216
+ if self.energy_feature_level == "frame_level":
217
+ energy_prediction, energy_embedding = self.get_energy_embedding(
218
+ x, energy_target, mel_mask, p_control
219
+ )
220
+ x = x + energy_embedding
221
+
222
+ return (
223
+ x,
224
+ pitch_prediction,
225
+ energy_prediction,
226
+ log_duration_prediction,
227
+ duration_rounded,
228
+ mel_len,
229
+ mel_mask,
230
+ )
231
+
232
+
233
+ class LengthRegulator(nn.Module):
234
+ """Length Regulator"""
235
+
236
+ def __init__(self):
237
+ super(LengthRegulator, self).__init__()
238
+
239
+ def LR(self, x, duration, max_len):
240
+ device = x.device
241
+ output = list()
242
+ mel_len = list()
243
+ for batch, expand_target in zip(x, duration):
244
+ expanded = self.expand(batch, expand_target)
245
+ output.append(expanded)
246
+ mel_len.append(expanded.shape[0])
247
+
248
+ if max_len is not None:
249
+ output = pad(output, max_len)
250
+ else:
251
+ output = pad(output)
252
+
253
+ return output, torch.LongTensor(mel_len).to(device)
254
+
255
+ def expand(self, batch, predicted):
256
+ out = list()
257
+
258
+ for i, vec in enumerate(batch):
259
+ expand_size = predicted[i].item()
260
+ out.append(vec.expand(max(int(expand_size), 0), -1))
261
+ out = torch.cat(out, 0)
262
+
263
+ return out
264
+
265
+ def forward(self, x, duration, max_len):
266
+ output, mel_len = self.LR(x, duration, max_len)
267
+ return output, mel_len
268
+
269
+
270
+ class VariancePredictor(nn.Module):
271
+ """Duration, Pitch and Energy Predictor"""
272
+
273
+ def __init__(self, cfg):
274
+ super(VariancePredictor, self).__init__()
275
+
276
+ self.input_size = cfg.model.transformer.encoder_hidden
277
+ self.filter_size = cfg.model.variance_predictor.filter_size
278
+ self.kernel = cfg.model.variance_predictor.kernel_size
279
+ self.conv_output_size = cfg.model.variance_predictor.filter_size
280
+ self.dropout = cfg.model.variance_predictor.dropout
281
+
282
+ self.conv_layer = nn.Sequential(
283
+ OrderedDict(
284
+ [
285
+ (
286
+ "conv1d_1",
287
+ Conv(
288
+ self.input_size,
289
+ self.filter_size,
290
+ kernel_size=self.kernel,
291
+ padding=(self.kernel - 1) // 2,
292
+ ),
293
+ ),
294
+ ("relu_1", nn.ReLU()),
295
+ ("layer_norm_1", nn.LayerNorm(self.filter_size)),
296
+ ("dropout_1", nn.Dropout(self.dropout)),
297
+ (
298
+ "conv1d_2",
299
+ Conv(
300
+ self.filter_size,
301
+ self.filter_size,
302
+ kernel_size=self.kernel,
303
+ padding=1,
304
+ ),
305
+ ),
306
+ ("relu_2", nn.ReLU()),
307
+ ("layer_norm_2", nn.LayerNorm(self.filter_size)),
308
+ ("dropout_2", nn.Dropout(self.dropout)),
309
+ ]
310
+ )
311
+ )
312
+
313
+ self.linear_layer = nn.Linear(self.conv_output_size, 1)
314
+
315
+ def forward(self, encoder_output, mask):
316
+ out = self.conv_layer(encoder_output)
317
+ out = self.linear_layer(out)
318
+ out = out.squeeze(-1)
319
+
320
+ if mask is not None:
321
+ out = out.masked_fill(mask, 0.0)
322
+
323
+ return out
324
+
325
+
326
+ class Conv(nn.Module):
327
+ """
328
+ Convolution Module
329
+ """
330
+
331
+ def __init__(
332
+ self,
333
+ in_channels,
334
+ out_channels,
335
+ kernel_size=1,
336
+ stride=1,
337
+ padding=0,
338
+ dilation=1,
339
+ bias=True,
340
+ w_init="linear",
341
+ ):
342
+ """
343
+ :param in_channels: dimension of input
344
+ :param out_channels: dimension of output
345
+ :param kernel_size: size of kernel
346
+ :param stride: size of stride
347
+ :param padding: size of padding
348
+ :param dilation: dilation rate
349
+ :param bias: boolean. if True, bias is included.
350
+ :param w_init: str. weight inits with xavier initialization.
351
+ """
352
+ super(Conv, self).__init__()
353
+
354
+ self.conv = nn.Conv1d(
355
+ in_channels,
356
+ out_channels,
357
+ kernel_size=kernel_size,
358
+ stride=stride,
359
+ padding=padding,
360
+ dilation=dilation,
361
+ bias=bias,
362
+ )
363
+
364
+ def forward(self, x):
365
+ x = x.contiguous().transpose(1, 2)
366
+ x = self.conv(x)
367
+ x = x.contiguous().transpose(1, 2)
368
+
369
+ return x
370
+
371
+
372
+ class FastSpeech2(nn.Module):
373
+ def __init__(self, cfg) -> None:
374
+ super(FastSpeech2, self).__init__()
375
+ self.cfg = cfg
376
+ self.encoder = Encoder(cfg.model)
377
+ self.variance_adaptor = VarianceAdaptor(cfg)
378
+ self.decoder = Decoder(cfg.model)
379
+ self.mel_linear = nn.Linear(
380
+ cfg.model.transformer.decoder_hidden,
381
+ cfg.preprocess.n_mel,
382
+ )
383
+ self.postnet = PostNet(n_mel_channels=cfg.preprocess.n_mel)
384
+
385
+ self.speaker_emb = None
386
+ if cfg.train.multi_speaker_training:
387
+ with open(
388
+ os.path.join(
389
+ cfg.preprocess.processed_dir, cfg.dataset[0], "spk2id.json"
390
+ ),
391
+ "r",
392
+ ) as f:
393
+ n_speaker = len(json.load(f))
394
+ self.speaker_emb = nn.Embedding(
395
+ n_speaker,
396
+ cfg.model.transformer.encoder_hidden,
397
+ )
398
+
399
+ def forward(self, data, p_control=1.0, e_control=1.0, d_control=1.0):
400
+ speakers = data["spk_id"]
401
+ texts = data["texts"]
402
+ src_lens = data["text_len"]
403
+ max_src_len = max(src_lens)
404
+ mel_lens = data["target_len"] if "target_len" in data else None
405
+ max_mel_len = max(mel_lens) if "target_len" in data else None
406
+ p_targets = data["pitch"] if "pitch" in data else None
407
+ e_targets = data["energy"] if "energy" in data else None
408
+ d_targets = data["durations"] if "durations" in data else None
409
+ src_masks = get_mask_from_lengths(src_lens, max_src_len)
410
+ mel_masks = (
411
+ get_mask_from_lengths(mel_lens, max_mel_len)
412
+ if mel_lens is not None
413
+ else None
414
+ )
415
+
416
+ output = self.encoder(texts, src_masks)
417
+
418
+ if self.speaker_emb is not None:
419
+ output = output + self.speaker_emb(speakers).unsqueeze(1).expand(
420
+ -1, max_src_len, -1
421
+ )
422
+
423
+ (
424
+ output,
425
+ p_predictions,
426
+ e_predictions,
427
+ log_d_predictions,
428
+ d_rounded,
429
+ mel_lens,
430
+ mel_masks,
431
+ ) = self.variance_adaptor(
432
+ output,
433
+ src_masks,
434
+ mel_masks,
435
+ max_mel_len,
436
+ p_targets,
437
+ e_targets,
438
+ d_targets,
439
+ p_control,
440
+ e_control,
441
+ d_control,
442
+ )
443
+
444
+ output, mel_masks = self.decoder(output, mel_masks)
445
+ output = self.mel_linear(output)
446
+
447
+ postnet_output = self.postnet(output) + output
448
+
449
+ return {
450
+ "output": output,
451
+ "postnet_output": postnet_output,
452
+ "p_predictions": p_predictions,
453
+ "e_predictions": e_predictions,
454
+ "log_d_predictions": log_d_predictions,
455
+ "d_rounded": d_rounded,
456
+ "src_masks": src_masks,
457
+ "mel_masks": mel_masks,
458
+ "src_lens": src_lens,
459
+ "mel_lens": mel_lens,
460
+ }
461
+
462
+
463
+ class FastSpeech2Loss(nn.Module):
464
+ """FastSpeech2 Loss"""
465
+
466
+ def __init__(self, cfg):
467
+ super(FastSpeech2Loss, self).__init__()
468
+ if cfg.preprocess.use_frame_pitch:
469
+ self.pitch_feature_level = "frame_level"
470
+ else:
471
+ self.pitch_feature_level = "phoneme_level"
472
+
473
+ if cfg.preprocess.use_frame_energy:
474
+ self.energy_feature_level = "frame_level"
475
+ else:
476
+ self.energy_feature_level = "phoneme_level"
477
+
478
+ self.mse_loss = nn.MSELoss()
479
+ self.mae_loss = nn.L1Loss()
480
+
481
+ def forward(self, data, predictions):
482
+ mel_targets = data["mel"]
483
+ pitch_targets = data["pitch"].float()
484
+ energy_targets = data["energy"].float()
485
+ duration_targets = data["durations"]
486
+
487
+ mel_predictions = predictions["output"]
488
+ postnet_mel_predictions = predictions["postnet_output"]
489
+ pitch_predictions = predictions["p_predictions"]
490
+ energy_predictions = predictions["e_predictions"]
491
+ log_duration_predictions = predictions["log_d_predictions"]
492
+ src_masks = predictions["src_masks"]
493
+ mel_masks = predictions["mel_masks"]
494
+
495
+ src_masks = ~src_masks
496
+ mel_masks = ~mel_masks
497
+
498
+ log_duration_targets = torch.log(duration_targets.float() + 1)
499
+ mel_targets = mel_targets[:, : mel_masks.shape[1], :]
500
+ mel_masks = mel_masks[:, : mel_masks.shape[1]]
501
+
502
+ log_duration_targets.requires_grad = False
503
+ pitch_targets.requires_grad = False
504
+ energy_targets.requires_grad = False
505
+ mel_targets.requires_grad = False
506
+
507
+ if self.pitch_feature_level == "phoneme_level":
508
+ pitch_predictions = pitch_predictions.masked_select(src_masks)
509
+ pitch_targets = pitch_targets.masked_select(src_masks)
510
+ elif self.pitch_feature_level == "frame_level":
511
+ pitch_predictions = pitch_predictions.masked_select(mel_masks)
512
+ pitch_targets = pitch_targets.masked_select(mel_masks)
513
+
514
+ if self.energy_feature_level == "phoneme_level":
515
+ energy_predictions = energy_predictions.masked_select(src_masks)
516
+ energy_targets = energy_targets.masked_select(src_masks)
517
+ if self.energy_feature_level == "frame_level":
518
+ energy_predictions = energy_predictions.masked_select(mel_masks)
519
+ energy_targets = energy_targets.masked_select(mel_masks)
520
+
521
+ log_duration_predictions = log_duration_predictions.masked_select(src_masks)
522
+ log_duration_targets = log_duration_targets.masked_select(src_masks)
523
+
524
+ mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1))
525
+ postnet_mel_predictions = postnet_mel_predictions.masked_select(
526
+ mel_masks.unsqueeze(-1)
527
+ )
528
+ mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1))
529
+
530
+ mel_loss = self.mae_loss(mel_predictions, mel_targets)
531
+ postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets)
532
+
533
+ pitch_loss = self.mse_loss(pitch_predictions, pitch_targets)
534
+ energy_loss = self.mse_loss(energy_predictions, energy_targets)
535
+ duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets)
536
+
537
+ total_loss = (
538
+ mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss
539
+ )
540
+
541
+ return {
542
+ "loss": total_loss,
543
+ "mel_loss": mel_loss,
544
+ "postnet_mel_loss": postnet_mel_loss,
545
+ "pitch_loss": pitch_loss,
546
+ "energy_loss": energy_loss,
547
+ "duration_loss": duration_loss,
548
+ }
Amphion/models/tts/fastspeech2/fs2_inference.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import torch
8
+ from tqdm import tqdm
9
+ from collections import OrderedDict
10
+
11
+ from models.tts.base.tts_inferece import TTSInference
12
+ from models.tts.fastspeech2.fs2_dataset import FS2TestDataset, FS2TestCollator
13
+ from utils.util import load_config
14
+ from utils.io import save_audio
15
+ from models.tts.fastspeech2.fs2 import FastSpeech2
16
+ from models.vocoders.vocoder_inference import synthesis
17
+ from pathlib import Path
18
+ from processors.phone_extractor import phoneExtractor
19
+ from text.text_token_collation import phoneIDCollation
20
+ import numpy as np
21
+ import json
22
+
23
+
24
+ class FastSpeech2Inference(TTSInference):
25
+ def __init__(self, args, cfg):
26
+ TTSInference.__init__(self, args, cfg)
27
+ self.args = args
28
+ self.cfg = cfg
29
+ self.infer_type = args.mode
30
+
31
+ def _build_model(self):
32
+ self.model = FastSpeech2(self.cfg)
33
+ return self.model
34
+
35
+ def load_model(self, state_dict):
36
+ raw_dict = state_dict["model"]
37
+ clean_dict = OrderedDict()
38
+ for k, v in raw_dict.items():
39
+ if k.startswith("module."):
40
+ clean_dict[k[7:]] = v
41
+ else:
42
+ clean_dict[k] = v
43
+
44
+ self.model.load_state_dict(clean_dict)
45
+
46
+ def _build_test_dataset(self):
47
+ return FS2TestDataset, FS2TestCollator
48
+
49
+ @staticmethod
50
+ def _parse_vocoder(vocoder_dir):
51
+ r"""Parse vocoder config"""
52
+ vocoder_dir = os.path.abspath(vocoder_dir)
53
+ ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
54
+ # last step (different from the base *int(x.stem)*)
55
+ ckpt_list.sort(
56
+ key=lambda x: int(x.stem.split("_")[-2].split("-")[-1]), reverse=True
57
+ )
58
+ ckpt_path = str(ckpt_list[0])
59
+ vocoder_cfg = load_config(
60
+ os.path.join(vocoder_dir, "args.json"), lowercase=True
61
+ )
62
+ return vocoder_cfg, ckpt_path
63
+
64
+ @torch.inference_mode()
65
+ def inference_for_batches(self):
66
+ y_pred = []
67
+ for i, batch in tqdm(enumerate(self.test_dataloader)):
68
+ y_pred, mel_lens, _ = self._inference_each_batch(batch)
69
+ y_ls = y_pred.chunk(self.test_batch_size)
70
+ tgt_ls = mel_lens.chunk(self.test_batch_size)
71
+ j = 0
72
+ for it, l in zip(y_ls, tgt_ls):
73
+ l = l.item()
74
+ it = it.squeeze(0)[:l].detach().cpu()
75
+
76
+ uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
77
+ torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
78
+ j += 1
79
+
80
+ vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
81
+ res = synthesis(
82
+ cfg=vocoder_cfg,
83
+ vocoder_weight_file=vocoder_ckpt,
84
+ n_samples=None,
85
+ pred=[
86
+ torch.load(
87
+ os.path.join(self.args.output_dir, "{}.pt".format(item["Uid"]))
88
+ ).numpy()
89
+ for item in self.test_dataset.metadata
90
+ ],
91
+ )
92
+ for it, wav in zip(self.test_dataset.metadata, res):
93
+ uid = it["Uid"]
94
+ save_audio(
95
+ os.path.join(self.args.output_dir, f"{uid}.wav"),
96
+ wav.numpy(),
97
+ self.cfg.preprocess.sample_rate,
98
+ add_silence=True,
99
+ turn_up=True,
100
+ )
101
+ os.remove(os.path.join(self.args.output_dir, f"{uid}.pt"))
102
+
103
+ @torch.inference_mode()
104
+ def _inference_each_batch(self, batch_data):
105
+ device = self.accelerator.device
106
+ control_values = (
107
+ self.args.pitch_control,
108
+ self.args.energy_control,
109
+ self.args.duration_control,
110
+ )
111
+ for k, v in batch_data.items():
112
+ batch_data[k] = v.to(device)
113
+
114
+ pitch_control, energy_control, duration_control = control_values
115
+
116
+ output = self.model(
117
+ batch_data,
118
+ p_control=pitch_control,
119
+ e_control=energy_control,
120
+ d_control=duration_control,
121
+ )
122
+ pred_res = output["postnet_output"]
123
+ mel_lens = output["mel_lens"].cpu()
124
+ return pred_res, mel_lens, 0
125
+
126
+ def inference_for_single_utterance(self):
127
+ text = self.args.text
128
+ control_values = (
129
+ self.args.pitch_control,
130
+ self.args.energy_control,
131
+ self.args.duration_control,
132
+ )
133
+ pitch_control, energy_control, duration_control = control_values
134
+
135
+ # get phone symbol file
136
+ phone_symbol_file = None
137
+ if self.cfg.preprocess.phone_extractor != "lexicon":
138
+ phone_symbol_file = os.path.join(
139
+ self.exp_dir, self.cfg.preprocess.symbols_dict
140
+ )
141
+ assert os.path.exists(phone_symbol_file)
142
+ # convert text to phone sequence
143
+ phone_extractor = phoneExtractor(self.cfg)
144
+
145
+ phone_seq = phone_extractor.extract_phone(text) # phone_seq: list
146
+ # convert phone sequence to phone id sequence
147
+ phon_id_collator = phoneIDCollation(
148
+ self.cfg, symbols_dict_file=phone_symbol_file
149
+ )
150
+ phone_seq = ["{"] + phone_seq + ["}"]
151
+ phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq)
152
+
153
+ # convert phone sequence to phone id sequence
154
+ phone_id_seq = np.array(phone_id_seq)
155
+ phone_id_seq = torch.from_numpy(phone_id_seq)
156
+
157
+ # get speaker id if multi-speaker training and use speaker id
158
+ speaker_id = None
159
+ if self.cfg.preprocess.use_spkid and self.cfg.train.multi_speaker_training:
160
+ spk2id_file = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
161
+ with open(spk2id_file, "r") as f:
162
+ spk2id = json.load(f)
163
+ speaker_id = spk2id[self.args.speaker_name]
164
+ speaker_id = torch.from_numpy(np.array([speaker_id], dtype=np.int32))
165
+ else:
166
+ speaker_id = torch.Tensor(0).view(-1)
167
+
168
+ with torch.no_grad():
169
+ x_tst = phone_id_seq.to(self.device).unsqueeze(0)
170
+ x_tst_lengths = torch.LongTensor([phone_id_seq.size(0)]).to(self.device)
171
+ if speaker_id is not None:
172
+ speaker_id = speaker_id.to(self.device)
173
+
174
+ data = {}
175
+ data["texts"] = x_tst
176
+ data["text_len"] = x_tst_lengths
177
+ data["spk_id"] = speaker_id
178
+
179
+ output = self.model(
180
+ data,
181
+ p_control=pitch_control,
182
+ e_control=energy_control,
183
+ d_control=duration_control,
184
+ )
185
+ pred_res = output["postnet_output"]
186
+ vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
187
+ audio = synthesis(
188
+ cfg=vocoder_cfg,
189
+ vocoder_weight_file=vocoder_ckpt,
190
+ n_samples=None,
191
+ pred=pred_res,
192
+ )
193
+ return audio[0]
Amphion/models/tts/naturalspeech2/__init__.py ADDED
File without changes
Amphion/models/tts/naturalspeech2/wavenet.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import numpy as np
9
+ import torch.nn.functional as F
10
+ import math
11
+
12
+
13
+ class FiLM(nn.Module):
14
+ def __init__(self, in_dim, cond_dim):
15
+ super().__init__()
16
+
17
+ self.gain = Linear(cond_dim, in_dim)
18
+ self.bias = Linear(cond_dim, in_dim)
19
+
20
+ nn.init.xavier_uniform_(self.gain.weight)
21
+ nn.init.constant_(self.gain.bias, 1)
22
+
23
+ nn.init.xavier_uniform_(self.bias.weight)
24
+ nn.init.constant_(self.bias.bias, 0)
25
+
26
+ def forward(self, x, condition):
27
+ gain = self.gain(condition)
28
+ bias = self.bias(condition)
29
+ if gain.dim() == 2:
30
+ gain = gain.unsqueeze(-1)
31
+ if bias.dim() == 2:
32
+ bias = bias.unsqueeze(-1)
33
+ return x * gain + bias
34
+
35
+
36
+ class Mish(nn.Module):
37
+ def forward(self, x):
38
+ return x * torch.tanh(F.softplus(x))
39
+
40
+
41
+ def Conv1d(*args, **kwargs):
42
+ layer = nn.Conv1d(*args, **kwargs)
43
+ nn.init.kaiming_normal_(layer.weight)
44
+ return layer
45
+
46
+
47
+ def Linear(*args, **kwargs):
48
+ layer = nn.Linear(*args, **kwargs)
49
+ layer.weight.data.normal_(0.0, 0.02)
50
+ return layer
51
+
52
+
53
+ class SinusoidalPosEmb(nn.Module):
54
+ def __init__(self, dim):
55
+ super().__init__()
56
+ self.dim = dim
57
+
58
+ def forward(self, x):
59
+ device = x.device
60
+ half_dim = self.dim // 2
61
+ emb = math.log(10000) / (half_dim - 1)
62
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
63
+ emb = x[:, None] * emb[None, :]
64
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
65
+ return emb
66
+
67
+
68
+ class ResidualBlock(nn.Module):
69
+ def __init__(self, hidden_dim, attn_head, dilation, drop_out, has_cattn=False):
70
+ super().__init__()
71
+
72
+ self.hidden_dim = hidden_dim
73
+ self.dilation = dilation
74
+ self.has_cattn = has_cattn
75
+ self.attn_head = attn_head
76
+ self.drop_out = drop_out
77
+
78
+ self.dilated_conv = Conv1d(
79
+ hidden_dim, 2 * hidden_dim, 3, padding=dilation, dilation=dilation
80
+ )
81
+ self.diffusion_proj = Linear(hidden_dim, hidden_dim)
82
+
83
+ self.cond_proj = Conv1d(hidden_dim, hidden_dim * 2, 1)
84
+ self.out_proj = Conv1d(hidden_dim, hidden_dim * 2, 1)
85
+
86
+ if self.has_cattn:
87
+ self.attn = nn.MultiheadAttention(
88
+ hidden_dim, attn_head, 0.1, batch_first=True
89
+ )
90
+ self.film = FiLM(hidden_dim * 2, hidden_dim)
91
+
92
+ self.ln = nn.LayerNorm(hidden_dim)
93
+
94
+ self.dropout = nn.Dropout(self.drop_out)
95
+
96
+ def forward(self, x, x_mask, cond, diffusion_step, spk_query_emb):
97
+ diffusion_step = self.diffusion_proj(diffusion_step).unsqueeze(-1) # (B, d, 1)
98
+ cond = self.cond_proj(cond) # (B, 2*d, T)
99
+
100
+ y = x + diffusion_step
101
+ if x_mask != None:
102
+ y = y * x_mask.to(y.dtype)[:, None, :] # (B, 2*d, T)
103
+
104
+ if self.has_cattn:
105
+ y_ = y.transpose(1, 2)
106
+ y_ = self.ln(y_)
107
+
108
+ y_, _ = self.attn(y_, spk_query_emb, spk_query_emb) # (B, T, d)
109
+
110
+ y = self.dilated_conv(y) + cond # (B, 2*d, T)
111
+
112
+ if self.has_cattn:
113
+ y = self.film(y.transpose(1, 2), y_) # (B, T, 2*d)
114
+ y = y.transpose(1, 2) # (B, 2*d, T)
115
+
116
+ gate, filter_ = torch.chunk(y, 2, dim=1)
117
+ y = torch.sigmoid(gate) * torch.tanh(filter_)
118
+
119
+ y = self.out_proj(y)
120
+
121
+ residual, skip = torch.chunk(y, 2, dim=1)
122
+
123
+ if x_mask != None:
124
+ residual = residual * x_mask.to(y.dtype)[:, None, :]
125
+ skip = skip * x_mask.to(y.dtype)[:, None, :]
126
+
127
+ return (x + residual) / math.sqrt(2.0), skip
128
+
129
+
130
+ class WaveNet(nn.Module):
131
+ def __init__(self, cfg):
132
+ super().__init__()
133
+
134
+ self.cfg = cfg
135
+ self.in_dim = cfg.input_size
136
+ self.hidden_dim = cfg.hidden_size
137
+ self.out_dim = cfg.out_size
138
+ self.num_layers = cfg.num_layers
139
+ self.cross_attn_per_layer = cfg.cross_attn_per_layer
140
+ self.dilation_cycle = cfg.dilation_cycle
141
+ self.attn_head = cfg.attn_head
142
+ self.drop_out = cfg.drop_out
143
+
144
+ self.in_proj = Conv1d(self.in_dim, self.hidden_dim, 1)
145
+ self.diffusion_embedding = SinusoidalPosEmb(self.hidden_dim)
146
+
147
+ self.mlp = nn.Sequential(
148
+ Linear(self.hidden_dim, self.hidden_dim * 4),
149
+ Mish(),
150
+ Linear(self.hidden_dim * 4, self.hidden_dim),
151
+ )
152
+
153
+ self.cond_ln = nn.LayerNorm(self.hidden_dim)
154
+
155
+ self.layers = nn.ModuleList(
156
+ [
157
+ ResidualBlock(
158
+ self.hidden_dim,
159
+ self.attn_head,
160
+ 2 ** (i % self.dilation_cycle),
161
+ self.drop_out,
162
+ has_cattn=(i % self.cross_attn_per_layer == 0),
163
+ )
164
+ for i in range(self.num_layers)
165
+ ]
166
+ )
167
+
168
+ self.skip_proj = Conv1d(self.hidden_dim, self.hidden_dim, 1)
169
+ self.out_proj = Conv1d(self.hidden_dim, self.out_dim, 1)
170
+
171
+ nn.init.zeros_(self.out_proj.weight)
172
+
173
+ def forward(self, x, x_mask, cond, diffusion_step, spk_query_emb):
174
+ """
175
+ x: (B, 128, T)
176
+ x_mask: (B, T), mask is 0
177
+ cond: (B, T, 512)
178
+ diffusion_step: (B,)
179
+ spk_query_emb: (B, 32, 512)
180
+ """
181
+ cond = self.cond_ln(cond)
182
+ cond_input = cond.transpose(1, 2)
183
+
184
+ x_input = self.in_proj(x)
185
+
186
+ x_input = F.relu(x_input)
187
+
188
+ diffusion_step = self.diffusion_embedding(diffusion_step).to(x.dtype)
189
+ diffusion_step = self.mlp(diffusion_step)
190
+
191
+ skip = []
192
+ for _, layer in enumerate(self.layers):
193
+ x_input, skip_connection = layer(
194
+ x_input, x_mask, cond_input, diffusion_step, spk_query_emb
195
+ )
196
+ skip.append(skip_connection)
197
+
198
+ x_input = torch.sum(torch.stack(skip), dim=0) / math.sqrt(self.num_layers)
199
+
200
+ x_out = self.skip_proj(x_input)
201
+
202
+ x_out = F.relu(x_out)
203
+
204
+ x_out = self.out_proj(x_out) # (B, 128, T)
205
+
206
+ return x_out
Amphion/models/tts/valle/valle.py ADDED
@@ -0,0 +1,794 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is modified from https://github.com/lifeiteng/vall-e/blob/main/valle/models/valle.py
7
+
8
+ import random
9
+ from typing import Dict, Iterator, List, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from torchmetrics.classification import MulticlassAccuracy
15
+ from utils.util import make_pad_mask
16
+ from utils.topk_sampling import topk_sampling
17
+ from modules.general import Transpose
18
+ from modules.encoder import TokenEmbedding
19
+ from modules.general import PromptedFeatures
20
+ from modules.transformer import SinePositionalEmbedding
21
+ from modules.norms import AdaptiveLayerNorm, LayerNorm
22
+ from modules.transformer.transformer import TransformerEncoder, TransformerEncoderLayer
23
+
24
+
25
+ class VALLE(nn.Module):
26
+ def __init__(
27
+ self,
28
+ cfg,
29
+ decoder_cls=TransformerEncoder,
30
+ decoder_layer_cls=TransformerEncoderLayer,
31
+ ):
32
+ super().__init__()
33
+ decoder_dim = cfg.decoder_dim
34
+ nhead = cfg.nhead
35
+ nar_scale_factor = cfg.nar_scale_factor
36
+ num_quantizers = cfg.num_quantizers
37
+ num_decoder_layers = cfg.num_decoder_layers
38
+ nar_decoder_dim = int(decoder_dim * nar_scale_factor)
39
+
40
+ self.ar_text_embedding = TokenEmbedding(decoder_dim, cfg.text_token_num)
41
+ self.nar_text_embedding = TokenEmbedding(nar_decoder_dim, cfg.text_token_num)
42
+
43
+ self.ar_audio_prepend_bos = cfg.prepend_bos
44
+ self.ar_audio_embedding = TokenEmbedding(
45
+ decoder_dim, cfg.audio_token_num + 1 + int(cfg.prepend_bos)
46
+ )
47
+ self.audio_token_num = cfg.audio_token_num
48
+
49
+ # PreNet of AR
50
+ if cfg.add_prenet:
51
+ self.ar_text_prenet = nn.Sequential(
52
+ Transpose(),
53
+ nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"),
54
+ nn.BatchNorm1d(decoder_dim),
55
+ nn.ReLU(),
56
+ nn.Dropout(0.5),
57
+ nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"),
58
+ nn.BatchNorm1d(decoder_dim),
59
+ nn.ReLU(),
60
+ nn.Dropout(0.5),
61
+ nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"),
62
+ nn.BatchNorm1d(decoder_dim),
63
+ nn.ReLU(),
64
+ nn.Dropout(0.5),
65
+ Transpose(),
66
+ nn.Linear(decoder_dim, decoder_dim),
67
+ )
68
+
69
+ self.ar_audio_prenet = nn.Sequential(
70
+ nn.Linear(decoder_dim, 256),
71
+ nn.ReLU(),
72
+ nn.Dropout(0.25),
73
+ nn.Linear(256, 256),
74
+ nn.ReLU(),
75
+ nn.Dropout(0.25),
76
+ nn.Linear(256, decoder_dim),
77
+ )
78
+ else:
79
+ self.ar_text_prenet = nn.Identity()
80
+ self.ar_audio_prenet = nn.Identity()
81
+
82
+ self.ar_text_position = SinePositionalEmbedding(
83
+ decoder_dim,
84
+ dropout=0.1,
85
+ scale=False,
86
+ alpha=True,
87
+ )
88
+ self.ar_audio_position = SinePositionalEmbedding(
89
+ decoder_dim,
90
+ dropout=0.1,
91
+ scale=False,
92
+ alpha=True,
93
+ )
94
+
95
+ self.ar_decoder = decoder_cls(
96
+ decoder_layer_cls(
97
+ decoder_dim,
98
+ nhead,
99
+ dim_feedforward=decoder_dim * 4, # *4?
100
+ dropout=0.1,
101
+ batch_first=True,
102
+ norm_first=cfg.norm_first,
103
+ ),
104
+ num_layers=num_decoder_layers,
105
+ norm=LayerNorm(decoder_dim) if cfg.norm_first else None,
106
+ )
107
+ self.ar_predict_layer = nn.Linear(
108
+ decoder_dim, cfg.audio_token_num + 1, bias=False
109
+ )
110
+
111
+ self.ar_accuracy_metric = MulticlassAccuracy(
112
+ cfg.audio_token_num + 1,
113
+ top_k=10,
114
+ average="micro",
115
+ multidim_average="global",
116
+ ignore_index=cfg.audio_token_num,
117
+ )
118
+
119
+ self.rng = random.Random(0)
120
+ self.num_heads = nhead
121
+ self.prefix_mode = cfg.prefix_mode
122
+ self.num_quantizers = num_quantizers
123
+
124
+ assert num_quantizers >= 1
125
+ if num_quantizers > 1:
126
+ self.nar_audio_embeddings = nn.ModuleList(
127
+ [
128
+ TokenEmbedding(nar_decoder_dim, cfg.audio_token_num + 1)
129
+ ] # Why the first layer is audio_token_num + 1?
130
+ + [
131
+ TokenEmbedding(nar_decoder_dim, cfg.audio_token_num)
132
+ for i in range(num_quantizers - 1)
133
+ ]
134
+ )
135
+
136
+ if cfg.add_prenet:
137
+ self.nar_text_prenet = nn.Sequential(
138
+ Transpose(),
139
+ nn.Conv1d(
140
+ nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same"
141
+ ),
142
+ nn.BatchNorm1d(nar_decoder_dim),
143
+ nn.ReLU(),
144
+ nn.Dropout(0.5),
145
+ nn.Conv1d(
146
+ nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same"
147
+ ),
148
+ nn.BatchNorm1d(nar_decoder_dim),
149
+ nn.ReLU(),
150
+ nn.Dropout(0.5),
151
+ nn.Conv1d(
152
+ nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same"
153
+ ),
154
+ nn.BatchNorm1d(nar_decoder_dim),
155
+ nn.ReLU(),
156
+ nn.Dropout(0.5),
157
+ Transpose(),
158
+ nn.Linear(nar_decoder_dim, nar_decoder_dim),
159
+ )
160
+ self.nar_audio_prenet = nn.Sequential(
161
+ nn.Linear(nar_decoder_dim, 256),
162
+ nn.ReLU(),
163
+ nn.Dropout(0.25),
164
+ nn.Linear(256, 256),
165
+ nn.ReLU(),
166
+ nn.Dropout(0.25),
167
+ nn.Linear(256, nar_decoder_dim),
168
+ )
169
+ else:
170
+ self.nar_text_prenet = nn.Identity()
171
+ self.nar_audio_prenet = nn.Identity()
172
+
173
+ self.nar_text_position = SinePositionalEmbedding(
174
+ nar_decoder_dim,
175
+ dropout=0.0,
176
+ scale=False,
177
+ alpha=False,
178
+ )
179
+ self.nar_audio_position = SinePositionalEmbedding(
180
+ nar_decoder_dim,
181
+ dropout=0.1,
182
+ scale=False,
183
+ alpha=False,
184
+ )
185
+
186
+ self.nar_decoder = decoder_cls(
187
+ decoder_layer_cls(
188
+ nar_decoder_dim,
189
+ int(nhead * nar_scale_factor),
190
+ dim_feedforward=nar_decoder_dim * 4,
191
+ dropout=0.1,
192
+ batch_first=True,
193
+ norm_first=cfg.norm_first,
194
+ adaptive_layer_norm=True,
195
+ ),
196
+ num_layers=int(num_decoder_layers * nar_scale_factor),
197
+ norm=(
198
+ AdaptiveLayerNorm(
199
+ nar_decoder_dim, norm=nn.LayerNorm(nar_decoder_dim)
200
+ )
201
+ if cfg.norm_first
202
+ else None
203
+ ),
204
+ )
205
+ self.nar_predict_layers = nn.ModuleList(
206
+ [
207
+ nn.Linear(nar_decoder_dim, cfg.audio_token_num, bias=False)
208
+ for i in range(num_quantizers - 1)
209
+ ]
210
+ )
211
+ self.nar_stage_embeddings = nn.ModuleList(
212
+ [TokenEmbedding(nar_decoder_dim, 1) for i in range(num_quantizers - 1)]
213
+ )
214
+
215
+ if cfg.share_embedding:
216
+ for j in range(0, num_quantizers - 2):
217
+ self.nar_predict_layers[j].weight = self.nar_audio_embeddings[
218
+ j + 2
219
+ ].weight
220
+
221
+ self.nar_accuracy_metric = MulticlassAccuracy(
222
+ cfg.audio_token_num + 1,
223
+ top_k=10,
224
+ average="micro",
225
+ multidim_average="global",
226
+ ignore_index=cfg.audio_token_num,
227
+ )
228
+
229
+ def forward(
230
+ self,
231
+ x: torch.Tensor,
232
+ x_lens: torch.Tensor,
233
+ y: Union[torch.Tensor, PromptedFeatures],
234
+ y_lens: Union[torch.Tensor, PromptedFeatures],
235
+ reduction: str = "sum",
236
+ train_stage: int = 0,
237
+ **kwargs,
238
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
239
+ """
240
+ Args:
241
+ x:
242
+ A 2-D tensor of shape (N, S).
243
+ x_lens:
244
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
245
+ before padding.
246
+ y:
247
+ A 3-D tensor of shape (N, T, 8).
248
+ y_lens:
249
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
250
+ before padding.
251
+ train_stage:
252
+ 0: AR & NAR modules, 1: AR modules, 2: NAR modules
253
+ Returns:
254
+ Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
255
+ """
256
+ assert x.ndim == 2, x.shape
257
+ assert x_lens.ndim == 1, x_lens.shape
258
+
259
+ y_prompts_codes = None
260
+ if isinstance(y, PromptedFeatures):
261
+ y_prompts_codes, y = y.data
262
+ prompts_len, y_lens = y_lens.data
263
+ assert prompts_len.min() == prompts_len.max()
264
+ assert self.prefix_mode == 4
265
+ y_prompts_codes = y_prompts_codes.type(torch.int64)
266
+
267
+ assert y.ndim == 3, y.shape
268
+ assert y_lens.ndim == 1, y_lens.shape
269
+
270
+ x_mask = make_pad_mask(x_lens).to(x.device)
271
+ y_mask = make_pad_mask(y_lens).to(y.device)
272
+ y_mask_int = y_mask.type(torch.int64)
273
+
274
+ text = x
275
+ codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1))
276
+
277
+ y, targets = self.pad_y_eos(
278
+ codes[..., 0], y_mask_int, eos_id=self.audio_token_num
279
+ )
280
+ self.y_mask_int = y_mask_int
281
+
282
+ metrics = {}
283
+ total_loss = 0.0
284
+
285
+ xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
286
+ if self.ar_audio_prepend_bos:
287
+ ar_xy_padding_mask = torch.concat(
288
+ [x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1
289
+ )
290
+ else:
291
+ ar_xy_padding_mask = xy_padding_mask
292
+ self.xy_padding_mask = xy_padding_mask
293
+ self.ar_xy_padding_mask = ar_xy_padding_mask
294
+
295
+ # AR Decoder
296
+ if train_stage in [0, 1]:
297
+ ar_loss, ar_metrics = self._forward_ar_decoder(
298
+ text, x_lens.max(), y, y_lens.max(), targets, x_mask, y_mask, reduction
299
+ )
300
+ total_loss += ar_loss
301
+ metrics["AR_Top100Acc"] = ar_metrics
302
+
303
+ # NAR Decoder
304
+ if self.ar_audio_prepend_bos:
305
+ y = y[:, 1:]
306
+
307
+ if self.num_quantizers > 1 and train_stage in [0, 2]:
308
+ nar_loss, nar_metrics = self._forward_nar_decoder(
309
+ text,
310
+ x_lens,
311
+ y,
312
+ y_lens,
313
+ codes,
314
+ y_prompts_codes,
315
+ x_mask,
316
+ y_mask,
317
+ reduction,
318
+ )
319
+ total_loss += nar_loss
320
+ metrics["NAR_Top100Acc"] = nar_metrics
321
+
322
+ if train_stage == 0:
323
+ total_loss = total_loss / 2.0
324
+
325
+ return total_loss, metrics
326
+
327
+ def _forward_ar_decoder(
328
+ self, x, x_len, y, y_lens, targets, x_mask, y_mask, reduction
329
+ ):
330
+ x = self.ar_text_embedding(x)
331
+ x = self.ar_text_prenet(x)
332
+ x = self.ar_text_position(x)
333
+
334
+ y_len = y_lens.max() + int(self.ar_audio_prepend_bos)
335
+
336
+ x_attn_mask = F.pad(
337
+ torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
338
+ (0, y_len),
339
+ value=True,
340
+ )
341
+ y_attn_mask = F.pad(
342
+ torch.triu(
343
+ torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
344
+ diagonal=1,
345
+ ),
346
+ (x_len, 0),
347
+ value=False,
348
+ )
349
+ xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
350
+
351
+ bsz, src_len = x.shape[0], x_len + y_len
352
+ _xy_padding_mask = (
353
+ self.ar_xy_padding_mask.view(bsz, 1, 1, src_len)
354
+ .expand(-1, self.num_heads, -1, -1)
355
+ .reshape(bsz * self.num_heads, 1, src_len)
356
+ )
357
+ xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
358
+
359
+ new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
360
+ new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
361
+ xy_attn_mask = new_attn_mask
362
+
363
+ y_emb = self.ar_audio_embedding(y)
364
+ y_emb = self.ar_audio_prenet(y_emb)
365
+ y_pos = self.ar_audio_position(y_emb)
366
+
367
+ xy_pos = torch.concat([x, y_pos], dim=1)
368
+
369
+ xy_dec, _ = self.ar_decoder(
370
+ (xy_pos, None),
371
+ mask=xy_attn_mask,
372
+ )
373
+ logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
374
+ ar_loss = F.cross_entropy(logits, targets, reduction=reduction)
375
+
376
+ ar_metrics = self.ar_accuracy_metric(
377
+ logits.detach(), targets
378
+ ).item() * y_lens.sum().type(torch.float32)
379
+
380
+ return ar_loss, ar_metrics
381
+
382
+ def _forward_nar_decoder(
383
+ self, x, x_lens, y, y_lens, codes, y_prompts_codes, x_mask, y_mask, reduction
384
+ ):
385
+ num_nar_layers = self.num_quantizers - 1
386
+ nar_stage = self.rng.choices(
387
+ [_k for _k in range(1, self.num_quantizers)],
388
+ weights=[1.0 / num_nar_layers] * num_nar_layers,
389
+ k=1,
390
+ )[0]
391
+
392
+ x = self.nar_text_embedding(x)
393
+ x = self.nar_text_prenet(x)
394
+ x = self.nar_text_position(x)
395
+
396
+ y_emb, prefix_len = self._prepare_prompts(
397
+ y, y_lens, codes, nar_stage, y_prompts_codes
398
+ )
399
+
400
+ y_len = y_lens.max()
401
+ targets = codes[..., nar_stage] + self.audio_token_num * self.y_mask_int
402
+ if self.prefix_mode in [2, 4]:
403
+ xy_padding_mask = torch.concat(
404
+ [
405
+ x_mask,
406
+ F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False),
407
+ ],
408
+ dim=1,
409
+ )
410
+ elif self.prefix_mode == 1:
411
+ targets = targets[:, prefix_len:]
412
+
413
+ y_pos = self.nar_audio_prenet(y_emb)
414
+ y_pos = self.nar_audio_position(y_pos)
415
+ xy_pos = torch.concat([x, y_pos], dim=1)
416
+ xy_dec, _ = self.nar_decoder(
417
+ (xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight),
418
+ src_key_padding_mask=self.xy_padding_mask,
419
+ )
420
+ xy_dec = xy_dec[:, x_lens.max() + prefix_len :]
421
+ if self.prefix_mode == 4:
422
+ prefix_len = 0
423
+ logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(0, 2, 1)
424
+
425
+ total_length = (y_lens).sum().type(torch.float32)
426
+ nar_loss = F.cross_entropy(
427
+ logits,
428
+ targets,
429
+ ignore_index=self.audio_token_num,
430
+ reduction=reduction,
431
+ ) * (total_length / (total_length - prefix_len * x.shape[0]))
432
+ nar_metrics = (
433
+ self.nar_accuracy_metric(
434
+ F.pad(
435
+ logits.detach(),
436
+ (0, 0, 0, 1, 0, 0),
437
+ value=logits.min().cpu().item(),
438
+ ),
439
+ targets,
440
+ ).item()
441
+ * total_length
442
+ )
443
+ return nar_loss, nar_metrics
444
+
445
+ def inference(
446
+ self,
447
+ x: torch.Tensor,
448
+ x_lens: torch.Tensor,
449
+ y: torch.Tensor,
450
+ enroll_x_lens: torch.Tensor,
451
+ top_k: int = -100,
452
+ temperature: float = 1.0,
453
+ ) -> torch.Tensor:
454
+ """
455
+ Args:
456
+ x:
457
+ A 2-D tensor of shape (1, S).
458
+ x_lens:
459
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
460
+ before padding.
461
+ y:
462
+ A 3-D tensor of shape (1, T, 8).
463
+ top_k: (`optional`) int
464
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
465
+ temperature: (`optional`) float
466
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
467
+ Returns:
468
+ Return the predicted audio code matrix.
469
+ """
470
+ assert x.ndim == 2, x.shape
471
+ assert x_lens.ndim == 1, x_lens.shape
472
+ assert y.ndim == 3, y.shape
473
+ assert y.shape[0] == 1, y.shape
474
+
475
+ assert torch.all(x_lens > 0)
476
+
477
+ text = x
478
+ x = self.ar_text_embedding(text)
479
+ x = self.ar_text_prenet(x)
480
+ x = self.ar_text_position(x)
481
+
482
+ text_len = x_lens.max()
483
+ prompts = y
484
+ prefix_len = y.shape[1]
485
+
486
+ # AR Decoder
487
+ y = prompts[..., 0]
488
+ if self.ar_audio_prepend_bos:
489
+ y = F.pad(y, (1, 0), value=self.audio_token_num + 1)
490
+
491
+ x_len = x_lens.max()
492
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
493
+
494
+ while True:
495
+ y_emb = self.ar_audio_embedding(y)
496
+ y_emb = self.ar_audio_prenet(y_emb)
497
+ y_pos = self.ar_audio_position(y_emb)
498
+ xy_pos = torch.concat([x, y_pos], dim=1)
499
+
500
+ y_len = y.shape[1]
501
+ x_attn_mask_pad = F.pad(
502
+ x_attn_mask,
503
+ (0, y_len),
504
+ value=True,
505
+ )
506
+ y_attn_mask = F.pad(
507
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
508
+ (x_len, 0),
509
+ value=False,
510
+ )
511
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
512
+ y.device
513
+ )
514
+
515
+ xy_dec, _ = self.ar_decoder(
516
+ (xy_pos, None),
517
+ mask=xy_attn_mask,
518
+ )
519
+ logits = self.ar_predict_layer(xy_dec[:, -1])
520
+ samples = topk_sampling(
521
+ logits, top_k=top_k, top_p=1.0, temperature=temperature
522
+ )
523
+
524
+ if (
525
+ torch.argmax(logits, dim=-1)[0] == self.audio_token_num
526
+ or samples[0, 0] == self.audio_token_num
527
+ or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
528
+ ):
529
+ if prompts.shape[1] == y.shape[1]:
530
+ raise SyntaxError("well trained model shouldn't reach here.")
531
+
532
+ break
533
+
534
+ y = torch.concat([y, samples], dim=1)
535
+
536
+ codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
537
+ if self.num_quantizers == 1:
538
+ return torch.stack(codes, dim=-1)
539
+
540
+ # Non-AR Decoders
541
+ y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :])
542
+
543
+ if self.prefix_mode in [2, 4]:
544
+ enrolled_len = enroll_x_lens.max().item()
545
+ # SOS + Synthesis Text + EOS
546
+ text = torch.concat(
547
+ [
548
+ text[:, :1],
549
+ text[:, enrolled_len - 1 :],
550
+ ],
551
+ dim=1,
552
+ )
553
+ text_len = text_len - (enrolled_len - 2)
554
+ assert text.shape[0] == 1
555
+
556
+ x = self.nar_text_embedding(text)
557
+ x = self.nar_text_prenet(x)
558
+ x = self.nar_text_position(x)
559
+
560
+ if self.prefix_mode == 0:
561
+ for i, (predict_layer, embedding_layer) in enumerate(
562
+ zip(
563
+ self.nar_predict_layers,
564
+ self.nar_audio_embeddings[1:],
565
+ )
566
+ ):
567
+ y_pos = self.nar_audio_prenet(y_emb)
568
+ y_pos = self.nar_audio_position(y_pos)
569
+ xy_pos = torch.concat([x, y_pos], dim=1)
570
+
571
+ xy_dec, _ = self.nar_decoder(
572
+ (xy_pos, self.nar_stage_embeddings[i].weight)
573
+ )
574
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
575
+
576
+ samples = torch.argmax(logits, dim=-1)
577
+ codes.append(samples)
578
+
579
+ if i < self.num_quantizers - 2:
580
+ y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
581
+ y_emb[:, prefix_len:] += embedding_layer(samples)
582
+ else:
583
+ for j in range(1, self.num_quantizers):
584
+ y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
585
+
586
+ for i, (predict_layer, embedding_layer) in enumerate(
587
+ zip(
588
+ self.nar_predict_layers,
589
+ self.nar_audio_embeddings[1:],
590
+ )
591
+ ):
592
+ y_pos = self.nar_audio_prenet(y_emb)
593
+ y_pos = self.nar_audio_position(y_pos)
594
+ xy_pos = torch.concat([x, y_pos], dim=1)
595
+
596
+ xy_dec, _ = self.nar_decoder(
597
+ (xy_pos, self.nar_stage_embeddings[i].weight)
598
+ )
599
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
600
+
601
+ samples = torch.argmax(logits, dim=-1)
602
+ codes.append(samples)
603
+
604
+ if i < self.num_quantizers - 2:
605
+ y_emb[:, prefix_len:] += embedding_layer(samples)
606
+
607
+ assert len(codes) == self.num_quantizers
608
+ return torch.stack(codes, dim=-1)
609
+
610
+ def continual(
611
+ self,
612
+ x: torch.Tensor,
613
+ x_lens: torch.Tensor,
614
+ y: torch.Tensor,
615
+ ) -> torch.Tensor:
616
+ """
617
+ Args:
618
+ x:
619
+ A 2-D tensor of shape (1, S).
620
+ x_lens:
621
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
622
+ before padding.
623
+ y:
624
+ A 3-D tensor of shape (1, T, 8).
625
+ Returns:
626
+ Return the predicted audio code matrix.
627
+ """
628
+ assert x.ndim == 2, x.shape
629
+ assert x_lens.ndim == 1, x_lens.shape
630
+ assert y.ndim == 3, y.shape
631
+ assert y.shape[0] == 1, y.shape
632
+
633
+ assert torch.all(x_lens > 0)
634
+ assert self.num_quantizers == 8
635
+
636
+ text = x
637
+ x = self.ar_text_embedding(text)
638
+ x = self.ar_text_prenet(x)
639
+ x = self.ar_text_position(x)
640
+
641
+ text_len = x_lens.max()
642
+
643
+ prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
644
+
645
+ # AR Decoder
646
+ prompts = y[:, :prefix_len]
647
+
648
+ codes = [y[:, prefix_len:, 0]]
649
+ # Non-AR Decoders
650
+ x = self.nar_text_embedding(text)
651
+ x = self.nar_text_prenet(x)
652
+ x = self.nar_text_position(x)
653
+
654
+ y_emb = self.nar_audio_embeddings[0](y[..., 0])
655
+
656
+ if self.prefix_mode == 0:
657
+ for i, (predict_layer, embedding_layer) in enumerate(
658
+ zip(
659
+ self.nar_predict_layers,
660
+ self.nar_audio_embeddings[1:],
661
+ )
662
+ ):
663
+ y_pos = self.nar_audio_position(y_emb)
664
+ y_pos = self.nar_audio_prenet(y_pos)
665
+ xy_pos = torch.concat([x, y_pos], dim=1)
666
+
667
+ xy_dec, _ = self.nar_decoder(
668
+ (xy_pos, self.nar_stage_embeddings[i].weight)
669
+ )
670
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
671
+
672
+ samples = torch.argmax(logits, dim=-1)
673
+ codes.append(samples)
674
+
675
+ if i < 6:
676
+ y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
677
+ y_emb[:, prefix_len:] += embedding_layer(samples)
678
+ else:
679
+ for j in range(1, 8):
680
+ y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
681
+
682
+ for i, (predict_layer, embedding_layer) in enumerate(
683
+ zip(
684
+ self.nar_predict_layers,
685
+ self.nar_audio_embeddings[1:],
686
+ )
687
+ ):
688
+ y_pos = self.nar_audio_prenet(y_emb)
689
+ y_pos = self.nar_audio_position(y_pos)
690
+ xy_pos = torch.concat([x, y_pos], dim=1)
691
+
692
+ xy_dec, _ = self.nar_decoder(
693
+ (xy_pos, self.nar_stage_embeddings[i].weight)
694
+ )
695
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
696
+
697
+ samples = torch.argmax(logits, dim=-1)
698
+ codes.append(samples)
699
+
700
+ if i < 6:
701
+ y_emb[:, prefix_len:] += embedding_layer(samples)
702
+
703
+ assert len(codes) == 8
704
+ return torch.stack(codes, dim=-1)
705
+
706
+ def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
707
+ assert stage > 0
708
+ if stage == 1:
709
+ for name, param in self.named_parameters():
710
+ if name.startswith("ar_"):
711
+ yield param
712
+
713
+ if stage == 2:
714
+ for name, param in self.named_parameters():
715
+ if name.startswith("nar_"):
716
+ yield param
717
+
718
+ def stage_named_parameters(
719
+ self, stage: int = 1
720
+ ) -> Iterator[Tuple[str, nn.Parameter]]:
721
+ assert stage > 0
722
+ if stage == 1:
723
+ for pair in self.named_parameters():
724
+ if pair[0].startswith("ar_"):
725
+ yield pair
726
+
727
+ if stage == 2:
728
+ for pair in self.named_parameters():
729
+ if pair[0].startswith("nar_"):
730
+ yield pair
731
+
732
+ def pad_y_eos(self, y, y_mask_int, eos_id):
733
+ targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
734
+ y_mask_int, (0, 1), value=1
735
+ )
736
+ if self.ar_audio_prepend_bos:
737
+ return (
738
+ F.pad(targets[:, :-1], (1, 0), value=self.audio_token_num + 1),
739
+ targets,
740
+ )
741
+
742
+ return targets[:, :-1], targets[:, 1:]
743
+
744
+ def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes):
745
+ # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
746
+ # from the same utterance.
747
+ # We implement this differently.
748
+ if self.prefix_mode == 0:
749
+ # no prefix
750
+ prefix_len = 0
751
+ y_emb = self.nar_audio_embeddings[0](y)
752
+ for j in range(1, nar_stage):
753
+ # Formula (4) (5)
754
+ y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
755
+ elif self.prefix_mode == 1:
756
+ # prefix at begining
757
+ int_low = (0.25 * y_lens.min()).type(torch.int64).item()
758
+ prefix_len = torch.randint(int_low, int_low * 2, size=()).item()
759
+ prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames
760
+
761
+ y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
762
+ y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
763
+ for j in range(1, self.num_quantizers):
764
+ y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j])
765
+ if j < nar_stage:
766
+ y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j])
767
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
768
+ elif self.prefix_mode in [2, 4]:
769
+ if self.prefix_mode == 2:
770
+ # random prefix
771
+ prefix_len = min(225, int(0.25 * y_lens.min().item()))
772
+
773
+ y_prompts_codes = []
774
+ for b in range(codes.shape[0]):
775
+ start = self.rng.randint(0, y_lens[b].item() - prefix_len)
776
+ y_prompts_codes.append(
777
+ torch.clone(codes[b, start : start + prefix_len])
778
+ )
779
+ codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS
780
+ y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
781
+ else:
782
+ prefix_len = y_prompts_codes.shape[1]
783
+
784
+ y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
785
+ y_emb = self.nar_audio_embeddings[0](y)
786
+ for j in range(1, self.num_quantizers):
787
+ y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j])
788
+ if j < nar_stage:
789
+ y_emb += self.nar_audio_embeddings[j](codes[..., j])
790
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
791
+ else:
792
+ raise ValueError
793
+
794
+ return y_emb, prefix_len
Amphion/models/tts/valle/valle_inference.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import numpy as np
8
+ import torch
9
+ import torchaudio
10
+ import argparse
11
+
12
+
13
+ from text.g2p_module import G2PModule
14
+ from utils.tokenizer import AudioTokenizer, tokenize_audio
15
+ from models.tts.valle.valle import VALLE
16
+ from models.tts.base.tts_inferece import TTSInference
17
+ from models.tts.valle.valle_dataset import VALLETestDataset, VALLETestCollator
18
+ from processors.phone_extractor import phoneExtractor
19
+ from text.text_token_collation import phoneIDCollation
20
+
21
+
22
+ class VALLEInference(TTSInference):
23
+ def __init__(self, args=None, cfg=None):
24
+ TTSInference.__init__(self, args, cfg)
25
+
26
+ self.g2p_module = G2PModule(backend=self.cfg.preprocess.phone_extractor)
27
+ text_token_path = os.path.join(
28
+ cfg.preprocess.processed_dir, cfg.dataset[0], cfg.preprocess.symbols_dict
29
+ )
30
+ self.audio_tokenizer = AudioTokenizer()
31
+
32
+ def _build_model(self):
33
+ model = VALLE(self.cfg.model)
34
+ return model
35
+
36
+ def _build_test_dataset(self):
37
+ return VALLETestDataset, VALLETestCollator
38
+
39
+ def inference_one_clip(self, text, text_prompt, audio_file, save_name="pred"):
40
+ # get phone symbol file
41
+ phone_symbol_file = None
42
+ if self.cfg.preprocess.phone_extractor != "lexicon":
43
+ phone_symbol_file = os.path.join(
44
+ self.exp_dir, self.cfg.preprocess.symbols_dict
45
+ )
46
+ assert os.path.exists(phone_symbol_file)
47
+ # convert text to phone sequence
48
+ phone_extractor = phoneExtractor(self.cfg)
49
+ # convert phone sequence to phone id sequence
50
+ phon_id_collator = phoneIDCollation(
51
+ self.cfg, symbols_dict_file=phone_symbol_file
52
+ )
53
+
54
+ text = f"{text_prompt} {text}".strip()
55
+ phone_seq = phone_extractor.extract_phone(text) # phone_seq: list
56
+ phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq)
57
+ phone_id_seq_len = torch.IntTensor([len(phone_id_seq)]).to(self.device)
58
+
59
+ # convert phone sequence to phone id sequence
60
+ phone_id_seq = np.array([phone_id_seq])
61
+ phone_id_seq = torch.from_numpy(phone_id_seq).to(self.device)
62
+
63
+ # extract acoustic token
64
+ encoded_frames = tokenize_audio(self.audio_tokenizer, audio_file)
65
+ audio_prompt_token = encoded_frames[0][0].transpose(2, 1).to(self.device)
66
+
67
+ # copysyn
68
+ if self.args.copysyn:
69
+ samples = self.audio_tokenizer.decode(encoded_frames)
70
+ audio_copysyn = samples[0].cpu().detach()
71
+
72
+ out_path = os.path.join(
73
+ self.args.output_dir, self.infer_type, f"{save_name}_copysyn.wav"
74
+ )
75
+ torchaudio.save(out_path, audio_copysyn, self.cfg.preprocess.sampling_rate)
76
+
77
+ if self.args.continual:
78
+ encoded_frames = self.model.continual(
79
+ phone_id_seq,
80
+ phone_id_seq_len,
81
+ audio_prompt_token,
82
+ )
83
+ else:
84
+ enroll_x_lens = None
85
+ if text_prompt:
86
+ # prompt_phone_seq = tokenize_text(self.g2p_module, text=f"{text_prompt}".strip())
87
+ # _, enroll_x_lens = self.text_tokenizer.get_token_id_seq(prompt_phone_seq)
88
+
89
+ text = f"{text_prompt}".strip()
90
+ prompt_phone_seq = phone_extractor.extract_phone(
91
+ text
92
+ ) # phone_seq: list
93
+ prompt_phone_id_seq = phon_id_collator.get_phone_id_sequence(
94
+ self.cfg, prompt_phone_seq
95
+ )
96
+ prompt_phone_id_seq_len = torch.IntTensor(
97
+ [len(prompt_phone_id_seq)]
98
+ ).to(self.device)
99
+
100
+ encoded_frames = self.model.inference(
101
+ phone_id_seq,
102
+ phone_id_seq_len,
103
+ audio_prompt_token,
104
+ enroll_x_lens=prompt_phone_id_seq_len,
105
+ top_k=self.args.top_k,
106
+ temperature=self.args.temperature,
107
+ )
108
+
109
+ samples = self.audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)])
110
+
111
+ audio = samples[0].squeeze(0).cpu().detach()
112
+
113
+ return audio
114
+
115
+ def inference_for_single_utterance(self):
116
+ text = self.args.text
117
+ text_prompt = self.args.text_prompt
118
+ audio_file = self.args.audio_prompt
119
+
120
+ if not self.args.continual:
121
+ assert text != ""
122
+ else:
123
+ text = ""
124
+ assert text_prompt != ""
125
+ assert audio_file != ""
126
+
127
+ audio = self.inference_one_clip(text, text_prompt, audio_file)
128
+
129
+ return audio
130
+
131
+ def inference_for_batches(self):
132
+ test_list_file = self.args.test_list_file
133
+ assert test_list_file is not None
134
+
135
+ pred_res = []
136
+ with open(test_list_file, "r") as fin:
137
+ for idx, line in enumerate(fin.readlines()):
138
+ fields = line.strip().split("|")
139
+ if self.args.continual:
140
+ assert len(fields) == 2
141
+ text_prompt, audio_prompt_path = fields
142
+ text = ""
143
+ else:
144
+ assert len(fields) == 3
145
+ text_prompt, audio_prompt_path, text = fields
146
+
147
+ audio = self.inference_one_clip(
148
+ text, text_prompt, audio_prompt_path, str(idx)
149
+ )
150
+ pred_res.append(audio)
151
+
152
+ return pred_res
153
+
154
+ """
155
+ TODO: batch inference
156
+ ###### Construct test_batch ######
157
+ n_batch = len(self.test_dataloader)
158
+ now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
159
+ print(
160
+ "Model eval time: {}, batch_size = {}, n_batch = {}".format(
161
+ now, self.test_batch_size, n_batch
162
+ )
163
+ )
164
+
165
+ ###### Inference for each batch ######
166
+ pred_res = []
167
+ with torch.no_grad():
168
+ for i, batch_data in enumerate(
169
+ self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader)
170
+ ):
171
+ if self.args.continual:
172
+ encoded_frames = self.model.continual(
173
+ batch_data["phone_seq"],
174
+ batch_data["phone_len"],
175
+ batch_data["acoustic_token"],
176
+ )
177
+ else:
178
+ encoded_frames = self.model.inference(
179
+ batch_data["phone_seq"],
180
+ batch_data["phone_len"],
181
+ batch_data["acoustic_token"],
182
+ enroll_x_lens=batch_data["pmt_phone_len"],
183
+ top_k=self.args.top_k,
184
+ temperature=self.args.temperature
185
+ )
186
+
187
+ samples = self.audio_tokenizer.decode(
188
+ [(encoded_frames.transpose(2, 1), None)]
189
+ )
190
+
191
+
192
+ for idx in range(samples.size(0)):
193
+ audio = samples[idx].cpu()
194
+ pred_res.append(audio)
195
+
196
+ return pred_res
197
+ """
198
+
199
+ def add_arguments(parser: argparse.ArgumentParser):
200
+ parser.add_argument(
201
+ "--text_prompt",
202
+ type=str,
203
+ default="",
204
+ help="Text prompt that should be aligned with --audio_prompt.",
205
+ )
206
+
207
+ parser.add_argument(
208
+ "--audio_prompt",
209
+ type=str,
210
+ default="",
211
+ help="Audio prompt that should be aligned with --text_prompt.",
212
+ )
213
+ parser.add_argument(
214
+ "--top-k",
215
+ type=int,
216
+ default=-100,
217
+ help="Whether AR Decoder do top_k(if > 0) sampling.",
218
+ )
219
+
220
+ parser.add_argument(
221
+ "--temperature",
222
+ type=float,
223
+ default=1.0,
224
+ help="The temperature of AR Decoder top_k sampling.",
225
+ )
226
+
227
+ parser.add_argument(
228
+ "--continual",
229
+ action="store_true",
230
+ help="Inference for continual task.",
231
+ )
232
+
233
+ parser.add_argument(
234
+ "--copysyn",
235
+ action="store_true",
236
+ help="Copysyn: generate audio by decoder of the original audio tokenizer.",
237
+ )
Amphion/models/tts/valle/valle_trainer.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ from tqdm import tqdm
8
+ import torch
9
+ import numpy as np
10
+ from torch.utils.data import DataLoader
11
+ from torch.nn.parallel import DistributedDataParallel
12
+ from optimizer.optimizers import Eve, ScaledAdam
13
+ from schedulers.scheduler import NoamScheduler, Eden
14
+ from models.tts.valle.valle_dataset import (
15
+ VALLEDataset,
16
+ VALLECollator,
17
+ batch_by_size,
18
+ )
19
+ from models.base.base_sampler import VariableSampler
20
+ from models.tts.base import TTSTrainer
21
+ from models.tts.valle.valle import VALLE
22
+ import diffusers
23
+
24
+
25
+ class VALLETrainer(TTSTrainer):
26
+ def __init__(self, args, cfg):
27
+ TTSTrainer.__init__(self, args, cfg)
28
+
29
+ def _build_model(self):
30
+ model = VALLE(self.cfg.model)
31
+
32
+ return model
33
+
34
+ def _build_dataset(self):
35
+ return VALLEDataset, VALLECollator
36
+
37
+ def _build_optimizer(self):
38
+ if self.args.train_stage:
39
+ if isinstance(self.model, DistributedDataParallel):
40
+ model = self.model.module
41
+ else:
42
+ model = self.model
43
+ model_parameters = model.stage_parameters(self.args.train_stage)
44
+ else:
45
+ model_parameters = self.model.parameters()
46
+
47
+ if self.cfg.train.optimizer == "ScaledAdam":
48
+ parameters_names = []
49
+ if self.args.train_stage != 0:
50
+ parameters_names.append(
51
+ [
52
+ name_param_pair[0]
53
+ for name_param_pair in model.stage_named_parameters(
54
+ self.args.train_stage
55
+ )
56
+ ]
57
+ )
58
+ else:
59
+ parameters_names.append(
60
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
61
+ )
62
+
63
+ optimizer = ScaledAdam(
64
+ model_parameters,
65
+ lr=self.cfg.train.base_lr,
66
+ betas=(0.9, 0.95),
67
+ clipping_scale=2.0,
68
+ parameters_names=parameters_names,
69
+ show_dominant_parameters=False,
70
+ clipping_update_period=1000,
71
+ )
72
+ elif self.cfg.train.optimizer == "Eve":
73
+ optimizer = Eve(
74
+ model_parameters,
75
+ lr=self.cfg.train.base_lr,
76
+ betas=(0.9, 0.98),
77
+ target_rms=0.1,
78
+ )
79
+ elif self.cfg.train.optimizer == "AdamW":
80
+ optimizer = torch.optim.AdamW(
81
+ model_parameters,
82
+ lr=self.cfg.train.base_lr,
83
+ betas=(0.9, 0.95),
84
+ weight_decay=1e-2,
85
+ eps=1e-8,
86
+ )
87
+ elif self.cfg.train.optimizer == "Adam":
88
+ optimizer = torch.optim.Adam(
89
+ model_parameters,
90
+ lr=self.cfg.train.base_lr,
91
+ betas=(0.9, 0.95),
92
+ eps=1e-8,
93
+ )
94
+ else:
95
+ raise NotImplementedError()
96
+
97
+ return optimizer
98
+
99
+ def _build_scheduler(self):
100
+ if self.cfg.train.scheduler.lower() == "eden":
101
+ scheduler = Eden(
102
+ self.optimizer, 5000, 4, warmup_batches=self.cfg.train.warmup_steps
103
+ )
104
+ elif self.cfg.train.scheduler.lower() == "noam":
105
+ scheduler = NoamScheduler(
106
+ self.cfg.train.base_lr,
107
+ self.optimizer,
108
+ self.cfg.model.decoder_dim,
109
+ warmup_steps=self.cfg.train.warmup_steps,
110
+ )
111
+ elif self.cfg.train.scheduler.lower() == "cosine":
112
+ from diffusers.optimization import get_cosine_schedule_with_warmup
113
+
114
+ scheduler = get_cosine_schedule_with_warmup(
115
+ self.optimizer,
116
+ num_warmup_steps=self.cfg.train.warmup_steps
117
+ * self.accelerator.num_processes,
118
+ num_training_steps=self.cfg.train.total_training_steps
119
+ * self.accelerator.num_processes,
120
+ )
121
+ else:
122
+ raise NotImplementedError(f"{self.cfg.train.scheduler}")
123
+
124
+ return scheduler
125
+
126
+ def _train_epoch(self):
127
+ r"""Training epoch. Should return average loss of a batch (sample) over
128
+ one epoch. See ``train_loop`` for usage.
129
+ """
130
+ if isinstance(self.model, dict):
131
+ for key in self.model.keys():
132
+ self.model[key].train()
133
+ else:
134
+ self.model.train()
135
+
136
+ epoch_sum_loss: float = 0.0
137
+ epoch_losses: dict = {}
138
+ epoch_step: int = 0
139
+ for batch in tqdm(
140
+ self.train_dataloader,
141
+ desc=f"Training Epoch {self.epoch}",
142
+ unit="batch",
143
+ colour="GREEN",
144
+ leave=False,
145
+ dynamic_ncols=True,
146
+ smoothing=0.04,
147
+ disable=not self.accelerator.is_main_process,
148
+ ):
149
+ # Do training step and BP
150
+ with self.accelerator.accumulate(self.model):
151
+ total_loss, train_losses = self._train_step(batch)
152
+ self.accelerator.backward(total_loss)
153
+ self.optimizer.step()
154
+ self.optimizer.zero_grad()
155
+ self.batch_count += 1
156
+
157
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
158
+ if self.cfg.train.optimizer not in ["ScaledAdam", "Eve"]:
159
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
160
+
161
+ for k in range(self.cfg.train.gradient_accumulation_step):
162
+ if isinstance(self.scheduler, Eden):
163
+ self.scheduler.step_batch(self.step)
164
+ else:
165
+ self.scheduler.step()
166
+
167
+ epoch_sum_loss += total_loss.detach().cpu().item()
168
+
169
+ if isinstance(train_losses, dict):
170
+ for key, value in train_losses.items():
171
+ if key not in epoch_losses.keys():
172
+ epoch_losses[key] = value
173
+ else:
174
+ epoch_losses[key] += value
175
+
176
+ if isinstance(train_losses, dict):
177
+ for key, loss in train_losses.items():
178
+ self.accelerator.log(
179
+ {"Step/Train {}".format(key): "{:.6f}".format(loss)},
180
+ step=self.step,
181
+ )
182
+ else:
183
+ self.accelerator.log(
184
+ {"Step/Train Loss": loss},
185
+ step=self.step,
186
+ )
187
+
188
+ self.accelerator.log(
189
+ {"Step/lr": self.scheduler.get_last_lr()[0]},
190
+ step=self.step,
191
+ )
192
+
193
+ # print loss every log_epoch_step steps
194
+ # if epoch_step % self.cfg.train.log_epoch_step == 0:
195
+ # for key, loss in train_losses.items():
196
+ # self.logger.info("Step/Train {}: {:.6f}".format(key, loss))
197
+ # print("Step/Train {}: {:.6f}".format(key, loss))
198
+
199
+ self.step += 1
200
+ epoch_step += 1
201
+
202
+ self.accelerator.wait_for_everyone()
203
+
204
+ epoch_sum_loss = (
205
+ epoch_sum_loss
206
+ / len(self.train_dataloader)
207
+ * self.cfg.train.gradient_accumulation_step
208
+ )
209
+
210
+ for key in epoch_losses.keys():
211
+ epoch_losses[key] = (
212
+ epoch_losses[key]
213
+ / len(self.train_dataloader)
214
+ * self.cfg.train.gradient_accumulation_step
215
+ )
216
+
217
+ return epoch_sum_loss, epoch_losses
218
+
219
+ def _train_step(self, batch, is_training=True):
220
+ text_tokens = batch["phone_seq"].to(self.device)
221
+ text_tokens_lens = batch["phone_len"].to(self.device)
222
+ assert text_tokens.ndim == 2
223
+
224
+ audio_features = batch["acoustic_token"].to(self.device)
225
+ audio_features_lens = batch["target_len"].to(self.device)
226
+ assert audio_features.ndim == 3
227
+
228
+ with torch.set_grad_enabled(is_training):
229
+ loss, losses = self.model(
230
+ x=text_tokens,
231
+ x_lens=text_tokens_lens,
232
+ y=audio_features,
233
+ y_lens=audio_features_lens,
234
+ train_stage=self.args.train_stage,
235
+ )
236
+
237
+ assert loss.requires_grad == is_training
238
+
239
+ loss_dict = {}
240
+ frames_sum = (audio_features_lens).sum()
241
+
242
+ avg_loss = loss / frames_sum
243
+
244
+ loss_dict["loss"] = avg_loss.detach().cpu().item()
245
+ for l in losses:
246
+ loss_dict[l] = losses[l].detach().cpu().item() / frames_sum.item()
247
+
248
+ return avg_loss, loss_dict
249
+
250
+ def _valid_step(self, batch):
251
+ valid_losses = {}
252
+ total_loss = 0
253
+ valid_stats = {}
254
+
255
+ total_loss, valid_losses = self._train_step(
256
+ batch=batch,
257
+ is_training=False,
258
+ )
259
+ assert total_loss.requires_grad is False
260
+
261
+ total_loss = total_loss.detach().cpu().item()
262
+
263
+ return total_loss, valid_losses, valid_stats
264
+
265
+ def _build_dataloader(self):
266
+ if not self.cfg.train.use_dynamic_batchsize:
267
+ return super()._build_dataloader()
268
+ if len(self.cfg.dataset) > 1:
269
+ raise Exception("use_dynamic_batchsize only supports single dataset now.")
270
+ Dataset, Collator = self._build_dataset()
271
+ train_dataset = Dataset(
272
+ self.cfg, self.cfg.dataset[0], is_valid=False
273
+ ) # TODO: support use_dynamic_batchsize for more than one datasets.
274
+ train_collate = Collator(self.cfg)
275
+ batch_sampler = batch_by_size(
276
+ train_dataset.num_frame_indices,
277
+ train_dataset.get_num_frames,
278
+ max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
279
+ max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes,
280
+ required_batch_size_multiple=self.accelerator.num_processes,
281
+ )
282
+ np.random.seed(1234)
283
+ np.random.shuffle(batch_sampler)
284
+ print(batch_sampler[:1])
285
+ batches = [
286
+ x[self.accelerator.local_process_index :: self.accelerator.num_processes]
287
+ for x in batch_sampler
288
+ if len(x) % self.accelerator.num_processes == 0
289
+ ]
290
+
291
+ train_loader = DataLoader(
292
+ train_dataset,
293
+ collate_fn=train_collate,
294
+ num_workers=self.cfg.train.dataloader.num_worker,
295
+ batch_sampler=VariableSampler(
296
+ batches, drop_last=False, use_random_sampler=True
297
+ ),
298
+ pin_memory=False,
299
+ )
300
+ self.accelerator.wait_for_everyone()
301
+
302
+ valid_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=True)
303
+ valid_collate = Collator(self.cfg)
304
+ batch_sampler = batch_by_size(
305
+ valid_dataset.num_frame_indices,
306
+ valid_dataset.get_num_frames,
307
+ max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
308
+ max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes,
309
+ required_batch_size_multiple=self.accelerator.num_processes,
310
+ )
311
+ batches = [
312
+ x[self.accelerator.local_process_index :: self.accelerator.num_processes]
313
+ for x in batch_sampler
314
+ if len(x) % self.accelerator.num_processes == 0
315
+ ]
316
+ valid_loader = DataLoader(
317
+ valid_dataset,
318
+ collate_fn=valid_collate,
319
+ num_workers=self.cfg.train.dataloader.num_worker,
320
+ batch_sampler=VariableSampler(batches, drop_last=False),
321
+ pin_memory=False,
322
+ )
323
+ self.accelerator.wait_for_everyone()
324
+
325
+ return train_loader, valid_loader
326
+
327
+ def _accelerator_prepare(self):
328
+ if not self.cfg.train.use_dynamic_batchsize:
329
+ (
330
+ self.train_dataloader,
331
+ self.valid_dataloader,
332
+ ) = self.accelerator.prepare(
333
+ self.train_dataloader,
334
+ self.valid_dataloader,
335
+ )
336
+
337
+ if isinstance(self.model, dict):
338
+ for key in self.model.keys():
339
+ self.model[key] = self.accelerator.prepare(self.model[key])
340
+ else:
341
+ self.model = self.accelerator.prepare(self.model)
342
+
343
+ if isinstance(self.optimizer, dict):
344
+ for key in self.optimizer.keys():
345
+ self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
346
+ else:
347
+ self.optimizer = self.accelerator.prepare(self.optimizer)
348
+
349
+ if isinstance(self.scheduler, dict):
350
+ for key in self.scheduler.keys():
351
+ self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
352
+ else:
353
+ self.scheduler = self.accelerator.prepare(self.scheduler)
354
+
355
+ def add_arguments(parser: argparse.ArgumentParser):
356
+ parser.add_argument(
357
+ "--train_stage",
358
+ type=int,
359
+ default="1",
360
+ help="0: train all modules, 1: AR Decoder, 2: NAR Decoder",
361
+ )
362
+ parser.add_argument(
363
+ "--ar_model_ckpt_dir",
364
+ type=str,
365
+ default=None,
366
+ help="Checkpoint for ar model ckeckpoint in the first training stage.",
367
+ )
Amphion/models/tts/vits/__init__.py ADDED
File without changes
Amphion/models/tts/vits/vits.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is modified from https://github.com/jaywalnut310/vits/blob/main/models.py
7
+ import math
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ from utils.util import *
13
+ from modules.flow.modules import *
14
+ from modules.base.base_module import *
15
+ from modules.transformer.attentions import Encoder
16
+ from modules.duration_predictor.standard_duration_predictor import DurationPredictor
17
+ from modules.duration_predictor.stochastic_duration_predictor import (
18
+ StochasticDurationPredictor,
19
+ )
20
+ from models.vocoders.gan.generator.hifigan import HiFiGAN_vits as Generator
21
+
22
+ try:
23
+ from modules import monotonic_align
24
+ except ImportError:
25
+ print("Monotonic align not found. Please make sure you have compiled it.")
26
+
27
+
28
+ class TextEncoder(nn.Module):
29
+ def __init__(
30
+ self,
31
+ n_vocab,
32
+ out_channels,
33
+ hidden_channels,
34
+ filter_channels,
35
+ n_heads,
36
+ n_layers,
37
+ kernel_size,
38
+ p_dropout,
39
+ ):
40
+ super().__init__()
41
+ self.n_vocab = n_vocab
42
+ self.out_channels = out_channels
43
+ self.hidden_channels = hidden_channels
44
+ self.filter_channels = filter_channels
45
+ self.n_heads = n_heads
46
+ self.n_layers = n_layers
47
+ self.kernel_size = kernel_size
48
+ self.p_dropout = p_dropout
49
+
50
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
51
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
52
+
53
+ self.encoder = Encoder(
54
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
55
+ )
56
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
57
+
58
+ def forward(self, x, x_lengths):
59
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
60
+ x = torch.transpose(x, 1, -1) # [b, h, t]
61
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
62
+
63
+ x = self.encoder(x * x_mask, x_mask)
64
+ stats = self.proj(x) * x_mask
65
+
66
+ m, logs = torch.split(stats, self.out_channels, dim=1)
67
+ return x, m, logs, x_mask
68
+
69
+
70
+ class ResidualCouplingBlock(nn.Module):
71
+ def __init__(
72
+ self,
73
+ channels,
74
+ hidden_channels,
75
+ kernel_size,
76
+ dilation_rate,
77
+ n_layers,
78
+ n_flows=4,
79
+ gin_channels=0,
80
+ ):
81
+ super().__init__()
82
+ self.channels = channels
83
+ self.hidden_channels = hidden_channels
84
+ self.kernel_size = kernel_size
85
+ self.dilation_rate = dilation_rate
86
+ self.n_layers = n_layers
87
+ self.n_flows = n_flows
88
+ self.gin_channels = gin_channels
89
+
90
+ self.flows = nn.ModuleList()
91
+ for i in range(n_flows):
92
+ self.flows.append(
93
+ ResidualCouplingLayer(
94
+ channels,
95
+ hidden_channels,
96
+ kernel_size,
97
+ dilation_rate,
98
+ n_layers,
99
+ gin_channels=gin_channels,
100
+ mean_only=True,
101
+ )
102
+ )
103
+ self.flows.append(Flip())
104
+
105
+ def forward(self, x, x_mask, g=None, reverse=False):
106
+ if not reverse:
107
+ for flow in self.flows:
108
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
109
+ else:
110
+ for flow in reversed(self.flows):
111
+ x = flow(x, x_mask, g=g, reverse=reverse)
112
+ return x
113
+
114
+
115
+ class PosteriorEncoder(nn.Module):
116
+ def __init__(
117
+ self,
118
+ in_channels,
119
+ out_channels,
120
+ hidden_channels,
121
+ kernel_size,
122
+ dilation_rate,
123
+ n_layers,
124
+ gin_channels=0,
125
+ ):
126
+ super().__init__()
127
+ self.in_channels = in_channels
128
+ self.out_channels = out_channels
129
+ self.hidden_channels = hidden_channels
130
+ self.kernel_size = kernel_size
131
+ self.dilation_rate = dilation_rate
132
+ self.n_layers = n_layers
133
+ self.gin_channels = gin_channels
134
+
135
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
136
+ self.enc = WN(
137
+ hidden_channels,
138
+ kernel_size,
139
+ dilation_rate,
140
+ n_layers,
141
+ gin_channels=gin_channels,
142
+ )
143
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
144
+
145
+ def forward(self, x, x_lengths, g=None):
146
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
147
+ x = self.pre(x) * x_mask
148
+ x = self.enc(x, x_mask, g=g)
149
+ stats = self.proj(x) * x_mask
150
+ m, logs = torch.split(stats, self.out_channels, dim=1)
151
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
152
+ return z, m, logs, x_mask
153
+
154
+
155
+ class SynthesizerTrn(nn.Module):
156
+ """
157
+ Synthesizer for Training
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ n_vocab,
163
+ spec_channels,
164
+ segment_size,
165
+ inter_channels,
166
+ hidden_channels,
167
+ filter_channels,
168
+ n_heads,
169
+ n_layers,
170
+ kernel_size,
171
+ p_dropout,
172
+ resblock,
173
+ resblock_kernel_sizes,
174
+ resblock_dilation_sizes,
175
+ upsample_rates,
176
+ upsample_initial_channel,
177
+ upsample_kernel_sizes,
178
+ n_speakers=0,
179
+ gin_channels=0,
180
+ use_sdp=True,
181
+ **kwargs,
182
+ ):
183
+ super().__init__()
184
+ self.n_vocab = n_vocab
185
+ self.spec_channels = spec_channels
186
+ self.inter_channels = inter_channels
187
+ self.hidden_channels = hidden_channels
188
+ self.filter_channels = filter_channels
189
+ self.n_heads = n_heads
190
+ self.n_layers = n_layers
191
+ self.kernel_size = kernel_size
192
+ self.p_dropout = p_dropout
193
+ self.resblock = resblock
194
+ self.resblock_kernel_sizes = resblock_kernel_sizes
195
+ self.resblock_dilation_sizes = resblock_dilation_sizes
196
+ self.upsample_rates = upsample_rates
197
+ self.upsample_initial_channel = upsample_initial_channel
198
+ self.upsample_kernel_sizes = upsample_kernel_sizes
199
+ self.segment_size = segment_size
200
+ self.n_speakers = n_speakers
201
+ self.gin_channels = gin_channels
202
+
203
+ self.use_sdp = use_sdp
204
+
205
+ self.enc_p = TextEncoder(
206
+ n_vocab,
207
+ inter_channels,
208
+ hidden_channels,
209
+ filter_channels,
210
+ n_heads,
211
+ n_layers,
212
+ kernel_size,
213
+ p_dropout,
214
+ )
215
+ self.dec = Generator(
216
+ inter_channels,
217
+ resblock,
218
+ resblock_kernel_sizes,
219
+ resblock_dilation_sizes,
220
+ upsample_rates,
221
+ upsample_initial_channel,
222
+ upsample_kernel_sizes,
223
+ gin_channels=gin_channels,
224
+ )
225
+ self.enc_q = PosteriorEncoder(
226
+ spec_channels,
227
+ inter_channels,
228
+ hidden_channels,
229
+ 5,
230
+ 1,
231
+ 16,
232
+ gin_channels=gin_channels,
233
+ )
234
+ self.flow = ResidualCouplingBlock(
235
+ inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
236
+ )
237
+
238
+ if use_sdp:
239
+ self.dp = StochasticDurationPredictor(
240
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
241
+ )
242
+ else:
243
+ self.dp = DurationPredictor(
244
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
245
+ )
246
+
247
+ if n_speakers >= 1:
248
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
249
+
250
+ def forward(self, data):
251
+ x = data["phone_seq"]
252
+ x_lengths = data["phone_len"]
253
+ y = data["linear"]
254
+ y_lengths = data["target_len"]
255
+
256
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
257
+ if self.n_speakers > 0:
258
+ g = self.emb_g(data["spk_id"].squeeze(-1)).unsqueeze(-1) # [b, h, 1]
259
+ else:
260
+ g = None
261
+
262
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
263
+ z_p = self.flow(z, y_mask, g=g)
264
+
265
+ with torch.no_grad():
266
+ # negative cross-entropy
267
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
268
+ neg_cent1 = torch.sum(
269
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
270
+ ) # [b, 1, t_s]
271
+ neg_cent2 = torch.matmul(
272
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
273
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
274
+ neg_cent3 = torch.matmul(
275
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
276
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
277
+ neg_cent4 = torch.sum(
278
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
279
+ ) # [b, 1, t_s]
280
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
281
+
282
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
283
+ attn = (
284
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
285
+ .unsqueeze(1)
286
+ .detach()
287
+ )
288
+
289
+ w = attn.sum(2)
290
+ if self.use_sdp:
291
+ l_length = self.dp(x, x_mask, w, g=g)
292
+ l_length = l_length / torch.sum(x_mask)
293
+ else:
294
+ logw_ = torch.log(w + 1e-6) * x_mask
295
+ logw = self.dp(x, x_mask, g=g)
296
+ l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
297
+
298
+ # expand prior
299
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
300
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
301
+
302
+ z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
303
+ o = self.dec(z_slice, g=g)
304
+ outputs = {
305
+ "y_hat": o,
306
+ "l_length": l_length,
307
+ "attn": attn,
308
+ "ids_slice": ids_slice,
309
+ "x_mask": x_mask,
310
+ "z_mask": y_mask,
311
+ "z": z,
312
+ "z_p": z_p,
313
+ "m_p": m_p,
314
+ "logs_p": logs_p,
315
+ "m_q": m_q,
316
+ "logs_q": logs_q,
317
+ }
318
+ return outputs
319
+
320
+ def infer(
321
+ self,
322
+ x,
323
+ x_lengths,
324
+ sid=None,
325
+ noise_scale=1,
326
+ length_scale=1,
327
+ noise_scale_w=1.0,
328
+ max_len=None,
329
+ ):
330
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
331
+ if self.n_speakers > 0:
332
+ sid = sid.squeeze(-1)
333
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
334
+ else:
335
+ g = None
336
+
337
+ if self.use_sdp:
338
+ logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
339
+ else:
340
+ logw = self.dp(x, x_mask, g=g)
341
+ w = torch.exp(logw) * x_mask * length_scale
342
+ w_ceil = torch.ceil(w)
343
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
344
+ y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
345
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
346
+ attn = generate_path(w_ceil, attn_mask)
347
+
348
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
349
+ 1, 2
350
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
351
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
352
+ 1, 2
353
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
354
+
355
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
356
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
357
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
358
+
359
+ outputs = {
360
+ "y_hat": o,
361
+ "attn": attn,
362
+ "mask": y_mask,
363
+ "z": z,
364
+ "z_p": z_p,
365
+ "m_p": m_p,
366
+ "logs_p": logs_p,
367
+ }
368
+
369
+ return outputs
370
+
371
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
372
+ assert self.n_speakers > 0, "n_speakers have to be larger than 0."
373
+ g_src = self.emb_g(sid_src).unsqueeze(-1)
374
+ g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
375
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
376
+ z_p = self.flow(z, y_mask, g=g_src)
377
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
378
+ o_hat = self.dec(z_hat * y_mask, g=g_tgt)
379
+ return o_hat, y_mask, (z, z_p, z_hat)
Amphion/models/tts/vits/vits_dataset.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import json
8
+ import numpy as np
9
+ from text import text_to_sequence
10
+ from text.text_token_collation import phoneIDCollation
11
+ from models.tts.base.tts_dataset import (
12
+ TTSDataset,
13
+ TTSCollator,
14
+ TTSTestDataset,
15
+ TTSTestCollator,
16
+ )
17
+
18
+
19
+ class VITSDataset(TTSDataset):
20
+ def __init__(self, cfg, dataset, is_valid):
21
+ super().__init__(cfg, dataset, is_valid=is_valid)
22
+
23
+ def __getitem__(self, index):
24
+ single_feature = super().__getitem__(index)
25
+ return single_feature
26
+
27
+ def __len__(self):
28
+ return super().__len__()
29
+
30
+ def get_metadata(self):
31
+ metadata_filter = []
32
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
33
+ metadata = json.load(f)
34
+ for utt_info in metadata:
35
+ duration = utt_info["Duration"]
36
+ frame_len = (
37
+ duration
38
+ * self.cfg.preprocess.sample_rate
39
+ // self.cfg.preprocess.hop_size
40
+ )
41
+ if (
42
+ frame_len
43
+ < self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size
44
+ ):
45
+ continue
46
+ metadata_filter.append(utt_info)
47
+
48
+ return metadata_filter
49
+
50
+
51
+ class VITSCollator(TTSCollator):
52
+ """Zero-pads model inputs and targets based on number of frames per step"""
53
+
54
+ def __init__(self, cfg):
55
+ super().__init__(cfg)
56
+
57
+ def __call__(self, batch):
58
+ parsed_batch_features = super().__call__(batch)
59
+ return parsed_batch_features
60
+
61
+
62
+ class VITSTestDataset(TTSTestDataset):
63
+ def __init__(self, args, cfg):
64
+ super().__init__(args, cfg)
65
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, args.dataset)
66
+ if cfg.preprocess.use_spkid:
67
+ spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
68
+ with open(spk2id_path, "r") as f:
69
+ self.spk2id = json.load(f)
70
+
71
+ utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
72
+ self.utt2spk = dict()
73
+ with open(utt2spk_path, "r") as f:
74
+ for line in f.readlines():
75
+ utt, spk = line.strip().split("\t")
76
+ self.utt2spk[utt] = spk
77
+
78
+ if cfg.preprocess.use_text or cfg.preprocess.use_phone:
79
+ self.utt2seq = {}
80
+ for utt_info in self.metadata:
81
+ dataset = utt_info["Dataset"]
82
+ uid = utt_info["Uid"]
83
+ utt = "{}_{}".format(dataset, uid)
84
+
85
+ if cfg.preprocess.use_text:
86
+ text = utt_info["Text"]
87
+ sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
88
+ elif cfg.preprocess.use_phone:
89
+ # load phoneme squence from phone file
90
+ phone_path = os.path.join(
91
+ processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
92
+ )
93
+ with open(phone_path, "r") as fin:
94
+ phones = fin.readlines()
95
+ assert len(phones) == 1
96
+ phones = phones[0].strip()
97
+ phones_seq = phones.split(" ")
98
+
99
+ phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
100
+ sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
101
+
102
+ self.utt2seq[utt] = sequence
103
+
104
+ def __getitem__(self, index):
105
+ utt_info = self.metadata[index]
106
+
107
+ dataset = utt_info["Dataset"]
108
+ uid = utt_info["Uid"]
109
+ utt = "{}_{}".format(dataset, uid)
110
+
111
+ single_feature = dict()
112
+
113
+ if self.cfg.preprocess.use_spkid:
114
+ single_feature["spk_id"] = np.array(
115
+ [self.spk2id[self.utt2spk[utt]]], dtype=np.int32
116
+ )
117
+
118
+ if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
119
+ single_feature["phone_seq"] = np.array(self.utt2seq[utt])
120
+ single_feature["phone_len"] = len(self.utt2seq[utt])
121
+
122
+ return single_feature
123
+
124
+ def get_metadata(self):
125
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
126
+ metadata = json.load(f)
127
+ return metadata
128
+
129
+ def __len__(self):
130
+ return len(self.metadata)
131
+
132
+
133
+ class VITSTestCollator(TTSTestCollator):
134
+ """Zero-pads model inputs and targets based on number of frames per step"""
135
+
136
+ def __init__(self, cfg):
137
+ self.cfg = cfg
138
+
139
+ def __call__(self, batch):
140
+ return super().__call__(batch)
Amphion/models/tts/vits/vits_trainer.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.optim.lr_scheduler import ExponentialLR
9
+
10
+ from tqdm import tqdm
11
+
12
+ from utils.util import *
13
+ from utils.mel import mel_spectrogram_torch
14
+ from models.tts.base import TTSTrainer
15
+ from models.tts.vits.vits import SynthesizerTrn
16
+ from models.tts.vits.vits_dataset import VITSDataset, VITSCollator
17
+ from models.vocoders.gan.discriminator.mpd import (
18
+ MultiPeriodDiscriminator_vits as MultiPeriodDiscriminator,
19
+ )
20
+
21
+
22
+ class VITSTrainer(TTSTrainer):
23
+ def __init__(self, args, cfg):
24
+ TTSTrainer.__init__(self, args, cfg)
25
+
26
+ if cfg.preprocess.use_spkid and cfg.train.multi_speaker_training:
27
+ if cfg.model.n_speakers == 0:
28
+ cfg.model.n_speaker = len(self.speakers)
29
+
30
+ def _build_model(self):
31
+ net_g = SynthesizerTrn(
32
+ self.cfg.model.text_token_num,
33
+ self.cfg.preprocess.n_fft // 2 + 1,
34
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
35
+ **self.cfg.model,
36
+ )
37
+ net_d = MultiPeriodDiscriminator(self.cfg.model.use_spectral_norm)
38
+ model = {"generator": net_g, "discriminator": net_d}
39
+
40
+ return model
41
+
42
+ def _build_dataset(self):
43
+ return VITSDataset, VITSCollator
44
+
45
+ def _build_optimizer(self):
46
+ optimizer_g = torch.optim.AdamW(
47
+ self.model["generator"].parameters(),
48
+ self.cfg.train.learning_rate,
49
+ betas=self.cfg.train.AdamW.betas,
50
+ eps=self.cfg.train.AdamW.eps,
51
+ )
52
+ optimizer_d = torch.optim.AdamW(
53
+ self.model["discriminator"].parameters(),
54
+ self.cfg.train.learning_rate,
55
+ betas=self.cfg.train.AdamW.betas,
56
+ eps=self.cfg.train.AdamW.eps,
57
+ )
58
+ optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d}
59
+
60
+ return optimizer
61
+
62
+ def _build_scheduler(self):
63
+ scheduler_g = ExponentialLR(
64
+ self.optimizer["optimizer_g"],
65
+ gamma=self.cfg.train.lr_decay,
66
+ last_epoch=self.epoch - 1,
67
+ )
68
+ scheduler_d = ExponentialLR(
69
+ self.optimizer["optimizer_d"],
70
+ gamma=self.cfg.train.lr_decay,
71
+ last_epoch=self.epoch - 1,
72
+ )
73
+
74
+ scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d}
75
+ return scheduler
76
+
77
+ def _build_criterion(self):
78
+ class GeneratorLoss(nn.Module):
79
+ def __init__(self, cfg):
80
+ super(GeneratorLoss, self).__init__()
81
+ self.cfg = cfg
82
+ self.l1_loss = nn.L1Loss()
83
+
84
+ def generator_loss(self, disc_outputs):
85
+ loss = 0
86
+ gen_losses = []
87
+ for dg in disc_outputs:
88
+ dg = dg.float()
89
+ l = torch.mean((1 - dg) ** 2)
90
+ gen_losses.append(l)
91
+ loss += l
92
+
93
+ return loss, gen_losses
94
+
95
+ def feature_loss(self, fmap_r, fmap_g):
96
+ loss = 0
97
+ for dr, dg in zip(fmap_r, fmap_g):
98
+ for rl, gl in zip(dr, dg):
99
+ rl = rl.float().detach()
100
+ gl = gl.float()
101
+ loss += torch.mean(torch.abs(rl - gl))
102
+
103
+ return loss * 2
104
+
105
+ def kl_loss(self, z_p, logs_q, m_p, logs_p, z_mask):
106
+ """
107
+ z_p, logs_q: [b, h, t_t]
108
+ m_p, logs_p: [b, h, t_t]
109
+ """
110
+ z_p = z_p.float()
111
+ logs_q = logs_q.float()
112
+ m_p = m_p.float()
113
+ logs_p = logs_p.float()
114
+ z_mask = z_mask.float()
115
+
116
+ kl = logs_p - logs_q - 0.5
117
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
118
+ kl = torch.sum(kl * z_mask)
119
+ l = kl / torch.sum(z_mask)
120
+ return l
121
+
122
+ def forward(
123
+ self,
124
+ outputs_g,
125
+ outputs_d,
126
+ y_mel,
127
+ y_hat_mel,
128
+ ):
129
+ loss_g = {}
130
+
131
+ # duration loss
132
+ loss_dur = torch.sum(outputs_g["l_length"].float())
133
+ loss_g["loss_dur"] = loss_dur
134
+
135
+ # mel loss
136
+ loss_mel = self.l1_loss(y_mel, y_hat_mel) * self.cfg.train.c_mel
137
+ loss_g["loss_mel"] = loss_mel
138
+
139
+ # kl loss
140
+ loss_kl = (
141
+ self.kl_loss(
142
+ outputs_g["z_p"],
143
+ outputs_g["logs_q"],
144
+ outputs_g["m_p"],
145
+ outputs_g["logs_p"],
146
+ outputs_g["z_mask"],
147
+ )
148
+ * self.cfg.train.c_kl
149
+ )
150
+ loss_g["loss_kl"] = loss_kl
151
+
152
+ # feature loss
153
+ loss_fm = self.feature_loss(outputs_d["fmap_rs"], outputs_d["fmap_gs"])
154
+ loss_g["loss_fm"] = loss_fm
155
+
156
+ # gan loss
157
+ loss_gen, losses_gen = self.generator_loss(outputs_d["y_d_hat_g"])
158
+ loss_g["loss_gen"] = loss_gen
159
+ loss_g["loss_gen_all"] = (
160
+ loss_dur + loss_mel + loss_kl + loss_fm + loss_gen
161
+ )
162
+
163
+ return loss_g
164
+
165
+ class DiscriminatorLoss(nn.Module):
166
+ def __init__(self, cfg):
167
+ super(DiscriminatorLoss, self).__init__()
168
+ self.cfg = cfg
169
+ self.l1Loss = torch.nn.L1Loss(reduction="mean")
170
+
171
+ def __call__(self, disc_real_outputs, disc_generated_outputs):
172
+ loss_d = {}
173
+
174
+ loss = 0
175
+ r_losses = []
176
+ g_losses = []
177
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
178
+ dr = dr.float()
179
+ dg = dg.float()
180
+ r_loss = torch.mean((1 - dr) ** 2)
181
+ g_loss = torch.mean(dg**2)
182
+ loss += r_loss + g_loss
183
+ r_losses.append(r_loss.item())
184
+ g_losses.append(g_loss.item())
185
+
186
+ loss_d["loss_disc_all"] = loss
187
+
188
+ return loss_d
189
+
190
+ criterion = {
191
+ "generator": GeneratorLoss(self.cfg),
192
+ "discriminator": DiscriminatorLoss(self.cfg),
193
+ }
194
+ return criterion
195
+
196
+ def write_summary(
197
+ self,
198
+ losses,
199
+ stats,
200
+ images={},
201
+ audios={},
202
+ audio_sampling_rate=24000,
203
+ tag="train",
204
+ ):
205
+ for key, value in losses.items():
206
+ self.sw.add_scalar(tag + "/" + key, value, self.step)
207
+ self.sw.add_scalar(
208
+ "learning_rate",
209
+ self.optimizer["optimizer_g"].param_groups[0]["lr"],
210
+ self.step,
211
+ )
212
+
213
+ if len(images) != 0:
214
+ for key, value in images.items():
215
+ self.sw.add_image(key, value, self.global_step, batchformats="HWC")
216
+ if len(audios) != 0:
217
+ for key, value in audios.items():
218
+ self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
219
+
220
+ def write_valid_summary(
221
+ self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val"
222
+ ):
223
+ for key, value in losses.items():
224
+ self.sw.add_scalar(tag + "/" + key, value, self.step)
225
+
226
+ if len(images) != 0:
227
+ for key, value in images.items():
228
+ self.sw.add_image(key, value, self.global_step, batchformats="HWC")
229
+ if len(audios) != 0:
230
+ for key, value in audios.items():
231
+ self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
232
+
233
+ def get_state_dict(self):
234
+ state_dict = {
235
+ "generator": self.model["generator"].state_dict(),
236
+ "discriminator": self.model["discriminator"].state_dict(),
237
+ "optimizer_g": self.optimizer["optimizer_g"].state_dict(),
238
+ "optimizer_d": self.optimizer["optimizer_d"].state_dict(),
239
+ "scheduler_g": self.scheduler["scheduler_g"].state_dict(),
240
+ "scheduler_d": self.scheduler["scheduler_d"].state_dict(),
241
+ "step": self.step,
242
+ "epoch": self.epoch,
243
+ "batch_size": self.cfg.train.batch_size,
244
+ }
245
+ return state_dict
246
+
247
+ def load_model(self, checkpoint):
248
+ self.step = checkpoint["step"]
249
+ self.epoch = checkpoint["epoch"]
250
+ self.model["generator"].load_state_dict(checkpoint["generator"])
251
+ self.model["discriminator"].load_state_dict(checkpoint["discriminator"])
252
+ self.optimizer["optimizer_g"].load_state_dict(checkpoint["optimizer_g"])
253
+ self.optimizer["optimizer_d"].load_state_dict(checkpoint["optimizer_d"])
254
+ self.scheduler["scheduler_g"].load_state_dict(checkpoint["scheduler_g"])
255
+ self.scheduler["scheduler_d"].load_state_dict(checkpoint["scheduler_d"])
256
+
257
+ @torch.inference_mode()
258
+ def _valid_step(self, batch):
259
+ r"""Testing forward step. Should return average loss of a sample over
260
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
261
+ See ``_test_epoch`` for usage.
262
+ """
263
+
264
+ valid_losses = {}
265
+ total_loss = 0
266
+ valid_stats = {}
267
+
268
+ batch["linear"] = batch["linear"].transpose(2, 1) # [b, d, t]
269
+ batch["mel"] = batch["mel"].transpose(2, 1) # [b, d, t]
270
+ batch["audio"] = batch["audio"].unsqueeze(1) # [b, d, t]
271
+
272
+ # Discriminator
273
+ # Generator output
274
+ outputs_g = self.model["generator"](batch)
275
+
276
+ y_mel = slice_segments(
277
+ batch["mel"],
278
+ outputs_g["ids_slice"],
279
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
280
+ )
281
+ y_hat_mel = mel_spectrogram_torch(
282
+ outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
283
+ )
284
+ y = slice_segments(
285
+ batch["audio"],
286
+ outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
287
+ self.cfg.preprocess.segment_size,
288
+ )
289
+
290
+ # Discriminator output
291
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
292
+ ## Discriminator loss
293
+ loss_d = self.criterion["discriminator"](
294
+ outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
295
+ )
296
+ valid_losses.update(loss_d)
297
+
298
+ ## Generator
299
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
300
+ loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
301
+ valid_losses.update(loss_g)
302
+
303
+ for item in valid_losses:
304
+ valid_losses[item] = valid_losses[item].item()
305
+
306
+ total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
307
+
308
+ return (
309
+ total_loss.item(),
310
+ valid_losses,
311
+ valid_stats,
312
+ )
313
+
314
+ def _train_step(self, batch):
315
+ r"""Forward step for training and inference. This function is called
316
+ in ``_train_step`` & ``_test_step`` function.
317
+ """
318
+
319
+ train_losses = {}
320
+ total_loss = 0
321
+ training_stats = {}
322
+
323
+ batch["linear"] = batch["linear"].transpose(2, 1) # [b, d, t]
324
+ batch["mel"] = batch["mel"].transpose(2, 1) # [b, d, t]
325
+ batch["audio"] = batch["audio"].unsqueeze(1) # [b, d, t]
326
+
327
+ # Train Discriminator
328
+ # Generator output
329
+ outputs_g = self.model["generator"](batch)
330
+
331
+ y_mel = slice_segments(
332
+ batch["mel"],
333
+ outputs_g["ids_slice"],
334
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
335
+ )
336
+ y_hat_mel = mel_spectrogram_torch(
337
+ outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
338
+ )
339
+ y = slice_segments(
340
+ batch["audio"],
341
+ outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
342
+ self.cfg.preprocess.segment_size,
343
+ )
344
+
345
+ # Discriminator output
346
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
347
+ ## Discriminator loss
348
+ loss_d = self.criterion["discriminator"](
349
+ outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
350
+ )
351
+ train_losses.update(loss_d)
352
+
353
+ # BP and Grad Updated
354
+ self.optimizer["optimizer_d"].zero_grad()
355
+ self.accelerator.backward(loss_d["loss_disc_all"])
356
+ self.optimizer["optimizer_d"].step()
357
+
358
+ ## Train Generator
359
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
360
+ loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
361
+ train_losses.update(loss_g)
362
+
363
+ # BP and Grad Updated
364
+ self.optimizer["optimizer_g"].zero_grad()
365
+ self.accelerator.backward(loss_g["loss_gen_all"])
366
+ self.optimizer["optimizer_g"].step()
367
+
368
+ for item in train_losses:
369
+ train_losses[item] = train_losses[item].item()
370
+
371
+ total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
372
+
373
+ return (
374
+ total_loss.item(),
375
+ train_losses,
376
+ training_stats,
377
+ )
378
+
379
+ def _train_epoch(self):
380
+ r"""Training epoch. Should return average loss of a batch (sample) over
381
+ one epoch. See ``train_loop`` for usage.
382
+ """
383
+ epoch_sum_loss: float = 0.0
384
+ epoch_losses: dict = {}
385
+ epoch_step: int = 0
386
+ for batch in tqdm(
387
+ self.train_dataloader,
388
+ desc=f"Training Epoch {self.epoch}",
389
+ unit="batch",
390
+ colour="GREEN",
391
+ leave=False,
392
+ dynamic_ncols=True,
393
+ smoothing=0.04,
394
+ disable=not self.accelerator.is_main_process,
395
+ ):
396
+ with self.accelerator.accumulate(self.model):
397
+ total_loss, train_losses, training_stats = self._train_step(batch)
398
+ self.batch_count += 1
399
+
400
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
401
+ epoch_sum_loss += total_loss
402
+ for key, value in train_losses.items():
403
+ if key not in epoch_losses.keys():
404
+ epoch_losses[key] = value
405
+ else:
406
+ epoch_losses[key] += value
407
+
408
+ self.accelerator.log(
409
+ {
410
+ "Step/Generator Loss": train_losses["loss_gen_all"],
411
+ "Step/Discriminator Loss": train_losses["loss_disc_all"],
412
+ "Step/Generator Learning Rate": self.optimizer[
413
+ "optimizer_d"
414
+ ].param_groups[0]["lr"],
415
+ "Step/Discriminator Learning Rate": self.optimizer[
416
+ "optimizer_g"
417
+ ].param_groups[0]["lr"],
418
+ },
419
+ step=self.step,
420
+ )
421
+ self.step += 1
422
+ epoch_step += 1
423
+
424
+ self.accelerator.wait_for_everyone()
425
+
426
+ epoch_sum_loss = (
427
+ epoch_sum_loss
428
+ / len(self.train_dataloader)
429
+ * self.cfg.train.gradient_accumulation_step
430
+ )
431
+
432
+ for key in epoch_losses.keys():
433
+ epoch_losses[key] = (
434
+ epoch_losses[key]
435
+ / len(self.train_dataloader)
436
+ * self.cfg.train.gradient_accumulation_step
437
+ )
438
+
439
+ return epoch_sum_loss, epoch_losses
Amphion/models/vocoders/autoregressive/autoregressive_vocoder_trainer.py ADDED
File without changes
Amphion/modules/activation_functions/gated_activation_unit.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from modules.general.utils import Conv1d
10
+
11
+
12
+ class GaU(nn.Module):
13
+ r"""Gated Activation Unit (GaU) proposed in `Gated Activation Units for Neural
14
+ Networks <https://arxiv.org/pdf/1606.05328.pdf>`_.
15
+
16
+ Args:
17
+ channels: number of input channels.
18
+ kernel_size: kernel size of the convolution.
19
+ dilation: dilation rate of the convolution.
20
+ d_context: dimension of context tensor, None if don't use context.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ channels: int,
26
+ kernel_size: int = 3,
27
+ dilation: int = 1,
28
+ d_context: int = None,
29
+ ):
30
+ super().__init__()
31
+
32
+ self.context = d_context
33
+
34
+ self.conv = Conv1d(
35
+ channels,
36
+ channels * 2,
37
+ kernel_size,
38
+ dilation=dilation,
39
+ padding=dilation * (kernel_size - 1) // 2,
40
+ )
41
+
42
+ if self.context:
43
+ self.context_proj = Conv1d(d_context, channels * 2, 1)
44
+
45
+ def forward(self, x: torch.Tensor, context: torch.Tensor = None):
46
+ r"""Calculate forward propagation.
47
+
48
+ Args:
49
+ x: input tensor with shape [B, C, T].
50
+ context: context tensor with shape [B, ``d_context``, T], default to None.
51
+ """
52
+
53
+ h = self.conv(x)
54
+
55
+ if self.context:
56
+ h = h + self.context_proj(context)
57
+
58
+ h1, h2 = h.chunk(2, 1)
59
+ h = torch.tanh(h1) * torch.sigmoid(h2)
60
+
61
+ return h
Amphion/modules/base/base_module.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+
11
+ class LayerNorm(nn.Module):
12
+ def __init__(self, channels, eps=1e-5):
13
+ super().__init__()
14
+ self.channels = channels
15
+ self.eps = eps
16
+
17
+ self.gamma = nn.Parameter(torch.ones(channels))
18
+ self.beta = nn.Parameter(torch.zeros(channels))
19
+
20
+ def forward(self, x):
21
+ x = x.transpose(1, -1)
22
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
23
+ return x.transpose(1, -1)
24
+
25
+
26
+ class ConvReluNorm(nn.Module):
27
+ def __init__(
28
+ self,
29
+ in_channels,
30
+ hidden_channels,
31
+ out_channels,
32
+ kernel_size,
33
+ n_layers,
34
+ p_dropout,
35
+ ):
36
+ super().__init__()
37
+ self.in_channels = in_channels
38
+ self.hidden_channels = hidden_channels
39
+ self.out_channels = out_channels
40
+ self.kernel_size = kernel_size
41
+ self.n_layers = n_layers
42
+ self.p_dropout = p_dropout
43
+ assert n_layers > 1, "Number of layers should be larger than 0."
44
+
45
+ self.conv_layers = nn.ModuleList()
46
+ self.norm_layers = nn.ModuleList()
47
+ self.conv_layers.append(
48
+ nn.Conv1d(
49
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
50
+ )
51
+ )
52
+ self.norm_layers.append(LayerNorm(hidden_channels))
53
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
54
+ for _ in range(n_layers - 1):
55
+ self.conv_layers.append(
56
+ nn.Conv1d(
57
+ hidden_channels,
58
+ hidden_channels,
59
+ kernel_size,
60
+ padding=kernel_size // 2,
61
+ )
62
+ )
63
+ self.norm_layers.append(LayerNorm(hidden_channels))
64
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
65
+ self.proj.weight.data.zero_()
66
+ self.proj.bias.data.zero_()
67
+
68
+ def forward(self, x, x_mask):
69
+ x_org = x
70
+ for i in range(self.n_layers):
71
+ x = self.conv_layers[i](x * x_mask)
72
+ x = self.norm_layers[i](x)
73
+ x = self.relu_drop(x)
74
+ x = x_org + self.proj(x)
75
+ return x * x_mask
Amphion/modules/diffusion/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .bidilconv.bidilated_conv import BiDilConv
7
+ from .unet.unet import UNet
Amphion/modules/duration_predictor/__init__.py ADDED
File without changes
Amphion/modules/duration_predictor/standard_duration_predictor.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is modified from https://github.com/jaywalnut310/vits/blob/main/models.py
7
+
8
+ import torch
9
+ from torch import nn
10
+ from modules.base.base_module import LayerNorm
11
+
12
+
13
+ class DurationPredictor(nn.Module):
14
+ def __init__(
15
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
16
+ ):
17
+ super().__init__()
18
+
19
+ self.in_channels = in_channels
20
+ self.filter_channels = filter_channels
21
+ self.kernel_size = kernel_size
22
+ self.p_dropout = p_dropout
23
+ self.gin_channels = gin_channels
24
+
25
+ self.drop = nn.Dropout(p_dropout)
26
+ self.conv_1 = nn.Conv1d(
27
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
28
+ )
29
+ self.norm_1 = LayerNorm(filter_channels)
30
+ self.conv_2 = nn.Conv1d(
31
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
32
+ )
33
+ self.norm_2 = LayerNorm(filter_channels)
34
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
35
+
36
+ if gin_channels != 0:
37
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
38
+
39
+ def forward(self, x, x_mask, g=None):
40
+ x = torch.detach(x)
41
+ if g is not None:
42
+ g = torch.detach(g)
43
+ x = x + self.cond(g)
44
+ x = self.conv_1(x * x_mask)
45
+ x = torch.relu(x)
46
+ x = self.norm_1(x)
47
+ x = self.drop(x)
48
+ x = self.conv_2(x * x_mask)
49
+ x = torch.relu(x)
50
+ x = self.norm_2(x)
51
+ x = self.drop(x)
52
+ x = self.proj(x * x_mask)
53
+ return x * x_mask
Amphion/modules/duration_predictor/stochastic_duration_predictor.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is modified from https://github.com/jaywalnut310/vits/blob/main/models.pyimport torch
7
+
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ import math
11
+ from modules.flow.modules import *
12
+
13
+
14
+ class StochasticDurationPredictor(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_channels,
18
+ filter_channels,
19
+ kernel_size,
20
+ p_dropout,
21
+ n_flows=4,
22
+ gin_channels=0,
23
+ ):
24
+ super().__init__()
25
+ filter_channels = in_channels
26
+ self.in_channels = in_channels
27
+ self.filter_channels = filter_channels
28
+ self.kernel_size = kernel_size
29
+ self.p_dropout = p_dropout
30
+ self.n_flows = n_flows
31
+ self.gin_channels = gin_channels
32
+
33
+ self.log_flow = Log()
34
+ self.flows = nn.ModuleList()
35
+ self.flows.append(ElementwiseAffine(2))
36
+ for i in range(n_flows):
37
+ self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3))
38
+ self.flows.append(Flip())
39
+
40
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
41
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
42
+ self.post_convs = DDSConv(
43
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
44
+ )
45
+ self.post_flows = nn.ModuleList()
46
+ self.post_flows.append(ElementwiseAffine(2))
47
+ for i in range(4):
48
+ self.post_flows.append(
49
+ ConvFlow(2, filter_channels, kernel_size, n_layers=3)
50
+ )
51
+ self.post_flows.append(Flip())
52
+
53
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
54
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
55
+ self.convs = DDSConv(
56
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
57
+ )
58
+ if gin_channels != 0:
59
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
60
+
61
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
62
+ x = torch.detach(x)
63
+ x = self.pre(x)
64
+ if g is not None:
65
+ g = torch.detach(g)
66
+ x = x + self.cond(g)
67
+ x = self.convs(x, x_mask)
68
+ x = self.proj(x) * x_mask
69
+
70
+ if not reverse:
71
+ flows = self.flows
72
+ assert w is not None
73
+
74
+ logdet_tot_q = 0
75
+ h_w = self.post_pre(w)
76
+ h_w = self.post_convs(h_w, x_mask)
77
+ h_w = self.post_proj(h_w) * x_mask
78
+ e_q = (
79
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
80
+ * x_mask
81
+ )
82
+ z_q = e_q
83
+ for flow in self.post_flows:
84
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
85
+ logdet_tot_q += logdet_q
86
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
87
+ u = torch.sigmoid(z_u) * x_mask
88
+ z0 = (w - u) * x_mask
89
+ logdet_tot_q += torch.sum(
90
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
91
+ )
92
+ logq = (
93
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
94
+ - logdet_tot_q
95
+ )
96
+
97
+ logdet_tot = 0
98
+ z0, logdet = self.log_flow(z0, x_mask)
99
+ logdet_tot += logdet
100
+ z = torch.cat([z0, z1], 1)
101
+ for flow in flows:
102
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
103
+ logdet_tot = logdet_tot + logdet
104
+ nll = (
105
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
106
+ - logdet_tot
107
+ )
108
+ return nll + logq
109
+ else:
110
+ flows = list(reversed(self.flows))
111
+ flows = flows[:-2] + [flows[-1]]
112
+ z = (
113
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
114
+ * noise_scale
115
+ )
116
+ for flow in flows:
117
+ z = flow(z, x_mask, g=x, reverse=reverse)
118
+ z0, z1 = torch.split(z, [1, 1], 1)
119
+ logw = z0
120
+ return logw
Amphion/modules/general/scaling.py ADDED
@@ -0,0 +1,1349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This module is modified from https://github.com/Plachtaa/VALL-E-X/blob/3faaf8ccadb154d63b38070caf518ce9309ea0f4/modules/scaling.py
2
+
3
+
4
+ import logging
5
+ import random
6
+ import math
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import Tensor
12
+
13
+
14
+ class Transpose(nn.Identity):
15
+ """(N, T, D) -> (N, D, T)"""
16
+
17
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
18
+ return input.transpose(1, 2)
19
+
20
+
21
+ class ActivationBalancerFunction(torch.autograd.Function):
22
+ @staticmethod
23
+ def forward(
24
+ ctx,
25
+ x: Tensor,
26
+ scale_factor: Tensor,
27
+ sign_factor: Optional[Tensor],
28
+ channel_dim: int,
29
+ ) -> Tensor:
30
+ if channel_dim < 0:
31
+ channel_dim += x.ndim
32
+ ctx.channel_dim = channel_dim
33
+ xgt0 = x > 0
34
+ if sign_factor is None:
35
+ ctx.save_for_backward(xgt0, scale_factor)
36
+ else:
37
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
38
+ return x
39
+
40
+ @staticmethod
41
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
42
+ if len(ctx.saved_tensors) == 3:
43
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
44
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
45
+ scale_factor = scale_factor.unsqueeze(-1)
46
+ sign_factor = sign_factor.unsqueeze(-1)
47
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
48
+ else:
49
+ xgt0, scale_factor = ctx.saved_tensors
50
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
51
+ scale_factor = scale_factor.unsqueeze(-1)
52
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
53
+ neg_delta_grad = x_grad.abs() * factor
54
+ return (
55
+ x_grad - neg_delta_grad,
56
+ None,
57
+ None,
58
+ None,
59
+ )
60
+
61
+
62
+ def _compute_scale_factor(
63
+ x: Tensor,
64
+ channel_dim: int,
65
+ min_abs: float,
66
+ max_abs: float,
67
+ gain_factor: float,
68
+ max_factor: float,
69
+ ) -> Tensor:
70
+ if channel_dim < 0:
71
+ channel_dim += x.ndim
72
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
73
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
74
+
75
+ if min_abs == 0.0:
76
+ below_threshold = 0.0
77
+ else:
78
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
79
+ # x_abs)_mean , min_abs.
80
+ below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
81
+ min=0, max=max_factor
82
+ )
83
+
84
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
85
+ min=0, max=max_factor
86
+ )
87
+
88
+ return below_threshold - above_threshold
89
+
90
+
91
+ def _compute_sign_factor(
92
+ x: Tensor,
93
+ channel_dim: int,
94
+ min_positive: float,
95
+ max_positive: float,
96
+ gain_factor: float,
97
+ max_factor: float,
98
+ ) -> Tensor:
99
+ if channel_dim < 0:
100
+ channel_dim += x.ndim
101
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
102
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
103
+ if min_positive == 0.0:
104
+ factor1 = 0.0
105
+ else:
106
+ # 0 if proportion_positive >= min_positive, else can be
107
+ # as large as max_factor.
108
+ factor1 = (
109
+ (min_positive - proportion_positive) * (gain_factor / min_positive)
110
+ ).clamp_(min=0, max=max_factor)
111
+
112
+ if max_positive == 1.0:
113
+ factor2 = 0.0
114
+ else:
115
+ # 0 if self.proportion_positive <= max_positive, else can be
116
+ # as large as -max_factor.
117
+ factor2 = (
118
+ (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
119
+ ).clamp_(min=0, max=max_factor)
120
+ sign_factor = factor1 - factor2
121
+ # require min_positive != 0 or max_positive != 1:
122
+ assert not isinstance(sign_factor, float)
123
+ return sign_factor
124
+
125
+
126
+ class ActivationScaleBalancerFunction(torch.autograd.Function):
127
+ """
128
+ This object is used in class ActivationBalancer when the user specified
129
+ min_positive=0, max_positive=1, so there are no constraints on the signs
130
+ of the activations and only the absolute value has a constraint.
131
+ """
132
+
133
+ @staticmethod
134
+ def forward(
135
+ ctx,
136
+ x: Tensor,
137
+ sign_factor: Tensor,
138
+ scale_factor: Tensor,
139
+ channel_dim: int,
140
+ ) -> Tensor:
141
+ if channel_dim < 0:
142
+ channel_dim += x.ndim
143
+ ctx.channel_dim = channel_dim
144
+ xgt0 = x > 0
145
+ ctx.save_for_backward(xgt0, sign_factor, scale_factor)
146
+ return x
147
+
148
+ @staticmethod
149
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
150
+ xgt0, sign_factor, scale_factor = ctx.saved_tensors
151
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
152
+ sign_factor = sign_factor.unsqueeze(-1)
153
+ scale_factor = scale_factor.unsqueeze(-1)
154
+
155
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
156
+ neg_delta_grad = x_grad.abs() * factor
157
+ return (
158
+ x_grad - neg_delta_grad,
159
+ None,
160
+ None,
161
+ None,
162
+ )
163
+
164
+
165
+ class RandomClampFunction(torch.autograd.Function):
166
+ @staticmethod
167
+ def forward(
168
+ ctx,
169
+ x: Tensor,
170
+ min: Optional[float],
171
+ max: Optional[float],
172
+ prob: float,
173
+ reflect: float,
174
+ ) -> Tensor:
175
+ x_clamped = torch.clamp(x, min=min, max=max)
176
+ mask = torch.rand_like(x) < prob
177
+ ans = torch.where(mask, x_clamped, x)
178
+ if x.requires_grad:
179
+ ctx.save_for_backward(ans == x)
180
+ ctx.reflect = reflect
181
+ if reflect != 0.0:
182
+ ans = ans * (1.0 + reflect) - (x * reflect)
183
+ return ans
184
+
185
+ @staticmethod
186
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
187
+ (is_same,) = ctx.saved_tensors
188
+ x_grad = ans_grad * is_same.to(ans_grad.dtype)
189
+ reflect = ctx.reflect
190
+ if reflect != 0.0:
191
+ x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
192
+ return x_grad, None, None, None, None
193
+
194
+
195
+ def random_clamp(
196
+ x: Tensor,
197
+ min: Optional[float] = None,
198
+ max: Optional[float] = None,
199
+ prob: float = 0.5,
200
+ reflect: float = 0.0,
201
+ ):
202
+ return RandomClampFunction.apply(x, min, max, prob, reflect)
203
+
204
+
205
+ def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
206
+ """
207
+ A randomized way of casting a floating point value to half precision.
208
+ """
209
+ if x.dtype == torch.float16:
210
+ return x
211
+ x_abs = x.abs()
212
+ is_too_small = x_abs < min_abs
213
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
214
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
215
+ # for those elements].
216
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
217
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
218
+
219
+
220
+ class RandomGradFunction(torch.autograd.Function):
221
+ """
222
+ Does nothing in forward pass; in backward pass, gets rid of very small grads using
223
+ randomized approach that preserves expectations (intended to reduce roundoff).
224
+ """
225
+
226
+ @staticmethod
227
+ def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
228
+ ctx.min_abs = min_abs
229
+ return x
230
+
231
+ @staticmethod
232
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
233
+ if ans_grad.dtype == torch.float16:
234
+ return (
235
+ random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs),
236
+ None,
237
+ )
238
+ else:
239
+ return ans_grad, None
240
+
241
+
242
+ class RandomGrad(torch.nn.Module):
243
+ """
244
+ Gets rid of very small gradients using an expectation-preserving method, intended to increase
245
+ accuracy of training when using amp (automatic mixed precision)
246
+ """
247
+
248
+ def __init__(self, min_abs: float = 5.0e-06):
249
+ super(RandomGrad, self).__init__()
250
+ self.min_abs = min_abs
251
+
252
+ def forward(self, x: Tensor):
253
+ if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
254
+ return x
255
+ else:
256
+ return RandomGradFunction.apply(x, self.min_abs)
257
+
258
+
259
+ class SoftmaxFunction(torch.autograd.Function):
260
+ """
261
+ Tries to handle half-precision derivatives in a randomized way that should
262
+ be more accurate for training than the default behavior.
263
+ """
264
+
265
+ @staticmethod
266
+ def forward(ctx, x: Tensor, dim: int):
267
+ ans = x.softmax(dim=dim)
268
+ # if x dtype is float16, x.softmax() returns a float32 because
269
+ # (presumably) that op does not support float16, and autocast
270
+ # is enabled.
271
+ if torch.is_autocast_enabled():
272
+ ans = ans.to(torch.float16)
273
+ ctx.save_for_backward(ans)
274
+ ctx.x_dtype = x.dtype
275
+ ctx.dim = dim
276
+ return ans
277
+
278
+ @staticmethod
279
+ def backward(ctx, ans_grad: Tensor):
280
+ (ans,) = ctx.saved_tensors
281
+ with torch.cuda.amp.autocast(enabled=False):
282
+ ans_grad = ans_grad.to(torch.float32)
283
+ ans = ans.to(torch.float32)
284
+ x_grad = ans_grad * ans
285
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
286
+ return x_grad, None
287
+
288
+
289
+ def softmax(x: Tensor, dim: int):
290
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
291
+ return x.softmax(dim)
292
+
293
+ return SoftmaxFunction.apply(x, dim)
294
+
295
+
296
+ class MaxEigLimiterFunction(torch.autograd.Function):
297
+ @staticmethod
298
+ def forward(
299
+ ctx,
300
+ x: Tensor,
301
+ coeffs: Tensor,
302
+ direction: Tensor,
303
+ channel_dim: int,
304
+ grad_scale: float,
305
+ ) -> Tensor:
306
+ ctx.channel_dim = channel_dim
307
+ ctx.grad_scale = grad_scale
308
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
309
+ return x
310
+
311
+ @staticmethod
312
+ def backward(ctx, x_grad, *args):
313
+ with torch.enable_grad():
314
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
315
+ x_orig.requires_grad = True
316
+ num_channels = x_orig.shape[ctx.channel_dim]
317
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
318
+ new_direction.requires_grad = False
319
+ x = x - x.mean(dim=0)
320
+ x_var = (x**2).mean()
321
+ x_residual = x - coeffs * new_direction
322
+ x_residual_var = (x_residual**2).mean()
323
+ # `variance_proportion` is the proportion of the variance accounted for
324
+ # by the top eigen-direction. This is to be minimized.
325
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
326
+ variance_proportion.backward()
327
+ x_orig_grad = x_orig.grad
328
+ x_extra_grad = (
329
+ x_orig.grad
330
+ * ctx.grad_scale
331
+ * x_grad.norm()
332
+ / (x_orig_grad.norm() + 1.0e-20)
333
+ )
334
+ return x_grad + x_extra_grad.detach(), None, None, None, None
335
+
336
+
337
+ class BasicNorm(torch.nn.Module):
338
+ """
339
+ This is intended to be a simpler, and hopefully cheaper, replacement for
340
+ LayerNorm. The observation this is based on, is that Transformer-type
341
+ networks, especially with pre-norm, sometimes seem to set one of the
342
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
343
+ the LayerNorm because the output magnitude is then not strongly dependent
344
+ on the other (useful) features. Presumably the weight and bias of the
345
+ LayerNorm are required to allow it to do this.
346
+
347
+ So the idea is to introduce this large constant value as an explicit
348
+ parameter, that takes the role of the "eps" in LayerNorm, so the network
349
+ doesn't have to do this trick. We make the "eps" learnable.
350
+
351
+ Args:
352
+ num_channels: the number of channels, e.g. 512.
353
+ channel_dim: the axis/dimension corresponding to the channel,
354
+ interprted as an offset from the input's ndim if negative.
355
+ shis is NOT the num_channels; it should typically be one of
356
+ {-2, -1, 0, 1, 2, 3}.
357
+ eps: the initial "epsilon" that we add as ballast in:
358
+ scale = ((input_vec**2).mean() + epsilon)**-0.5
359
+ Note: our epsilon is actually large, but we keep the name
360
+ to indicate the connection with conventional LayerNorm.
361
+ learn_eps: if true, we learn epsilon; if false, we keep it
362
+ at the initial value.
363
+ eps_min: float
364
+ eps_max: float
365
+ """
366
+
367
+ def __init__(
368
+ self,
369
+ num_channels: int,
370
+ channel_dim: int = -1, # CAUTION: see documentation.
371
+ eps: float = 0.25,
372
+ learn_eps: bool = True,
373
+ eps_min: float = -3.0,
374
+ eps_max: float = 3.0,
375
+ ) -> None:
376
+ super(BasicNorm, self).__init__()
377
+ self.num_channels = num_channels
378
+ self.channel_dim = channel_dim
379
+ if learn_eps:
380
+ self.eps = nn.Parameter(torch.tensor(eps).log().detach())
381
+ else:
382
+ self.register_buffer("eps", torch.tensor(eps).log().detach())
383
+ self.eps_min = eps_min
384
+ self.eps_max = eps_max
385
+
386
+ def forward(self, x: Tensor) -> Tensor:
387
+ assert x.shape[self.channel_dim] == self.num_channels
388
+ eps = self.eps
389
+ if self.training and random.random() < 0.25:
390
+ # with probability 0.25, in training mode, clamp eps between the min
391
+ # and max; this will encourage it to learn parameters within the
392
+ # allowed range by making parameters that are outside the allowed
393
+ # range noisy.
394
+
395
+ # gradients to allow the parameter to get back into the allowed
396
+ # region if it happens to exit it.
397
+ eps = eps.clamp(min=self.eps_min, max=self.eps_max)
398
+ scales = (
399
+ torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp()
400
+ ) ** -0.5
401
+ return x * scales
402
+
403
+
404
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
405
+ """
406
+ Behaves like a constructor of a modified version of nn.Linear
407
+ that gives an easy way to set the default initial parameter scale.
408
+
409
+ Args:
410
+ Accepts the standard args and kwargs that nn.Linear accepts
411
+ e.g. in_features, out_features, bias=False.
412
+
413
+ initial_scale: you can override this if you want to increase
414
+ or decrease the initial magnitude of the module's output
415
+ (affects the initialization of weight_scale and bias_scale).
416
+ Another option, if you want to do something like this, is
417
+ to re-initialize the parameters.
418
+ """
419
+ ans = nn.Linear(*args, **kwargs)
420
+ with torch.no_grad():
421
+ ans.weight[:] *= initial_scale
422
+ if ans.bias is not None:
423
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
424
+ return ans
425
+
426
+
427
+ def ScaledConv1d(
428
+ *args,
429
+ initial_scale: float = 1.0,
430
+ kernel_size: int = 3,
431
+ padding: str = "same",
432
+ **kwargs,
433
+ ) -> nn.Conv1d:
434
+ """
435
+ Behaves like a constructor of a modified version of nn.Conv1d
436
+ that gives an easy way to set the default initial parameter scale.
437
+
438
+ Args:
439
+ Accepts the standard args and kwargs that nn.Linear accepts
440
+ e.g. in_features, out_features, bias=False.
441
+
442
+ initial_scale: you can override this if you want to increase
443
+ or decrease the initial magnitude of the module's output
444
+ (affects the initialization of weight_scale and bias_scale).
445
+ Another option, if you want to do something like this, is
446
+ to re-initialize the parameters.
447
+ """
448
+ ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
449
+ with torch.no_grad():
450
+ ans.weight[:] *= initial_scale
451
+ if ans.bias is not None:
452
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
453
+ return ans
454
+
455
+
456
+ def TransposeScaledConv1d(
457
+ *args,
458
+ initial_scale: float = 1.0,
459
+ kernel_size: int = 3,
460
+ padding: str = "same",
461
+ **kwargs,
462
+ ) -> nn.Sequential:
463
+ """
464
+ Transpose -> ScaledConv1d
465
+ """
466
+ return nn.Sequential(
467
+ Transpose(),
468
+ ScaledConv1d(
469
+ *args,
470
+ initial_scale=initial_scale,
471
+ kernel_size=kernel_size,
472
+ padding=padding,
473
+ **kwargs,
474
+ ),
475
+ )
476
+
477
+
478
+ def ScaledConv1dTranspose(
479
+ *args,
480
+ initial_scale: float = 1.0,
481
+ kernel_size: int = 3,
482
+ padding: str = "same",
483
+ **kwargs,
484
+ ) -> nn.Sequential:
485
+ """
486
+ Transpose -> ScaledConv1d
487
+ """
488
+ return nn.Sequential(
489
+ ScaledConv1d(
490
+ *args,
491
+ initial_scale=initial_scale,
492
+ kernel_size=kernel_size,
493
+ padding=padding,
494
+ **kwargs,
495
+ ),
496
+ Transpose(),
497
+ )
498
+
499
+
500
+ def TransposeConv1d(
501
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
502
+ ) -> nn.Sequential:
503
+ """
504
+ Transpose -> Conv1d
505
+ """
506
+ return nn.Sequential(
507
+ Transpose(),
508
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
509
+ )
510
+
511
+
512
+ def Conv1dTranspose(
513
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
514
+ ) -> nn.Sequential:
515
+ """
516
+ ScaledConv1d -> Transpose
517
+ """
518
+ return nn.Sequential(
519
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
520
+ Transpose(),
521
+ )
522
+
523
+
524
+ class SRLinear(nn.Linear):
525
+ """https://arxiv.org/abs/2303.06296
526
+ Stabilizing Transformer Training by Preventing Attention Entropy Collapse
527
+ """
528
+
529
+ def __init__(self, in_features, out_features, bias=True, **kwargs):
530
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
531
+ self.register_buffer(
532
+ "u", nn.functional.normalize(torch.randn(in_features), dim=0)
533
+ )
534
+ with torch.no_grad():
535
+ sigma = self.get_sigma()
536
+ self.register_buffer("spectral_norm", sigma)
537
+ self.sigma = nn.Parameter(torch.ones(1))
538
+
539
+ def get_sigma(self):
540
+ with torch.no_grad():
541
+ u = self.u
542
+ v = self.weight.mv(u)
543
+ v = nn.functional.normalize(v, dim=0)
544
+ u = self.weight.T.mv(v)
545
+ u = nn.functional.normalize(u, dim=0)
546
+ self.u.data.copy_(u)
547
+ return torch.einsum("c,cd,d->", v, self.weight, u)
548
+
549
+ def get_weight(self):
550
+ sigma = self.get_sigma()
551
+ if self.training:
552
+ self.spectral_norm.data.copy_(sigma)
553
+ weight = (self.sigma / sigma) * self.weight
554
+ return weight
555
+
556
+ def forward(self, x):
557
+ return nn.functional.linear(x, self.get_weight(), self.bias)
558
+
559
+
560
+ class SRConv1d(SRLinear):
561
+ def __init__(
562
+ self,
563
+ in_features,
564
+ out_features,
565
+ kernel_size,
566
+ stride: int = 1,
567
+ padding: str = "same",
568
+ bias: bool = True,
569
+ **kwargs,
570
+ ):
571
+ in_features = in_features * kernel_size
572
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
573
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
574
+ self.kernel_size = kernel_size
575
+ self.stride = stride
576
+ self.padding = padding
577
+
578
+ def forward(self, x):
579
+ in_features = self.in_features // self.kernel_size
580
+ weight = self.get_weight().view(
581
+ self.out_features, in_features, self.kernel_size
582
+ )
583
+ return nn.functional.conv1d(
584
+ x, weight, bias=self.bias, stride=self.stride, padding=self.padding
585
+ )
586
+
587
+
588
+ def TransposeSRConv1d(
589
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
590
+ ) -> nn.Sequential:
591
+ """
592
+ Transpose -> SRConv1d
593
+ """
594
+ return nn.Sequential(
595
+ Transpose(),
596
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
597
+ )
598
+
599
+
600
+ def SRConv1dTranspose(
601
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
602
+ ) -> nn.Sequential:
603
+ """
604
+ SRConv1d -> Transpose
605
+ """
606
+ return nn.Sequential(
607
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
608
+ Transpose(),
609
+ )
610
+
611
+
612
+ class ActivationBalancer(torch.nn.Module):
613
+ """
614
+ Modifies the backpropped derivatives of a function to try to encourage, for
615
+ each channel, that it is positive at least a proportion `threshold` of the
616
+ time. It does this by multiplying negative derivative values by up to
617
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
618
+ interpolated from 1 at the threshold to those extremal values when none
619
+ of the inputs are positive.
620
+
621
+ Args:
622
+ num_channels: the number of channels
623
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
624
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
625
+ min_positive: the minimum, per channel, of the proportion of the time
626
+ that (x > 0), below which we start to modify the derivatives.
627
+ max_positive: the maximum, per channel, of the proportion of the time
628
+ that (x > 0), above which we start to modify the derivatives.
629
+ max_factor: the maximum factor by which we modify the derivatives for
630
+ either the sign constraint or the magnitude constraint;
631
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
632
+ values in the range [0.98..1.02].
633
+ sign_gain_factor: determines the 'gain' with which we increase the
634
+ change in gradient once the constraints on min_positive and max_positive
635
+ are violated.
636
+ scale_gain_factor: determines the 'gain' with which we increase the
637
+ change in gradient once the constraints on min_abs and max_abs
638
+ are violated.
639
+ min_abs: the minimum average-absolute-value difference from the mean
640
+ value per channel, which we allow, before we start to modify
641
+ the derivatives to prevent this.
642
+ max_abs: the maximum average-absolute-value difference from the mean
643
+ value per channel, which we allow, before we start to modify
644
+ the derivatives to prevent this.
645
+ min_prob: determines the minimum probability with which we modify the
646
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
647
+ on each forward(). This is done randomly to prevent all layers
648
+ from doing it at the same time. Early in training we may use
649
+ higher probabilities than this; it will decay to this value.
650
+ """
651
+
652
+ def __init__(
653
+ self,
654
+ num_channels: int,
655
+ channel_dim: int,
656
+ min_positive: float = 0.05,
657
+ max_positive: float = 0.95,
658
+ max_factor: float = 0.04,
659
+ sign_gain_factor: float = 0.01,
660
+ scale_gain_factor: float = 0.02,
661
+ min_abs: float = 0.2,
662
+ max_abs: float = 100.0,
663
+ min_prob: float = 0.1,
664
+ ):
665
+ super(ActivationBalancer, self).__init__()
666
+ self.num_channels = num_channels
667
+ self.channel_dim = channel_dim
668
+ self.min_positive = min_positive
669
+ self.max_positive = max_positive
670
+ self.max_factor = max_factor
671
+ self.min_abs = min_abs
672
+ self.max_abs = max_abs
673
+ self.min_prob = min_prob
674
+ self.sign_gain_factor = sign_gain_factor
675
+ self.scale_gain_factor = scale_gain_factor
676
+
677
+ # count measures how many times the forward() function has been called.
678
+ # We occasionally sync this to a tensor called `count`, that exists to
679
+ # make sure it is synced to disk when we load and save the model.
680
+ self.cpu_count = 0
681
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
682
+
683
+ def forward(self, x: Tensor) -> Tensor:
684
+ if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
685
+ return _no_op(x)
686
+
687
+ count = self.cpu_count
688
+ self.cpu_count += 1
689
+
690
+ if random.random() < 0.01:
691
+ # Occasionally sync self.cpu_count with self.count.
692
+ # count affects the decay of 'prob'. don't do this on every iter,
693
+ # because syncing with the GPU is slow.
694
+ self.cpu_count = max(self.cpu_count, self.count.item())
695
+ self.count.fill_(self.cpu_count)
696
+
697
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
698
+ # a floor at min_prob (==0.1, by default)
699
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
700
+
701
+ if random.random() < prob:
702
+ sign_gain_factor = 0.5
703
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
704
+ sign_factor = _compute_sign_factor(
705
+ x,
706
+ self.channel_dim,
707
+ self.min_positive,
708
+ self.max_positive,
709
+ gain_factor=self.sign_gain_factor / prob,
710
+ max_factor=self.max_factor,
711
+ )
712
+ else:
713
+ sign_factor = None
714
+
715
+ scale_factor = _compute_scale_factor(
716
+ x.detach(),
717
+ self.channel_dim,
718
+ min_abs=self.min_abs,
719
+ max_abs=self.max_abs,
720
+ gain_factor=self.scale_gain_factor / prob,
721
+ max_factor=self.max_factor,
722
+ )
723
+ return ActivationBalancerFunction.apply(
724
+ x,
725
+ scale_factor,
726
+ sign_factor,
727
+ self.channel_dim,
728
+ )
729
+ else:
730
+ return _no_op(x)
731
+
732
+
733
+ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
734
+ """
735
+ Returns x unmodified, but in backprop will put a penalty for the excess of
736
+ the absolute values of elements of x over the limit "limit". E.g. if
737
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
738
+
739
+ Caution: the value of this penalty will be affected by grad scaling used
740
+ in automatic mixed precision training. For this reasons we use this,
741
+ it shouldn't really matter, or may even be helpful; we just use this
742
+ to disallow really implausible values of scores to be given to softmax.
743
+ """
744
+ x_sign = x.sign()
745
+ over_limit = (x.abs() - limit) > 0
746
+ # The following is a memory efficient way to penalize the absolute values of
747
+ # x that's over the limit. (The memory efficiency comes when you think
748
+ # about which items torch needs to cache for the autograd, and which ones it
749
+ # can throw away). The numerical value of aux_loss as computed here will
750
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
751
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
752
+ # limit).relu().
753
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
754
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
755
+ # sum() due to how with_loss() works.
756
+ x = with_loss(x, aux_loss)
757
+ # you must use x for something, or this will be ineffective.
758
+ return x
759
+
760
+
761
+ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
762
+ if x.ndim == 2:
763
+ return x.diag()
764
+ else:
765
+ (batch, dim, dim) = x.shape
766
+ x = x.reshape(batch, dim * dim)
767
+ x = x[:, :: dim + 1]
768
+ assert x.shape == (batch, dim)
769
+ return x
770
+
771
+
772
+ def _whitening_metric(x: Tensor, num_groups: int):
773
+ """
774
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
775
+ of the centered feature covariance are the same within each group's covariance matrix
776
+ and also between groups.
777
+ Args:
778
+ x: a Tensor of shape (*, num_channels)
779
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
780
+ Returns:
781
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
782
+ greater than 1.0 otherwise.
783
+ """
784
+ assert x.dtype != torch.float16
785
+ x = x.reshape(-1, x.shape[-1])
786
+ (num_frames, num_channels) = x.shape
787
+ assert num_channels % num_groups == 0
788
+ channels_per_group = num_channels // num_groups
789
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
790
+ # x now has shape (num_groups, num_frames, channels_per_group)
791
+ # subtract the mean so we use the centered, not uncentered, covariance.
792
+ # My experience has been that when we "mess with the gradients" like this,
793
+ # it's better not do anything that tries to move the mean around, because
794
+ # that can easily cause instability.
795
+ x = x - x.mean(dim=1, keepdim=True)
796
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
797
+ x_covar = torch.matmul(x.transpose(1, 2), x)
798
+ x_covar_mean_diag = _diag(x_covar).mean()
799
+ # the following expression is what we'd get if we took the matrix product
800
+ # of each covariance and measured the mean of its trace, i.e.
801
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
802
+ x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
803
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
804
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
805
+ return metric
806
+
807
+
808
+ class WhiteningPenaltyFunction(torch.autograd.Function):
809
+ @staticmethod
810
+ def forward(
811
+ ctx,
812
+ x: Tensor,
813
+ num_groups: int,
814
+ whitening_limit: float,
815
+ grad_scale: float,
816
+ ) -> Tensor:
817
+ ctx.save_for_backward(x)
818
+ ctx.num_groups = num_groups
819
+ ctx.whitening_limit = whitening_limit
820
+ ctx.grad_scale = grad_scale
821
+ return x
822
+
823
+ @staticmethod
824
+ def backward(ctx, x_grad: Tensor):
825
+ (x_orig,) = ctx.saved_tensors
826
+ with torch.enable_grad():
827
+ with torch.cuda.amp.autocast(enabled=False):
828
+ x_detached = x_orig.to(torch.float32).detach()
829
+ x_detached.requires_grad = True
830
+
831
+ metric = _whitening_metric(x_detached, ctx.num_groups)
832
+
833
+ if random.random() < 0.005 or __name__ == "__main__":
834
+ logging.info(
835
+ f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
836
+ f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
837
+ )
838
+
839
+ (metric - ctx.whitening_limit).relu().backward()
840
+ penalty_grad = x_detached.grad
841
+ scale = ctx.grad_scale * (
842
+ x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20)
843
+ )
844
+ penalty_grad = penalty_grad * scale
845
+ return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
846
+
847
+
848
+ class Whiten(nn.Module):
849
+ def __init__(
850
+ self,
851
+ num_groups: int,
852
+ whitening_limit: float,
853
+ prob: Union[float, Tuple[float, float]],
854
+ grad_scale: float,
855
+ ):
856
+ """
857
+ Args:
858
+ num_groups: the number of groups to divide the channel dim into before
859
+ whitening. We will attempt to make the feature covariance
860
+ within each group, after mean subtraction, as "white" as possible,
861
+ while having the same trace across all groups.
862
+ whitening_limit: a value greater than 1.0, that dictates how much
863
+ freedom we have to violate the constraints. 1.0 would mean perfectly
864
+ white, with exactly the same trace across groups; larger values
865
+ give more freedom. E.g. 2.0.
866
+ prob: the probability with which we apply the gradient modification
867
+ (also affects the grad scale). May be supplied as a float,
868
+ or as a pair (min_prob, max_prob)
869
+
870
+ grad_scale: determines the scale on the gradient term from this object,
871
+ relative to the rest of the gradient on the attention weights.
872
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
873
+ """
874
+ super(Whiten, self).__init__()
875
+ assert num_groups >= 1
876
+ assert whitening_limit >= 1
877
+ assert grad_scale >= 0
878
+ self.num_groups = num_groups
879
+ self.whitening_limit = whitening_limit
880
+ if isinstance(prob, float):
881
+ assert 0 < prob <= 1
882
+ self.prob = prob
883
+ else:
884
+ (self.min_prob, self.max_prob) = prob
885
+ assert 0 < self.min_prob < self.max_prob <= 1
886
+ self.prob = self.max_prob
887
+
888
+ self.grad_scale = grad_scale
889
+
890
+ def forward(self, x: Tensor) -> Tensor:
891
+ """
892
+ In the forward pass, this function just returns the input unmodified.
893
+ In the backward pass, it will modify the gradients to ensure that the
894
+ distribution in each group has close to (lambda times I) as the covariance
895
+ after mean subtraction, with the same lambda across groups.
896
+ For whitening_limit > 1, there will be more freedom to violate this
897
+ constraint.
898
+
899
+ Args:
900
+ x: the input of shape (*, num_channels)
901
+
902
+ Returns:
903
+ x, unmodified. You should make sure
904
+ you use the returned value, or the graph will be freed
905
+ and nothing will happen in backprop.
906
+ """
907
+ if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
908
+ return _no_op(x)
909
+ else:
910
+ if hasattr(self, "min_prob") and random.random() < 0.25:
911
+ # occasionally switch between min_prob and max_prob, based on whether
912
+ # we are above or below the threshold.
913
+ if (
914
+ _whitening_metric(x.to(torch.float32), self.num_groups)
915
+ > self.whitening_limit
916
+ ):
917
+ # there would be a change to the grad.
918
+ self.prob = self.max_prob
919
+ else:
920
+ self.prob = self.min_prob
921
+
922
+ return WhiteningPenaltyFunction.apply(
923
+ x, self.num_groups, self.whitening_limit, self.grad_scale
924
+ )
925
+
926
+
927
+ class WithLoss(torch.autograd.Function):
928
+ @staticmethod
929
+ def forward(ctx, x: Tensor, y: Tensor):
930
+ ctx.y_shape = y.shape
931
+ return x
932
+
933
+ @staticmethod
934
+ def backward(ctx, ans_grad: Tensor):
935
+ return ans_grad, torch.ones(
936
+ ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
937
+ )
938
+
939
+
940
+ def with_loss(x, y):
941
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
942
+ return x
943
+ # returns x but adds y.sum() to the loss function.
944
+ return WithLoss.apply(x, y)
945
+
946
+
947
+ def _no_op(x: Tensor) -> Tensor:
948
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
949
+ return x
950
+ else:
951
+ # a no-op function that will have a node in the autograd graph,
952
+ # to avoid certain bugs relating to backward hooks
953
+ return x.chunk(1, dim=-1)[0]
954
+
955
+
956
+ class Identity(torch.nn.Module):
957
+ def __init__(self):
958
+ super(Identity, self).__init__()
959
+
960
+ def forward(self, x):
961
+ return _no_op(x)
962
+
963
+
964
+ class MaxEig(torch.nn.Module):
965
+ """
966
+ Modifies the backpropped derivatives of a function to try to discourage
967
+ that any given direction in activation space accounts for more than
968
+ a specified proportion of the covariance (e.g. 0.2).
969
+
970
+
971
+ Args:
972
+ num_channels: the number of channels
973
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
974
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
975
+ max_var_per_eig: the maximum proportion of the variance of the
976
+ features/channels, after mean subtraction, that can come from
977
+ any given eigenvalue.
978
+ min_prob: the minimum probability with which we apply this during any invocation
979
+ of forward(), assuming last time we applied the constraint it was
980
+ not active; supplied for speed.
981
+ scale: determines the scale with which we modify the gradients, relative
982
+ to the existing / unmodified gradients
983
+ """
984
+
985
+ def __init__(
986
+ self,
987
+ num_channels: int,
988
+ channel_dim: int,
989
+ max_var_per_eig: float = 0.2,
990
+ min_prob: float = 0.01,
991
+ scale: float = 0.01,
992
+ ):
993
+ super(MaxEig, self).__init__()
994
+ self.num_channels = num_channels
995
+ self.channel_dim = channel_dim
996
+ self.scale = scale
997
+ assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
998
+ self.max_var_per_eig = max_var_per_eig
999
+
1000
+ # we figure out the dominant direction using the power method: starting with
1001
+ # a random vector, keep multiplying by the covariance and renormalizing.
1002
+ with torch.no_grad():
1003
+ # arbitrary.. would use randn() but want to leave the rest of the model's
1004
+ # random parameters unchanged for comparison
1005
+ direction = torch.arange(num_channels).to(torch.float)
1006
+ direction = direction / direction.norm()
1007
+ self.register_buffer("max_eig_direction", direction)
1008
+
1009
+ self.min_prob = min_prob
1010
+ # cur_prob is the current probability we'll use to apply the ActivationBalancer.
1011
+ # We'll regress this towards prob, each tiem we try to apply it and it is not
1012
+ # active.
1013
+ self.cur_prob = 1.0
1014
+
1015
+ def forward(self, x: Tensor) -> Tensor:
1016
+ if (
1017
+ torch.jit.is_scripting()
1018
+ or self.max_var_per_eig <= 0
1019
+ or random.random() > self.cur_prob
1020
+ or torch.jit.is_tracing()
1021
+ ):
1022
+ return _no_op(x)
1023
+
1024
+ with torch.cuda.amp.autocast(enabled=False):
1025
+ eps = 1.0e-20
1026
+ orig_x = x
1027
+ x = x.to(torch.float32)
1028
+ with torch.no_grad():
1029
+ x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels)
1030
+ x = x - x.mean(dim=0)
1031
+ new_direction, coeffs = self._find_direction_coeffs(
1032
+ x, self.max_eig_direction
1033
+ )
1034
+ x_var = (x**2).mean()
1035
+ x_residual = x - coeffs * new_direction
1036
+ x_residual_var = (x_residual**2).mean()
1037
+
1038
+ # `variance_proportion` is the proportion of the variance accounted for
1039
+ # by the top eigen-direction.
1040
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
1041
+
1042
+ # ensure new direction is nonzero even if x == 0, by including `direction`.
1043
+ self._set_direction(0.1 * self.max_eig_direction + new_direction)
1044
+
1045
+ if random.random() < 0.01 or __name__ == "__main__":
1046
+ logging.info(
1047
+ f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
1048
+ )
1049
+
1050
+ if variance_proportion >= self.max_var_per_eig:
1051
+ # The constraint is active. Note, we should quite rarely
1052
+ # reach here, only near the beginning of training if we are
1053
+ # starting to diverge, should this constraint be active.
1054
+ cur_prob = self.cur_prob
1055
+ self.cur_prob = 1.0 # next time, do the update with probability 1.0.
1056
+ return MaxEigLimiterFunction.apply(
1057
+ orig_x, coeffs, new_direction, self.channel_dim, self.scale
1058
+ )
1059
+ else:
1060
+ # let self.cur_prob exponentially approach self.min_prob, as
1061
+ # long as the constraint is inactive.
1062
+ self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
1063
+ return orig_x
1064
+
1065
+ def _set_direction(self, direction: Tensor):
1066
+ """
1067
+ Sets self.max_eig_direction to a normalized version of `direction`
1068
+ """
1069
+ direction = direction.detach()
1070
+ direction = direction / direction.norm()
1071
+ direction_sum = direction.sum().item()
1072
+ if direction_sum - direction_sum == 0: # no inf/nan
1073
+ self.max_eig_direction[:] = direction
1074
+ else:
1075
+ logging.info(
1076
+ f"Warning: sum of direction in MaxEig is {direction_sum}, "
1077
+ "num_channels={self.num_channels}, channel_dim={self.channel_dim}"
1078
+ )
1079
+
1080
+ def _find_direction_coeffs(
1081
+ self, x: Tensor, prev_direction: Tensor
1082
+ ) -> Tuple[Tensor, Tensor, Tensor]:
1083
+ """
1084
+ Figure out (an approximation to) the proportion of the variance of a set of
1085
+ feature vectors that can be attributed to the top eigen-direction.
1086
+ Args:
1087
+ x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
1088
+ prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
1089
+ of the top eigen-direction, or a random direction if this is the first
1090
+ iteration. Does not have to be normalized, but should be nonzero.
1091
+
1092
+ Returns: (cur_direction, coeffs), where:
1093
+ cur_direction: a Tensor of shape (num_channels,) that is the current
1094
+ estimate of the top eigen-direction.
1095
+ coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
1096
+ approximately minimizes, (x - coeffs * cur_direction).norm()
1097
+ """
1098
+ (num_frames, num_channels) = x.shape
1099
+ assert num_channels > 1 and num_frames > 1
1100
+ assert prev_direction.shape == (num_channels,)
1101
+ # `coeffs` are the coefficients of `prev_direction` in x.
1102
+ # actually represent the coeffs up to a constant positive factor.
1103
+ coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
1104
+ cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20)
1105
+ return cur_direction, coeffs
1106
+
1107
+
1108
+ class DoubleSwishFunction(torch.autograd.Function):
1109
+ """
1110
+ double_swish(x) = x * torch.sigmoid(x-1)
1111
+ This is a definition, originally motivated by its close numerical
1112
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
1113
+
1114
+ Memory-efficient derivative computation:
1115
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
1116
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
1117
+ Now, s'(x) = s(x) * (1-s(x)).
1118
+ double_swish'(x) = x * s'(x) + s(x).
1119
+ = x * s(x) * (1-s(x)) + s(x).
1120
+ = double_swish(x) * (1-s(x)) + s(x)
1121
+ ... so we just need to remember s(x) but not x itself.
1122
+ """
1123
+
1124
+ @staticmethod
1125
+ def forward(ctx, x: Tensor) -> Tensor:
1126
+ requires_grad = x.requires_grad
1127
+ x_dtype = x.dtype
1128
+ if x.dtype == torch.float16:
1129
+ x = x.to(torch.float32)
1130
+
1131
+ s = torch.sigmoid(x - 1.0)
1132
+ y = x * s
1133
+
1134
+ if requires_grad:
1135
+ deriv = y * (1 - s) + s
1136
+ # notes on derivative of x * sigmoid(x - 1):
1137
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
1138
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
1139
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
1140
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
1141
+ # floors), should be expectation-preserving.
1142
+ floor = -0.043637
1143
+ ceil = 1.2
1144
+ d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
1145
+ deriv
1146
+ )
1147
+ if __name__ == "__main__":
1148
+ # for self-testing only.
1149
+ assert d_scaled.min() >= 0.0
1150
+ assert d_scaled.max() < 256.0
1151
+ d_int = d_scaled.to(torch.uint8)
1152
+ ctx.save_for_backward(d_int)
1153
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1154
+ y = y.to(torch.float16)
1155
+ return y
1156
+
1157
+ @staticmethod
1158
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1159
+ (d,) = ctx.saved_tensors
1160
+ # the same constants as used in forward pass.
1161
+ floor = -0.043637
1162
+ ceil = 1.2
1163
+ d = d * ((ceil - floor) / 255.0) + floor
1164
+ return y_grad * d
1165
+
1166
+
1167
+ class DoubleSwish(torch.nn.Module):
1168
+ def forward(self, x: Tensor) -> Tensor:
1169
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
1170
+ that we approximate closely with x * sigmoid(x-1).
1171
+ """
1172
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1173
+ return x * torch.sigmoid(x - 1.0)
1174
+ return DoubleSwishFunction.apply(x)
1175
+
1176
+
1177
+ def BalancedDoubleSwish(
1178
+ d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
1179
+ ) -> nn.Sequential:
1180
+ """
1181
+ ActivationBalancer -> DoubleSwish
1182
+ """
1183
+ balancer = ActivationBalancer(
1184
+ d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
1185
+ )
1186
+ return nn.Sequential(
1187
+ balancer,
1188
+ DoubleSwish(),
1189
+ )
1190
+
1191
+
1192
+ def _test_max_eig():
1193
+ for proportion in [0.1, 0.5, 10.0]:
1194
+ logging.info(f"proportion = {proportion}")
1195
+ x = torch.randn(100, 128)
1196
+ direction = torch.randn(128)
1197
+ coeffs = torch.randn(100, 1)
1198
+ x += proportion * direction * coeffs
1199
+
1200
+ x.requires_grad = True
1201
+
1202
+ num_channels = 128
1203
+ m = MaxEig(
1204
+ num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
1205
+ ) # grad_scale
1206
+
1207
+ for _ in range(4):
1208
+ y = m(x)
1209
+
1210
+ y_grad = torch.randn_like(x)
1211
+ y.backward(gradient=y_grad)
1212
+
1213
+ if proportion < 0.2:
1214
+ assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
1215
+ elif proportion > 1.0:
1216
+ assert not torch.allclose(x.grad, y_grad)
1217
+
1218
+
1219
+ def _test_whiten():
1220
+ for proportion in [0.1, 0.5, 10.0]:
1221
+ logging.info(f"_test_whiten(): proportion = {proportion}")
1222
+ x = torch.randn(100, 128)
1223
+ direction = torch.randn(128)
1224
+ coeffs = torch.randn(100, 1)
1225
+ x += proportion * direction * coeffs
1226
+
1227
+ x.requires_grad = True
1228
+
1229
+ num_channels = 128
1230
+ m = Whiten(
1231
+ 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
1232
+ ) # grad_scale
1233
+
1234
+ for _ in range(4):
1235
+ y = m(x)
1236
+
1237
+ y_grad = torch.randn_like(x)
1238
+ y.backward(gradient=y_grad)
1239
+
1240
+ if proportion < 0.2:
1241
+ assert torch.allclose(x.grad, y_grad)
1242
+ elif proportion > 1.0:
1243
+ assert not torch.allclose(x.grad, y_grad)
1244
+
1245
+
1246
+ def _test_activation_balancer_sign():
1247
+ probs = torch.arange(0, 1, 0.01)
1248
+ N = 1000
1249
+ x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
1250
+ x = x.detach()
1251
+ x.requires_grad = True
1252
+ m = ActivationBalancer(
1253
+ probs.numel(),
1254
+ channel_dim=0,
1255
+ min_positive=0.05,
1256
+ max_positive=0.95,
1257
+ max_factor=0.2,
1258
+ min_abs=0.0,
1259
+ )
1260
+
1261
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
1262
+
1263
+ y = m(x)
1264
+ y.backward(gradient=y_grad)
1265
+ print("_test_activation_balancer_sign: x = ", x)
1266
+ print("_test_activation_balancer_sign: y grad = ", y_grad)
1267
+ print("_test_activation_balancer_sign: x grad = ", x.grad)
1268
+
1269
+
1270
+ def _test_activation_balancer_magnitude():
1271
+ magnitudes = torch.arange(0, 1, 0.01)
1272
+ N = 1000
1273
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
1274
+ x = x.detach()
1275
+ x.requires_grad = True
1276
+ m = ActivationBalancer(
1277
+ magnitudes.numel(),
1278
+ channel_dim=0,
1279
+ min_positive=0.0,
1280
+ max_positive=1.0,
1281
+ max_factor=0.2,
1282
+ min_abs=0.2,
1283
+ max_abs=0.8,
1284
+ min_prob=1.0,
1285
+ )
1286
+
1287
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
1288
+
1289
+ y = m(x)
1290
+ y.backward(gradient=y_grad)
1291
+ print("_test_activation_balancer_magnitude: x = ", x)
1292
+ print("_test_activation_balancer_magnitude: y grad = ", y_grad)
1293
+ print("_test_activation_balancer_magnitude: x grad = ", x.grad)
1294
+
1295
+
1296
+ def _test_basic_norm():
1297
+ num_channels = 128
1298
+ m = BasicNorm(num_channels=num_channels, channel_dim=1)
1299
+
1300
+ x = torch.randn(500, num_channels)
1301
+
1302
+ y = m(x)
1303
+
1304
+ assert y.shape == x.shape
1305
+ x_rms = (x**2).mean().sqrt()
1306
+ y_rms = (y**2).mean().sqrt()
1307
+ print("x rms = ", x_rms)
1308
+ print("y rms = ", y_rms)
1309
+ assert y_rms < x_rms
1310
+ assert y_rms > 0.5 * x_rms
1311
+
1312
+
1313
+ def _test_double_swish_deriv():
1314
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1315
+ x.requires_grad = True
1316
+ m = DoubleSwish()
1317
+
1318
+ tol = (1.2 - (-0.043637)) / 255.0
1319
+ torch.autograd.gradcheck(m, x, atol=tol)
1320
+
1321
+ # for self-test.
1322
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1323
+ x.requires_grad = True
1324
+ y = m(x)
1325
+
1326
+
1327
+ def _test_softmax():
1328
+ a = torch.randn(2, 10, dtype=torch.float64)
1329
+ b = a.clone()
1330
+ a.requires_grad = True
1331
+ b.requires_grad = True
1332
+ a.softmax(dim=1)[:, 0].sum().backward()
1333
+ print("a grad = ", a.grad)
1334
+ softmax(b, dim=1)[:, 0].sum().backward()
1335
+ print("b grad = ", b.grad)
1336
+ assert torch.allclose(a.grad, b.grad)
1337
+
1338
+
1339
+ if __name__ == "__main__":
1340
+ logging.getLogger().setLevel(logging.INFO)
1341
+ torch.set_num_threads(1)
1342
+ torch.set_num_interop_threads(1)
1343
+ _test_softmax()
1344
+ _test_whiten()
1345
+ _test_max_eig()
1346
+ _test_activation_balancer_sign()
1347
+ _test_activation_balancer_magnitude()
1348
+ _test_basic_norm()
1349
+ _test_double_swish_deriv()
Amphion/modules/norms/norm.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ import copy
8
+ import numbers
9
+ from typing import Any, List, Tuple, Union
10
+
11
+ import torch
12
+ from torch import Tensor, nn
13
+ from torch.nn import functional as F
14
+
15
+ from modules.general.scaling import ActivationBalancer
16
+ from modules.general.scaling import BasicNorm as _BasicNorm
17
+
18
+
19
+ _shape_t = Union[int, List[int], torch.Size]
20
+
21
+
22
+ class LayerNorm(nn.Module):
23
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
24
+ normalized_shape: Tuple[int, ...]
25
+ eps: float
26
+ elementwise_affine: bool
27
+
28
+ def __init__(
29
+ self,
30
+ normalized_shape: _shape_t,
31
+ eps: float = 1e-5,
32
+ elementwise_affine: bool = True,
33
+ device=None,
34
+ dtype=None,
35
+ ) -> None:
36
+ factory_kwargs = {"device": device, "dtype": dtype}
37
+ super(LayerNorm, self).__init__()
38
+ if isinstance(normalized_shape, numbers.Integral):
39
+ normalized_shape = (normalized_shape,)
40
+ self.normalized_shape = tuple(normalized_shape)
41
+ self.eps = eps
42
+ self.elementwise_affine = elementwise_affine
43
+ if self.elementwise_affine:
44
+ self.weight = nn.Parameter(
45
+ torch.empty(self.normalized_shape, **factory_kwargs)
46
+ )
47
+ self.bias = nn.Parameter(
48
+ torch.empty(self.normalized_shape, **factory_kwargs)
49
+ )
50
+ else:
51
+ self.register_parameter("weight", None)
52
+ self.register_parameter("bias", None)
53
+
54
+ self.reset_parameters()
55
+
56
+ def reset_parameters(self) -> None:
57
+ if self.elementwise_affine:
58
+ nn.init.ones_(self.weight)
59
+ nn.init.zeros_(self.bias)
60
+
61
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
62
+ if isinstance(input, tuple):
63
+ input, embedding = input
64
+ output = F.layer_norm(
65
+ input, self.normalized_shape, self.weight, self.bias, self.eps
66
+ )
67
+ return output, embedding
68
+
69
+ assert embedding is None
70
+ return F.layer_norm(
71
+ input, self.normalized_shape, self.weight, self.bias, self.eps
72
+ )
73
+
74
+ def extra_repr(self) -> str:
75
+ return (
76
+ "{normalized_shape}, eps={eps}, "
77
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
78
+ )
79
+
80
+
81
+ class AdaptiveLayerNorm(nn.Module):
82
+ r"""Adaptive Layer Normalization"""
83
+
84
+ def __init__(self, d_model, norm) -> None:
85
+ super(AdaptiveLayerNorm, self).__init__()
86
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
87
+ self.norm = norm
88
+ self.d_model = d_model
89
+ self.eps = self.norm.eps
90
+
91
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
92
+ if isinstance(input, tuple):
93
+ input, embedding = input
94
+ weight, bias = torch.split(
95
+ self.project_layer(embedding),
96
+ split_size_or_sections=self.d_model,
97
+ dim=-1,
98
+ )
99
+ return (weight * self.norm(input) + bias, embedding)
100
+
101
+ weight, bias = torch.split(
102
+ self.project_layer(embedding),
103
+ split_size_or_sections=self.d_model,
104
+ dim=-1,
105
+ )
106
+ return weight * self.norm(input) + bias
107
+
108
+
109
+ class BasicNorm(_BasicNorm):
110
+ def __init__(
111
+ self,
112
+ d_model: int,
113
+ eps: float = 1e-5,
114
+ device=None,
115
+ dtype=None,
116
+ ):
117
+ super(BasicNorm, self).__init__(d_model, eps=eps)
118
+
119
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
120
+ if isinstance(input, tuple):
121
+ input, embedding = input
122
+ return (
123
+ super(BasicNorm, self).forward(input),
124
+ embedding,
125
+ )
126
+
127
+ assert embedding is None
128
+ return super(BasicNorm, self).forward(input)
129
+
130
+
131
+ class BalancedBasicNorm(nn.Module):
132
+ def __init__(
133
+ self,
134
+ d_model: int,
135
+ eps: float = 1e-5,
136
+ device=None,
137
+ dtype=None,
138
+ ):
139
+ super(BalancedBasicNorm, self).__init__()
140
+ self.balancer = ActivationBalancer(
141
+ d_model,
142
+ channel_dim=-1,
143
+ min_positive=0.45,
144
+ max_positive=0.55,
145
+ max_abs=6.0,
146
+ )
147
+ self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
148
+
149
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
150
+ if isinstance(input, tuple):
151
+ input, embedding = input
152
+ return self.norm((self.balancer(input), embedding))
153
+
154
+ assert embedding is None
155
+ return self.norm(self.balancer(input))
156
+
157
+
158
+ class IdentityNorm(nn.Module):
159
+ def __init__(
160
+ self,
161
+ d_model: int,
162
+ eps: float = 1e-5,
163
+ device=None,
164
+ dtype=None,
165
+ ) -> None:
166
+ super(IdentityNorm, self).__init__()
167
+
168
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
169
+ if isinstance(input, tuple):
170
+ return input
171
+
172
+ assert embedding is None
173
+ return input