Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Amphion/egs/metrics/README.md +174 -0
- Amphion/egs/metrics/run.sh +132 -0
- Amphion/egs/svc/TransformerSVC/exp_config.json +108 -0
- Amphion/egs/svc/VitsSVC/README.md +125 -0
- Amphion/egs/tta/README.md +19 -0
- Amphion/egs/tta/audioldm/exp_config.json +90 -0
- Amphion/egs/tta/audioldm/run_train.sh +26 -0
- Amphion/egs/tta/audioldm/run_train_latent_4_10_78.sh +26 -0
- Amphion/egs/tta/autoencoderkl/run_train_latent_4_10_78.sh +26 -0
- Amphion/egs/tts/FastSpeech2/prepare_mfa.sh +29 -0
- Amphion/egs/tts/FastSpeech2/run.sh +155 -0
- Amphion/egs/tts/NaturalSpeech2/run_inference.sh +49 -0
- Amphion/egs/tts/VALLE/README.md +207 -0
- Amphion/egs/tts/VALLE/prompt_examples/5142_33396_000002_000004.wav +0 -0
- Amphion/egs/tts/VALLE/prompt_examples/7176_92135_000004_000000.normalized.txt +1 -0
- Amphion/egs/tts/VITS/exp_config.json +34 -0
- Amphion/egs/vocoder/gan/bigvgan_large/exp_config.json +70 -0
- Amphion/egs/vocoder/gan/hifigan/exp_config.json +59 -0
- Amphion/egs/vocoder/gan/hifigan/run.sh +141 -0
- Amphion/egs/vocoder/gan/nsfhifigan/run.sh +141 -0
- Amphion/models/__pycache__/__init__.cpython-310.pyc +0 -0
- Amphion/models/base/__init__.py +7 -0
- Amphion/models/base/new_dataset.py +50 -0
- Amphion/models/base/new_trainer.py +727 -0
- Amphion/models/codec/ns3_codec/__pycache__/facodec.cpython-310.pyc +0 -0
- Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/__init__.cpython-310.pyc +0 -0
- Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/resample.cpython-310.pyc +0 -0
- Amphion/models/codec/ns3_codec/quantize/__pycache__/rvq.cpython-310.pyc +0 -0
- Amphion/models/codec/ns3_codec/quantize/rvq.py +87 -0
- Amphion/models/svc/transformer/transformer.py +82 -0
- Amphion/models/tts/fastspeech2/fs2.py +548 -0
- Amphion/models/tts/fastspeech2/fs2_inference.py +193 -0
- Amphion/models/tts/naturalspeech2/__init__.py +0 -0
- Amphion/models/tts/naturalspeech2/wavenet.py +206 -0
- Amphion/models/tts/valle/valle.py +794 -0
- Amphion/models/tts/valle/valle_inference.py +237 -0
- Amphion/models/tts/valle/valle_trainer.py +367 -0
- Amphion/models/tts/vits/__init__.py +0 -0
- Amphion/models/tts/vits/vits.py +379 -0
- Amphion/models/tts/vits/vits_dataset.py +140 -0
- Amphion/models/tts/vits/vits_trainer.py +439 -0
- Amphion/models/vocoders/autoregressive/autoregressive_vocoder_trainer.py +0 -0
- Amphion/modules/activation_functions/gated_activation_unit.py +61 -0
- Amphion/modules/base/base_module.py +75 -0
- Amphion/modules/diffusion/__init__.py +7 -0
- Amphion/modules/duration_predictor/__init__.py +0 -0
- Amphion/modules/duration_predictor/standard_duration_predictor.py +53 -0
- Amphion/modules/duration_predictor/stochastic_duration_predictor.py +120 -0
- Amphion/modules/general/scaling.py +1349 -0
- Amphion/modules/norms/norm.py +173 -0
Amphion/egs/metrics/README.md
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Amphion Evaluation Recipe
|
| 2 |
+
|
| 3 |
+
## Supported Evaluation Metrics
|
| 4 |
+
|
| 5 |
+
Until now, Amphion Evaluation has supported the following objective metrics:
|
| 6 |
+
|
| 7 |
+
- **F0 Modeling**:
|
| 8 |
+
- F0 Pearson Coefficients (FPC)
|
| 9 |
+
- F0 Periodicity Root Mean Square Error (PeriodicityRMSE)
|
| 10 |
+
- F0 Root Mean Square Error (F0RMSE)
|
| 11 |
+
- Voiced/Unvoiced F1 Score (V/UV F1)
|
| 12 |
+
- **Energy Modeling**:
|
| 13 |
+
- Energy Root Mean Square Error (EnergyRMSE)
|
| 14 |
+
- Energy Pearson Coefficients (EnergyPC)
|
| 15 |
+
- **Intelligibility**:
|
| 16 |
+
- Character Error Rate (CER) based on [Whipser](https://github.com/openai/whisper)
|
| 17 |
+
- Word Error Rate (WER) based on [Whipser](https://github.com/openai/whisper)
|
| 18 |
+
- **Spectrogram Distortion**:
|
| 19 |
+
- Frechet Audio Distance (FAD)
|
| 20 |
+
- Mel Cepstral Distortion (MCD)
|
| 21 |
+
- Multi-Resolution STFT Distance (MSTFT)
|
| 22 |
+
- Perceptual Evaluation of Speech Quality (PESQ)
|
| 23 |
+
- Short Time Objective Intelligibility (STOI)
|
| 24 |
+
- Scale Invariant Signal to Distortion Ratio (SISDR)
|
| 25 |
+
- Scale Invariant Signal to Noise Ratio (SISNR)
|
| 26 |
+
- **Speaker Similarity**:
|
| 27 |
+
- Cosine similarity based on:
|
| 28 |
+
- [Rawnet3](https://github.com/Jungjee/RawNet)
|
| 29 |
+
- [Resemblyzer](https://github.com/resemble-ai/Resemblyzer)
|
| 30 |
+
- [WavLM](https://huggingface.co/microsoft/wavlm-base-plus-sv)
|
| 31 |
+
|
| 32 |
+
We provide a recipe to demonstrate how to objectively evaluate your generated audios. There are three steps in total:
|
| 33 |
+
|
| 34 |
+
1. Pretrained Models Preparation
|
| 35 |
+
2. Audio Data Preparation
|
| 36 |
+
3. Evaluation
|
| 37 |
+
|
| 38 |
+
## 1. Pretrained Models Preparation
|
| 39 |
+
|
| 40 |
+
If you want to calculate `RawNet3` based speaker similarity, you need to download the pretrained model first, as illustrated [here](../../pretrained/README.md).
|
| 41 |
+
|
| 42 |
+
## 2. Audio Data Preparation
|
| 43 |
+
|
| 44 |
+
Prepare reference audios and generated audios in two folders, the `ref_dir` contains the reference audio and the `gen_dir` contains the generated audio. Here is an example.
|
| 45 |
+
|
| 46 |
+
```plaintext
|
| 47 |
+
┣ {ref_dir}
|
| 48 |
+
┃ ┣ sample1.wav
|
| 49 |
+
┃ ┣ sample2.wav
|
| 50 |
+
┣ {gen_dir}
|
| 51 |
+
┃ ┣ sample1.wav
|
| 52 |
+
┃ ┣ sample2.wav
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
You have to make sure that the pairwise **reference audio and generated audio are named the same**, as illustrated above (sample1 to sample1, sample2 to sample2).
|
| 56 |
+
|
| 57 |
+
## 3. Evaluation
|
| 58 |
+
|
| 59 |
+
Run the `run.sh` with specified refenrece folder, generated folder, dump folder and metrics.
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
cd Amphion
|
| 63 |
+
sh egs/metrics/run.sh \
|
| 64 |
+
--reference_folder [Your path to the reference audios] \
|
| 65 |
+
--generated_folder [Your path to the generated audios] \
|
| 66 |
+
--dump_folder [Your path to dump the objective results] \
|
| 67 |
+
--metrics [The metrics you need] \
|
| 68 |
+
--fs [Optional. To calculate all metrics in the specified sampling rate] \
|
| 69 |
+
--similarity_model [Optional. To choose the model for calculating the speaker similarity. Currently "rawnet", "wavlm" and "resemblyzer" are available. Default to "wavlm"] \
|
| 70 |
+
--similarity_mode [Optional. To choose the mode for calculating the speaker similarity. "pairwith" for calculating a series of ground truth / prediction audio pairs to obtain the speaker similarity, and "overall" for computing the average score with all possible pairs between the refernece folder and generated folder. Default to "pairwith"] \
|
| 71 |
+
--intelligibility_mode [Optionoal. To choose the mode for computing CER and WER. "gt_audio" means selecting the recognition content of the reference audio as the target, "gt_content" means using transcription as the target. Default to "gt_audio"] \
|
| 72 |
+
--ltr_path [Optional. Path to the transcription file] \
|
| 73 |
+
--language [Optional. Language for computing CER and WER. Default to "english"]
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
As for the metrics, an example is provided below:
|
| 77 |
+
|
| 78 |
+
```bash
|
| 79 |
+
--metrics "mcd pesq fad"
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
All currently available metrics keywords are listed below:
|
| 83 |
+
|
| 84 |
+
| Keys | Description |
|
| 85 |
+
| ------------------------- | ------------------------------------------ |
|
| 86 |
+
| `fpc` | F0 Pearson Coefficients |
|
| 87 |
+
| `f0_periodicity_rmse` | F0 Periodicity Root Mean Square Error |
|
| 88 |
+
| `f0rmse` | F0 Root Mean Square Error |
|
| 89 |
+
| `v_uv_f1` | Voiced/Unvoiced F1 Score |
|
| 90 |
+
| `energy_rmse` | Energy Root Mean Square Error |
|
| 91 |
+
| `energy_pc` | Energy Pearson Coefficients |
|
| 92 |
+
| `cer` | Character Error Rate |
|
| 93 |
+
| `wer` | Word Error Rate |
|
| 94 |
+
| `similarity` | Speaker Similarity
|
| 95 |
+
| `fad` | Frechet Audio Distance |
|
| 96 |
+
| `mcd` | Mel Cepstral Distortion |
|
| 97 |
+
| `mstft` | Multi-Resolution STFT Distance |
|
| 98 |
+
| `pesq` | Perceptual Evaluation of Speech Quality |
|
| 99 |
+
| `si_sdr` | Scale Invariant Signal to Distortion Ratio |
|
| 100 |
+
| `si_snr` | Scale Invariant Signal to Noise Ratio |
|
| 101 |
+
| `stoi` | Short Time Objective Intelligibility |
|
| 102 |
+
|
| 103 |
+
For example, if want to calculate the speaker similarity between the synthesized audio and the reference audio with the same content, run:
|
| 104 |
+
|
| 105 |
+
```bash
|
| 106 |
+
sh egs/metrics/run.sh \
|
| 107 |
+
--reference_folder [Your path to the reference audios] \
|
| 108 |
+
--generated_folder [Your path to the generated audios] \
|
| 109 |
+
--dump_folder [Your path to dump the objective results] \
|
| 110 |
+
--metrics "similarity" \
|
| 111 |
+
--similarity_model [Optional. To choose the model for calculating the speaker similarity. Currently "rawnet", "wavlm" and "resemblyzer" are available. Default to "wavlm"] \
|
| 112 |
+
--similarity_mode "pairwith" \
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
If you don't have the reference audio with the same content, run the following to get the conteng-free similarity score:
|
| 116 |
+
|
| 117 |
+
```bash
|
| 118 |
+
sh egs/metrics/run.sh \
|
| 119 |
+
--reference_folder [Your path to the reference audios] \
|
| 120 |
+
--generated_folder [Your path to the generated audios] \
|
| 121 |
+
--dump_folder [Your path to dump the objective results] \
|
| 122 |
+
--metrics "similarity" \
|
| 123 |
+
--similarity_model [Optional. To choose the model for calculating the speaker similarity. Currently "rawnet", "wavlm" and "resemblyzer" are available. Default to "wavlm"] \
|
| 124 |
+
--similarity_mode "overall" \
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
## Troubleshooting
|
| 128 |
+
### FAD (Using Offline Models)
|
| 129 |
+
If your system is unable to access huggingface.co from the terminal, you might run into an error like "OSError: Can't load tokenizer for ...". To work around this, follow these steps to use local models:
|
| 130 |
+
|
| 131 |
+
1. Download the [bert-base-uncased](https://huggingface.co/bert-base-uncased), [roberta-base](https://huggingface.co/roberta-base), and [facebook/bart-base](https://huggingface.co/facebook/bart-base) models from `huggingface.co`. Ensure that the models are complete and uncorrupted. Place these directories within `Amphion/pretrained`. For a detailed file structure reference, see [This README](../../pretrained/README.md#optional-model-dependencies-for-evaluation) under `Amphion/pretrained`.
|
| 132 |
+
2. Inside the `Amphion/pretrained` directory, create a bash script with the content outlined below. This script will automatically update the tokenizer paths used by your system:
|
| 133 |
+
```bash
|
| 134 |
+
#!/bin/bash
|
| 135 |
+
|
| 136 |
+
BERT_DIR="bert-base-uncased"
|
| 137 |
+
ROBERTA_DIR="roberta-base"
|
| 138 |
+
BART_DIR="facebook/bart-base"
|
| 139 |
+
PYTHON_SCRIPT="[YOUR ENV PATH]/lib/python3.9/site-packages/laion_clap/training/data.py"
|
| 140 |
+
|
| 141 |
+
update_tokenizer_path() {
|
| 142 |
+
local dir_name=$1
|
| 143 |
+
local tokenizer_variable=$2
|
| 144 |
+
local full_path
|
| 145 |
+
|
| 146 |
+
if [ -d "$dir_name" ]; then
|
| 147 |
+
full_path=$(realpath "$dir_name")
|
| 148 |
+
if [ -f "$PYTHON_SCRIPT" ]; then
|
| 149 |
+
sed -i "s|${tokenizer_variable}.from_pretrained(\".*\")|${tokenizer_variable}.from_pretrained(\"$full_path\")|" "$PYTHON_SCRIPT"
|
| 150 |
+
echo "Updated ${tokenizer_variable} path to $full_path."
|
| 151 |
+
else
|
| 152 |
+
echo "Error: The specified Python script does not exist."
|
| 153 |
+
exit 1
|
| 154 |
+
fi
|
| 155 |
+
else
|
| 156 |
+
echo "Error: The directory $dir_name does not exist in the current directory."
|
| 157 |
+
exit 1
|
| 158 |
+
fi
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
update_tokenizer_path "$BERT_DIR" "BertTokenizer"
|
| 162 |
+
update_tokenizer_path "$ROBERTA_DIR" "RobertaTokenizer"
|
| 163 |
+
update_tokenizer_path "$BART_DIR" "BartTokenizer"
|
| 164 |
+
|
| 165 |
+
echo "BERT, BART and RoBERTa Python script paths have been updated."
|
| 166 |
+
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
3. The script provided is intended to adjust the tokenizer paths in the `data.py` file, found under `/lib/python3.9/site-packages/laion_clap/training/`, within your specific environment. For those utilizing conda, you can determine your environment path by running `conda info --envs`. Then, substitute `[YOUR ENV PATH]` in the script with this path. If your environment is configured differently, you'll need to update the `PYTHON_SCRIPT` variable to correctly point to the `data.py` file.
|
| 170 |
+
4. Run the script. If it executes successfully, the tokenizer paths will be updated, allowing them to be loaded locally.
|
| 171 |
+
|
| 172 |
+
### WavLM-based Speaker Similarity (Using Offline Models)
|
| 173 |
+
|
| 174 |
+
If your system is unable to access huggingface.co from the terminal and you want to calculate `WavLM` based speaker similarity, you need to download the pretrained model first, as illustrated [here](../../pretrained/README.md).
|
Amphion/egs/metrics/run.sh
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
######## Build Experiment Environment ###########
|
| 7 |
+
exp_dir=$(cd `dirname $0`; pwd)
|
| 8 |
+
work_dir=$(dirname $(dirname $exp_dir))
|
| 9 |
+
|
| 10 |
+
export WORK_DIR=$work_dir
|
| 11 |
+
export PYTHONPATH=$work_dir
|
| 12 |
+
export PYTHONIOENCODING=UTF-8
|
| 13 |
+
|
| 14 |
+
######## Parse the Given Parameters from the Commond ###########
|
| 15 |
+
options=$(getopt -o c:n:s --long gpu:,reference_folder:,generated_folder:,dump_folder:,metrics:,fs:,align_method:,energy_db_scale:,f0_subtract_mean:,similarity_model:,similarity_mode:,ltr_path:,intelligibility_mode:,language: -- "$@")
|
| 16 |
+
eval set -- "$options"
|
| 17 |
+
|
| 18 |
+
while true; do
|
| 19 |
+
case $1 in
|
| 20 |
+
# Visible GPU machines. The default value is "0".
|
| 21 |
+
--gpu) shift; gpu=$1 ; shift ;;
|
| 22 |
+
# Reference Audio Folder
|
| 23 |
+
--reference_folder) shift; ref_dir=$1 ; shift ;;
|
| 24 |
+
# Generated Audio Folder
|
| 25 |
+
--generated_folder) shift; deg_dir=$1 ; shift ;;
|
| 26 |
+
# Result Dumping Folder
|
| 27 |
+
--dump_folder) shift; dump_dir=$1 ; shift ;;
|
| 28 |
+
# Metrics to Compute
|
| 29 |
+
--metrics) shift; metrics=$1 ; shift ;;
|
| 30 |
+
# Sampling Rate
|
| 31 |
+
--fs) shift; fs=$1 ; shift ;;
|
| 32 |
+
|
| 33 |
+
# Method for aligning F0. The default value is "cut"
|
| 34 |
+
--align_method) shift; align_method=$1 ; shift ;;
|
| 35 |
+
# Method for normalizing F0. The default value is "True"
|
| 36 |
+
--f0_subtract_mean) shift; f0_subtract_mean=$1 ; shift ;;
|
| 37 |
+
# Method for normalizing Energy. The default value is "True"
|
| 38 |
+
--energy_db_scale) shift; energy_db_scale=$1 ; shift ;;
|
| 39 |
+
|
| 40 |
+
# Model for computing speaker similarity. The default value is "wavlm"
|
| 41 |
+
--similarity_model) shift; similarity_model=$1 ; shift ;;
|
| 42 |
+
# Mode for computing speaker similarity. The default value is "pairwith"
|
| 43 |
+
--similarity_mode) shift; similarity_mode=$1 ; shift ;;
|
| 44 |
+
|
| 45 |
+
# Path for the transcript.
|
| 46 |
+
--ltr_path) shift; ltr_path=$1 ; shift ;;
|
| 47 |
+
# Mode for computing CER and WER. The default value is "gt_audio"
|
| 48 |
+
--intelligibility_mode) shift; intelligibility_mode=$1 ; shift ;;
|
| 49 |
+
# Language for computing CER and WER. The default value is "english"
|
| 50 |
+
--language) shift; language=$1 ; shift ;;
|
| 51 |
+
|
| 52 |
+
--) shift ; break ;;
|
| 53 |
+
*) echo "Invalid option: $1" exit 1 ;;
|
| 54 |
+
esac
|
| 55 |
+
done
|
| 56 |
+
|
| 57 |
+
### Value check ###
|
| 58 |
+
if [ -z "$ref_dir" ]; then
|
| 59 |
+
echo "[Error] Please specify the reference_folder"
|
| 60 |
+
exit 1
|
| 61 |
+
fi
|
| 62 |
+
|
| 63 |
+
if [ -z "$deg_dir" ]; then
|
| 64 |
+
echo "[Error] Please specify the generated_folder"
|
| 65 |
+
exit 1
|
| 66 |
+
fi
|
| 67 |
+
|
| 68 |
+
if [ -z "$dump_dir" ]; then
|
| 69 |
+
echo "[Error] Please specify the dump_folder"
|
| 70 |
+
exit 1
|
| 71 |
+
fi
|
| 72 |
+
|
| 73 |
+
if [ -z "$metrics" ]; then
|
| 74 |
+
echo "[Error] Please specify the metrics"
|
| 75 |
+
exit 1
|
| 76 |
+
fi
|
| 77 |
+
|
| 78 |
+
if [ -z "$gpu" ]; then
|
| 79 |
+
gpu="0"
|
| 80 |
+
fi
|
| 81 |
+
|
| 82 |
+
if [ -z "$fs" ]; then
|
| 83 |
+
fs="None"
|
| 84 |
+
fi
|
| 85 |
+
|
| 86 |
+
if [ -z "$align_method" ]; then
|
| 87 |
+
align_method="dtw"
|
| 88 |
+
fi
|
| 89 |
+
|
| 90 |
+
if [ -z "$energy_db_scale" ]; then
|
| 91 |
+
energy_db_scale="True"
|
| 92 |
+
fi
|
| 93 |
+
|
| 94 |
+
if [ -z "$f0_subtract_mean" ]; then
|
| 95 |
+
f0_subtract_mean="True"
|
| 96 |
+
fi
|
| 97 |
+
|
| 98 |
+
if [ -z "$similarity_model" ]; then
|
| 99 |
+
similarity_model="wavlm"
|
| 100 |
+
fi
|
| 101 |
+
|
| 102 |
+
if [ -z "$similarity_mode" ]; then
|
| 103 |
+
similarity_mode="pairwith"
|
| 104 |
+
fi
|
| 105 |
+
|
| 106 |
+
if [ -z "$ltr_path" ]; then
|
| 107 |
+
ltr_path="None"
|
| 108 |
+
fi
|
| 109 |
+
|
| 110 |
+
if [ -z "$intelligibility_mode" ]; then
|
| 111 |
+
intelligibility_mode="gt_audio"
|
| 112 |
+
fi
|
| 113 |
+
|
| 114 |
+
if [ -z "$language" ]; then
|
| 115 |
+
language="english"
|
| 116 |
+
fi
|
| 117 |
+
|
| 118 |
+
######## Calculate Objective Metrics ###########
|
| 119 |
+
CUDA_VISIBLE_DEVICES=$gpu python "$work_dir"/bins/calc_metrics.py \
|
| 120 |
+
--ref_dir $ref_dir \
|
| 121 |
+
--deg_dir $deg_dir \
|
| 122 |
+
--dump_dir $dump_dir \
|
| 123 |
+
--metrics $metrics \
|
| 124 |
+
--fs $fs \
|
| 125 |
+
--align_method $align_method \
|
| 126 |
+
--db_scale $energy_db_scale \
|
| 127 |
+
--f0_subtract_mean $f0_subtract_mean \
|
| 128 |
+
--similarity_model $similarity_model \
|
| 129 |
+
--similarity_mode $similarity_mode \
|
| 130 |
+
--ltr_path $ltr_path \
|
| 131 |
+
--intelligibility_mode $intelligibility_mode \
|
| 132 |
+
--language $language
|
Amphion/egs/svc/TransformerSVC/exp_config.json
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_config": "config/transformer.json",
|
| 3 |
+
"model_type": "TransformerSVC",
|
| 4 |
+
"dataset": [
|
| 5 |
+
"m4singer",
|
| 6 |
+
"opencpop",
|
| 7 |
+
"opensinger",
|
| 8 |
+
"svcc",
|
| 9 |
+
"vctk"
|
| 10 |
+
],
|
| 11 |
+
"dataset_path": {
|
| 12 |
+
// TODO: Fill in your dataset path
|
| 13 |
+
"m4singer": "[M4Singer dataset path]",
|
| 14 |
+
"opencpop": "[Opencpop dataset path]",
|
| 15 |
+
"opensinger": "[OpenSinger dataset path]",
|
| 16 |
+
"svcc": "[SVCC dataset path]",
|
| 17 |
+
"vctk": "[VCTK dataset path]"
|
| 18 |
+
},
|
| 19 |
+
// TODO: Fill in the output log path. The default value is "Amphion/ckpts/svc"
|
| 20 |
+
"log_dir": "ckpts/svc",
|
| 21 |
+
"preprocess": {
|
| 22 |
+
// TODO: Fill in the output data path. The default value is "Amphion/data"
|
| 23 |
+
"processed_dir": "data",
|
| 24 |
+
// Config for features extraction
|
| 25 |
+
"extract_mel": true,
|
| 26 |
+
"extract_pitch": true,
|
| 27 |
+
"extract_energy": true,
|
| 28 |
+
"extract_whisper_feature": true,
|
| 29 |
+
"extract_contentvec_feature": true,
|
| 30 |
+
"extract_wenet_feature": false,
|
| 31 |
+
"whisper_batch_size": 30, // decrease it if your GPU is out of memory
|
| 32 |
+
"contentvec_batch_size": 1,
|
| 33 |
+
// Fill in the content-based pretrained model's path
|
| 34 |
+
"contentvec_file": "pretrained/contentvec/checkpoint_best_legacy_500.pt",
|
| 35 |
+
"wenet_model_path": "pretrained/wenet/20220506_u2pp_conformer_exp/final.pt",
|
| 36 |
+
"wenet_config": "pretrained/wenet/20220506_u2pp_conformer_exp/train.yaml",
|
| 37 |
+
"whisper_model": "medium",
|
| 38 |
+
"whisper_model_path": "pretrained/whisper/medium.pt",
|
| 39 |
+
// Config for features usage
|
| 40 |
+
"use_mel": true,
|
| 41 |
+
"use_min_max_norm_mel": true,
|
| 42 |
+
"use_frame_pitch": true,
|
| 43 |
+
"use_frame_energy": true,
|
| 44 |
+
"use_spkid": true,
|
| 45 |
+
"use_whisper": true,
|
| 46 |
+
"use_contentvec": true,
|
| 47 |
+
"use_wenet": false,
|
| 48 |
+
"n_mel": 100,
|
| 49 |
+
"sample_rate": 24000
|
| 50 |
+
},
|
| 51 |
+
"model": {
|
| 52 |
+
"condition_encoder": {
|
| 53 |
+
// Config for features usage
|
| 54 |
+
"use_whisper": true,
|
| 55 |
+
"use_contentvec": true,
|
| 56 |
+
"use_wenet": false,
|
| 57 |
+
"whisper_dim": 1024,
|
| 58 |
+
"contentvec_dim": 256,
|
| 59 |
+
"wenet_dim": 512,
|
| 60 |
+
"use_singer_encoder": false,
|
| 61 |
+
"pitch_min": 50,
|
| 62 |
+
"pitch_max": 1100
|
| 63 |
+
},
|
| 64 |
+
"transformer": {
|
| 65 |
+
// 'conformer' or 'transformer'
|
| 66 |
+
"type": "conformer",
|
| 67 |
+
"input_dim": 384,
|
| 68 |
+
"output_dim": 100,
|
| 69 |
+
"n_heads": 2,
|
| 70 |
+
"n_layers": 6,
|
| 71 |
+
"filter_channels": 512,
|
| 72 |
+
"dropout": 0.1,
|
| 73 |
+
}
|
| 74 |
+
},
|
| 75 |
+
"train": {
|
| 76 |
+
"batch_size": 64,
|
| 77 |
+
"gradient_accumulation_step": 1,
|
| 78 |
+
"max_epoch": -1, // -1 means no limit
|
| 79 |
+
"save_checkpoint_stride": [
|
| 80 |
+
50,
|
| 81 |
+
50
|
| 82 |
+
],
|
| 83 |
+
"keep_last": [
|
| 84 |
+
5,
|
| 85 |
+
-1
|
| 86 |
+
],
|
| 87 |
+
"run_eval": [
|
| 88 |
+
false,
|
| 89 |
+
true
|
| 90 |
+
],
|
| 91 |
+
"adamw": {
|
| 92 |
+
"lr": 4.0e-4
|
| 93 |
+
},
|
| 94 |
+
"reducelronplateau": {
|
| 95 |
+
"factor": 0.8,
|
| 96 |
+
"patience": 10,
|
| 97 |
+
"min_lr": 1.0e-4
|
| 98 |
+
},
|
| 99 |
+
"dataloader": {
|
| 100 |
+
"num_worker": 8,
|
| 101 |
+
"pin_memory": true
|
| 102 |
+
},
|
| 103 |
+
"sampler": {
|
| 104 |
+
"holistic_shuffle": false,
|
| 105 |
+
"drop_last": true
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
}
|
Amphion/egs/svc/VitsSVC/README.md
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VITS for Singing Voice Conversion
|
| 2 |
+
|
| 3 |
+
This is an implementation of VITS as acoustic model for end-to-end singing voice conversion. Adapted from [so-vits-svc](https://github.com/svc-develop-team/so-vits-svc), SoftVC content encoder is used to extract content features from the source audio. These feature vectors are directly fed into VITS without the need for conversion to a text-based intermediate representation.
|
| 4 |
+
|
| 5 |
+
There are four stages in total:
|
| 6 |
+
|
| 7 |
+
1. Data preparation
|
| 8 |
+
2. Features extraction
|
| 9 |
+
3. Training
|
| 10 |
+
4. Inference/conversion
|
| 11 |
+
|
| 12 |
+
> **NOTE:** You need to run every command of this recipe in the `Amphion` root path:
|
| 13 |
+
> ```bash
|
| 14 |
+
> cd Amphion
|
| 15 |
+
> ```
|
| 16 |
+
|
| 17 |
+
## 1. Data Preparation
|
| 18 |
+
|
| 19 |
+
### Dataset Download
|
| 20 |
+
|
| 21 |
+
By default, we utilize the five datasets for training: M4Singer, Opencpop, OpenSinger, SVCC, and VCTK. How to download them is detailed [here](../../datasets/README.md).
|
| 22 |
+
|
| 23 |
+
### Configuration
|
| 24 |
+
|
| 25 |
+
Specify the dataset paths in `exp_config.json`. Note that you can change the `dataset` list to use your preferred datasets.
|
| 26 |
+
|
| 27 |
+
```json
|
| 28 |
+
"dataset": [
|
| 29 |
+
"m4singer",
|
| 30 |
+
"opencpop",
|
| 31 |
+
"opensinger",
|
| 32 |
+
"svcc",
|
| 33 |
+
"vctk"
|
| 34 |
+
],
|
| 35 |
+
"dataset_path": {
|
| 36 |
+
// TODO: Fill in your dataset path
|
| 37 |
+
"m4singer": "[M4Singer dataset path]",
|
| 38 |
+
"opencpop": "[Opencpop dataset path]",
|
| 39 |
+
"opensinger": "[OpenSinger dataset path]",
|
| 40 |
+
"svcc": "[SVCC dataset path]",
|
| 41 |
+
"vctk": "[VCTK dataset path]"
|
| 42 |
+
},
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## 2. Features Extraction
|
| 46 |
+
|
| 47 |
+
### Content-based Pretrained Models Download
|
| 48 |
+
|
| 49 |
+
By default, we utilize ContentVec and Whisper to extract content features. How to download them is detailed [here](../../../pretrained/README.md).
|
| 50 |
+
|
| 51 |
+
### Configuration
|
| 52 |
+
|
| 53 |
+
Specify the dataset path and the output path for saving the processed data and the training model in `exp_config.json`:
|
| 54 |
+
|
| 55 |
+
```json
|
| 56 |
+
// TODO: Fill in the output log path. The default value is "Amphion/ckpts/svc"
|
| 57 |
+
"log_dir": "ckpts/svc",
|
| 58 |
+
"preprocess": {
|
| 59 |
+
// TODO: Fill in the output data path. The default value is "Amphion/data"
|
| 60 |
+
"processed_dir": "data",
|
| 61 |
+
...
|
| 62 |
+
},
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
### Run
|
| 66 |
+
|
| 67 |
+
Run the `run.sh` as the preproces stage (set `--stage 1`).
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
sh egs/svc/VitsSVC/run.sh --stage 1
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "1"`.
|
| 74 |
+
|
| 75 |
+
## 3. Training
|
| 76 |
+
|
| 77 |
+
### Configuration
|
| 78 |
+
|
| 79 |
+
We provide the default hyparameters in the `exp_config.json`. They can work on single NVIDIA-24g GPU. You can adjust them based on you GPU machines.
|
| 80 |
+
|
| 81 |
+
```json
|
| 82 |
+
"train": {
|
| 83 |
+
"batch_size": 32,
|
| 84 |
+
...
|
| 85 |
+
"adamw": {
|
| 86 |
+
"lr": 2.0e-4
|
| 87 |
+
},
|
| 88 |
+
...
|
| 89 |
+
}
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
### Run
|
| 93 |
+
|
| 94 |
+
Run the `run.sh` as the training stage (set `--stage 2`). Specify a experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/svc/[YourExptName]`.
|
| 95 |
+
|
| 96 |
+
```bash
|
| 97 |
+
sh egs/svc/VitsSVC/run.sh --stage 2 --name [YourExptName]
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`.
|
| 101 |
+
|
| 102 |
+
## 4. Inference/Conversion
|
| 103 |
+
|
| 104 |
+
### Run
|
| 105 |
+
|
| 106 |
+
For inference/conversion, you need to specify the following configurations when running `run.sh`:
|
| 107 |
+
|
| 108 |
+
| Parameters | Description | Example |
|
| 109 |
+
| --------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| 110 |
+
| `--infer_expt_dir` | The experimental directory which contains `checkpoint` | `[Your path to save logs and checkpoints]/[YourExptName]` |
|
| 111 |
+
| `--infer_output_dir` | The output directory to save inferred audios. | `[Your path to save logs and checkpoints]/[YourExptName]/result` |
|
| 112 |
+
| `--infer_source_file` or `--infer_source_audio_dir` | The inference source (can be a json file or a dir). | The `infer_source_file` could be `[Your path to save processed data]/[YourDataset]/test.json`, and the `infer_source_audio_dir` is a folder which includes several audio files (*.wav, *.mp3 or *.flac). |
|
| 113 |
+
| `--infer_target_speaker` | The target speaker you want to convert into. You can refer to `[Your path to save logs and checkpoints]/[YourExptName]/singers.json` to choose a trained speaker. | For opencpop dataset, the speaker name would be `opencpop_female1`. |
|
| 114 |
+
| `--infer_key_shift` | How many semitones you want to transpose. | `"autoshfit"` (by default), `3`, `-3`, etc. |
|
| 115 |
+
|
| 116 |
+
For example, if you want to make `opencpop_female1` sing the songs in the `[Your Audios Folder]`, just run:
|
| 117 |
+
|
| 118 |
+
```bash
|
| 119 |
+
sh egs/svc/VitsSVC/run.sh --stage 3 --gpu "0" \
|
| 120 |
+
--infer_expt_dir Amphion/ckpts/svc/[YourExptName] \
|
| 121 |
+
--infer_output_dir Amphion/ckpts/svc/[YourExptName]/result \
|
| 122 |
+
--infer_source_audio_dir [Your Audios Folder] \
|
| 123 |
+
--infer_target_speaker "opencpop_female1" \
|
| 124 |
+
--infer_key_shift "autoshift"
|
| 125 |
+
```
|
Amphion/egs/tta/README.md
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Amphion Text-to-Audio (TTA) Recipe
|
| 2 |
+
|
| 3 |
+
## Quick Start
|
| 4 |
+
|
| 5 |
+
We provide a **[beginner recipe](RECIPE.md)** to demonstrate how to train a cutting edge TTA model. Specifically, it is designed as a latent diffusion model like [AudioLDM](https://arxiv.org/abs/2301.12503), [Make-an-Audio](https://arxiv.org/abs/2301.12661), and [AUDIT](https://arxiv.org/abs/2304.00830).
|
| 6 |
+
|
| 7 |
+
## Supported Model Architectures
|
| 8 |
+
|
| 9 |
+
Until now, Amphion has supported a latent diffusion based text-to-audio model:
|
| 10 |
+
|
| 11 |
+
<br>
|
| 12 |
+
<div align="center">
|
| 13 |
+
<img src="../../imgs/tta/DiffusionTTA.png" width="65%">
|
| 14 |
+
</div>
|
| 15 |
+
<br>
|
| 16 |
+
|
| 17 |
+
Similar to [AUDIT](https://arxiv.org/abs/2304.00830), we implement it in two-stage training:
|
| 18 |
+
1. Training the VAE which is called `AutoencoderKL` in Amphion.
|
| 19 |
+
2. Training the conditional latent diffusion model which is called `AudioLDM` in Amphion.
|
Amphion/egs/tta/audioldm/exp_config.json
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_config": "egs/tta/audioldm/exp_config_base.json",
|
| 3 |
+
"dataset": [
|
| 4 |
+
"AudioCaps"
|
| 5 |
+
],
|
| 6 |
+
"preprocess": {
|
| 7 |
+
// Specify the output root path to save the processed data
|
| 8 |
+
"processed_dir": "data",
|
| 9 |
+
// For example: "/home/TTADataset/processed_data"
|
| 10 |
+
|
| 11 |
+
// feature
|
| 12 |
+
"use_spkid": false,
|
| 13 |
+
"use_uv": false,
|
| 14 |
+
"use_frame_pitch": false,
|
| 15 |
+
"use_phone_pitch": false,
|
| 16 |
+
"use_frame_energy": false,
|
| 17 |
+
"use_phone_energy": false,
|
| 18 |
+
"use_mel": false,
|
| 19 |
+
"use_audio": false,
|
| 20 |
+
"use_label": false,
|
| 21 |
+
"use_one_hot": false,
|
| 22 |
+
// feature for text to audio
|
| 23 |
+
"use_caption": true,
|
| 24 |
+
"use_melspec": true,
|
| 25 |
+
"use_wav": false,
|
| 26 |
+
// feature dir
|
| 27 |
+
"melspec_dir": "mel",
|
| 28 |
+
"wav_dir": "wav"
|
| 29 |
+
},
|
| 30 |
+
// Specify the output root path to save model ckpts and logs
|
| 31 |
+
"log_dir": "ckpts/tta",
|
| 32 |
+
// For example: "/home/TTADataset/processed_data/logs"
|
| 33 |
+
|
| 34 |
+
// model
|
| 35 |
+
"model": {
|
| 36 |
+
"audioldm": {
|
| 37 |
+
"image_size": 32,
|
| 38 |
+
"in_channels": 4,
|
| 39 |
+
"out_channels": 4,
|
| 40 |
+
"model_channels": 256,
|
| 41 |
+
"attention_resolutions": [4, 2, 1],
|
| 42 |
+
"num_res_blocks": 2,
|
| 43 |
+
"channel_mult": [1, 2, 4],
|
| 44 |
+
"num_heads": 8,
|
| 45 |
+
"use_spatial_transformer": true,
|
| 46 |
+
"transformer_depth": 1,
|
| 47 |
+
"context_dim": 768,
|
| 48 |
+
"use_checkpoint": true,
|
| 49 |
+
"legacy": false
|
| 50 |
+
},
|
| 51 |
+
"autoencoderkl": {
|
| 52 |
+
"ch": 128,
|
| 53 |
+
"ch_mult": [1,1,2,2,4],
|
| 54 |
+
"num_res_blocks": 2,
|
| 55 |
+
"in_channels": 1,
|
| 56 |
+
"z_channels": 4,
|
| 57 |
+
"out_ch": 1,
|
| 58 |
+
"double_z": true
|
| 59 |
+
},
|
| 60 |
+
"noise_scheduler": {
|
| 61 |
+
"num_train_timesteps": 1000,
|
| 62 |
+
"beta_start": 0.00085,
|
| 63 |
+
"beta_end": 0.012,
|
| 64 |
+
"beta_schedule": "scaled_linear",
|
| 65 |
+
"clip_sample": false,
|
| 66 |
+
"steps_offset": 1,
|
| 67 |
+
"set_alpha_to_one": false,
|
| 68 |
+
"skip_prk_steps": true,
|
| 69 |
+
"prediction_type": "epsilon"
|
| 70 |
+
},
|
| 71 |
+
"autoencoder_path": "ckpts/tta/autoencoder_kl_debug/checkpoints/step-0445000_loss-0.3306.pt"
|
| 72 |
+
},
|
| 73 |
+
|
| 74 |
+
// train
|
| 75 |
+
"train": {
|
| 76 |
+
"adam": {
|
| 77 |
+
"lr": 5.0e-5
|
| 78 |
+
},
|
| 79 |
+
"ddp": false,
|
| 80 |
+
"random_seed": 12345,
|
| 81 |
+
"batch_size": 12,
|
| 82 |
+
"epochs": 50000,
|
| 83 |
+
"max_steps": 1000000,
|
| 84 |
+
"total_training_steps": 800000,
|
| 85 |
+
"save_summary_steps": 1000,
|
| 86 |
+
"save_checkpoints_steps": 5000,
|
| 87 |
+
"valid_interval": 5000,
|
| 88 |
+
"keep_checkpoint_max": 100
|
| 89 |
+
}
|
| 90 |
+
}
|
Amphion/egs/tta/audioldm/run_train.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
######## Build Experiment Environment ###########
|
| 7 |
+
exp_dir=$(cd `dirname $0`; pwd)
|
| 8 |
+
work_dir=$(dirname $(dirname $(dirname $exp_dir)))
|
| 9 |
+
|
| 10 |
+
export WORK_DIR=$work_dir
|
| 11 |
+
export PYTHONPATH=$work_dir
|
| 12 |
+
export PYTHONIOENCODING=UTF-8
|
| 13 |
+
|
| 14 |
+
######## Set Experiment Configuration ###########
|
| 15 |
+
exp_config="$exp_dir/exp_config.json"
|
| 16 |
+
exp_name="audioldm_debug_latent_size_4_5_39"
|
| 17 |
+
|
| 18 |
+
num_workers=8
|
| 19 |
+
export CUDA_VISIBLE_DEVICES="0"
|
| 20 |
+
|
| 21 |
+
######## Train Model ###########
|
| 22 |
+
python "${work_dir}"/bins/tta/train_tta.py \
|
| 23 |
+
--config=$exp_config \
|
| 24 |
+
--num_workers=$num_workers \
|
| 25 |
+
--exp_name=$exp_name \
|
| 26 |
+
--stdout_interval=25 \
|
Amphion/egs/tta/audioldm/run_train_latent_4_10_78.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
######## Build Experiment Environment ###########
|
| 7 |
+
exp_dir=$(cd `dirname $0`; pwd)
|
| 8 |
+
work_dir=$(dirname $(dirname $(dirname $exp_dir)))
|
| 9 |
+
|
| 10 |
+
export WORK_DIR=$work_dir
|
| 11 |
+
export PYTHONPATH=$work_dir
|
| 12 |
+
export PYTHONIOENCODING=UTF-8
|
| 13 |
+
|
| 14 |
+
######## Set Experiment Configuration ###########
|
| 15 |
+
exp_config="$exp_dir/exp_config_latent_4_10_78.json"
|
| 16 |
+
exp_name="audioldm_debug_latent_size_4_10_78"
|
| 17 |
+
|
| 18 |
+
num_workers=8
|
| 19 |
+
export CUDA_VISIBLE_DEVICES="0"
|
| 20 |
+
|
| 21 |
+
######## Train Model ###########
|
| 22 |
+
python "${work_dir}"/bins/tta/train_tta.py \
|
| 23 |
+
--config=$exp_config \
|
| 24 |
+
--num_workers=$num_workers \
|
| 25 |
+
--exp_name=$exp_name \
|
| 26 |
+
--stdout_interval=25 \
|
Amphion/egs/tta/autoencoderkl/run_train_latent_4_10_78.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
######## Build Experiment Environment ###########
|
| 7 |
+
exp_dir=$(cd `dirname $0`; pwd)
|
| 8 |
+
work_dir=$(dirname $(dirname $(dirname $exp_dir)))
|
| 9 |
+
|
| 10 |
+
export WORK_DIR=$work_dir
|
| 11 |
+
export PYTHONPATH=$work_dir
|
| 12 |
+
export PYTHONIOENCODING=UTF-8
|
| 13 |
+
|
| 14 |
+
######## Set Experiment Configuration ###########
|
| 15 |
+
exp_config="$exp_dir/exp_config_latent_4_10_78.json"
|
| 16 |
+
exp_name="autoencoder_kl_debug_latent_size_4_10_78"
|
| 17 |
+
|
| 18 |
+
num_workers=8
|
| 19 |
+
export CUDA_VISIBLE_DEVICES="0"
|
| 20 |
+
|
| 21 |
+
######## Train Model ###########
|
| 22 |
+
python "${work_dir}"/bins/tta/train_tta.py \
|
| 23 |
+
--config=$exp_config \
|
| 24 |
+
--num_workers=$num_workers \
|
| 25 |
+
--exp_name=$exp_name \
|
| 26 |
+
--stdout_interval=25 \
|
Amphion/egs/tts/FastSpeech2/prepare_mfa.sh
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
#!/bin/bash
|
| 7 |
+
|
| 8 |
+
# Navigate to the 'pretrained' directory
|
| 9 |
+
cd pretrained || { echo "Failed to change directory to 'pretrained'"; exit 1; }
|
| 10 |
+
|
| 11 |
+
# Create and navigate to the 'mfa' directory
|
| 12 |
+
mkdir -p mfa && cd mfa || { echo "Failed to create or change directory to 'mfa'"; exit 1; }
|
| 13 |
+
|
| 14 |
+
# Define the MFA file URL and the file name
|
| 15 |
+
mfa_url="https://github.com/MontrealCorpusTools/Montreal-Forced-Aligner/releases/download/v1.1.0-beta.2/montreal-forced-aligner_linux.tar.gz"
|
| 16 |
+
mfa_file="montreal-forced-aligner_linux.tar.gz"
|
| 17 |
+
|
| 18 |
+
# Download MFA if it doesn't exist
|
| 19 |
+
if [ ! -f "$mfa_file" ]; then
|
| 20 |
+
wget "$mfa_url" || { echo "Failed to download MFA"; exit 1; }
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
# Extract MFA
|
| 24 |
+
tar -zxvf "$mfa_file" || { echo "Failed to extract MFA"; exit 1; }
|
| 25 |
+
|
| 26 |
+
# Optionally, remove the tar.gz file after extraction
|
| 27 |
+
rm "$mfa_file"
|
| 28 |
+
|
| 29 |
+
echo "MFA setup completed successfully."
|
Amphion/egs/tts/FastSpeech2/run.sh
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
######## Build Experiment Environment ###########
|
| 7 |
+
exp_dir=$(cd `dirname $0`; pwd)
|
| 8 |
+
work_dir=$(dirname $(dirname $(dirname $exp_dir)))
|
| 9 |
+
|
| 10 |
+
export WORK_DIR=$work_dir
|
| 11 |
+
export PYTHONPATH=$work_dir
|
| 12 |
+
export PYTHONIOENCODING=UTF-8
|
| 13 |
+
|
| 14 |
+
cd $work_dir/modules/monotonic_align
|
| 15 |
+
mkdir -p monotonic_align
|
| 16 |
+
python setup.py build_ext --inplace
|
| 17 |
+
cd $work_dir
|
| 18 |
+
|
| 19 |
+
mfa_dir=$work_dir/pretrained/mfa
|
| 20 |
+
echo $mfa_dir
|
| 21 |
+
|
| 22 |
+
######## Parse the Given Parameters from the Commond ###########
|
| 23 |
+
# options=$(getopt -o c:n:s --long gpu:,config:,infer_expt_dir:,infer_output_dir:,infer_source_file:,infer_source_audio_dir:,infer_target_speaker:,infer_key_shift:,infer_vocoder_dir:,name:,stage: -- "$@")
|
| 24 |
+
options=$(getopt -o c:n:s --long gpu:,config:,infer_expt_dir:,infer_output_dir:,infer_mode:,infer_dataset:,infer_testing_set:,infer_text:,name:,stage:,vocoder_dir: -- "$@")
|
| 25 |
+
eval set -- "$options"
|
| 26 |
+
|
| 27 |
+
while true; do
|
| 28 |
+
case $1 in
|
| 29 |
+
# Experimental Configuration File
|
| 30 |
+
-c | --config) shift; exp_config=$1 ; shift ;;
|
| 31 |
+
# Experimental Name
|
| 32 |
+
-n | --name) shift; exp_name=$1 ; shift ;;
|
| 33 |
+
# Running Stage
|
| 34 |
+
-s | --stage) shift; running_stage=$1 ; shift ;;
|
| 35 |
+
# Visible GPU machines. The default value is "0".
|
| 36 |
+
--gpu) shift; gpu=$1 ; shift ;;
|
| 37 |
+
|
| 38 |
+
# [Only for Inference] The experiment dir. The value is like "[Your path to save logs and checkpoints]/[YourExptName]"
|
| 39 |
+
--infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;;
|
| 40 |
+
# [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result"
|
| 41 |
+
--infer_output_dir) shift; infer_output_dir=$1 ; shift ;;
|
| 42 |
+
# [Only for Inference] The inference mode. It can be "batch" to generate speech by batch, or "single" to generage a single clip of speech.
|
| 43 |
+
--infer_mode) shift; infer_mode=$1 ; shift ;;
|
| 44 |
+
# [Only for Inference] The inference dataset. It is only used when the inference model is "batch".
|
| 45 |
+
--infer_dataset) shift; infer_dataset=$1 ; shift ;;
|
| 46 |
+
# [Only for Inference] The inference testing set. It is only used when the inference model is "batch". It can be "test" set split from the dataset, or "golden_test" carefully selected from the testing set.
|
| 47 |
+
--infer_testing_set) shift; infer_testing_set=$1 ; shift ;;
|
| 48 |
+
# [Only for Inference] The text to be synthesized from. It is only used when the inference model is "single".
|
| 49 |
+
--infer_text) shift; infer_text=$1 ; shift ;;
|
| 50 |
+
# [Only for Inference] The output dir to the vocoder.
|
| 51 |
+
--vocoder_dir) shift; vocoder_dir=$1 ; shift ;;
|
| 52 |
+
|
| 53 |
+
--) shift ; break ;;
|
| 54 |
+
*) echo "Invalid option: $1" exit 1 ;;
|
| 55 |
+
esac
|
| 56 |
+
done
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
### Value check ###
|
| 60 |
+
if [ -z "$running_stage" ]; then
|
| 61 |
+
echo "[Error] Please specify the running stage"
|
| 62 |
+
exit 1
|
| 63 |
+
fi
|
| 64 |
+
|
| 65 |
+
if [ -z "$exp_config" ]; then
|
| 66 |
+
exp_config="${exp_dir}"/exp_config.json
|
| 67 |
+
fi
|
| 68 |
+
echo "Exprimental Configuration File: $exp_config"
|
| 69 |
+
|
| 70 |
+
if [ -z "$gpu" ]; then
|
| 71 |
+
gpu="0"
|
| 72 |
+
fi
|
| 73 |
+
|
| 74 |
+
######## Features Extraction ###########
|
| 75 |
+
if [ $running_stage -eq 1 ]; then
|
| 76 |
+
if [ ! -d "$mfa_dir/montreal-forced-aligner" ]; then
|
| 77 |
+
bash ${exp_dir}/prepare_mfa.sh
|
| 78 |
+
fi
|
| 79 |
+
CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/tts/preprocess.py \
|
| 80 |
+
--config=$exp_config \
|
| 81 |
+
--num_workers=4 \
|
| 82 |
+
--prepare_alignment=true
|
| 83 |
+
fi
|
| 84 |
+
|
| 85 |
+
######## Training ###########
|
| 86 |
+
if [ $running_stage -eq 2 ]; then
|
| 87 |
+
if [ -z "$exp_name" ]; then
|
| 88 |
+
echo "[Error] Please specify the experiments name"
|
| 89 |
+
exit 1
|
| 90 |
+
fi
|
| 91 |
+
echo "Exprimental Name: $exp_name"
|
| 92 |
+
|
| 93 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch "${work_dir}"/bins/tts/train.py \
|
| 94 |
+
--config $exp_config \
|
| 95 |
+
--exp_name $exp_name \
|
| 96 |
+
--log_level debug
|
| 97 |
+
fi
|
| 98 |
+
|
| 99 |
+
######## Inference ###########
|
| 100 |
+
if [ $running_stage -eq 3 ]; then
|
| 101 |
+
if [ -z "$infer_expt_dir" ]; then
|
| 102 |
+
echo "[Error] Please specify the experimental directionary. The value is like [Your path to save logs and checkpoints]/[YourExptName]"
|
| 103 |
+
exit 1
|
| 104 |
+
fi
|
| 105 |
+
|
| 106 |
+
if [ -z "$infer_output_dir" ]; then
|
| 107 |
+
infer_output_dir="$expt_dir/result"
|
| 108 |
+
fi
|
| 109 |
+
|
| 110 |
+
if [ -z "$vocoder_dir" ]; then
|
| 111 |
+
echo "[Error] Please specify the vocoder directory to reconstruct waveform from mel spectrogram."
|
| 112 |
+
exit 1
|
| 113 |
+
fi
|
| 114 |
+
|
| 115 |
+
if [ -z "$infer_mode" ]; then
|
| 116 |
+
echo "[Error] Please specify the inference mode, e.g., "batch", "single""
|
| 117 |
+
exit 1
|
| 118 |
+
fi
|
| 119 |
+
|
| 120 |
+
if [ "$infer_mode" = "batch" ] && [ -z "$infer_dataset" ]; then
|
| 121 |
+
echo "[Error] Please specify the dataset used in inference when the inference mode is batch"
|
| 122 |
+
exit 1
|
| 123 |
+
fi
|
| 124 |
+
|
| 125 |
+
if [ "$infer_mode" = "batch" ] && [ -z "$infer_testing_set" ]; then
|
| 126 |
+
echo "[Error] Please specify the testing set used in inference when the inference mode is batch"
|
| 127 |
+
exit 1
|
| 128 |
+
fi
|
| 129 |
+
|
| 130 |
+
if [ "$infer_mode" = "single" ] && [ -z "$infer_text" ]; then
|
| 131 |
+
echo "[Error] Please specify the text to be synthesized when the inference mode is single"
|
| 132 |
+
exit 1
|
| 133 |
+
fi
|
| 134 |
+
|
| 135 |
+
if [ "$infer_mode" = "single" ]; then
|
| 136 |
+
echo 'Text: ' ${infer_text}
|
| 137 |
+
infer_dataset=None
|
| 138 |
+
infer_testing_set=None
|
| 139 |
+
elif [ "$infer_mode" = "batch" ]; then
|
| 140 |
+
infer_text=''
|
| 141 |
+
fi
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/tts/inference.py \
|
| 145 |
+
--config $exp_config \
|
| 146 |
+
--acoustics_dir $infer_expt_dir \
|
| 147 |
+
--output_dir $infer_output_dir \
|
| 148 |
+
--mode $infer_mode \
|
| 149 |
+
--dataset $infer_dataset \
|
| 150 |
+
--testing_set $infer_testing_set \
|
| 151 |
+
--text "$infer_text" \
|
| 152 |
+
--log_level debug \
|
| 153 |
+
--vocoder_dir $vocoder_dir
|
| 154 |
+
|
| 155 |
+
fi
|
Amphion/egs/tts/NaturalSpeech2/run_inference.sh
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
######## Build Experiment Environment ###########
|
| 8 |
+
exp_dir=$(cd `dirname $0`; pwd)
|
| 9 |
+
work_dir=$(dirname $(dirname $(dirname $exp_dir)))
|
| 10 |
+
|
| 11 |
+
export WORK_DIR=$work_dir
|
| 12 |
+
export PYTHONPATH=$work_dir
|
| 13 |
+
export PYTHONIOENCODING=UTF-8
|
| 14 |
+
|
| 15 |
+
######## Set Experiment Configuration ###########
|
| 16 |
+
exp_config="$exp_dir/exp_config.json"
|
| 17 |
+
exp_name="ns2_libritts"
|
| 18 |
+
ref_audio="$work_dir/egs/tts/NaturalSpeech2/prompt_example/ref_audio.wav"
|
| 19 |
+
checkpoint_path="$work_dir/ckpts/tts/naturalspeech2_libritts/checkpoint/epoch-0089_step-0512912_loss-6.367693"
|
| 20 |
+
output_dir="$work_dir/output"
|
| 21 |
+
mode="single"
|
| 22 |
+
|
| 23 |
+
export CUDA_VISIBLE_DEVICES="0"
|
| 24 |
+
|
| 25 |
+
######## Parse Command Line Arguments ###########
|
| 26 |
+
while [[ $# -gt 0 ]]
|
| 27 |
+
do
|
| 28 |
+
key="$1"
|
| 29 |
+
|
| 30 |
+
case $key in
|
| 31 |
+
--text)
|
| 32 |
+
text="$2"
|
| 33 |
+
shift # past argument
|
| 34 |
+
shift # past value
|
| 35 |
+
;;
|
| 36 |
+
*) # unknown option
|
| 37 |
+
shift # past argument
|
| 38 |
+
;;
|
| 39 |
+
esac
|
| 40 |
+
done
|
| 41 |
+
|
| 42 |
+
######## Train Model ###########
|
| 43 |
+
python "${work_dir}"/bins/tts/inference.py \
|
| 44 |
+
--config=$exp_config \
|
| 45 |
+
--text="$text" \
|
| 46 |
+
--mode=$mode \
|
| 47 |
+
--checkpoint_path=$checkpoint_path \
|
| 48 |
+
--ref_audio=$ref_audio \
|
| 49 |
+
--output_dir=$output_dir \
|
Amphion/egs/tts/VALLE/README.md
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VALL-E Recipe
|
| 2 |
+
|
| 3 |
+
In this recipe, we will show how to train [VALL-E](https://arxiv.org/abs/2301.02111) using Amphion's infrastructure. VALL-E is a zero-shot TTS architecture that uses a neural codec language model with discrete codes.
|
| 4 |
+
|
| 5 |
+
There are four stages in total:
|
| 6 |
+
|
| 7 |
+
1. Data preparation
|
| 8 |
+
2. Features extraction
|
| 9 |
+
3. Training
|
| 10 |
+
4. Inference
|
| 11 |
+
|
| 12 |
+
> **NOTE:** You need to run every command of this recipe in the `Amphion` root path:
|
| 13 |
+
> ```bash
|
| 14 |
+
> cd Amphion
|
| 15 |
+
> ```
|
| 16 |
+
|
| 17 |
+
## 1. Data Preparation
|
| 18 |
+
|
| 19 |
+
### Dataset Download
|
| 20 |
+
You can use the commonly used TTS dataset to train the VALL-E model, e.g., LibriTTS, etc. We strongly recommend you use LibriTTS to train the VALL-E model for the first time. How to download the dataset is detailed [here](../../datasets/README.md).
|
| 21 |
+
|
| 22 |
+
### Configuration
|
| 23 |
+
|
| 24 |
+
After downloading the dataset, you can set the dataset paths in `exp_config.json`. Note that you can change the `dataset` list to use your preferred datasets.
|
| 25 |
+
|
| 26 |
+
```json
|
| 27 |
+
"dataset": [
|
| 28 |
+
"libritts",
|
| 29 |
+
],
|
| 30 |
+
"dataset_path": {
|
| 31 |
+
// TODO: Fill in your dataset path
|
| 32 |
+
"libritts": "[LibriTTS dataset path]",
|
| 33 |
+
},
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
## 2. Features Extraction
|
| 37 |
+
|
| 38 |
+
### Configuration
|
| 39 |
+
|
| 40 |
+
Specify the `processed_dir` and the `log_dir` and for saving the processed data and the checkpoints in `exp_config.json`:
|
| 41 |
+
|
| 42 |
+
```json
|
| 43 |
+
// TODO: Fill in the output log path. The default value is "Amphion/ckpts/tts"
|
| 44 |
+
"log_dir": "ckpts/tts",
|
| 45 |
+
"preprocess": {
|
| 46 |
+
// TODO: Fill in the output data path. The default value is "Amphion/data"
|
| 47 |
+
"processed_dir": "data",
|
| 48 |
+
...
|
| 49 |
+
},
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
### Run
|
| 53 |
+
|
| 54 |
+
Run the `run.sh` as the preprocess stage (set `--stage 1`):
|
| 55 |
+
|
| 56 |
+
```bash
|
| 57 |
+
sh egs/tts/VALLE/run.sh --stage 1
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "1"`.
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
## 3. Training
|
| 64 |
+
|
| 65 |
+
### Configuration
|
| 66 |
+
|
| 67 |
+
We provide the default hyperparameters in the `exp_config.json`. They can work on a single NVIDIA-24g GPU. You can adjust them based on your GPU machines.
|
| 68 |
+
|
| 69 |
+
```json
|
| 70 |
+
"train": {
|
| 71 |
+
"batch_size": 4,
|
| 72 |
+
}
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
### Train From Scratch
|
| 76 |
+
|
| 77 |
+
Run the `run.sh` as the training stage (set `--stage 2`). Specify an experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/tts/[YourExptName]`.
|
| 78 |
+
|
| 79 |
+
Specifically, VALL-E needs to train an autoregressive (AR) model and then a non-autoregressive (NAR) model. So, you can set `--model_train_stage 1` to train AR model, and set `--model_train_stage 2` to train NAR model, where `--ar_model_ckpt_dir` should be set as the checkpoint path to the trained AR model.
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
Train an AR model, just run:
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 1 --name [YourExptName]
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
Train a NAR model, just run:
|
| 89 |
+
```bash
|
| 90 |
+
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 2 --ar_model_ckpt_dir [ARModelPath] --name [YourExptName]
|
| 91 |
+
```
|
| 92 |
+
<!-- > **NOTE:** To train a NAR model, `--checkpoint_path` should be set as the checkpoint path to the trained AR model. -->
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
### Train From Existing Source
|
| 96 |
+
|
| 97 |
+
We support training from existing sources for various purposes. You can resume training the model from a checkpoint or fine-tune a model from another checkpoint.
|
| 98 |
+
|
| 99 |
+
By setting `--resume true`, the training will resume from the **latest checkpoint** from the current `[YourExptName]` by default. For example, if you want to resume training from the latest checkpoint in `Amphion/ckpts/tts/[YourExptName]/checkpoint`,
|
| 100 |
+
|
| 101 |
+
Train an AR model, just run:
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 1 --name [YourExptName] \
|
| 105 |
+
--resume true
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
Train a NAR model, just run:
|
| 109 |
+
```bash
|
| 110 |
+
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 2 --ar_model_ckpt_dir [ARModelPath] --name [YourExptName] \
|
| 111 |
+
--resume true
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
You can also choose a **specific checkpoint** for retraining by `--resume_from_ckpt_path` argument. For example, if you want to resume training from the checkpoint `Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificCheckpoint]`,
|
| 117 |
+
|
| 118 |
+
Train an AR model, just run:
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 1 --name [YourExptName] \
|
| 122 |
+
--resume true \
|
| 123 |
+
--resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificARCheckpoint]"
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
Train a NAR model, just run:
|
| 127 |
+
```bash
|
| 128 |
+
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 2 --ar_model_ckpt_dir [ARModelPath] --name [YourExptName] \
|
| 129 |
+
--resume true \
|
| 130 |
+
--resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificNARCheckpoint]"
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
If you want to **fine-tune from another checkpoint**, just use `--resume_type` and set it to `"finetune"`. For example, If you want to fine-tune the model from the checkpoint `Amphion/ckpts/tts/[AnotherExperiment]/checkpoint/[SpecificCheckpoint]`,
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
Train an AR model, just run:
|
| 138 |
+
|
| 139 |
+
```bash
|
| 140 |
+
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 1 --name [YourExptName] \
|
| 141 |
+
--resume true \
|
| 142 |
+
--resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificARCheckpoint]" \
|
| 143 |
+
--resume_type "finetune"
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
Train a NAR model, just run:
|
| 147 |
+
```bash
|
| 148 |
+
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 2 --ar_model_ckpt_dir [ARModelPath] --name [YourExptName] \
|
| 149 |
+
--resume true \
|
| 150 |
+
--resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificNARCheckpoint]" \
|
| 151 |
+
--resume_type "finetune"
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
> **NOTE:** The `--resume_type` is set as `"resume"` in default. It's not necessary to specify it when resuming training.
|
| 155 |
+
>
|
| 156 |
+
> The difference between `"resume"` and `"finetune"` is that the `"finetune"` will **only** load the pretrained model weights from the checkpoint, while the `"resume"` will load all the training states (including optimizer, scheduler, etc.) from the checkpoint.
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`.
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
## 4. Inference
|
| 165 |
+
|
| 166 |
+
### Configuration
|
| 167 |
+
|
| 168 |
+
For inference, you need to specify the following configurations when running `run.sh`:
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
| Parameters | Description | Example |
|
| 173 |
+
| --------------------- | -------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| 174 |
+
| `--infer_expt_dir` | The experimental directory of NAR model which contains `checkpoint` | `Amphion/ckpts/tts/[YourExptName]` |
|
| 175 |
+
| `--infer_output_dir` | The output directory to save inferred audios. | `Amphion/ckpts/tts/[YourExptName]/result` |
|
| 176 |
+
| `--infer_mode` | The inference mode, e.g., "`single`", "`batch`". | "`single`" to generate a clip of speech, "`batch`" to generate a batch of speech at a time. |
|
| 177 |
+
| `--infer_text` | The text to be synthesized. | "`This is a clip of generated speech with the given text from a TTS model.`" |
|
| 178 |
+
| `--infer_text_prompt` | The text prompt for inference. | The text prompt should be aligned with the audio prompt. |
|
| 179 |
+
| `--infer_audio_prompt` | The audio prompt for inference. | The audio prompt should be aligned with text prompt.|
|
| 180 |
+
| `--test_list_file` | The test list file used for batch inference. | The format of test list file is `text\|text_prompt\|audio_prompt`.|
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
### Run
|
| 184 |
+
For example, if you want to generate a single clip of speech, just run:
|
| 185 |
+
|
| 186 |
+
```bash
|
| 187 |
+
sh egs/tts/VALLE/run.sh --stage 3 --gpu "0" \
|
| 188 |
+
--infer_expt_dir Amphion/ckpts/tts/[YourExptName] \
|
| 189 |
+
--infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \
|
| 190 |
+
--infer_mode "single" \
|
| 191 |
+
--infer_text "This is a clip of generated speech with the given text from a TTS model." \
|
| 192 |
+
--infer_text_prompt "But even the unsuccessful dramatist has his moments." \
|
| 193 |
+
--infer_audio_prompt egs/tts/VALLE/prompt_examples/7176_92135_000004_000000.wav
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
We have released pre-trained VALL-E models, so you can download the pre-trained model and then generate speech following the above inference instruction. Specifically,
|
| 197 |
+
1. The pre-trained VALL-E trained on [LibriTTS](https://github.com/open-mmlab/Amphion/tree/main/egs/datasets#libritts) can be downloaded [here](https://huggingface.co/amphion/valle-libritts).
|
| 198 |
+
2. The pre-trained VALL-E trained on the part of [Libri-light](https://ai.meta.com/tools/libri-light/) (about 6k hours) can be downloaded [here](https://huggingface.co/amphion/valle_librilight_6k).
|
| 199 |
+
|
| 200 |
+
```bibtex
|
| 201 |
+
@article{wang2023neural,
|
| 202 |
+
title={Neural codec language models are zero-shot text to speech synthesizers},
|
| 203 |
+
author={Wang, Chengyi and Chen, Sanyuan and Wu, Yu and Zhang, Ziqiang and Zhou, Long and Liu, Shujie and Chen, Zhuo and Liu, Yanqing and Wang, Huaming and Li, Jinyu and others},
|
| 204 |
+
journal={arXiv preprint arXiv:2301.02111},
|
| 205 |
+
year={2023}
|
| 206 |
+
}
|
| 207 |
+
```
|
Amphion/egs/tts/VALLE/prompt_examples/5142_33396_000002_000004.wav
ADDED
|
Binary file (144 kB). View file
|
|
|
Amphion/egs/tts/VALLE/prompt_examples/7176_92135_000004_000000.normalized.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
But even the unsuccessful dramatist has his moments.
|
Amphion/egs/tts/VITS/exp_config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_config": "config/vits.json",
|
| 3 |
+
"model_type": "VITS",
|
| 4 |
+
"dataset": [
|
| 5 |
+
"LJSpeech",
|
| 6 |
+
//"hifitts"
|
| 7 |
+
],
|
| 8 |
+
"dataset_path": {
|
| 9 |
+
// TODO: Fill in your dataset path
|
| 10 |
+
"LJSpeech": "[LJSpeech dataset path]",
|
| 11 |
+
//"hifitts": "[Hi-Fi TTS dataset path]
|
| 12 |
+
},
|
| 13 |
+
// TODO: Fill in the output log path. The default value is "Amphion/ckpts/tts"
|
| 14 |
+
"log_dir": "ckpts/tts",
|
| 15 |
+
"preprocess": {
|
| 16 |
+
//"extract_audio":true,
|
| 17 |
+
"use_phone": true,
|
| 18 |
+
// linguistic features
|
| 19 |
+
"extract_phone": true,
|
| 20 |
+
"phone_extractor": "espeak", // "espeak, pypinyin, pypinyin_initials_finals, lexicon (only for language=en-us right now)"
|
| 21 |
+
// TODO: Fill in the output data path. The default value is "Amphion/data"
|
| 22 |
+
"processed_dir": "data",
|
| 23 |
+
"sample_rate": 22050, // target sampling rate
|
| 24 |
+
"valid_file": "valid.json", // validation set
|
| 25 |
+
//"use_spkid": true // use speaker ID to train multi-speaker TTS model
|
| 26 |
+
},
|
| 27 |
+
"model":{
|
| 28 |
+
//"n_speakers": 10 // number of speakers, greater than or equal to the number of speakers in the dataset(s) used. The default value is 0 if not specified.
|
| 29 |
+
},
|
| 30 |
+
"train": {
|
| 31 |
+
"batch_size": 16,
|
| 32 |
+
//"multi_speaker_training": true
|
| 33 |
+
}
|
| 34 |
+
}
|
Amphion/egs/vocoder/gan/bigvgan_large/exp_config.json
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_config": "egs/vocoder/gan/exp_config_base.json",
|
| 3 |
+
"preprocess": {
|
| 4 |
+
// acoustic features
|
| 5 |
+
"extract_mel": true,
|
| 6 |
+
"extract_audio": true,
|
| 7 |
+
|
| 8 |
+
// Features used for model training
|
| 9 |
+
"use_mel": true,
|
| 10 |
+
"use_audio": true
|
| 11 |
+
},
|
| 12 |
+
"model": {
|
| 13 |
+
"generator": "bigvgan",
|
| 14 |
+
"bigvgan": {
|
| 15 |
+
"resblock": "1",
|
| 16 |
+
"activation": "snakebeta",
|
| 17 |
+
"snake_logscale": true,
|
| 18 |
+
"upsample_rates": [
|
| 19 |
+
4,
|
| 20 |
+
4,
|
| 21 |
+
2,
|
| 22 |
+
2,
|
| 23 |
+
2,
|
| 24 |
+
2
|
| 25 |
+
],
|
| 26 |
+
"upsample_kernel_sizes": [
|
| 27 |
+
8,
|
| 28 |
+
8,
|
| 29 |
+
4,
|
| 30 |
+
4,
|
| 31 |
+
4,
|
| 32 |
+
4
|
| 33 |
+
],
|
| 34 |
+
"upsample_initial_channel": 1536,
|
| 35 |
+
"resblock_kernel_sizes": [
|
| 36 |
+
3,
|
| 37 |
+
7,
|
| 38 |
+
11
|
| 39 |
+
],
|
| 40 |
+
"resblock_dilation_sizes": [
|
| 41 |
+
[
|
| 42 |
+
1,
|
| 43 |
+
3,
|
| 44 |
+
5
|
| 45 |
+
],
|
| 46 |
+
[
|
| 47 |
+
1,
|
| 48 |
+
3,
|
| 49 |
+
5
|
| 50 |
+
],
|
| 51 |
+
[
|
| 52 |
+
1,
|
| 53 |
+
3,
|
| 54 |
+
5
|
| 55 |
+
]
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
},
|
| 59 |
+
"train": {
|
| 60 |
+
"criterions": [
|
| 61 |
+
"feature",
|
| 62 |
+
"discriminator",
|
| 63 |
+
"generator",
|
| 64 |
+
"mel",
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
"inference": {
|
| 68 |
+
"batch_size": 1,
|
| 69 |
+
}
|
| 70 |
+
}
|
Amphion/egs/vocoder/gan/hifigan/exp_config.json
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"base_config": "egs/vocoder/gan/exp_config_base.json",
|
| 3 |
+
"preprocess": {
|
| 4 |
+
// acoustic features
|
| 5 |
+
"extract_mel": true,
|
| 6 |
+
"extract_audio": true,
|
| 7 |
+
|
| 8 |
+
// Features used for model training
|
| 9 |
+
"use_mel": true,
|
| 10 |
+
"use_audio": true
|
| 11 |
+
},
|
| 12 |
+
"model": {
|
| 13 |
+
"generator": "hifigan",
|
| 14 |
+
"hifigan": {
|
| 15 |
+
"resblock": "2",
|
| 16 |
+
"upsample_rates": [
|
| 17 |
+
8,
|
| 18 |
+
8,
|
| 19 |
+
4
|
| 20 |
+
],
|
| 21 |
+
"upsample_kernel_sizes": [
|
| 22 |
+
16,
|
| 23 |
+
16,
|
| 24 |
+
8
|
| 25 |
+
],
|
| 26 |
+
"upsample_initial_channel": 256,
|
| 27 |
+
"resblock_kernel_sizes": [
|
| 28 |
+
3,
|
| 29 |
+
5,
|
| 30 |
+
7
|
| 31 |
+
],
|
| 32 |
+
"resblock_dilation_sizes": [
|
| 33 |
+
[
|
| 34 |
+
1,
|
| 35 |
+
2
|
| 36 |
+
],
|
| 37 |
+
[
|
| 38 |
+
2,
|
| 39 |
+
6
|
| 40 |
+
],
|
| 41 |
+
[
|
| 42 |
+
3,
|
| 43 |
+
12
|
| 44 |
+
]
|
| 45 |
+
]
|
| 46 |
+
}
|
| 47 |
+
},
|
| 48 |
+
"train": {
|
| 49 |
+
"criterions": [
|
| 50 |
+
"feature",
|
| 51 |
+
"discriminator",
|
| 52 |
+
"generator",
|
| 53 |
+
"mel",
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
"inference": {
|
| 57 |
+
"batch_size": 1,
|
| 58 |
+
}
|
| 59 |
+
}
|
Amphion/egs/vocoder/gan/hifigan/run.sh
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
######## Build Experiment Environment ###########
|
| 7 |
+
exp_dir=$(cd `dirname $0`; pwd)
|
| 8 |
+
work_dir=$(dirname $(dirname $(dirname $(dirname $exp_dir))))
|
| 9 |
+
|
| 10 |
+
export WORK_DIR=$work_dir
|
| 11 |
+
export PYTHONPATH=$work_dir
|
| 12 |
+
export PYTHONIOENCODING=UTF-8
|
| 13 |
+
|
| 14 |
+
######## Parse the Given Parameters from the Commond ###########
|
| 15 |
+
options=$(getopt -o c:n:s --long gpu:,config:,name:,stage:,checkpoint:,resume_type:,main_process_port:,infer_mode:,infer_datasets:,infer_feature_dir:,infer_audio_dir:,infer_expt_dir:,infer_output_dir: -- "$@")
|
| 16 |
+
eval set -- "$options"
|
| 17 |
+
|
| 18 |
+
while true; do
|
| 19 |
+
case $1 in
|
| 20 |
+
# Experimental Configuration File
|
| 21 |
+
-c | --config) shift; exp_config=$1 ; shift ;;
|
| 22 |
+
# Experimental Name
|
| 23 |
+
-n | --name) shift; exp_name=$1 ; shift ;;
|
| 24 |
+
# Running Stage
|
| 25 |
+
-s | --stage) shift; running_stage=$1 ; shift ;;
|
| 26 |
+
# Visible GPU machines. The default value is "0".
|
| 27 |
+
--gpu) shift; gpu=$1 ; shift ;;
|
| 28 |
+
|
| 29 |
+
# [Only for Training] The specific checkpoint path that you want to resume from.
|
| 30 |
+
--checkpoint) shift; checkpoint=$1 ; shift ;;
|
| 31 |
+
# [Only for Training] `resume` for loading all the things (including model weights, optimizer, scheduler, and random states). `finetune` for loading only the model weights.
|
| 32 |
+
--resume_type) shift; resume_type=$1 ; shift ;;
|
| 33 |
+
# [Only for Traiing] `main_process_port` for multi gpu training
|
| 34 |
+
--main_process_port) shift; main_process_port=$1 ; shift ;;
|
| 35 |
+
|
| 36 |
+
# [Only for Inference] The inference mode
|
| 37 |
+
--infer_mode) shift; infer_mode=$1 ; shift ;;
|
| 38 |
+
# [Only for Inference] The inferenced datasets
|
| 39 |
+
--infer_datasets) shift; infer_datasets=$1 ; shift ;;
|
| 40 |
+
# [Only for Inference] The feature dir for inference
|
| 41 |
+
--infer_feature_dir) shift; infer_feature_dir=$1 ; shift ;;
|
| 42 |
+
# [Only for Inference] The audio dir for inference
|
| 43 |
+
--infer_audio_dir) shift; infer_audio_dir=$1 ; shift ;;
|
| 44 |
+
# [Only for Inference] The experiment dir. The value is like "[Your path to save logs and checkpoints]/[YourExptName]"
|
| 45 |
+
--infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;;
|
| 46 |
+
# [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result"
|
| 47 |
+
--infer_output_dir) shift; infer_output_dir=$1 ; shift ;;
|
| 48 |
+
|
| 49 |
+
--) shift ; break ;;
|
| 50 |
+
*) echo "Invalid option: $1" exit 1 ;;
|
| 51 |
+
esac
|
| 52 |
+
done
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
### Value check ###
|
| 56 |
+
if [ -z "$running_stage" ]; then
|
| 57 |
+
echo "[Error] Please specify the running stage"
|
| 58 |
+
exit 1
|
| 59 |
+
fi
|
| 60 |
+
|
| 61 |
+
if [ -z "$exp_config" ]; then
|
| 62 |
+
exp_config="${exp_dir}"/exp_config.json
|
| 63 |
+
fi
|
| 64 |
+
echo "Exprimental Configuration File: $exp_config"
|
| 65 |
+
|
| 66 |
+
if [ -z "$gpu" ]; then
|
| 67 |
+
gpu="0"
|
| 68 |
+
fi
|
| 69 |
+
|
| 70 |
+
if [ -z "$main_process_port" ]; then
|
| 71 |
+
main_process_port=29500
|
| 72 |
+
fi
|
| 73 |
+
echo "Main Process Port: $main_process_port"
|
| 74 |
+
|
| 75 |
+
######## Features Extraction ###########
|
| 76 |
+
if [ $running_stage -eq 1 ]; then
|
| 77 |
+
CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/vocoder/preprocess.py \
|
| 78 |
+
--config $exp_config \
|
| 79 |
+
--num_workers 8
|
| 80 |
+
fi
|
| 81 |
+
|
| 82 |
+
######## Training ###########
|
| 83 |
+
if [ $running_stage -eq 2 ]; then
|
| 84 |
+
if [ -z "$exp_name" ]; then
|
| 85 |
+
echo "[Error] Please specify the experiments name"
|
| 86 |
+
exit 1
|
| 87 |
+
fi
|
| 88 |
+
echo "Exprimental Name: $exp_name"
|
| 89 |
+
|
| 90 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch \
|
| 91 |
+
--main_process_port "$main_process_port" \
|
| 92 |
+
"${work_dir}"/bins/vocoder/train.py \
|
| 93 |
+
--config "$exp_config" \
|
| 94 |
+
--exp_name "$exp_name" \
|
| 95 |
+
--log_level info \
|
| 96 |
+
--checkpoint "$checkpoint" \
|
| 97 |
+
--resume_type "$resume_type"
|
| 98 |
+
fi
|
| 99 |
+
|
| 100 |
+
######## Inference/Conversion ###########
|
| 101 |
+
if [ $running_stage -eq 3 ]; then
|
| 102 |
+
if [ -z "$infer_expt_dir" ]; then
|
| 103 |
+
echo "[Error] Please specify the experimental directionary. The value is like [Your path to save logs and checkpoints]/[YourExptName]"
|
| 104 |
+
exit 1
|
| 105 |
+
fi
|
| 106 |
+
|
| 107 |
+
if [ -z "$infer_output_dir" ]; then
|
| 108 |
+
infer_output_dir="$infer_expt_dir/result"
|
| 109 |
+
fi
|
| 110 |
+
|
| 111 |
+
if [ $infer_mode = "infer_from_dataset" ]; then
|
| 112 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
|
| 113 |
+
--config $exp_config \
|
| 114 |
+
--infer_mode $infer_mode \
|
| 115 |
+
--infer_datasets $infer_datasets \
|
| 116 |
+
--vocoder_dir $infer_expt_dir \
|
| 117 |
+
--output_dir $infer_output_dir \
|
| 118 |
+
--log_level debug
|
| 119 |
+
fi
|
| 120 |
+
|
| 121 |
+
if [ $infer_mode = "infer_from_feature" ]; then
|
| 122 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
|
| 123 |
+
--config $exp_config \
|
| 124 |
+
--infer_mode $infer_mode \
|
| 125 |
+
--feature_folder $infer_feature_dir \
|
| 126 |
+
--vocoder_dir $infer_expt_dir \
|
| 127 |
+
--output_dir $infer_output_dir \
|
| 128 |
+
--log_level debug
|
| 129 |
+
fi
|
| 130 |
+
|
| 131 |
+
if [ $infer_mode = "infer_from_audio" ]; then
|
| 132 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
|
| 133 |
+
--config $exp_config \
|
| 134 |
+
--infer_mode $infer_mode \
|
| 135 |
+
--audio_folder $infer_audio_dir \
|
| 136 |
+
--vocoder_dir $infer_expt_dir \
|
| 137 |
+
--output_dir $infer_output_dir \
|
| 138 |
+
--log_level debug
|
| 139 |
+
fi
|
| 140 |
+
|
| 141 |
+
fi
|
Amphion/egs/vocoder/gan/nsfhifigan/run.sh
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
######## Build Experiment Environment ###########
|
| 7 |
+
exp_dir=$(cd `dirname $0`; pwd)
|
| 8 |
+
work_dir=$(dirname $(dirname $(dirname $(dirname $exp_dir))))
|
| 9 |
+
|
| 10 |
+
export WORK_DIR=$work_dir
|
| 11 |
+
export PYTHONPATH=$work_dir
|
| 12 |
+
export PYTHONIOENCODING=UTF-8
|
| 13 |
+
|
| 14 |
+
######## Parse the Given Parameters from the Commond ###########
|
| 15 |
+
options=$(getopt -o c:n:s --long gpu:,config:,name:,stage:,checkpoint:,resume_type:,main_process_port:,infer_mode:,infer_datasets:,infer_feature_dir:,infer_audio_dir:,infer_expt_dir:,infer_output_dir: -- "$@")
|
| 16 |
+
eval set -- "$options"
|
| 17 |
+
|
| 18 |
+
while true; do
|
| 19 |
+
case $1 in
|
| 20 |
+
# Experimental Configuration File
|
| 21 |
+
-c | --config) shift; exp_config=$1 ; shift ;;
|
| 22 |
+
# Experimental Name
|
| 23 |
+
-n | --name) shift; exp_name=$1 ; shift ;;
|
| 24 |
+
# Running Stage
|
| 25 |
+
-s | --stage) shift; running_stage=$1 ; shift ;;
|
| 26 |
+
# Visible GPU machines. The default value is "0".
|
| 27 |
+
--gpu) shift; gpu=$1 ; shift ;;
|
| 28 |
+
|
| 29 |
+
# [Only for Training] The specific checkpoint path that you want to resume from.
|
| 30 |
+
--checkpoint) shift; checkpoint=$1 ; shift ;;
|
| 31 |
+
# [Only for Training] `resume` for loading all the things (including model weights, optimizer, scheduler, and random states). `finetune` for loading only the model weights.
|
| 32 |
+
--resume_type) shift; resume_type=$1 ; shift ;;
|
| 33 |
+
# [Only for Traiing] `main_process_port` for multi gpu training
|
| 34 |
+
--main_process_port) shift; main_process_port=$1 ; shift ;;
|
| 35 |
+
|
| 36 |
+
# [Only for Inference] The inference mode
|
| 37 |
+
--infer_mode) shift; infer_mode=$1 ; shift ;;
|
| 38 |
+
# [Only for Inference] The inferenced datasets
|
| 39 |
+
--infer_datasets) shift; infer_datasets=$1 ; shift ;;
|
| 40 |
+
# [Only for Inference] The feature dir for inference
|
| 41 |
+
--infer_feature_dir) shift; infer_feature_dir=$1 ; shift ;;
|
| 42 |
+
# [Only for Inference] The audio dir for inference
|
| 43 |
+
--infer_audio_dir) shift; infer_audio_dir=$1 ; shift ;;
|
| 44 |
+
# [Only for Inference] The experiment dir. The value is like "[Your path to save logs and checkpoints]/[YourExptName]"
|
| 45 |
+
--infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;;
|
| 46 |
+
# [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result"
|
| 47 |
+
--infer_output_dir) shift; infer_output_dir=$1 ; shift ;;
|
| 48 |
+
|
| 49 |
+
--) shift ; break ;;
|
| 50 |
+
*) echo "Invalid option: $1" exit 1 ;;
|
| 51 |
+
esac
|
| 52 |
+
done
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
### Value check ###
|
| 56 |
+
if [ -z "$running_stage" ]; then
|
| 57 |
+
echo "[Error] Please specify the running stage"
|
| 58 |
+
exit 1
|
| 59 |
+
fi
|
| 60 |
+
|
| 61 |
+
if [ -z "$exp_config" ]; then
|
| 62 |
+
exp_config="${exp_dir}"/exp_config.json
|
| 63 |
+
fi
|
| 64 |
+
echo "Exprimental Configuration File: $exp_config"
|
| 65 |
+
|
| 66 |
+
if [ -z "$gpu" ]; then
|
| 67 |
+
gpu="0"
|
| 68 |
+
fi
|
| 69 |
+
|
| 70 |
+
if [ -z "$main_process_port" ]; then
|
| 71 |
+
main_process_port=29500
|
| 72 |
+
fi
|
| 73 |
+
echo "Main Process Port: $main_process_port"
|
| 74 |
+
|
| 75 |
+
######## Features Extraction ###########
|
| 76 |
+
if [ $running_stage -eq 1 ]; then
|
| 77 |
+
CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/vocoder/preprocess.py \
|
| 78 |
+
--config $exp_config \
|
| 79 |
+
--num_workers 8
|
| 80 |
+
fi
|
| 81 |
+
|
| 82 |
+
######## Training ###########
|
| 83 |
+
if [ $running_stage -eq 2 ]; then
|
| 84 |
+
if [ -z "$exp_name" ]; then
|
| 85 |
+
echo "[Error] Please specify the experiments name"
|
| 86 |
+
exit 1
|
| 87 |
+
fi
|
| 88 |
+
echo "Exprimental Name: $exp_name"
|
| 89 |
+
|
| 90 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch \
|
| 91 |
+
--main_process_port "$main_process_port" \
|
| 92 |
+
"${work_dir}"/bins/vocoder/train.py \
|
| 93 |
+
--config "$exp_config" \
|
| 94 |
+
--exp_name "$exp_name" \
|
| 95 |
+
--log_level info \
|
| 96 |
+
--checkpoint "$checkpoint" \
|
| 97 |
+
--resume_type "$resume_type"
|
| 98 |
+
fi
|
| 99 |
+
|
| 100 |
+
######## Inference/Conversion ###########
|
| 101 |
+
if [ $running_stage -eq 3 ]; then
|
| 102 |
+
if [ -z "$infer_expt_dir" ]; then
|
| 103 |
+
echo "[Error] Please specify the experimental directionary. The value is like [Your path to save logs and checkpoints]/[YourExptName]"
|
| 104 |
+
exit 1
|
| 105 |
+
fi
|
| 106 |
+
|
| 107 |
+
if [ -z "$infer_output_dir" ]; then
|
| 108 |
+
infer_output_dir="$infer_expt_dir/result"
|
| 109 |
+
fi
|
| 110 |
+
|
| 111 |
+
if [ $infer_mode = "infer_from_dataset" ]; then
|
| 112 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
|
| 113 |
+
--config $exp_config \
|
| 114 |
+
--infer_mode $infer_mode \
|
| 115 |
+
--infer_datasets $infer_datasets \
|
| 116 |
+
--vocoder_dir $infer_expt_dir \
|
| 117 |
+
--output_dir $infer_output_dir \
|
| 118 |
+
--log_level debug
|
| 119 |
+
fi
|
| 120 |
+
|
| 121 |
+
if [ $infer_mode = "infer_from_feature" ]; then
|
| 122 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
|
| 123 |
+
--config $exp_config \
|
| 124 |
+
--infer_mode $infer_mode \
|
| 125 |
+
--feature_folder $infer_feature_dir \
|
| 126 |
+
--vocoder_dir $infer_expt_dir \
|
| 127 |
+
--output_dir $infer_output_dir \
|
| 128 |
+
--log_level debug
|
| 129 |
+
fi
|
| 130 |
+
|
| 131 |
+
if [ $infer_mode = "infer_from_audio" ]; then
|
| 132 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/vocoder/inference.py \
|
| 133 |
+
--config $exp_config \
|
| 134 |
+
--infer_mode $infer_mode \
|
| 135 |
+
--audio_folder $infer_audio_dir \
|
| 136 |
+
--vocoder_dir $infer_expt_dir \
|
| 137 |
+
--output_dir $infer_output_dir \
|
| 138 |
+
--log_level debug
|
| 139 |
+
fi
|
| 140 |
+
|
| 141 |
+
fi
|
Amphion/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (186 Bytes). View file
|
|
|
Amphion/models/base/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .new_trainer import BaseTrainer
|
| 7 |
+
from .new_inference import BaseInference
|
Amphion/models/base/new_dataset.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from abc import abstractmethod
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import json5
|
| 12 |
+
import torch
|
| 13 |
+
import yaml
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# TODO: for training and validating
|
| 17 |
+
class BaseDataset(torch.utils.data.Dataset):
|
| 18 |
+
r"""Base dataset for training and validating."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, args, cfg, is_valid=False):
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class BaseTestDataset(torch.utils.data.Dataset):
|
| 25 |
+
r"""Test dataset for inference."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
| 28 |
+
assert infer_type in ["from_dataset", "from_file"]
|
| 29 |
+
|
| 30 |
+
self.args = args
|
| 31 |
+
self.cfg = cfg
|
| 32 |
+
self.infer_type = infer_type
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def __getitem__(self, index):
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
def __len__(self):
|
| 39 |
+
return len(self.metadata)
|
| 40 |
+
|
| 41 |
+
def get_metadata(self):
|
| 42 |
+
path = Path(self.args.source)
|
| 43 |
+
if path.suffix == ".json" or path.suffix == ".jsonc":
|
| 44 |
+
metadata = json5.load(open(self.args.source, "r"))
|
| 45 |
+
elif path.suffix == ".yaml" or path.suffix == ".yml":
|
| 46 |
+
metadata = yaml.full_load(open(self.args.source, "r"))
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"Unsupported file type: {path.suffix}")
|
| 49 |
+
|
| 50 |
+
return metadata
|
Amphion/models/base/new_trainer.py
ADDED
|
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
import shutil
|
| 10 |
+
import time
|
| 11 |
+
from abc import abstractmethod
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import accelerate
|
| 15 |
+
import json5
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
from accelerate.logging import get_logger
|
| 19 |
+
from accelerate.utils import ProjectConfiguration
|
| 20 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
from models.base.base_sampler import build_samplers
|
| 24 |
+
from optimizer.optimizers import NoamLR
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class BaseTrainer(object):
|
| 28 |
+
r"""The base trainer for all tasks. Any trainer should inherit from this class."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, args=None, cfg=None):
|
| 31 |
+
super().__init__()
|
| 32 |
+
|
| 33 |
+
self.args = args
|
| 34 |
+
self.cfg = cfg
|
| 35 |
+
|
| 36 |
+
cfg.exp_name = args.exp_name
|
| 37 |
+
|
| 38 |
+
# init with accelerate
|
| 39 |
+
self._init_accelerator()
|
| 40 |
+
self.accelerator.wait_for_everyone()
|
| 41 |
+
|
| 42 |
+
# Use accelerate logger for distributed training
|
| 43 |
+
with self.accelerator.main_process_first():
|
| 44 |
+
self.logger = get_logger(args.exp_name, log_level=args.log_level)
|
| 45 |
+
|
| 46 |
+
# Log some info
|
| 47 |
+
self.logger.info("=" * 56)
|
| 48 |
+
self.logger.info("||\t\t" + "New training process started." + "\t\t||")
|
| 49 |
+
self.logger.info("=" * 56)
|
| 50 |
+
self.logger.info("\n")
|
| 51 |
+
self.logger.debug(f"Using {args.log_level.upper()} logging level.")
|
| 52 |
+
self.logger.info(f"Experiment name: {args.exp_name}")
|
| 53 |
+
self.logger.info(f"Experiment directory: {self.exp_dir}")
|
| 54 |
+
self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
|
| 55 |
+
if self.accelerator.is_main_process:
|
| 56 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
| 57 |
+
self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
|
| 58 |
+
|
| 59 |
+
# init counts
|
| 60 |
+
self.batch_count: int = 0
|
| 61 |
+
self.step: int = 0
|
| 62 |
+
self.epoch: int = 0
|
| 63 |
+
self.max_epoch = (
|
| 64 |
+
self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
|
| 65 |
+
)
|
| 66 |
+
self.logger.info(
|
| 67 |
+
"Max epoch: {}".format(
|
| 68 |
+
self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
|
| 69 |
+
)
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Check values
|
| 73 |
+
if self.accelerator.is_main_process:
|
| 74 |
+
self.__check_basic_configs()
|
| 75 |
+
# Set runtime configs
|
| 76 |
+
self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
|
| 77 |
+
self.checkpoints_path = [
|
| 78 |
+
[] for _ in range(len(self.save_checkpoint_stride))
|
| 79 |
+
]
|
| 80 |
+
self.keep_last = [
|
| 81 |
+
i if i > 0 else float("inf") for i in self.cfg.train.keep_last
|
| 82 |
+
]
|
| 83 |
+
self.run_eval = self.cfg.train.run_eval
|
| 84 |
+
|
| 85 |
+
# set random seed
|
| 86 |
+
with self.accelerator.main_process_first():
|
| 87 |
+
start = time.monotonic_ns()
|
| 88 |
+
self._set_random_seed(self.cfg.train.random_seed)
|
| 89 |
+
end = time.monotonic_ns()
|
| 90 |
+
self.logger.debug(
|
| 91 |
+
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
|
| 92 |
+
)
|
| 93 |
+
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
|
| 94 |
+
|
| 95 |
+
# setup data_loader
|
| 96 |
+
with self.accelerator.main_process_first():
|
| 97 |
+
self.logger.info("Building dataset...")
|
| 98 |
+
start = time.monotonic_ns()
|
| 99 |
+
self.train_dataloader, self.valid_dataloader = self._build_dataloader()
|
| 100 |
+
end = time.monotonic_ns()
|
| 101 |
+
self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
|
| 102 |
+
|
| 103 |
+
# setup model
|
| 104 |
+
with self.accelerator.main_process_first():
|
| 105 |
+
self.logger.info("Building model...")
|
| 106 |
+
start = time.monotonic_ns()
|
| 107 |
+
self.model = self._build_model()
|
| 108 |
+
end = time.monotonic_ns()
|
| 109 |
+
self.logger.debug(self.model)
|
| 110 |
+
self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
|
| 111 |
+
self.logger.info(
|
| 112 |
+
f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
|
| 113 |
+
)
|
| 114 |
+
# optimizer & scheduler
|
| 115 |
+
with self.accelerator.main_process_first():
|
| 116 |
+
self.logger.info("Building optimizer and scheduler...")
|
| 117 |
+
start = time.monotonic_ns()
|
| 118 |
+
self.optimizer = self._build_optimizer()
|
| 119 |
+
self.scheduler = self._build_scheduler()
|
| 120 |
+
end = time.monotonic_ns()
|
| 121 |
+
self.logger.info(
|
| 122 |
+
f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# accelerate prepare
|
| 126 |
+
self.logger.info("Initializing accelerate...")
|
| 127 |
+
start = time.monotonic_ns()
|
| 128 |
+
self._accelerator_prepare()
|
| 129 |
+
end = time.monotonic_ns()
|
| 130 |
+
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
|
| 131 |
+
|
| 132 |
+
# create criterion
|
| 133 |
+
with self.accelerator.main_process_first():
|
| 134 |
+
self.logger.info("Building criterion...")
|
| 135 |
+
start = time.monotonic_ns()
|
| 136 |
+
self.criterion = self._build_criterion()
|
| 137 |
+
end = time.monotonic_ns()
|
| 138 |
+
self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
|
| 139 |
+
|
| 140 |
+
# Resume or Finetune
|
| 141 |
+
with self.accelerator.main_process_first():
|
| 142 |
+
if args.resume:
|
| 143 |
+
if args.resume_from_ckpt_path == "":
|
| 144 |
+
## Automatically resume according to the current exprimental name
|
| 145 |
+
self.logger.info(
|
| 146 |
+
"Automatically resuming from latest checkpoint in {}...".format(
|
| 147 |
+
self.checkpoint_dir
|
| 148 |
+
)
|
| 149 |
+
)
|
| 150 |
+
start = time.monotonic_ns()
|
| 151 |
+
ckpt_path = self._load_model(
|
| 152 |
+
checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type
|
| 153 |
+
)
|
| 154 |
+
end = time.monotonic_ns()
|
| 155 |
+
self.logger.info(
|
| 156 |
+
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
|
| 157 |
+
)
|
| 158 |
+
self.checkpoints_path = json.load(
|
| 159 |
+
open(os.path.join(ckpt_path, "ckpts.json"), "r")
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
## Resume from the given checkpoint path
|
| 163 |
+
if not os.path.exists(args.resume_from_ckpt_path):
|
| 164 |
+
raise ValueError(
|
| 165 |
+
"[Error] The resumed checkpoint path {} don't exist.".format(
|
| 166 |
+
args.resume_from_ckpt_path
|
| 167 |
+
)
|
| 168 |
+
)
|
| 169 |
+
self.logger.info(
|
| 170 |
+
"Resuming from {}...".format(args.resume_from_ckpt_path)
|
| 171 |
+
)
|
| 172 |
+
start = time.monotonic_ns()
|
| 173 |
+
ckpt_path = self._load_model(
|
| 174 |
+
checkpoint_path=args.resume_from_ckpt_path,
|
| 175 |
+
resume_type=args.resume_type,
|
| 176 |
+
)
|
| 177 |
+
end = time.monotonic_ns()
|
| 178 |
+
self.logger.info(
|
| 179 |
+
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# save config file path
|
| 183 |
+
self.config_save_path = os.path.join(self.exp_dir, "args.json")
|
| 184 |
+
|
| 185 |
+
def _accelerator_prepare(self):
|
| 186 |
+
(
|
| 187 |
+
self.train_dataloader,
|
| 188 |
+
self.valid_dataloader,
|
| 189 |
+
self.model,
|
| 190 |
+
self.optimizer,
|
| 191 |
+
self.scheduler,
|
| 192 |
+
) = self.accelerator.prepare(
|
| 193 |
+
self.train_dataloader,
|
| 194 |
+
self.valid_dataloader,
|
| 195 |
+
self.model,
|
| 196 |
+
self.optimizer,
|
| 197 |
+
self.scheduler,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
### Following are abstract methods that should be implemented in child classes ###
|
| 201 |
+
@abstractmethod
|
| 202 |
+
def _build_dataset(self):
|
| 203 |
+
r"""Build dataset for model training/validating/evaluating."""
|
| 204 |
+
pass
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
@abstractmethod
|
| 208 |
+
def _build_criterion():
|
| 209 |
+
r"""Build criterion function for model loss calculation."""
|
| 210 |
+
pass
|
| 211 |
+
|
| 212 |
+
@abstractmethod
|
| 213 |
+
def _build_model(self):
|
| 214 |
+
r"""Build model for training/validating/evaluating."""
|
| 215 |
+
pass
|
| 216 |
+
|
| 217 |
+
@abstractmethod
|
| 218 |
+
def _forward_step(self, batch):
|
| 219 |
+
r"""One forward step of the neural network. This abstract method is trying to
|
| 220 |
+
unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation.
|
| 221 |
+
However, for special case that using different forward step pattern for
|
| 222 |
+
training and validating, you could just override this method with ``pass`` and
|
| 223 |
+
implement ``_train_step`` and ``_valid_step`` separately.
|
| 224 |
+
"""
|
| 225 |
+
pass
|
| 226 |
+
|
| 227 |
+
@abstractmethod
|
| 228 |
+
def _save_auxiliary_states(self):
|
| 229 |
+
r"""To save some auxiliary states when saving model's ckpt"""
|
| 230 |
+
pass
|
| 231 |
+
|
| 232 |
+
### Abstract methods end ###
|
| 233 |
+
|
| 234 |
+
### THIS IS MAIN ENTRY ###
|
| 235 |
+
def train_loop(self):
|
| 236 |
+
r"""Training loop. The public entry of training process."""
|
| 237 |
+
# Wait everyone to prepare before we move on
|
| 238 |
+
self.accelerator.wait_for_everyone()
|
| 239 |
+
# dump config file
|
| 240 |
+
if self.accelerator.is_main_process:
|
| 241 |
+
self.__dump_cfg(self.config_save_path)
|
| 242 |
+
self.model.train()
|
| 243 |
+
self.optimizer.zero_grad()
|
| 244 |
+
# Wait to ensure good to go
|
| 245 |
+
self.accelerator.wait_for_everyone()
|
| 246 |
+
while self.epoch < self.max_epoch:
|
| 247 |
+
self.logger.info("\n")
|
| 248 |
+
self.logger.info("-" * 32)
|
| 249 |
+
self.logger.info("Epoch {}: ".format(self.epoch))
|
| 250 |
+
|
| 251 |
+
### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
|
| 252 |
+
### It's inconvenient for the model with multiple losses
|
| 253 |
+
# Do training & validating epoch
|
| 254 |
+
train_loss = self._train_epoch()
|
| 255 |
+
self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
|
| 256 |
+
valid_loss = self._valid_epoch()
|
| 257 |
+
self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
|
| 258 |
+
self.accelerator.log(
|
| 259 |
+
{"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
|
| 260 |
+
step=self.epoch,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
self.accelerator.wait_for_everyone()
|
| 264 |
+
# TODO: what is scheduler?
|
| 265 |
+
self.scheduler.step(valid_loss) # FIXME: use epoch track correct?
|
| 266 |
+
|
| 267 |
+
# Check if hit save_checkpoint_stride and run_eval
|
| 268 |
+
run_eval = False
|
| 269 |
+
if self.accelerator.is_main_process:
|
| 270 |
+
save_checkpoint = False
|
| 271 |
+
hit_dix = []
|
| 272 |
+
for i, num in enumerate(self.save_checkpoint_stride):
|
| 273 |
+
if self.epoch % num == 0:
|
| 274 |
+
save_checkpoint = True
|
| 275 |
+
hit_dix.append(i)
|
| 276 |
+
run_eval |= self.run_eval[i]
|
| 277 |
+
|
| 278 |
+
self.accelerator.wait_for_everyone()
|
| 279 |
+
if self.accelerator.is_main_process and save_checkpoint:
|
| 280 |
+
path = os.path.join(
|
| 281 |
+
self.checkpoint_dir,
|
| 282 |
+
"epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
| 283 |
+
self.epoch, self.step, train_loss
|
| 284 |
+
),
|
| 285 |
+
)
|
| 286 |
+
self.tmp_checkpoint_save_path = path
|
| 287 |
+
self.accelerator.save_state(path)
|
| 288 |
+
print(f"save checkpoint in {path}")
|
| 289 |
+
json.dump(
|
| 290 |
+
self.checkpoints_path,
|
| 291 |
+
open(os.path.join(path, "ckpts.json"), "w"),
|
| 292 |
+
ensure_ascii=False,
|
| 293 |
+
indent=4,
|
| 294 |
+
)
|
| 295 |
+
self._save_auxiliary_states()
|
| 296 |
+
|
| 297 |
+
# Remove old checkpoints
|
| 298 |
+
to_remove = []
|
| 299 |
+
for idx in hit_dix:
|
| 300 |
+
self.checkpoints_path[idx].append(path)
|
| 301 |
+
while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
|
| 302 |
+
to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
|
| 303 |
+
|
| 304 |
+
# Search conflicts
|
| 305 |
+
total = set()
|
| 306 |
+
for i in self.checkpoints_path:
|
| 307 |
+
total |= set(i)
|
| 308 |
+
do_remove = set()
|
| 309 |
+
for idx, path in to_remove[::-1]:
|
| 310 |
+
if path in total:
|
| 311 |
+
self.checkpoints_path[idx].insert(0, path)
|
| 312 |
+
else:
|
| 313 |
+
do_remove.add(path)
|
| 314 |
+
|
| 315 |
+
# Remove old checkpoints
|
| 316 |
+
for path in do_remove:
|
| 317 |
+
shutil.rmtree(path, ignore_errors=True)
|
| 318 |
+
self.logger.debug(f"Remove old checkpoint: {path}")
|
| 319 |
+
|
| 320 |
+
self.accelerator.wait_for_everyone()
|
| 321 |
+
if run_eval:
|
| 322 |
+
# TODO: run evaluation
|
| 323 |
+
pass
|
| 324 |
+
|
| 325 |
+
# Update info for each epoch
|
| 326 |
+
self.epoch += 1
|
| 327 |
+
|
| 328 |
+
# Finish training and save final checkpoint
|
| 329 |
+
self.accelerator.wait_for_everyone()
|
| 330 |
+
if self.accelerator.is_main_process:
|
| 331 |
+
self.accelerator.save_state(
|
| 332 |
+
os.path.join(
|
| 333 |
+
self.checkpoint_dir,
|
| 334 |
+
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
| 335 |
+
self.epoch, self.step, valid_loss
|
| 336 |
+
),
|
| 337 |
+
)
|
| 338 |
+
)
|
| 339 |
+
self._save_auxiliary_states()
|
| 340 |
+
|
| 341 |
+
self.accelerator.end_training()
|
| 342 |
+
|
| 343 |
+
### Following are methods that can be used directly in child classes ###
|
| 344 |
+
def _train_epoch(self):
|
| 345 |
+
r"""Training epoch. Should return average loss of a batch (sample) over
|
| 346 |
+
one epoch. See ``train_loop`` for usage.
|
| 347 |
+
"""
|
| 348 |
+
self.model.train()
|
| 349 |
+
epoch_sum_loss: float = 0.0
|
| 350 |
+
epoch_step: int = 0
|
| 351 |
+
for batch in tqdm(
|
| 352 |
+
self.train_dataloader,
|
| 353 |
+
desc=f"Training Epoch {self.epoch}",
|
| 354 |
+
unit="batch",
|
| 355 |
+
colour="GREEN",
|
| 356 |
+
leave=False,
|
| 357 |
+
dynamic_ncols=True,
|
| 358 |
+
smoothing=0.04,
|
| 359 |
+
disable=not self.accelerator.is_main_process,
|
| 360 |
+
):
|
| 361 |
+
# Do training step and BP
|
| 362 |
+
with self.accelerator.accumulate(self.model):
|
| 363 |
+
loss = self._train_step(batch)
|
| 364 |
+
self.accelerator.backward(loss)
|
| 365 |
+
self.optimizer.step()
|
| 366 |
+
self.optimizer.zero_grad()
|
| 367 |
+
self.batch_count += 1
|
| 368 |
+
|
| 369 |
+
# Update info for each step
|
| 370 |
+
# TODO: step means BP counts or batch counts?
|
| 371 |
+
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
|
| 372 |
+
epoch_sum_loss += loss
|
| 373 |
+
self.accelerator.log(
|
| 374 |
+
{
|
| 375 |
+
"Step/Train Loss": loss,
|
| 376 |
+
"Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
|
| 377 |
+
},
|
| 378 |
+
step=self.step,
|
| 379 |
+
)
|
| 380 |
+
self.step += 1
|
| 381 |
+
epoch_step += 1
|
| 382 |
+
|
| 383 |
+
self.accelerator.wait_for_everyone()
|
| 384 |
+
return (
|
| 385 |
+
epoch_sum_loss
|
| 386 |
+
/ len(self.train_dataloader)
|
| 387 |
+
* self.cfg.train.gradient_accumulation_step
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
@torch.inference_mode()
|
| 391 |
+
def _valid_epoch(self):
|
| 392 |
+
r"""Testing epoch. Should return average loss of a batch (sample) over
|
| 393 |
+
one epoch. See ``train_loop`` for usage.
|
| 394 |
+
"""
|
| 395 |
+
self.model.eval()
|
| 396 |
+
epoch_sum_loss = 0.0
|
| 397 |
+
for batch in tqdm(
|
| 398 |
+
self.valid_dataloader,
|
| 399 |
+
desc=f"Validating Epoch {self.epoch}",
|
| 400 |
+
unit="batch",
|
| 401 |
+
colour="GREEN",
|
| 402 |
+
leave=False,
|
| 403 |
+
dynamic_ncols=True,
|
| 404 |
+
smoothing=0.04,
|
| 405 |
+
disable=not self.accelerator.is_main_process,
|
| 406 |
+
):
|
| 407 |
+
batch_loss = self._valid_step(batch)
|
| 408 |
+
epoch_sum_loss += batch_loss.item()
|
| 409 |
+
|
| 410 |
+
self.accelerator.wait_for_everyone()
|
| 411 |
+
return epoch_sum_loss / len(self.valid_dataloader)
|
| 412 |
+
|
| 413 |
+
def _train_step(self, batch):
|
| 414 |
+
r"""Training forward step. Should return average loss of a sample over
|
| 415 |
+
one batch. Provoke ``_forward_step`` is recommended except for special case.
|
| 416 |
+
See ``_train_epoch`` for usage.
|
| 417 |
+
"""
|
| 418 |
+
return self._forward_step(batch)
|
| 419 |
+
|
| 420 |
+
@torch.inference_mode()
|
| 421 |
+
def _valid_step(self, batch):
|
| 422 |
+
r"""Testing forward step. Should return average loss of a sample over
|
| 423 |
+
one batch. Provoke ``_forward_step`` is recommended except for special case.
|
| 424 |
+
See ``_test_epoch`` for usage.
|
| 425 |
+
"""
|
| 426 |
+
return self._forward_step(batch)
|
| 427 |
+
|
| 428 |
+
def _load_model(
|
| 429 |
+
self,
|
| 430 |
+
checkpoint_dir: str = None,
|
| 431 |
+
checkpoint_path: str = None,
|
| 432 |
+
resume_type: str = "",
|
| 433 |
+
):
|
| 434 |
+
r"""Load model from checkpoint. If checkpoint_path is None, it will
|
| 435 |
+
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
|
| 436 |
+
None, it will load the checkpoint specified by checkpoint_path. **Only use this
|
| 437 |
+
method after** ``accelerator.prepare()``.
|
| 438 |
+
"""
|
| 439 |
+
if checkpoint_path is None:
|
| 440 |
+
ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
|
| 441 |
+
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
|
| 442 |
+
checkpoint_path = ls[0]
|
| 443 |
+
self.logger.info("Resume from {}...".format(checkpoint_path))
|
| 444 |
+
|
| 445 |
+
if resume_type in ["resume", ""]:
|
| 446 |
+
# Load all the things, including model weights, optimizer, scheduler, and random states.
|
| 447 |
+
self.accelerator.load_state(input_dir=checkpoint_path)
|
| 448 |
+
|
| 449 |
+
# set epoch and step
|
| 450 |
+
self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
|
| 451 |
+
self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
|
| 452 |
+
|
| 453 |
+
elif resume_type == "finetune":
|
| 454 |
+
# Load only the model weights
|
| 455 |
+
accelerate.load_checkpoint_and_dispatch(
|
| 456 |
+
self.accelerator.unwrap_model(self.model),
|
| 457 |
+
os.path.join(checkpoint_path, "pytorch_model.bin"),
|
| 458 |
+
)
|
| 459 |
+
self.logger.info("Load model weights for finetune...")
|
| 460 |
+
|
| 461 |
+
else:
|
| 462 |
+
raise ValueError("Resume_type must be `resume` or `finetune`.")
|
| 463 |
+
|
| 464 |
+
return checkpoint_path
|
| 465 |
+
|
| 466 |
+
def _build_dataloader(self):
|
| 467 |
+
Dataset, Collator = self._build_dataset()
|
| 468 |
+
|
| 469 |
+
# build dataset instance for each dataset and combine them by ConcatDataset
|
| 470 |
+
datasets_list = []
|
| 471 |
+
for dataset in self.cfg.dataset:
|
| 472 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=False)
|
| 473 |
+
datasets_list.append(subdataset)
|
| 474 |
+
train_dataset = ConcatDataset(datasets_list)
|
| 475 |
+
train_collate = Collator(self.cfg)
|
| 476 |
+
_, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
|
| 477 |
+
self.logger.debug(f"train batch_sampler: {list(batch_sampler)}")
|
| 478 |
+
self.logger.debug(f"length: {train_dataset.cumulative_sizes}")
|
| 479 |
+
# TODO: use config instead of (sampler, shuffle, drop_last, batch_size)
|
| 480 |
+
train_loader = DataLoader(
|
| 481 |
+
train_dataset,
|
| 482 |
+
# shuffle=True,
|
| 483 |
+
collate_fn=train_collate,
|
| 484 |
+
batch_sampler=batch_sampler,
|
| 485 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
| 486 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
# Build valid dataloader
|
| 490 |
+
datasets_list = []
|
| 491 |
+
for dataset in self.cfg.dataset:
|
| 492 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=True)
|
| 493 |
+
datasets_list.append(subdataset)
|
| 494 |
+
valid_dataset = ConcatDataset(datasets_list)
|
| 495 |
+
valid_collate = Collator(self.cfg)
|
| 496 |
+
_, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
|
| 497 |
+
self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}")
|
| 498 |
+
self.logger.debug(f"length: {valid_dataset.cumulative_sizes}")
|
| 499 |
+
valid_loader = DataLoader(
|
| 500 |
+
valid_dataset,
|
| 501 |
+
collate_fn=valid_collate,
|
| 502 |
+
batch_sampler=batch_sampler,
|
| 503 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
| 504 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
| 505 |
+
)
|
| 506 |
+
return train_loader, valid_loader
|
| 507 |
+
|
| 508 |
+
@staticmethod
|
| 509 |
+
def _set_random_seed(seed):
|
| 510 |
+
r"""Set random seed for all possible random modules."""
|
| 511 |
+
random.seed(seed)
|
| 512 |
+
np.random.seed(seed)
|
| 513 |
+
torch.random.manual_seed(seed)
|
| 514 |
+
|
| 515 |
+
def _check_nan(self, loss, y_pred, y_gt):
|
| 516 |
+
if torch.any(torch.isnan(loss)):
|
| 517 |
+
self.logger.error("Fatal Error: Training is down since loss has Nan!")
|
| 518 |
+
self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
|
| 519 |
+
|
| 520 |
+
### y_pred ###
|
| 521 |
+
if torch.any(torch.isnan(y_pred)):
|
| 522 |
+
self.logger.error(
|
| 523 |
+
f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
|
| 524 |
+
)
|
| 525 |
+
self.logger.error(f"y_pred: {y_pred}", in_order=True)
|
| 526 |
+
else:
|
| 527 |
+
self.logger.debug(
|
| 528 |
+
f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
|
| 529 |
+
)
|
| 530 |
+
self.logger.debug(f"y_pred: {y_pred}", in_order=True)
|
| 531 |
+
|
| 532 |
+
### y_gt ###
|
| 533 |
+
if torch.any(torch.isnan(y_gt)):
|
| 534 |
+
self.logger.error(
|
| 535 |
+
f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True
|
| 536 |
+
)
|
| 537 |
+
self.logger.error(f"y_gt: {y_gt}", in_order=True)
|
| 538 |
+
else:
|
| 539 |
+
self.logger.debug(
|
| 540 |
+
f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True
|
| 541 |
+
)
|
| 542 |
+
self.logger.debug(f"y_gt: {y_gt}", in_order=True)
|
| 543 |
+
|
| 544 |
+
self.accelerator.end_training()
|
| 545 |
+
raise RuntimeError("Loss has Nan! See log for more info.")
|
| 546 |
+
|
| 547 |
+
### Protected methods end ###
|
| 548 |
+
|
| 549 |
+
## Following are private methods ##
|
| 550 |
+
def _build_optimizer(self):
|
| 551 |
+
r"""Build optimizer for model."""
|
| 552 |
+
# Make case-insensitive matching
|
| 553 |
+
if self.cfg.train.optimizer.lower() == "adadelta":
|
| 554 |
+
optimizer = torch.optim.Adadelta(
|
| 555 |
+
self.model.parameters(), **self.cfg.train.adadelta
|
| 556 |
+
)
|
| 557 |
+
self.logger.info("Using Adadelta optimizer.")
|
| 558 |
+
elif self.cfg.train.optimizer.lower() == "adagrad":
|
| 559 |
+
optimizer = torch.optim.Adagrad(
|
| 560 |
+
self.model.parameters(), **self.cfg.train.adagrad
|
| 561 |
+
)
|
| 562 |
+
self.logger.info("Using Adagrad optimizer.")
|
| 563 |
+
elif self.cfg.train.optimizer.lower() == "adam":
|
| 564 |
+
optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
|
| 565 |
+
self.logger.info("Using Adam optimizer.")
|
| 566 |
+
elif self.cfg.train.optimizer.lower() == "adamw":
|
| 567 |
+
optimizer = torch.optim.AdamW(
|
| 568 |
+
self.model.parameters(), **self.cfg.train.adamw
|
| 569 |
+
)
|
| 570 |
+
elif self.cfg.train.optimizer.lower() == "sparseadam":
|
| 571 |
+
optimizer = torch.optim.SparseAdam(
|
| 572 |
+
self.model.parameters(), **self.cfg.train.sparseadam
|
| 573 |
+
)
|
| 574 |
+
elif self.cfg.train.optimizer.lower() == "adamax":
|
| 575 |
+
optimizer = torch.optim.Adamax(
|
| 576 |
+
self.model.parameters(), **self.cfg.train.adamax
|
| 577 |
+
)
|
| 578 |
+
elif self.cfg.train.optimizer.lower() == "asgd":
|
| 579 |
+
optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd)
|
| 580 |
+
elif self.cfg.train.optimizer.lower() == "lbfgs":
|
| 581 |
+
optimizer = torch.optim.LBFGS(
|
| 582 |
+
self.model.parameters(), **self.cfg.train.lbfgs
|
| 583 |
+
)
|
| 584 |
+
elif self.cfg.train.optimizer.lower() == "nadam":
|
| 585 |
+
optimizer = torch.optim.NAdam(
|
| 586 |
+
self.model.parameters(), **self.cfg.train.nadam
|
| 587 |
+
)
|
| 588 |
+
elif self.cfg.train.optimizer.lower() == "radam":
|
| 589 |
+
optimizer = torch.optim.RAdam(
|
| 590 |
+
self.model.parameters(), **self.cfg.train.radam
|
| 591 |
+
)
|
| 592 |
+
elif self.cfg.train.optimizer.lower() == "rmsprop":
|
| 593 |
+
optimizer = torch.optim.RMSprop(
|
| 594 |
+
self.model.parameters(), **self.cfg.train.rmsprop
|
| 595 |
+
)
|
| 596 |
+
elif self.cfg.train.optimizer.lower() == "rprop":
|
| 597 |
+
optimizer = torch.optim.Rprop(
|
| 598 |
+
self.model.parameters(), **self.cfg.train.rprop
|
| 599 |
+
)
|
| 600 |
+
elif self.cfg.train.optimizer.lower() == "sgd":
|
| 601 |
+
optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd)
|
| 602 |
+
else:
|
| 603 |
+
raise NotImplementedError(
|
| 604 |
+
f"Optimizer {self.cfg.train.optimizer} not supported yet!"
|
| 605 |
+
)
|
| 606 |
+
return optimizer
|
| 607 |
+
|
| 608 |
+
def _build_scheduler(self):
|
| 609 |
+
r"""Build scheduler for optimizer."""
|
| 610 |
+
# Make case-insensitive matching
|
| 611 |
+
if self.cfg.train.scheduler.lower() == "lambdalr":
|
| 612 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
| 613 |
+
self.optimizer, **self.cfg.train.lambdalr
|
| 614 |
+
)
|
| 615 |
+
elif self.cfg.train.scheduler.lower() == "multiplicativelr":
|
| 616 |
+
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
|
| 617 |
+
self.optimizer, **self.cfg.train.multiplicativelr
|
| 618 |
+
)
|
| 619 |
+
elif self.cfg.train.scheduler.lower() == "steplr":
|
| 620 |
+
scheduler = torch.optim.lr_scheduler.StepLR(
|
| 621 |
+
self.optimizer, **self.cfg.train.steplr
|
| 622 |
+
)
|
| 623 |
+
elif self.cfg.train.scheduler.lower() == "multisteplr":
|
| 624 |
+
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
| 625 |
+
self.optimizer, **self.cfg.train.multisteplr
|
| 626 |
+
)
|
| 627 |
+
elif self.cfg.train.scheduler.lower() == "constantlr":
|
| 628 |
+
scheduler = torch.optim.lr_scheduler.ConstantLR(
|
| 629 |
+
self.optimizer, **self.cfg.train.constantlr
|
| 630 |
+
)
|
| 631 |
+
elif self.cfg.train.scheduler.lower() == "linearlr":
|
| 632 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(
|
| 633 |
+
self.optimizer, **self.cfg.train.linearlr
|
| 634 |
+
)
|
| 635 |
+
elif self.cfg.train.scheduler.lower() == "exponentiallr":
|
| 636 |
+
scheduler = torch.optim.lr_scheduler.ExponentialLR(
|
| 637 |
+
self.optimizer, **self.cfg.train.exponentiallr
|
| 638 |
+
)
|
| 639 |
+
elif self.cfg.train.scheduler.lower() == "polynomiallr":
|
| 640 |
+
scheduler = torch.optim.lr_scheduler.PolynomialLR(
|
| 641 |
+
self.optimizer, **self.cfg.train.polynomiallr
|
| 642 |
+
)
|
| 643 |
+
elif self.cfg.train.scheduler.lower() == "cosineannealinglr":
|
| 644 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 645 |
+
self.optimizer, **self.cfg.train.cosineannealinglr
|
| 646 |
+
)
|
| 647 |
+
elif self.cfg.train.scheduler.lower() == "sequentiallr":
|
| 648 |
+
scheduler = torch.optim.lr_scheduler.SequentialLR(
|
| 649 |
+
self.optimizer, **self.cfg.train.sequentiallr
|
| 650 |
+
)
|
| 651 |
+
elif self.cfg.train.scheduler.lower() == "reducelronplateau":
|
| 652 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 653 |
+
self.optimizer, **self.cfg.train.reducelronplateau
|
| 654 |
+
)
|
| 655 |
+
elif self.cfg.train.scheduler.lower() == "cycliclr":
|
| 656 |
+
scheduler = torch.optim.lr_scheduler.CyclicLR(
|
| 657 |
+
self.optimizer, **self.cfg.train.cycliclr
|
| 658 |
+
)
|
| 659 |
+
elif self.cfg.train.scheduler.lower() == "onecyclelr":
|
| 660 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 661 |
+
self.optimizer, **self.cfg.train.onecyclelr
|
| 662 |
+
)
|
| 663 |
+
elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts":
|
| 664 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
| 665 |
+
self.optimizer, **self.cfg.train.cosineannearingwarmrestarts
|
| 666 |
+
)
|
| 667 |
+
elif self.cfg.train.scheduler.lower() == "noamlr":
|
| 668 |
+
scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
|
| 669 |
+
else:
|
| 670 |
+
raise NotImplementedError(
|
| 671 |
+
f"Scheduler {self.cfg.train.scheduler} not supported yet!"
|
| 672 |
+
)
|
| 673 |
+
return scheduler
|
| 674 |
+
|
| 675 |
+
def _init_accelerator(self):
|
| 676 |
+
self.exp_dir = os.path.join(
|
| 677 |
+
os.path.abspath(self.cfg.log_dir), self.args.exp_name
|
| 678 |
+
)
|
| 679 |
+
project_config = ProjectConfiguration(
|
| 680 |
+
project_dir=self.exp_dir,
|
| 681 |
+
logging_dir=os.path.join(self.exp_dir, "log"),
|
| 682 |
+
)
|
| 683 |
+
self.accelerator = accelerate.Accelerator(
|
| 684 |
+
gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
|
| 685 |
+
log_with=self.cfg.train.tracker,
|
| 686 |
+
project_config=project_config,
|
| 687 |
+
)
|
| 688 |
+
if self.accelerator.is_main_process:
|
| 689 |
+
os.makedirs(project_config.project_dir, exist_ok=True)
|
| 690 |
+
os.makedirs(project_config.logging_dir, exist_ok=True)
|
| 691 |
+
with self.accelerator.main_process_first():
|
| 692 |
+
self.accelerator.init_trackers(self.args.exp_name)
|
| 693 |
+
|
| 694 |
+
def __check_basic_configs(self):
|
| 695 |
+
if self.cfg.train.gradient_accumulation_step <= 0:
|
| 696 |
+
self.logger.fatal("Invalid gradient_accumulation_step value!")
|
| 697 |
+
self.logger.error(
|
| 698 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
| 699 |
+
)
|
| 700 |
+
self.accelerator.end_training()
|
| 701 |
+
raise ValueError(
|
| 702 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
| 703 |
+
)
|
| 704 |
+
# TODO: check other values
|
| 705 |
+
|
| 706 |
+
@staticmethod
|
| 707 |
+
def __count_parameters(model):
|
| 708 |
+
model_param = 0.0
|
| 709 |
+
if isinstance(model, dict):
|
| 710 |
+
for key, value in model.items():
|
| 711 |
+
model_param += sum(p.numel() for p in model[key].parameters())
|
| 712 |
+
else:
|
| 713 |
+
model_param = sum(p.numel() for p in model.parameters())
|
| 714 |
+
return model_param
|
| 715 |
+
|
| 716 |
+
def __dump_cfg(self, path):
|
| 717 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 718 |
+
json5.dump(
|
| 719 |
+
self.cfg,
|
| 720 |
+
open(path, "w"),
|
| 721 |
+
indent=4,
|
| 722 |
+
sort_keys=True,
|
| 723 |
+
ensure_ascii=False,
|
| 724 |
+
quote_keys=True,
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
### Private methods end ###
|
Amphion/models/codec/ns3_codec/__pycache__/facodec.cpython-310.pyc
ADDED
|
Binary file (24.2 kB). View file
|
|
|
Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (277 Bytes). View file
|
|
|
Amphion/models/codec/ns3_codec/alias_free_torch/__pycache__/resample.cpython-310.pyc
ADDED
|
Binary file (1.97 kB). View file
|
|
|
Amphion/models/codec/ns3_codec/quantize/__pycache__/rvq.cpython-310.pyc
ADDED
|
Binary file (2.51 kB). View file
|
|
|
Amphion/models/codec/ns3_codec/quantize/rvq.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from .fvq import FactorizedVectorQuantize
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ResidualVQ(nn.Module):
|
| 13 |
+
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, *, num_quantizers, codebook_size, **kwargs):
|
| 16 |
+
super().__init__()
|
| 17 |
+
VQ = FactorizedVectorQuantize
|
| 18 |
+
if type(codebook_size) == int:
|
| 19 |
+
codebook_size = [codebook_size] * num_quantizers
|
| 20 |
+
self.layers = nn.ModuleList(
|
| 21 |
+
[VQ(codebook_size=2**size, **kwargs) for size in codebook_size]
|
| 22 |
+
)
|
| 23 |
+
self.num_quantizers = num_quantizers
|
| 24 |
+
self.quantizer_dropout = kwargs.get("quantizer_dropout", 0.0)
|
| 25 |
+
self.dropout_type = kwargs.get("dropout_type", None)
|
| 26 |
+
|
| 27 |
+
def forward(self, x, n_quantizers=None):
|
| 28 |
+
quantized_out = 0.0
|
| 29 |
+
residual = x
|
| 30 |
+
|
| 31 |
+
all_losses = []
|
| 32 |
+
all_indices = []
|
| 33 |
+
all_quantized = []
|
| 34 |
+
|
| 35 |
+
if n_quantizers is None:
|
| 36 |
+
n_quantizers = self.num_quantizers
|
| 37 |
+
if self.training:
|
| 38 |
+
n_quantizers = torch.ones((x.shape[0],)) * self.num_quantizers + 1
|
| 39 |
+
if self.dropout_type == "linear":
|
| 40 |
+
dropout = torch.randint(1, self.num_quantizers + 1, (x.shape[0],))
|
| 41 |
+
elif self.dropout_type == "exp":
|
| 42 |
+
dropout = torch.randint(
|
| 43 |
+
1, int(math.log2(self.num_quantizers)), (x.shape[0],)
|
| 44 |
+
)
|
| 45 |
+
dropout = torch.pow(2, dropout)
|
| 46 |
+
n_dropout = int(x.shape[0] * self.quantizer_dropout)
|
| 47 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
| 48 |
+
n_quantizers = n_quantizers.to(x.device)
|
| 49 |
+
|
| 50 |
+
for idx, layer in enumerate(self.layers):
|
| 51 |
+
if not self.training and idx >= n_quantizers:
|
| 52 |
+
break
|
| 53 |
+
quantized, indices, loss = layer(residual)
|
| 54 |
+
|
| 55 |
+
mask = (
|
| 56 |
+
torch.full((x.shape[0],), fill_value=idx, device=x.device)
|
| 57 |
+
< n_quantizers
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
residual = residual - quantized
|
| 61 |
+
|
| 62 |
+
quantized_out = quantized_out + quantized * mask[:, None, None]
|
| 63 |
+
|
| 64 |
+
# loss
|
| 65 |
+
loss = (loss * mask).mean()
|
| 66 |
+
|
| 67 |
+
all_indices.append(indices)
|
| 68 |
+
all_losses.append(loss)
|
| 69 |
+
all_quantized.append(quantized)
|
| 70 |
+
all_losses, all_indices, all_quantized = map(
|
| 71 |
+
torch.stack, (all_losses, all_indices, all_quantized)
|
| 72 |
+
)
|
| 73 |
+
return quantized_out, all_indices, all_losses, all_quantized
|
| 74 |
+
|
| 75 |
+
def vq2emb(self, vq):
|
| 76 |
+
# vq: [n_quantizers, B, T]
|
| 77 |
+
quantized_out = 0.0
|
| 78 |
+
for idx, layer in enumerate(self.layers):
|
| 79 |
+
quantized = layer.vq2emb(vq[idx])
|
| 80 |
+
quantized_out += quantized
|
| 81 |
+
return quantized_out
|
| 82 |
+
|
| 83 |
+
def get_emb(self):
|
| 84 |
+
embs = []
|
| 85 |
+
for idx, layer in enumerate(self.layers):
|
| 86 |
+
embs.append(layer.get_emb())
|
| 87 |
+
return embs
|
Amphion/models/svc/transformer/transformer.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Transformer(nn.Module):
|
| 13 |
+
def __init__(self, cfg):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.cfg = cfg
|
| 16 |
+
|
| 17 |
+
dropout = self.cfg.dropout
|
| 18 |
+
nhead = self.cfg.n_heads
|
| 19 |
+
nlayers = self.cfg.n_layers
|
| 20 |
+
input_dim = self.cfg.input_dim
|
| 21 |
+
output_dim = self.cfg.output_dim
|
| 22 |
+
|
| 23 |
+
d_model = input_dim
|
| 24 |
+
self.pos_encoder = PositionalEncoding(d_model, dropout)
|
| 25 |
+
encoder_layers = TransformerEncoderLayer(
|
| 26 |
+
d_model, nhead, dropout=dropout, batch_first=True
|
| 27 |
+
)
|
| 28 |
+
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
|
| 29 |
+
|
| 30 |
+
self.output_mlp = nn.Linear(d_model, output_dim)
|
| 31 |
+
|
| 32 |
+
def forward(self, x, mask=None):
|
| 33 |
+
"""
|
| 34 |
+
Args:
|
| 35 |
+
x: (N, seq_len, input_dim)
|
| 36 |
+
Returns:
|
| 37 |
+
output: (N, seq_len, output_dim)
|
| 38 |
+
"""
|
| 39 |
+
# (N, seq_len, d_model)
|
| 40 |
+
src = self.pos_encoder(x)
|
| 41 |
+
# model_stats["pos_embedding"] = x
|
| 42 |
+
# (N, seq_len, d_model)
|
| 43 |
+
output = self.transformer_encoder(src)
|
| 44 |
+
# (N, seq_len, output_dim)
|
| 45 |
+
output = self.output_mlp(output)
|
| 46 |
+
return output
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class PositionalEncoding(nn.Module):
|
| 50 |
+
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 53 |
+
|
| 54 |
+
position = torch.arange(max_len).unsqueeze(1)
|
| 55 |
+
div_term = torch.exp(
|
| 56 |
+
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Assume that x is (seq_len, N, d)
|
| 60 |
+
# pe = torch.zeros(max_len, 1, d_model)
|
| 61 |
+
# pe[:, 0, 0::2] = torch.sin(position * div_term)
|
| 62 |
+
# pe[:, 0, 1::2] = torch.cos(position * div_term)
|
| 63 |
+
|
| 64 |
+
# Assume that x in (N, seq_len, d)
|
| 65 |
+
pe = torch.zeros(1, max_len, d_model)
|
| 66 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
| 67 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
| 68 |
+
|
| 69 |
+
self.register_buffer("pe", pe)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
"""
|
| 73 |
+
Args:
|
| 74 |
+
x: Tensor, shape [N, seq_len, d]
|
| 75 |
+
"""
|
| 76 |
+
# Old: Assume that x is (seq_len, N, d), and self.pe is (max_len, 1, d_model)
|
| 77 |
+
# x = x + self.pe[: x.size(0)]
|
| 78 |
+
|
| 79 |
+
# Now: self.pe is (1, max_len, d)
|
| 80 |
+
x = x + self.pe[:, : x.size(1), :]
|
| 81 |
+
|
| 82 |
+
return self.dropout(x)
|
Amphion/models/tts/fastspeech2/fs2.py
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# This code is modified from https://github.com/ming024/FastSpeech2/blob/master/model/fastspeech2.py
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
from modules.transformer.Models import Encoder, Decoder
|
| 13 |
+
from modules.transformer.Layers import PostNet
|
| 14 |
+
from collections import OrderedDict
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import json
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_mask_from_lengths(lengths, max_len=None):
|
| 21 |
+
device = lengths.device
|
| 22 |
+
batch_size = lengths.shape[0]
|
| 23 |
+
if max_len is None:
|
| 24 |
+
max_len = torch.max(lengths).item()
|
| 25 |
+
|
| 26 |
+
ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
|
| 27 |
+
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
|
| 28 |
+
|
| 29 |
+
return mask
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def pad(input_ele, mel_max_length=None):
|
| 33 |
+
if mel_max_length:
|
| 34 |
+
max_len = mel_max_length
|
| 35 |
+
else:
|
| 36 |
+
max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
|
| 37 |
+
|
| 38 |
+
out_list = list()
|
| 39 |
+
for i, batch in enumerate(input_ele):
|
| 40 |
+
if len(batch.shape) == 1:
|
| 41 |
+
one_batch_padded = F.pad(
|
| 42 |
+
batch, (0, max_len - batch.size(0)), "constant", 0.0
|
| 43 |
+
)
|
| 44 |
+
elif len(batch.shape) == 2:
|
| 45 |
+
one_batch_padded = F.pad(
|
| 46 |
+
batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
|
| 47 |
+
)
|
| 48 |
+
out_list.append(one_batch_padded)
|
| 49 |
+
out_padded = torch.stack(out_list)
|
| 50 |
+
return out_padded
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class VarianceAdaptor(nn.Module):
|
| 54 |
+
"""Variance Adaptor"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, cfg):
|
| 57 |
+
super(VarianceAdaptor, self).__init__()
|
| 58 |
+
self.duration_predictor = VariancePredictor(cfg)
|
| 59 |
+
self.length_regulator = LengthRegulator()
|
| 60 |
+
self.pitch_predictor = VariancePredictor(cfg)
|
| 61 |
+
self.energy_predictor = VariancePredictor(cfg)
|
| 62 |
+
|
| 63 |
+
# assign the pitch/energy feature level
|
| 64 |
+
if cfg.preprocess.use_frame_pitch:
|
| 65 |
+
self.pitch_feature_level = "frame_level"
|
| 66 |
+
self.pitch_dir = cfg.preprocess.pitch_dir
|
| 67 |
+
else:
|
| 68 |
+
self.pitch_feature_level = "phoneme_level"
|
| 69 |
+
self.pitch_dir = cfg.preprocess.phone_pitch_dir
|
| 70 |
+
|
| 71 |
+
if cfg.preprocess.use_frame_energy:
|
| 72 |
+
self.energy_feature_level = "frame_level"
|
| 73 |
+
self.energy_dir = cfg.preprocess.energy_dir
|
| 74 |
+
else:
|
| 75 |
+
self.energy_feature_level = "phoneme_level"
|
| 76 |
+
self.energy_dir = cfg.preprocess.phone_energy_dir
|
| 77 |
+
|
| 78 |
+
assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
|
| 79 |
+
assert self.energy_feature_level in ["phoneme_level", "frame_level"]
|
| 80 |
+
|
| 81 |
+
pitch_quantization = cfg.model.variance_embedding.pitch_quantization
|
| 82 |
+
energy_quantization = cfg.model.variance_embedding.energy_quantization
|
| 83 |
+
n_bins = cfg.model.variance_embedding.n_bins
|
| 84 |
+
assert pitch_quantization in ["linear", "log"]
|
| 85 |
+
assert energy_quantization in ["linear", "log"]
|
| 86 |
+
|
| 87 |
+
with open(
|
| 88 |
+
os.path.join(
|
| 89 |
+
cfg.preprocess.processed_dir,
|
| 90 |
+
cfg.dataset[0],
|
| 91 |
+
self.energy_dir,
|
| 92 |
+
"statistics.json",
|
| 93 |
+
)
|
| 94 |
+
) as f:
|
| 95 |
+
stats = json.load(f)
|
| 96 |
+
stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]]
|
| 97 |
+
mean, std = (
|
| 98 |
+
stats["voiced_positions"]["mean"],
|
| 99 |
+
stats["voiced_positions"]["std"],
|
| 100 |
+
)
|
| 101 |
+
energy_min = (stats["total_positions"]["min"] - mean) / std
|
| 102 |
+
energy_max = (stats["total_positions"]["max"] - mean) / std
|
| 103 |
+
|
| 104 |
+
with open(
|
| 105 |
+
os.path.join(
|
| 106 |
+
cfg.preprocess.processed_dir,
|
| 107 |
+
cfg.dataset[0],
|
| 108 |
+
self.pitch_dir,
|
| 109 |
+
"statistics.json",
|
| 110 |
+
)
|
| 111 |
+
) as f:
|
| 112 |
+
stats = json.load(f)
|
| 113 |
+
stats = stats[cfg.dataset[0] + "_" + cfg.dataset[0]]
|
| 114 |
+
mean, std = (
|
| 115 |
+
stats["voiced_positions"]["mean"],
|
| 116 |
+
stats["voiced_positions"]["std"],
|
| 117 |
+
)
|
| 118 |
+
pitch_min = (stats["total_positions"]["min"] - mean) / std
|
| 119 |
+
pitch_max = (stats["total_positions"]["max"] - mean) / std
|
| 120 |
+
|
| 121 |
+
if pitch_quantization == "log":
|
| 122 |
+
self.pitch_bins = nn.Parameter(
|
| 123 |
+
torch.exp(
|
| 124 |
+
torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1)
|
| 125 |
+
),
|
| 126 |
+
requires_grad=False,
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
self.pitch_bins = nn.Parameter(
|
| 130 |
+
torch.linspace(pitch_min, pitch_max, n_bins - 1),
|
| 131 |
+
requires_grad=False,
|
| 132 |
+
)
|
| 133 |
+
if energy_quantization == "log":
|
| 134 |
+
self.energy_bins = nn.Parameter(
|
| 135 |
+
torch.exp(
|
| 136 |
+
torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1)
|
| 137 |
+
),
|
| 138 |
+
requires_grad=False,
|
| 139 |
+
)
|
| 140 |
+
else:
|
| 141 |
+
self.energy_bins = nn.Parameter(
|
| 142 |
+
torch.linspace(energy_min, energy_max, n_bins - 1),
|
| 143 |
+
requires_grad=False,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
self.pitch_embedding = nn.Embedding(
|
| 147 |
+
n_bins, cfg.model.transformer.encoder_hidden
|
| 148 |
+
)
|
| 149 |
+
self.energy_embedding = nn.Embedding(
|
| 150 |
+
n_bins, cfg.model.transformer.encoder_hidden
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def get_pitch_embedding(self, x, target, mask, control):
|
| 154 |
+
prediction = self.pitch_predictor(x, mask)
|
| 155 |
+
if target is not None:
|
| 156 |
+
embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins))
|
| 157 |
+
else:
|
| 158 |
+
prediction = prediction * control
|
| 159 |
+
embedding = self.pitch_embedding(
|
| 160 |
+
torch.bucketize(prediction, self.pitch_bins)
|
| 161 |
+
)
|
| 162 |
+
return prediction, embedding
|
| 163 |
+
|
| 164 |
+
def get_energy_embedding(self, x, target, mask, control):
|
| 165 |
+
prediction = self.energy_predictor(x, mask)
|
| 166 |
+
if target is not None:
|
| 167 |
+
embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins))
|
| 168 |
+
else:
|
| 169 |
+
prediction = prediction * control
|
| 170 |
+
embedding = self.energy_embedding(
|
| 171 |
+
torch.bucketize(prediction, self.energy_bins)
|
| 172 |
+
)
|
| 173 |
+
return prediction, embedding
|
| 174 |
+
|
| 175 |
+
def forward(
|
| 176 |
+
self,
|
| 177 |
+
x,
|
| 178 |
+
src_mask,
|
| 179 |
+
mel_mask=None,
|
| 180 |
+
max_len=None,
|
| 181 |
+
pitch_target=None,
|
| 182 |
+
energy_target=None,
|
| 183 |
+
duration_target=None,
|
| 184 |
+
p_control=1.0,
|
| 185 |
+
e_control=1.0,
|
| 186 |
+
d_control=1.0,
|
| 187 |
+
):
|
| 188 |
+
log_duration_prediction = self.duration_predictor(x, src_mask)
|
| 189 |
+
if self.pitch_feature_level == "phoneme_level":
|
| 190 |
+
pitch_prediction, pitch_embedding = self.get_pitch_embedding(
|
| 191 |
+
x, pitch_target, src_mask, p_control
|
| 192 |
+
)
|
| 193 |
+
x = x + pitch_embedding
|
| 194 |
+
if self.energy_feature_level == "phoneme_level":
|
| 195 |
+
energy_prediction, energy_embedding = self.get_energy_embedding(
|
| 196 |
+
x, energy_target, src_mask, e_control
|
| 197 |
+
)
|
| 198 |
+
x = x + energy_embedding
|
| 199 |
+
|
| 200 |
+
if duration_target is not None:
|
| 201 |
+
x, mel_len = self.length_regulator(x, duration_target, max_len)
|
| 202 |
+
duration_rounded = duration_target
|
| 203 |
+
else:
|
| 204 |
+
duration_rounded = torch.clamp(
|
| 205 |
+
(torch.round(torch.exp(log_duration_prediction) - 1) * d_control),
|
| 206 |
+
min=0,
|
| 207 |
+
)
|
| 208 |
+
x, mel_len = self.length_regulator(x, duration_rounded, max_len)
|
| 209 |
+
mel_mask = get_mask_from_lengths(mel_len)
|
| 210 |
+
|
| 211 |
+
if self.pitch_feature_level == "frame_level":
|
| 212 |
+
pitch_prediction, pitch_embedding = self.get_pitch_embedding(
|
| 213 |
+
x, pitch_target, mel_mask, p_control
|
| 214 |
+
)
|
| 215 |
+
x = x + pitch_embedding
|
| 216 |
+
if self.energy_feature_level == "frame_level":
|
| 217 |
+
energy_prediction, energy_embedding = self.get_energy_embedding(
|
| 218 |
+
x, energy_target, mel_mask, p_control
|
| 219 |
+
)
|
| 220 |
+
x = x + energy_embedding
|
| 221 |
+
|
| 222 |
+
return (
|
| 223 |
+
x,
|
| 224 |
+
pitch_prediction,
|
| 225 |
+
energy_prediction,
|
| 226 |
+
log_duration_prediction,
|
| 227 |
+
duration_rounded,
|
| 228 |
+
mel_len,
|
| 229 |
+
mel_mask,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class LengthRegulator(nn.Module):
|
| 234 |
+
"""Length Regulator"""
|
| 235 |
+
|
| 236 |
+
def __init__(self):
|
| 237 |
+
super(LengthRegulator, self).__init__()
|
| 238 |
+
|
| 239 |
+
def LR(self, x, duration, max_len):
|
| 240 |
+
device = x.device
|
| 241 |
+
output = list()
|
| 242 |
+
mel_len = list()
|
| 243 |
+
for batch, expand_target in zip(x, duration):
|
| 244 |
+
expanded = self.expand(batch, expand_target)
|
| 245 |
+
output.append(expanded)
|
| 246 |
+
mel_len.append(expanded.shape[0])
|
| 247 |
+
|
| 248 |
+
if max_len is not None:
|
| 249 |
+
output = pad(output, max_len)
|
| 250 |
+
else:
|
| 251 |
+
output = pad(output)
|
| 252 |
+
|
| 253 |
+
return output, torch.LongTensor(mel_len).to(device)
|
| 254 |
+
|
| 255 |
+
def expand(self, batch, predicted):
|
| 256 |
+
out = list()
|
| 257 |
+
|
| 258 |
+
for i, vec in enumerate(batch):
|
| 259 |
+
expand_size = predicted[i].item()
|
| 260 |
+
out.append(vec.expand(max(int(expand_size), 0), -1))
|
| 261 |
+
out = torch.cat(out, 0)
|
| 262 |
+
|
| 263 |
+
return out
|
| 264 |
+
|
| 265 |
+
def forward(self, x, duration, max_len):
|
| 266 |
+
output, mel_len = self.LR(x, duration, max_len)
|
| 267 |
+
return output, mel_len
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class VariancePredictor(nn.Module):
|
| 271 |
+
"""Duration, Pitch and Energy Predictor"""
|
| 272 |
+
|
| 273 |
+
def __init__(self, cfg):
|
| 274 |
+
super(VariancePredictor, self).__init__()
|
| 275 |
+
|
| 276 |
+
self.input_size = cfg.model.transformer.encoder_hidden
|
| 277 |
+
self.filter_size = cfg.model.variance_predictor.filter_size
|
| 278 |
+
self.kernel = cfg.model.variance_predictor.kernel_size
|
| 279 |
+
self.conv_output_size = cfg.model.variance_predictor.filter_size
|
| 280 |
+
self.dropout = cfg.model.variance_predictor.dropout
|
| 281 |
+
|
| 282 |
+
self.conv_layer = nn.Sequential(
|
| 283 |
+
OrderedDict(
|
| 284 |
+
[
|
| 285 |
+
(
|
| 286 |
+
"conv1d_1",
|
| 287 |
+
Conv(
|
| 288 |
+
self.input_size,
|
| 289 |
+
self.filter_size,
|
| 290 |
+
kernel_size=self.kernel,
|
| 291 |
+
padding=(self.kernel - 1) // 2,
|
| 292 |
+
),
|
| 293 |
+
),
|
| 294 |
+
("relu_1", nn.ReLU()),
|
| 295 |
+
("layer_norm_1", nn.LayerNorm(self.filter_size)),
|
| 296 |
+
("dropout_1", nn.Dropout(self.dropout)),
|
| 297 |
+
(
|
| 298 |
+
"conv1d_2",
|
| 299 |
+
Conv(
|
| 300 |
+
self.filter_size,
|
| 301 |
+
self.filter_size,
|
| 302 |
+
kernel_size=self.kernel,
|
| 303 |
+
padding=1,
|
| 304 |
+
),
|
| 305 |
+
),
|
| 306 |
+
("relu_2", nn.ReLU()),
|
| 307 |
+
("layer_norm_2", nn.LayerNorm(self.filter_size)),
|
| 308 |
+
("dropout_2", nn.Dropout(self.dropout)),
|
| 309 |
+
]
|
| 310 |
+
)
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
self.linear_layer = nn.Linear(self.conv_output_size, 1)
|
| 314 |
+
|
| 315 |
+
def forward(self, encoder_output, mask):
|
| 316 |
+
out = self.conv_layer(encoder_output)
|
| 317 |
+
out = self.linear_layer(out)
|
| 318 |
+
out = out.squeeze(-1)
|
| 319 |
+
|
| 320 |
+
if mask is not None:
|
| 321 |
+
out = out.masked_fill(mask, 0.0)
|
| 322 |
+
|
| 323 |
+
return out
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class Conv(nn.Module):
|
| 327 |
+
"""
|
| 328 |
+
Convolution Module
|
| 329 |
+
"""
|
| 330 |
+
|
| 331 |
+
def __init__(
|
| 332 |
+
self,
|
| 333 |
+
in_channels,
|
| 334 |
+
out_channels,
|
| 335 |
+
kernel_size=1,
|
| 336 |
+
stride=1,
|
| 337 |
+
padding=0,
|
| 338 |
+
dilation=1,
|
| 339 |
+
bias=True,
|
| 340 |
+
w_init="linear",
|
| 341 |
+
):
|
| 342 |
+
"""
|
| 343 |
+
:param in_channels: dimension of input
|
| 344 |
+
:param out_channels: dimension of output
|
| 345 |
+
:param kernel_size: size of kernel
|
| 346 |
+
:param stride: size of stride
|
| 347 |
+
:param padding: size of padding
|
| 348 |
+
:param dilation: dilation rate
|
| 349 |
+
:param bias: boolean. if True, bias is included.
|
| 350 |
+
:param w_init: str. weight inits with xavier initialization.
|
| 351 |
+
"""
|
| 352 |
+
super(Conv, self).__init__()
|
| 353 |
+
|
| 354 |
+
self.conv = nn.Conv1d(
|
| 355 |
+
in_channels,
|
| 356 |
+
out_channels,
|
| 357 |
+
kernel_size=kernel_size,
|
| 358 |
+
stride=stride,
|
| 359 |
+
padding=padding,
|
| 360 |
+
dilation=dilation,
|
| 361 |
+
bias=bias,
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
def forward(self, x):
|
| 365 |
+
x = x.contiguous().transpose(1, 2)
|
| 366 |
+
x = self.conv(x)
|
| 367 |
+
x = x.contiguous().transpose(1, 2)
|
| 368 |
+
|
| 369 |
+
return x
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class FastSpeech2(nn.Module):
|
| 373 |
+
def __init__(self, cfg) -> None:
|
| 374 |
+
super(FastSpeech2, self).__init__()
|
| 375 |
+
self.cfg = cfg
|
| 376 |
+
self.encoder = Encoder(cfg.model)
|
| 377 |
+
self.variance_adaptor = VarianceAdaptor(cfg)
|
| 378 |
+
self.decoder = Decoder(cfg.model)
|
| 379 |
+
self.mel_linear = nn.Linear(
|
| 380 |
+
cfg.model.transformer.decoder_hidden,
|
| 381 |
+
cfg.preprocess.n_mel,
|
| 382 |
+
)
|
| 383 |
+
self.postnet = PostNet(n_mel_channels=cfg.preprocess.n_mel)
|
| 384 |
+
|
| 385 |
+
self.speaker_emb = None
|
| 386 |
+
if cfg.train.multi_speaker_training:
|
| 387 |
+
with open(
|
| 388 |
+
os.path.join(
|
| 389 |
+
cfg.preprocess.processed_dir, cfg.dataset[0], "spk2id.json"
|
| 390 |
+
),
|
| 391 |
+
"r",
|
| 392 |
+
) as f:
|
| 393 |
+
n_speaker = len(json.load(f))
|
| 394 |
+
self.speaker_emb = nn.Embedding(
|
| 395 |
+
n_speaker,
|
| 396 |
+
cfg.model.transformer.encoder_hidden,
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
def forward(self, data, p_control=1.0, e_control=1.0, d_control=1.0):
|
| 400 |
+
speakers = data["spk_id"]
|
| 401 |
+
texts = data["texts"]
|
| 402 |
+
src_lens = data["text_len"]
|
| 403 |
+
max_src_len = max(src_lens)
|
| 404 |
+
mel_lens = data["target_len"] if "target_len" in data else None
|
| 405 |
+
max_mel_len = max(mel_lens) if "target_len" in data else None
|
| 406 |
+
p_targets = data["pitch"] if "pitch" in data else None
|
| 407 |
+
e_targets = data["energy"] if "energy" in data else None
|
| 408 |
+
d_targets = data["durations"] if "durations" in data else None
|
| 409 |
+
src_masks = get_mask_from_lengths(src_lens, max_src_len)
|
| 410 |
+
mel_masks = (
|
| 411 |
+
get_mask_from_lengths(mel_lens, max_mel_len)
|
| 412 |
+
if mel_lens is not None
|
| 413 |
+
else None
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
output = self.encoder(texts, src_masks)
|
| 417 |
+
|
| 418 |
+
if self.speaker_emb is not None:
|
| 419 |
+
output = output + self.speaker_emb(speakers).unsqueeze(1).expand(
|
| 420 |
+
-1, max_src_len, -1
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
(
|
| 424 |
+
output,
|
| 425 |
+
p_predictions,
|
| 426 |
+
e_predictions,
|
| 427 |
+
log_d_predictions,
|
| 428 |
+
d_rounded,
|
| 429 |
+
mel_lens,
|
| 430 |
+
mel_masks,
|
| 431 |
+
) = self.variance_adaptor(
|
| 432 |
+
output,
|
| 433 |
+
src_masks,
|
| 434 |
+
mel_masks,
|
| 435 |
+
max_mel_len,
|
| 436 |
+
p_targets,
|
| 437 |
+
e_targets,
|
| 438 |
+
d_targets,
|
| 439 |
+
p_control,
|
| 440 |
+
e_control,
|
| 441 |
+
d_control,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
output, mel_masks = self.decoder(output, mel_masks)
|
| 445 |
+
output = self.mel_linear(output)
|
| 446 |
+
|
| 447 |
+
postnet_output = self.postnet(output) + output
|
| 448 |
+
|
| 449 |
+
return {
|
| 450 |
+
"output": output,
|
| 451 |
+
"postnet_output": postnet_output,
|
| 452 |
+
"p_predictions": p_predictions,
|
| 453 |
+
"e_predictions": e_predictions,
|
| 454 |
+
"log_d_predictions": log_d_predictions,
|
| 455 |
+
"d_rounded": d_rounded,
|
| 456 |
+
"src_masks": src_masks,
|
| 457 |
+
"mel_masks": mel_masks,
|
| 458 |
+
"src_lens": src_lens,
|
| 459 |
+
"mel_lens": mel_lens,
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class FastSpeech2Loss(nn.Module):
|
| 464 |
+
"""FastSpeech2 Loss"""
|
| 465 |
+
|
| 466 |
+
def __init__(self, cfg):
|
| 467 |
+
super(FastSpeech2Loss, self).__init__()
|
| 468 |
+
if cfg.preprocess.use_frame_pitch:
|
| 469 |
+
self.pitch_feature_level = "frame_level"
|
| 470 |
+
else:
|
| 471 |
+
self.pitch_feature_level = "phoneme_level"
|
| 472 |
+
|
| 473 |
+
if cfg.preprocess.use_frame_energy:
|
| 474 |
+
self.energy_feature_level = "frame_level"
|
| 475 |
+
else:
|
| 476 |
+
self.energy_feature_level = "phoneme_level"
|
| 477 |
+
|
| 478 |
+
self.mse_loss = nn.MSELoss()
|
| 479 |
+
self.mae_loss = nn.L1Loss()
|
| 480 |
+
|
| 481 |
+
def forward(self, data, predictions):
|
| 482 |
+
mel_targets = data["mel"]
|
| 483 |
+
pitch_targets = data["pitch"].float()
|
| 484 |
+
energy_targets = data["energy"].float()
|
| 485 |
+
duration_targets = data["durations"]
|
| 486 |
+
|
| 487 |
+
mel_predictions = predictions["output"]
|
| 488 |
+
postnet_mel_predictions = predictions["postnet_output"]
|
| 489 |
+
pitch_predictions = predictions["p_predictions"]
|
| 490 |
+
energy_predictions = predictions["e_predictions"]
|
| 491 |
+
log_duration_predictions = predictions["log_d_predictions"]
|
| 492 |
+
src_masks = predictions["src_masks"]
|
| 493 |
+
mel_masks = predictions["mel_masks"]
|
| 494 |
+
|
| 495 |
+
src_masks = ~src_masks
|
| 496 |
+
mel_masks = ~mel_masks
|
| 497 |
+
|
| 498 |
+
log_duration_targets = torch.log(duration_targets.float() + 1)
|
| 499 |
+
mel_targets = mel_targets[:, : mel_masks.shape[1], :]
|
| 500 |
+
mel_masks = mel_masks[:, : mel_masks.shape[1]]
|
| 501 |
+
|
| 502 |
+
log_duration_targets.requires_grad = False
|
| 503 |
+
pitch_targets.requires_grad = False
|
| 504 |
+
energy_targets.requires_grad = False
|
| 505 |
+
mel_targets.requires_grad = False
|
| 506 |
+
|
| 507 |
+
if self.pitch_feature_level == "phoneme_level":
|
| 508 |
+
pitch_predictions = pitch_predictions.masked_select(src_masks)
|
| 509 |
+
pitch_targets = pitch_targets.masked_select(src_masks)
|
| 510 |
+
elif self.pitch_feature_level == "frame_level":
|
| 511 |
+
pitch_predictions = pitch_predictions.masked_select(mel_masks)
|
| 512 |
+
pitch_targets = pitch_targets.masked_select(mel_masks)
|
| 513 |
+
|
| 514 |
+
if self.energy_feature_level == "phoneme_level":
|
| 515 |
+
energy_predictions = energy_predictions.masked_select(src_masks)
|
| 516 |
+
energy_targets = energy_targets.masked_select(src_masks)
|
| 517 |
+
if self.energy_feature_level == "frame_level":
|
| 518 |
+
energy_predictions = energy_predictions.masked_select(mel_masks)
|
| 519 |
+
energy_targets = energy_targets.masked_select(mel_masks)
|
| 520 |
+
|
| 521 |
+
log_duration_predictions = log_duration_predictions.masked_select(src_masks)
|
| 522 |
+
log_duration_targets = log_duration_targets.masked_select(src_masks)
|
| 523 |
+
|
| 524 |
+
mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1))
|
| 525 |
+
postnet_mel_predictions = postnet_mel_predictions.masked_select(
|
| 526 |
+
mel_masks.unsqueeze(-1)
|
| 527 |
+
)
|
| 528 |
+
mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1))
|
| 529 |
+
|
| 530 |
+
mel_loss = self.mae_loss(mel_predictions, mel_targets)
|
| 531 |
+
postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets)
|
| 532 |
+
|
| 533 |
+
pitch_loss = self.mse_loss(pitch_predictions, pitch_targets)
|
| 534 |
+
energy_loss = self.mse_loss(energy_predictions, energy_targets)
|
| 535 |
+
duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets)
|
| 536 |
+
|
| 537 |
+
total_loss = (
|
| 538 |
+
mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
return {
|
| 542 |
+
"loss": total_loss,
|
| 543 |
+
"mel_loss": mel_loss,
|
| 544 |
+
"postnet_mel_loss": postnet_mel_loss,
|
| 545 |
+
"pitch_loss": pitch_loss,
|
| 546 |
+
"energy_loss": energy_loss,
|
| 547 |
+
"duration_loss": duration_loss,
|
| 548 |
+
}
|
Amphion/models/tts/fastspeech2/fs2_inference.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import torch
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from collections import OrderedDict
|
| 10 |
+
|
| 11 |
+
from models.tts.base.tts_inferece import TTSInference
|
| 12 |
+
from models.tts.fastspeech2.fs2_dataset import FS2TestDataset, FS2TestCollator
|
| 13 |
+
from utils.util import load_config
|
| 14 |
+
from utils.io import save_audio
|
| 15 |
+
from models.tts.fastspeech2.fs2 import FastSpeech2
|
| 16 |
+
from models.vocoders.vocoder_inference import synthesis
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from processors.phone_extractor import phoneExtractor
|
| 19 |
+
from text.text_token_collation import phoneIDCollation
|
| 20 |
+
import numpy as np
|
| 21 |
+
import json
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class FastSpeech2Inference(TTSInference):
|
| 25 |
+
def __init__(self, args, cfg):
|
| 26 |
+
TTSInference.__init__(self, args, cfg)
|
| 27 |
+
self.args = args
|
| 28 |
+
self.cfg = cfg
|
| 29 |
+
self.infer_type = args.mode
|
| 30 |
+
|
| 31 |
+
def _build_model(self):
|
| 32 |
+
self.model = FastSpeech2(self.cfg)
|
| 33 |
+
return self.model
|
| 34 |
+
|
| 35 |
+
def load_model(self, state_dict):
|
| 36 |
+
raw_dict = state_dict["model"]
|
| 37 |
+
clean_dict = OrderedDict()
|
| 38 |
+
for k, v in raw_dict.items():
|
| 39 |
+
if k.startswith("module."):
|
| 40 |
+
clean_dict[k[7:]] = v
|
| 41 |
+
else:
|
| 42 |
+
clean_dict[k] = v
|
| 43 |
+
|
| 44 |
+
self.model.load_state_dict(clean_dict)
|
| 45 |
+
|
| 46 |
+
def _build_test_dataset(self):
|
| 47 |
+
return FS2TestDataset, FS2TestCollator
|
| 48 |
+
|
| 49 |
+
@staticmethod
|
| 50 |
+
def _parse_vocoder(vocoder_dir):
|
| 51 |
+
r"""Parse vocoder config"""
|
| 52 |
+
vocoder_dir = os.path.abspath(vocoder_dir)
|
| 53 |
+
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
|
| 54 |
+
# last step (different from the base *int(x.stem)*)
|
| 55 |
+
ckpt_list.sort(
|
| 56 |
+
key=lambda x: int(x.stem.split("_")[-2].split("-")[-1]), reverse=True
|
| 57 |
+
)
|
| 58 |
+
ckpt_path = str(ckpt_list[0])
|
| 59 |
+
vocoder_cfg = load_config(
|
| 60 |
+
os.path.join(vocoder_dir, "args.json"), lowercase=True
|
| 61 |
+
)
|
| 62 |
+
return vocoder_cfg, ckpt_path
|
| 63 |
+
|
| 64 |
+
@torch.inference_mode()
|
| 65 |
+
def inference_for_batches(self):
|
| 66 |
+
y_pred = []
|
| 67 |
+
for i, batch in tqdm(enumerate(self.test_dataloader)):
|
| 68 |
+
y_pred, mel_lens, _ = self._inference_each_batch(batch)
|
| 69 |
+
y_ls = y_pred.chunk(self.test_batch_size)
|
| 70 |
+
tgt_ls = mel_lens.chunk(self.test_batch_size)
|
| 71 |
+
j = 0
|
| 72 |
+
for it, l in zip(y_ls, tgt_ls):
|
| 73 |
+
l = l.item()
|
| 74 |
+
it = it.squeeze(0)[:l].detach().cpu()
|
| 75 |
+
|
| 76 |
+
uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
|
| 77 |
+
torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
|
| 78 |
+
j += 1
|
| 79 |
+
|
| 80 |
+
vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
|
| 81 |
+
res = synthesis(
|
| 82 |
+
cfg=vocoder_cfg,
|
| 83 |
+
vocoder_weight_file=vocoder_ckpt,
|
| 84 |
+
n_samples=None,
|
| 85 |
+
pred=[
|
| 86 |
+
torch.load(
|
| 87 |
+
os.path.join(self.args.output_dir, "{}.pt".format(item["Uid"]))
|
| 88 |
+
).numpy()
|
| 89 |
+
for item in self.test_dataset.metadata
|
| 90 |
+
],
|
| 91 |
+
)
|
| 92 |
+
for it, wav in zip(self.test_dataset.metadata, res):
|
| 93 |
+
uid = it["Uid"]
|
| 94 |
+
save_audio(
|
| 95 |
+
os.path.join(self.args.output_dir, f"{uid}.wav"),
|
| 96 |
+
wav.numpy(),
|
| 97 |
+
self.cfg.preprocess.sample_rate,
|
| 98 |
+
add_silence=True,
|
| 99 |
+
turn_up=True,
|
| 100 |
+
)
|
| 101 |
+
os.remove(os.path.join(self.args.output_dir, f"{uid}.pt"))
|
| 102 |
+
|
| 103 |
+
@torch.inference_mode()
|
| 104 |
+
def _inference_each_batch(self, batch_data):
|
| 105 |
+
device = self.accelerator.device
|
| 106 |
+
control_values = (
|
| 107 |
+
self.args.pitch_control,
|
| 108 |
+
self.args.energy_control,
|
| 109 |
+
self.args.duration_control,
|
| 110 |
+
)
|
| 111 |
+
for k, v in batch_data.items():
|
| 112 |
+
batch_data[k] = v.to(device)
|
| 113 |
+
|
| 114 |
+
pitch_control, energy_control, duration_control = control_values
|
| 115 |
+
|
| 116 |
+
output = self.model(
|
| 117 |
+
batch_data,
|
| 118 |
+
p_control=pitch_control,
|
| 119 |
+
e_control=energy_control,
|
| 120 |
+
d_control=duration_control,
|
| 121 |
+
)
|
| 122 |
+
pred_res = output["postnet_output"]
|
| 123 |
+
mel_lens = output["mel_lens"].cpu()
|
| 124 |
+
return pred_res, mel_lens, 0
|
| 125 |
+
|
| 126 |
+
def inference_for_single_utterance(self):
|
| 127 |
+
text = self.args.text
|
| 128 |
+
control_values = (
|
| 129 |
+
self.args.pitch_control,
|
| 130 |
+
self.args.energy_control,
|
| 131 |
+
self.args.duration_control,
|
| 132 |
+
)
|
| 133 |
+
pitch_control, energy_control, duration_control = control_values
|
| 134 |
+
|
| 135 |
+
# get phone symbol file
|
| 136 |
+
phone_symbol_file = None
|
| 137 |
+
if self.cfg.preprocess.phone_extractor != "lexicon":
|
| 138 |
+
phone_symbol_file = os.path.join(
|
| 139 |
+
self.exp_dir, self.cfg.preprocess.symbols_dict
|
| 140 |
+
)
|
| 141 |
+
assert os.path.exists(phone_symbol_file)
|
| 142 |
+
# convert text to phone sequence
|
| 143 |
+
phone_extractor = phoneExtractor(self.cfg)
|
| 144 |
+
|
| 145 |
+
phone_seq = phone_extractor.extract_phone(text) # phone_seq: list
|
| 146 |
+
# convert phone sequence to phone id sequence
|
| 147 |
+
phon_id_collator = phoneIDCollation(
|
| 148 |
+
self.cfg, symbols_dict_file=phone_symbol_file
|
| 149 |
+
)
|
| 150 |
+
phone_seq = ["{"] + phone_seq + ["}"]
|
| 151 |
+
phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq)
|
| 152 |
+
|
| 153 |
+
# convert phone sequence to phone id sequence
|
| 154 |
+
phone_id_seq = np.array(phone_id_seq)
|
| 155 |
+
phone_id_seq = torch.from_numpy(phone_id_seq)
|
| 156 |
+
|
| 157 |
+
# get speaker id if multi-speaker training and use speaker id
|
| 158 |
+
speaker_id = None
|
| 159 |
+
if self.cfg.preprocess.use_spkid and self.cfg.train.multi_speaker_training:
|
| 160 |
+
spk2id_file = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
|
| 161 |
+
with open(spk2id_file, "r") as f:
|
| 162 |
+
spk2id = json.load(f)
|
| 163 |
+
speaker_id = spk2id[self.args.speaker_name]
|
| 164 |
+
speaker_id = torch.from_numpy(np.array([speaker_id], dtype=np.int32))
|
| 165 |
+
else:
|
| 166 |
+
speaker_id = torch.Tensor(0).view(-1)
|
| 167 |
+
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
x_tst = phone_id_seq.to(self.device).unsqueeze(0)
|
| 170 |
+
x_tst_lengths = torch.LongTensor([phone_id_seq.size(0)]).to(self.device)
|
| 171 |
+
if speaker_id is not None:
|
| 172 |
+
speaker_id = speaker_id.to(self.device)
|
| 173 |
+
|
| 174 |
+
data = {}
|
| 175 |
+
data["texts"] = x_tst
|
| 176 |
+
data["text_len"] = x_tst_lengths
|
| 177 |
+
data["spk_id"] = speaker_id
|
| 178 |
+
|
| 179 |
+
output = self.model(
|
| 180 |
+
data,
|
| 181 |
+
p_control=pitch_control,
|
| 182 |
+
e_control=energy_control,
|
| 183 |
+
d_control=duration_control,
|
| 184 |
+
)
|
| 185 |
+
pred_res = output["postnet_output"]
|
| 186 |
+
vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
|
| 187 |
+
audio = synthesis(
|
| 188 |
+
cfg=vocoder_cfg,
|
| 189 |
+
vocoder_weight_file=vocoder_ckpt,
|
| 190 |
+
n_samples=None,
|
| 191 |
+
pred=pred_res,
|
| 192 |
+
)
|
| 193 |
+
return audio[0]
|
Amphion/models/tts/naturalspeech2/__init__.py
ADDED
|
File without changes
|
Amphion/models/tts/naturalspeech2/wavenet.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FiLM(nn.Module):
|
| 14 |
+
def __init__(self, in_dim, cond_dim):
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
self.gain = Linear(cond_dim, in_dim)
|
| 18 |
+
self.bias = Linear(cond_dim, in_dim)
|
| 19 |
+
|
| 20 |
+
nn.init.xavier_uniform_(self.gain.weight)
|
| 21 |
+
nn.init.constant_(self.gain.bias, 1)
|
| 22 |
+
|
| 23 |
+
nn.init.xavier_uniform_(self.bias.weight)
|
| 24 |
+
nn.init.constant_(self.bias.bias, 0)
|
| 25 |
+
|
| 26 |
+
def forward(self, x, condition):
|
| 27 |
+
gain = self.gain(condition)
|
| 28 |
+
bias = self.bias(condition)
|
| 29 |
+
if gain.dim() == 2:
|
| 30 |
+
gain = gain.unsqueeze(-1)
|
| 31 |
+
if bias.dim() == 2:
|
| 32 |
+
bias = bias.unsqueeze(-1)
|
| 33 |
+
return x * gain + bias
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Mish(nn.Module):
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
return x * torch.tanh(F.softplus(x))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def Conv1d(*args, **kwargs):
|
| 42 |
+
layer = nn.Conv1d(*args, **kwargs)
|
| 43 |
+
nn.init.kaiming_normal_(layer.weight)
|
| 44 |
+
return layer
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def Linear(*args, **kwargs):
|
| 48 |
+
layer = nn.Linear(*args, **kwargs)
|
| 49 |
+
layer.weight.data.normal_(0.0, 0.02)
|
| 50 |
+
return layer
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class SinusoidalPosEmb(nn.Module):
|
| 54 |
+
def __init__(self, dim):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.dim = dim
|
| 57 |
+
|
| 58 |
+
def forward(self, x):
|
| 59 |
+
device = x.device
|
| 60 |
+
half_dim = self.dim // 2
|
| 61 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 62 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
| 63 |
+
emb = x[:, None] * emb[None, :]
|
| 64 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 65 |
+
return emb
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class ResidualBlock(nn.Module):
|
| 69 |
+
def __init__(self, hidden_dim, attn_head, dilation, drop_out, has_cattn=False):
|
| 70 |
+
super().__init__()
|
| 71 |
+
|
| 72 |
+
self.hidden_dim = hidden_dim
|
| 73 |
+
self.dilation = dilation
|
| 74 |
+
self.has_cattn = has_cattn
|
| 75 |
+
self.attn_head = attn_head
|
| 76 |
+
self.drop_out = drop_out
|
| 77 |
+
|
| 78 |
+
self.dilated_conv = Conv1d(
|
| 79 |
+
hidden_dim, 2 * hidden_dim, 3, padding=dilation, dilation=dilation
|
| 80 |
+
)
|
| 81 |
+
self.diffusion_proj = Linear(hidden_dim, hidden_dim)
|
| 82 |
+
|
| 83 |
+
self.cond_proj = Conv1d(hidden_dim, hidden_dim * 2, 1)
|
| 84 |
+
self.out_proj = Conv1d(hidden_dim, hidden_dim * 2, 1)
|
| 85 |
+
|
| 86 |
+
if self.has_cattn:
|
| 87 |
+
self.attn = nn.MultiheadAttention(
|
| 88 |
+
hidden_dim, attn_head, 0.1, batch_first=True
|
| 89 |
+
)
|
| 90 |
+
self.film = FiLM(hidden_dim * 2, hidden_dim)
|
| 91 |
+
|
| 92 |
+
self.ln = nn.LayerNorm(hidden_dim)
|
| 93 |
+
|
| 94 |
+
self.dropout = nn.Dropout(self.drop_out)
|
| 95 |
+
|
| 96 |
+
def forward(self, x, x_mask, cond, diffusion_step, spk_query_emb):
|
| 97 |
+
diffusion_step = self.diffusion_proj(diffusion_step).unsqueeze(-1) # (B, d, 1)
|
| 98 |
+
cond = self.cond_proj(cond) # (B, 2*d, T)
|
| 99 |
+
|
| 100 |
+
y = x + diffusion_step
|
| 101 |
+
if x_mask != None:
|
| 102 |
+
y = y * x_mask.to(y.dtype)[:, None, :] # (B, 2*d, T)
|
| 103 |
+
|
| 104 |
+
if self.has_cattn:
|
| 105 |
+
y_ = y.transpose(1, 2)
|
| 106 |
+
y_ = self.ln(y_)
|
| 107 |
+
|
| 108 |
+
y_, _ = self.attn(y_, spk_query_emb, spk_query_emb) # (B, T, d)
|
| 109 |
+
|
| 110 |
+
y = self.dilated_conv(y) + cond # (B, 2*d, T)
|
| 111 |
+
|
| 112 |
+
if self.has_cattn:
|
| 113 |
+
y = self.film(y.transpose(1, 2), y_) # (B, T, 2*d)
|
| 114 |
+
y = y.transpose(1, 2) # (B, 2*d, T)
|
| 115 |
+
|
| 116 |
+
gate, filter_ = torch.chunk(y, 2, dim=1)
|
| 117 |
+
y = torch.sigmoid(gate) * torch.tanh(filter_)
|
| 118 |
+
|
| 119 |
+
y = self.out_proj(y)
|
| 120 |
+
|
| 121 |
+
residual, skip = torch.chunk(y, 2, dim=1)
|
| 122 |
+
|
| 123 |
+
if x_mask != None:
|
| 124 |
+
residual = residual * x_mask.to(y.dtype)[:, None, :]
|
| 125 |
+
skip = skip * x_mask.to(y.dtype)[:, None, :]
|
| 126 |
+
|
| 127 |
+
return (x + residual) / math.sqrt(2.0), skip
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class WaveNet(nn.Module):
|
| 131 |
+
def __init__(self, cfg):
|
| 132 |
+
super().__init__()
|
| 133 |
+
|
| 134 |
+
self.cfg = cfg
|
| 135 |
+
self.in_dim = cfg.input_size
|
| 136 |
+
self.hidden_dim = cfg.hidden_size
|
| 137 |
+
self.out_dim = cfg.out_size
|
| 138 |
+
self.num_layers = cfg.num_layers
|
| 139 |
+
self.cross_attn_per_layer = cfg.cross_attn_per_layer
|
| 140 |
+
self.dilation_cycle = cfg.dilation_cycle
|
| 141 |
+
self.attn_head = cfg.attn_head
|
| 142 |
+
self.drop_out = cfg.drop_out
|
| 143 |
+
|
| 144 |
+
self.in_proj = Conv1d(self.in_dim, self.hidden_dim, 1)
|
| 145 |
+
self.diffusion_embedding = SinusoidalPosEmb(self.hidden_dim)
|
| 146 |
+
|
| 147 |
+
self.mlp = nn.Sequential(
|
| 148 |
+
Linear(self.hidden_dim, self.hidden_dim * 4),
|
| 149 |
+
Mish(),
|
| 150 |
+
Linear(self.hidden_dim * 4, self.hidden_dim),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
self.cond_ln = nn.LayerNorm(self.hidden_dim)
|
| 154 |
+
|
| 155 |
+
self.layers = nn.ModuleList(
|
| 156 |
+
[
|
| 157 |
+
ResidualBlock(
|
| 158 |
+
self.hidden_dim,
|
| 159 |
+
self.attn_head,
|
| 160 |
+
2 ** (i % self.dilation_cycle),
|
| 161 |
+
self.drop_out,
|
| 162 |
+
has_cattn=(i % self.cross_attn_per_layer == 0),
|
| 163 |
+
)
|
| 164 |
+
for i in range(self.num_layers)
|
| 165 |
+
]
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.skip_proj = Conv1d(self.hidden_dim, self.hidden_dim, 1)
|
| 169 |
+
self.out_proj = Conv1d(self.hidden_dim, self.out_dim, 1)
|
| 170 |
+
|
| 171 |
+
nn.init.zeros_(self.out_proj.weight)
|
| 172 |
+
|
| 173 |
+
def forward(self, x, x_mask, cond, diffusion_step, spk_query_emb):
|
| 174 |
+
"""
|
| 175 |
+
x: (B, 128, T)
|
| 176 |
+
x_mask: (B, T), mask is 0
|
| 177 |
+
cond: (B, T, 512)
|
| 178 |
+
diffusion_step: (B,)
|
| 179 |
+
spk_query_emb: (B, 32, 512)
|
| 180 |
+
"""
|
| 181 |
+
cond = self.cond_ln(cond)
|
| 182 |
+
cond_input = cond.transpose(1, 2)
|
| 183 |
+
|
| 184 |
+
x_input = self.in_proj(x)
|
| 185 |
+
|
| 186 |
+
x_input = F.relu(x_input)
|
| 187 |
+
|
| 188 |
+
diffusion_step = self.diffusion_embedding(diffusion_step).to(x.dtype)
|
| 189 |
+
diffusion_step = self.mlp(diffusion_step)
|
| 190 |
+
|
| 191 |
+
skip = []
|
| 192 |
+
for _, layer in enumerate(self.layers):
|
| 193 |
+
x_input, skip_connection = layer(
|
| 194 |
+
x_input, x_mask, cond_input, diffusion_step, spk_query_emb
|
| 195 |
+
)
|
| 196 |
+
skip.append(skip_connection)
|
| 197 |
+
|
| 198 |
+
x_input = torch.sum(torch.stack(skip), dim=0) / math.sqrt(self.num_layers)
|
| 199 |
+
|
| 200 |
+
x_out = self.skip_proj(x_input)
|
| 201 |
+
|
| 202 |
+
x_out = F.relu(x_out)
|
| 203 |
+
|
| 204 |
+
x_out = self.out_proj(x_out) # (B, 128, T)
|
| 205 |
+
|
| 206 |
+
return x_out
|
Amphion/models/tts/valle/valle.py
ADDED
|
@@ -0,0 +1,794 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# This code is modified from https://github.com/lifeiteng/vall-e/blob/main/valle/models/valle.py
|
| 7 |
+
|
| 8 |
+
import random
|
| 9 |
+
from typing import Dict, Iterator, List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from torchmetrics.classification import MulticlassAccuracy
|
| 15 |
+
from utils.util import make_pad_mask
|
| 16 |
+
from utils.topk_sampling import topk_sampling
|
| 17 |
+
from modules.general import Transpose
|
| 18 |
+
from modules.encoder import TokenEmbedding
|
| 19 |
+
from modules.general import PromptedFeatures
|
| 20 |
+
from modules.transformer import SinePositionalEmbedding
|
| 21 |
+
from modules.norms import AdaptiveLayerNorm, LayerNorm
|
| 22 |
+
from modules.transformer.transformer import TransformerEncoder, TransformerEncoderLayer
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class VALLE(nn.Module):
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
cfg,
|
| 29 |
+
decoder_cls=TransformerEncoder,
|
| 30 |
+
decoder_layer_cls=TransformerEncoderLayer,
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
decoder_dim = cfg.decoder_dim
|
| 34 |
+
nhead = cfg.nhead
|
| 35 |
+
nar_scale_factor = cfg.nar_scale_factor
|
| 36 |
+
num_quantizers = cfg.num_quantizers
|
| 37 |
+
num_decoder_layers = cfg.num_decoder_layers
|
| 38 |
+
nar_decoder_dim = int(decoder_dim * nar_scale_factor)
|
| 39 |
+
|
| 40 |
+
self.ar_text_embedding = TokenEmbedding(decoder_dim, cfg.text_token_num)
|
| 41 |
+
self.nar_text_embedding = TokenEmbedding(nar_decoder_dim, cfg.text_token_num)
|
| 42 |
+
|
| 43 |
+
self.ar_audio_prepend_bos = cfg.prepend_bos
|
| 44 |
+
self.ar_audio_embedding = TokenEmbedding(
|
| 45 |
+
decoder_dim, cfg.audio_token_num + 1 + int(cfg.prepend_bos)
|
| 46 |
+
)
|
| 47 |
+
self.audio_token_num = cfg.audio_token_num
|
| 48 |
+
|
| 49 |
+
# PreNet of AR
|
| 50 |
+
if cfg.add_prenet:
|
| 51 |
+
self.ar_text_prenet = nn.Sequential(
|
| 52 |
+
Transpose(),
|
| 53 |
+
nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"),
|
| 54 |
+
nn.BatchNorm1d(decoder_dim),
|
| 55 |
+
nn.ReLU(),
|
| 56 |
+
nn.Dropout(0.5),
|
| 57 |
+
nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"),
|
| 58 |
+
nn.BatchNorm1d(decoder_dim),
|
| 59 |
+
nn.ReLU(),
|
| 60 |
+
nn.Dropout(0.5),
|
| 61 |
+
nn.Conv1d(decoder_dim, decoder_dim, kernel_size=5, padding="same"),
|
| 62 |
+
nn.BatchNorm1d(decoder_dim),
|
| 63 |
+
nn.ReLU(),
|
| 64 |
+
nn.Dropout(0.5),
|
| 65 |
+
Transpose(),
|
| 66 |
+
nn.Linear(decoder_dim, decoder_dim),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
self.ar_audio_prenet = nn.Sequential(
|
| 70 |
+
nn.Linear(decoder_dim, 256),
|
| 71 |
+
nn.ReLU(),
|
| 72 |
+
nn.Dropout(0.25),
|
| 73 |
+
nn.Linear(256, 256),
|
| 74 |
+
nn.ReLU(),
|
| 75 |
+
nn.Dropout(0.25),
|
| 76 |
+
nn.Linear(256, decoder_dim),
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
self.ar_text_prenet = nn.Identity()
|
| 80 |
+
self.ar_audio_prenet = nn.Identity()
|
| 81 |
+
|
| 82 |
+
self.ar_text_position = SinePositionalEmbedding(
|
| 83 |
+
decoder_dim,
|
| 84 |
+
dropout=0.1,
|
| 85 |
+
scale=False,
|
| 86 |
+
alpha=True,
|
| 87 |
+
)
|
| 88 |
+
self.ar_audio_position = SinePositionalEmbedding(
|
| 89 |
+
decoder_dim,
|
| 90 |
+
dropout=0.1,
|
| 91 |
+
scale=False,
|
| 92 |
+
alpha=True,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
self.ar_decoder = decoder_cls(
|
| 96 |
+
decoder_layer_cls(
|
| 97 |
+
decoder_dim,
|
| 98 |
+
nhead,
|
| 99 |
+
dim_feedforward=decoder_dim * 4, # *4?
|
| 100 |
+
dropout=0.1,
|
| 101 |
+
batch_first=True,
|
| 102 |
+
norm_first=cfg.norm_first,
|
| 103 |
+
),
|
| 104 |
+
num_layers=num_decoder_layers,
|
| 105 |
+
norm=LayerNorm(decoder_dim) if cfg.norm_first else None,
|
| 106 |
+
)
|
| 107 |
+
self.ar_predict_layer = nn.Linear(
|
| 108 |
+
decoder_dim, cfg.audio_token_num + 1, bias=False
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self.ar_accuracy_metric = MulticlassAccuracy(
|
| 112 |
+
cfg.audio_token_num + 1,
|
| 113 |
+
top_k=10,
|
| 114 |
+
average="micro",
|
| 115 |
+
multidim_average="global",
|
| 116 |
+
ignore_index=cfg.audio_token_num,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self.rng = random.Random(0)
|
| 120 |
+
self.num_heads = nhead
|
| 121 |
+
self.prefix_mode = cfg.prefix_mode
|
| 122 |
+
self.num_quantizers = num_quantizers
|
| 123 |
+
|
| 124 |
+
assert num_quantizers >= 1
|
| 125 |
+
if num_quantizers > 1:
|
| 126 |
+
self.nar_audio_embeddings = nn.ModuleList(
|
| 127 |
+
[
|
| 128 |
+
TokenEmbedding(nar_decoder_dim, cfg.audio_token_num + 1)
|
| 129 |
+
] # Why the first layer is audio_token_num + 1?
|
| 130 |
+
+ [
|
| 131 |
+
TokenEmbedding(nar_decoder_dim, cfg.audio_token_num)
|
| 132 |
+
for i in range(num_quantizers - 1)
|
| 133 |
+
]
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
if cfg.add_prenet:
|
| 137 |
+
self.nar_text_prenet = nn.Sequential(
|
| 138 |
+
Transpose(),
|
| 139 |
+
nn.Conv1d(
|
| 140 |
+
nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same"
|
| 141 |
+
),
|
| 142 |
+
nn.BatchNorm1d(nar_decoder_dim),
|
| 143 |
+
nn.ReLU(),
|
| 144 |
+
nn.Dropout(0.5),
|
| 145 |
+
nn.Conv1d(
|
| 146 |
+
nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same"
|
| 147 |
+
),
|
| 148 |
+
nn.BatchNorm1d(nar_decoder_dim),
|
| 149 |
+
nn.ReLU(),
|
| 150 |
+
nn.Dropout(0.5),
|
| 151 |
+
nn.Conv1d(
|
| 152 |
+
nar_decoder_dim, nar_decoder_dim, kernel_size=5, padding="same"
|
| 153 |
+
),
|
| 154 |
+
nn.BatchNorm1d(nar_decoder_dim),
|
| 155 |
+
nn.ReLU(),
|
| 156 |
+
nn.Dropout(0.5),
|
| 157 |
+
Transpose(),
|
| 158 |
+
nn.Linear(nar_decoder_dim, nar_decoder_dim),
|
| 159 |
+
)
|
| 160 |
+
self.nar_audio_prenet = nn.Sequential(
|
| 161 |
+
nn.Linear(nar_decoder_dim, 256),
|
| 162 |
+
nn.ReLU(),
|
| 163 |
+
nn.Dropout(0.25),
|
| 164 |
+
nn.Linear(256, 256),
|
| 165 |
+
nn.ReLU(),
|
| 166 |
+
nn.Dropout(0.25),
|
| 167 |
+
nn.Linear(256, nar_decoder_dim),
|
| 168 |
+
)
|
| 169 |
+
else:
|
| 170 |
+
self.nar_text_prenet = nn.Identity()
|
| 171 |
+
self.nar_audio_prenet = nn.Identity()
|
| 172 |
+
|
| 173 |
+
self.nar_text_position = SinePositionalEmbedding(
|
| 174 |
+
nar_decoder_dim,
|
| 175 |
+
dropout=0.0,
|
| 176 |
+
scale=False,
|
| 177 |
+
alpha=False,
|
| 178 |
+
)
|
| 179 |
+
self.nar_audio_position = SinePositionalEmbedding(
|
| 180 |
+
nar_decoder_dim,
|
| 181 |
+
dropout=0.1,
|
| 182 |
+
scale=False,
|
| 183 |
+
alpha=False,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
self.nar_decoder = decoder_cls(
|
| 187 |
+
decoder_layer_cls(
|
| 188 |
+
nar_decoder_dim,
|
| 189 |
+
int(nhead * nar_scale_factor),
|
| 190 |
+
dim_feedforward=nar_decoder_dim * 4,
|
| 191 |
+
dropout=0.1,
|
| 192 |
+
batch_first=True,
|
| 193 |
+
norm_first=cfg.norm_first,
|
| 194 |
+
adaptive_layer_norm=True,
|
| 195 |
+
),
|
| 196 |
+
num_layers=int(num_decoder_layers * nar_scale_factor),
|
| 197 |
+
norm=(
|
| 198 |
+
AdaptiveLayerNorm(
|
| 199 |
+
nar_decoder_dim, norm=nn.LayerNorm(nar_decoder_dim)
|
| 200 |
+
)
|
| 201 |
+
if cfg.norm_first
|
| 202 |
+
else None
|
| 203 |
+
),
|
| 204 |
+
)
|
| 205 |
+
self.nar_predict_layers = nn.ModuleList(
|
| 206 |
+
[
|
| 207 |
+
nn.Linear(nar_decoder_dim, cfg.audio_token_num, bias=False)
|
| 208 |
+
for i in range(num_quantizers - 1)
|
| 209 |
+
]
|
| 210 |
+
)
|
| 211 |
+
self.nar_stage_embeddings = nn.ModuleList(
|
| 212 |
+
[TokenEmbedding(nar_decoder_dim, 1) for i in range(num_quantizers - 1)]
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if cfg.share_embedding:
|
| 216 |
+
for j in range(0, num_quantizers - 2):
|
| 217 |
+
self.nar_predict_layers[j].weight = self.nar_audio_embeddings[
|
| 218 |
+
j + 2
|
| 219 |
+
].weight
|
| 220 |
+
|
| 221 |
+
self.nar_accuracy_metric = MulticlassAccuracy(
|
| 222 |
+
cfg.audio_token_num + 1,
|
| 223 |
+
top_k=10,
|
| 224 |
+
average="micro",
|
| 225 |
+
multidim_average="global",
|
| 226 |
+
ignore_index=cfg.audio_token_num,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
def forward(
|
| 230 |
+
self,
|
| 231 |
+
x: torch.Tensor,
|
| 232 |
+
x_lens: torch.Tensor,
|
| 233 |
+
y: Union[torch.Tensor, PromptedFeatures],
|
| 234 |
+
y_lens: Union[torch.Tensor, PromptedFeatures],
|
| 235 |
+
reduction: str = "sum",
|
| 236 |
+
train_stage: int = 0,
|
| 237 |
+
**kwargs,
|
| 238 |
+
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
| 239 |
+
"""
|
| 240 |
+
Args:
|
| 241 |
+
x:
|
| 242 |
+
A 2-D tensor of shape (N, S).
|
| 243 |
+
x_lens:
|
| 244 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
| 245 |
+
before padding.
|
| 246 |
+
y:
|
| 247 |
+
A 3-D tensor of shape (N, T, 8).
|
| 248 |
+
y_lens:
|
| 249 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
| 250 |
+
before padding.
|
| 251 |
+
train_stage:
|
| 252 |
+
0: AR & NAR modules, 1: AR modules, 2: NAR modules
|
| 253 |
+
Returns:
|
| 254 |
+
Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
|
| 255 |
+
"""
|
| 256 |
+
assert x.ndim == 2, x.shape
|
| 257 |
+
assert x_lens.ndim == 1, x_lens.shape
|
| 258 |
+
|
| 259 |
+
y_prompts_codes = None
|
| 260 |
+
if isinstance(y, PromptedFeatures):
|
| 261 |
+
y_prompts_codes, y = y.data
|
| 262 |
+
prompts_len, y_lens = y_lens.data
|
| 263 |
+
assert prompts_len.min() == prompts_len.max()
|
| 264 |
+
assert self.prefix_mode == 4
|
| 265 |
+
y_prompts_codes = y_prompts_codes.type(torch.int64)
|
| 266 |
+
|
| 267 |
+
assert y.ndim == 3, y.shape
|
| 268 |
+
assert y_lens.ndim == 1, y_lens.shape
|
| 269 |
+
|
| 270 |
+
x_mask = make_pad_mask(x_lens).to(x.device)
|
| 271 |
+
y_mask = make_pad_mask(y_lens).to(y.device)
|
| 272 |
+
y_mask_int = y_mask.type(torch.int64)
|
| 273 |
+
|
| 274 |
+
text = x
|
| 275 |
+
codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1))
|
| 276 |
+
|
| 277 |
+
y, targets = self.pad_y_eos(
|
| 278 |
+
codes[..., 0], y_mask_int, eos_id=self.audio_token_num
|
| 279 |
+
)
|
| 280 |
+
self.y_mask_int = y_mask_int
|
| 281 |
+
|
| 282 |
+
metrics = {}
|
| 283 |
+
total_loss = 0.0
|
| 284 |
+
|
| 285 |
+
xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
|
| 286 |
+
if self.ar_audio_prepend_bos:
|
| 287 |
+
ar_xy_padding_mask = torch.concat(
|
| 288 |
+
[x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1
|
| 289 |
+
)
|
| 290 |
+
else:
|
| 291 |
+
ar_xy_padding_mask = xy_padding_mask
|
| 292 |
+
self.xy_padding_mask = xy_padding_mask
|
| 293 |
+
self.ar_xy_padding_mask = ar_xy_padding_mask
|
| 294 |
+
|
| 295 |
+
# AR Decoder
|
| 296 |
+
if train_stage in [0, 1]:
|
| 297 |
+
ar_loss, ar_metrics = self._forward_ar_decoder(
|
| 298 |
+
text, x_lens.max(), y, y_lens.max(), targets, x_mask, y_mask, reduction
|
| 299 |
+
)
|
| 300 |
+
total_loss += ar_loss
|
| 301 |
+
metrics["AR_Top100Acc"] = ar_metrics
|
| 302 |
+
|
| 303 |
+
# NAR Decoder
|
| 304 |
+
if self.ar_audio_prepend_bos:
|
| 305 |
+
y = y[:, 1:]
|
| 306 |
+
|
| 307 |
+
if self.num_quantizers > 1 and train_stage in [0, 2]:
|
| 308 |
+
nar_loss, nar_metrics = self._forward_nar_decoder(
|
| 309 |
+
text,
|
| 310 |
+
x_lens,
|
| 311 |
+
y,
|
| 312 |
+
y_lens,
|
| 313 |
+
codes,
|
| 314 |
+
y_prompts_codes,
|
| 315 |
+
x_mask,
|
| 316 |
+
y_mask,
|
| 317 |
+
reduction,
|
| 318 |
+
)
|
| 319 |
+
total_loss += nar_loss
|
| 320 |
+
metrics["NAR_Top100Acc"] = nar_metrics
|
| 321 |
+
|
| 322 |
+
if train_stage == 0:
|
| 323 |
+
total_loss = total_loss / 2.0
|
| 324 |
+
|
| 325 |
+
return total_loss, metrics
|
| 326 |
+
|
| 327 |
+
def _forward_ar_decoder(
|
| 328 |
+
self, x, x_len, y, y_lens, targets, x_mask, y_mask, reduction
|
| 329 |
+
):
|
| 330 |
+
x = self.ar_text_embedding(x)
|
| 331 |
+
x = self.ar_text_prenet(x)
|
| 332 |
+
x = self.ar_text_position(x)
|
| 333 |
+
|
| 334 |
+
y_len = y_lens.max() + int(self.ar_audio_prepend_bos)
|
| 335 |
+
|
| 336 |
+
x_attn_mask = F.pad(
|
| 337 |
+
torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
|
| 338 |
+
(0, y_len),
|
| 339 |
+
value=True,
|
| 340 |
+
)
|
| 341 |
+
y_attn_mask = F.pad(
|
| 342 |
+
torch.triu(
|
| 343 |
+
torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
|
| 344 |
+
diagonal=1,
|
| 345 |
+
),
|
| 346 |
+
(x_len, 0),
|
| 347 |
+
value=False,
|
| 348 |
+
)
|
| 349 |
+
xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
|
| 350 |
+
|
| 351 |
+
bsz, src_len = x.shape[0], x_len + y_len
|
| 352 |
+
_xy_padding_mask = (
|
| 353 |
+
self.ar_xy_padding_mask.view(bsz, 1, 1, src_len)
|
| 354 |
+
.expand(-1, self.num_heads, -1, -1)
|
| 355 |
+
.reshape(bsz * self.num_heads, 1, src_len)
|
| 356 |
+
)
|
| 357 |
+
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
| 358 |
+
|
| 359 |
+
new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
|
| 360 |
+
new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
|
| 361 |
+
xy_attn_mask = new_attn_mask
|
| 362 |
+
|
| 363 |
+
y_emb = self.ar_audio_embedding(y)
|
| 364 |
+
y_emb = self.ar_audio_prenet(y_emb)
|
| 365 |
+
y_pos = self.ar_audio_position(y_emb)
|
| 366 |
+
|
| 367 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
| 368 |
+
|
| 369 |
+
xy_dec, _ = self.ar_decoder(
|
| 370 |
+
(xy_pos, None),
|
| 371 |
+
mask=xy_attn_mask,
|
| 372 |
+
)
|
| 373 |
+
logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
|
| 374 |
+
ar_loss = F.cross_entropy(logits, targets, reduction=reduction)
|
| 375 |
+
|
| 376 |
+
ar_metrics = self.ar_accuracy_metric(
|
| 377 |
+
logits.detach(), targets
|
| 378 |
+
).item() * y_lens.sum().type(torch.float32)
|
| 379 |
+
|
| 380 |
+
return ar_loss, ar_metrics
|
| 381 |
+
|
| 382 |
+
def _forward_nar_decoder(
|
| 383 |
+
self, x, x_lens, y, y_lens, codes, y_prompts_codes, x_mask, y_mask, reduction
|
| 384 |
+
):
|
| 385 |
+
num_nar_layers = self.num_quantizers - 1
|
| 386 |
+
nar_stage = self.rng.choices(
|
| 387 |
+
[_k for _k in range(1, self.num_quantizers)],
|
| 388 |
+
weights=[1.0 / num_nar_layers] * num_nar_layers,
|
| 389 |
+
k=1,
|
| 390 |
+
)[0]
|
| 391 |
+
|
| 392 |
+
x = self.nar_text_embedding(x)
|
| 393 |
+
x = self.nar_text_prenet(x)
|
| 394 |
+
x = self.nar_text_position(x)
|
| 395 |
+
|
| 396 |
+
y_emb, prefix_len = self._prepare_prompts(
|
| 397 |
+
y, y_lens, codes, nar_stage, y_prompts_codes
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
y_len = y_lens.max()
|
| 401 |
+
targets = codes[..., nar_stage] + self.audio_token_num * self.y_mask_int
|
| 402 |
+
if self.prefix_mode in [2, 4]:
|
| 403 |
+
xy_padding_mask = torch.concat(
|
| 404 |
+
[
|
| 405 |
+
x_mask,
|
| 406 |
+
F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False),
|
| 407 |
+
],
|
| 408 |
+
dim=1,
|
| 409 |
+
)
|
| 410 |
+
elif self.prefix_mode == 1:
|
| 411 |
+
targets = targets[:, prefix_len:]
|
| 412 |
+
|
| 413 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
| 414 |
+
y_pos = self.nar_audio_position(y_pos)
|
| 415 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
| 416 |
+
xy_dec, _ = self.nar_decoder(
|
| 417 |
+
(xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight),
|
| 418 |
+
src_key_padding_mask=self.xy_padding_mask,
|
| 419 |
+
)
|
| 420 |
+
xy_dec = xy_dec[:, x_lens.max() + prefix_len :]
|
| 421 |
+
if self.prefix_mode == 4:
|
| 422 |
+
prefix_len = 0
|
| 423 |
+
logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(0, 2, 1)
|
| 424 |
+
|
| 425 |
+
total_length = (y_lens).sum().type(torch.float32)
|
| 426 |
+
nar_loss = F.cross_entropy(
|
| 427 |
+
logits,
|
| 428 |
+
targets,
|
| 429 |
+
ignore_index=self.audio_token_num,
|
| 430 |
+
reduction=reduction,
|
| 431 |
+
) * (total_length / (total_length - prefix_len * x.shape[0]))
|
| 432 |
+
nar_metrics = (
|
| 433 |
+
self.nar_accuracy_metric(
|
| 434 |
+
F.pad(
|
| 435 |
+
logits.detach(),
|
| 436 |
+
(0, 0, 0, 1, 0, 0),
|
| 437 |
+
value=logits.min().cpu().item(),
|
| 438 |
+
),
|
| 439 |
+
targets,
|
| 440 |
+
).item()
|
| 441 |
+
* total_length
|
| 442 |
+
)
|
| 443 |
+
return nar_loss, nar_metrics
|
| 444 |
+
|
| 445 |
+
def inference(
|
| 446 |
+
self,
|
| 447 |
+
x: torch.Tensor,
|
| 448 |
+
x_lens: torch.Tensor,
|
| 449 |
+
y: torch.Tensor,
|
| 450 |
+
enroll_x_lens: torch.Tensor,
|
| 451 |
+
top_k: int = -100,
|
| 452 |
+
temperature: float = 1.0,
|
| 453 |
+
) -> torch.Tensor:
|
| 454 |
+
"""
|
| 455 |
+
Args:
|
| 456 |
+
x:
|
| 457 |
+
A 2-D tensor of shape (1, S).
|
| 458 |
+
x_lens:
|
| 459 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
| 460 |
+
before padding.
|
| 461 |
+
y:
|
| 462 |
+
A 3-D tensor of shape (1, T, 8).
|
| 463 |
+
top_k: (`optional`) int
|
| 464 |
+
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
|
| 465 |
+
temperature: (`optional`) float
|
| 466 |
+
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
| 467 |
+
Returns:
|
| 468 |
+
Return the predicted audio code matrix.
|
| 469 |
+
"""
|
| 470 |
+
assert x.ndim == 2, x.shape
|
| 471 |
+
assert x_lens.ndim == 1, x_lens.shape
|
| 472 |
+
assert y.ndim == 3, y.shape
|
| 473 |
+
assert y.shape[0] == 1, y.shape
|
| 474 |
+
|
| 475 |
+
assert torch.all(x_lens > 0)
|
| 476 |
+
|
| 477 |
+
text = x
|
| 478 |
+
x = self.ar_text_embedding(text)
|
| 479 |
+
x = self.ar_text_prenet(x)
|
| 480 |
+
x = self.ar_text_position(x)
|
| 481 |
+
|
| 482 |
+
text_len = x_lens.max()
|
| 483 |
+
prompts = y
|
| 484 |
+
prefix_len = y.shape[1]
|
| 485 |
+
|
| 486 |
+
# AR Decoder
|
| 487 |
+
y = prompts[..., 0]
|
| 488 |
+
if self.ar_audio_prepend_bos:
|
| 489 |
+
y = F.pad(y, (1, 0), value=self.audio_token_num + 1)
|
| 490 |
+
|
| 491 |
+
x_len = x_lens.max()
|
| 492 |
+
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
| 493 |
+
|
| 494 |
+
while True:
|
| 495 |
+
y_emb = self.ar_audio_embedding(y)
|
| 496 |
+
y_emb = self.ar_audio_prenet(y_emb)
|
| 497 |
+
y_pos = self.ar_audio_position(y_emb)
|
| 498 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
| 499 |
+
|
| 500 |
+
y_len = y.shape[1]
|
| 501 |
+
x_attn_mask_pad = F.pad(
|
| 502 |
+
x_attn_mask,
|
| 503 |
+
(0, y_len),
|
| 504 |
+
value=True,
|
| 505 |
+
)
|
| 506 |
+
y_attn_mask = F.pad(
|
| 507 |
+
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
| 508 |
+
(x_len, 0),
|
| 509 |
+
value=False,
|
| 510 |
+
)
|
| 511 |
+
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
| 512 |
+
y.device
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
xy_dec, _ = self.ar_decoder(
|
| 516 |
+
(xy_pos, None),
|
| 517 |
+
mask=xy_attn_mask,
|
| 518 |
+
)
|
| 519 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])
|
| 520 |
+
samples = topk_sampling(
|
| 521 |
+
logits, top_k=top_k, top_p=1.0, temperature=temperature
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
if (
|
| 525 |
+
torch.argmax(logits, dim=-1)[0] == self.audio_token_num
|
| 526 |
+
or samples[0, 0] == self.audio_token_num
|
| 527 |
+
or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
|
| 528 |
+
):
|
| 529 |
+
if prompts.shape[1] == y.shape[1]:
|
| 530 |
+
raise SyntaxError("well trained model shouldn't reach here.")
|
| 531 |
+
|
| 532 |
+
break
|
| 533 |
+
|
| 534 |
+
y = torch.concat([y, samples], dim=1)
|
| 535 |
+
|
| 536 |
+
codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
|
| 537 |
+
if self.num_quantizers == 1:
|
| 538 |
+
return torch.stack(codes, dim=-1)
|
| 539 |
+
|
| 540 |
+
# Non-AR Decoders
|
| 541 |
+
y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :])
|
| 542 |
+
|
| 543 |
+
if self.prefix_mode in [2, 4]:
|
| 544 |
+
enrolled_len = enroll_x_lens.max().item()
|
| 545 |
+
# SOS + Synthesis Text + EOS
|
| 546 |
+
text = torch.concat(
|
| 547 |
+
[
|
| 548 |
+
text[:, :1],
|
| 549 |
+
text[:, enrolled_len - 1 :],
|
| 550 |
+
],
|
| 551 |
+
dim=1,
|
| 552 |
+
)
|
| 553 |
+
text_len = text_len - (enrolled_len - 2)
|
| 554 |
+
assert text.shape[0] == 1
|
| 555 |
+
|
| 556 |
+
x = self.nar_text_embedding(text)
|
| 557 |
+
x = self.nar_text_prenet(x)
|
| 558 |
+
x = self.nar_text_position(x)
|
| 559 |
+
|
| 560 |
+
if self.prefix_mode == 0:
|
| 561 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
| 562 |
+
zip(
|
| 563 |
+
self.nar_predict_layers,
|
| 564 |
+
self.nar_audio_embeddings[1:],
|
| 565 |
+
)
|
| 566 |
+
):
|
| 567 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
| 568 |
+
y_pos = self.nar_audio_position(y_pos)
|
| 569 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
| 570 |
+
|
| 571 |
+
xy_dec, _ = self.nar_decoder(
|
| 572 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
| 573 |
+
)
|
| 574 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
| 575 |
+
|
| 576 |
+
samples = torch.argmax(logits, dim=-1)
|
| 577 |
+
codes.append(samples)
|
| 578 |
+
|
| 579 |
+
if i < self.num_quantizers - 2:
|
| 580 |
+
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
|
| 581 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
| 582 |
+
else:
|
| 583 |
+
for j in range(1, self.num_quantizers):
|
| 584 |
+
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
|
| 585 |
+
|
| 586 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
| 587 |
+
zip(
|
| 588 |
+
self.nar_predict_layers,
|
| 589 |
+
self.nar_audio_embeddings[1:],
|
| 590 |
+
)
|
| 591 |
+
):
|
| 592 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
| 593 |
+
y_pos = self.nar_audio_position(y_pos)
|
| 594 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
| 595 |
+
|
| 596 |
+
xy_dec, _ = self.nar_decoder(
|
| 597 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
| 598 |
+
)
|
| 599 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
| 600 |
+
|
| 601 |
+
samples = torch.argmax(logits, dim=-1)
|
| 602 |
+
codes.append(samples)
|
| 603 |
+
|
| 604 |
+
if i < self.num_quantizers - 2:
|
| 605 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
| 606 |
+
|
| 607 |
+
assert len(codes) == self.num_quantizers
|
| 608 |
+
return torch.stack(codes, dim=-1)
|
| 609 |
+
|
| 610 |
+
def continual(
|
| 611 |
+
self,
|
| 612 |
+
x: torch.Tensor,
|
| 613 |
+
x_lens: torch.Tensor,
|
| 614 |
+
y: torch.Tensor,
|
| 615 |
+
) -> torch.Tensor:
|
| 616 |
+
"""
|
| 617 |
+
Args:
|
| 618 |
+
x:
|
| 619 |
+
A 2-D tensor of shape (1, S).
|
| 620 |
+
x_lens:
|
| 621 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
| 622 |
+
before padding.
|
| 623 |
+
y:
|
| 624 |
+
A 3-D tensor of shape (1, T, 8).
|
| 625 |
+
Returns:
|
| 626 |
+
Return the predicted audio code matrix.
|
| 627 |
+
"""
|
| 628 |
+
assert x.ndim == 2, x.shape
|
| 629 |
+
assert x_lens.ndim == 1, x_lens.shape
|
| 630 |
+
assert y.ndim == 3, y.shape
|
| 631 |
+
assert y.shape[0] == 1, y.shape
|
| 632 |
+
|
| 633 |
+
assert torch.all(x_lens > 0)
|
| 634 |
+
assert self.num_quantizers == 8
|
| 635 |
+
|
| 636 |
+
text = x
|
| 637 |
+
x = self.ar_text_embedding(text)
|
| 638 |
+
x = self.ar_text_prenet(x)
|
| 639 |
+
x = self.ar_text_position(x)
|
| 640 |
+
|
| 641 |
+
text_len = x_lens.max()
|
| 642 |
+
|
| 643 |
+
prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
|
| 644 |
+
|
| 645 |
+
# AR Decoder
|
| 646 |
+
prompts = y[:, :prefix_len]
|
| 647 |
+
|
| 648 |
+
codes = [y[:, prefix_len:, 0]]
|
| 649 |
+
# Non-AR Decoders
|
| 650 |
+
x = self.nar_text_embedding(text)
|
| 651 |
+
x = self.nar_text_prenet(x)
|
| 652 |
+
x = self.nar_text_position(x)
|
| 653 |
+
|
| 654 |
+
y_emb = self.nar_audio_embeddings[0](y[..., 0])
|
| 655 |
+
|
| 656 |
+
if self.prefix_mode == 0:
|
| 657 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
| 658 |
+
zip(
|
| 659 |
+
self.nar_predict_layers,
|
| 660 |
+
self.nar_audio_embeddings[1:],
|
| 661 |
+
)
|
| 662 |
+
):
|
| 663 |
+
y_pos = self.nar_audio_position(y_emb)
|
| 664 |
+
y_pos = self.nar_audio_prenet(y_pos)
|
| 665 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
| 666 |
+
|
| 667 |
+
xy_dec, _ = self.nar_decoder(
|
| 668 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
| 669 |
+
)
|
| 670 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
| 671 |
+
|
| 672 |
+
samples = torch.argmax(logits, dim=-1)
|
| 673 |
+
codes.append(samples)
|
| 674 |
+
|
| 675 |
+
if i < 6:
|
| 676 |
+
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
|
| 677 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
| 678 |
+
else:
|
| 679 |
+
for j in range(1, 8):
|
| 680 |
+
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
|
| 681 |
+
|
| 682 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
| 683 |
+
zip(
|
| 684 |
+
self.nar_predict_layers,
|
| 685 |
+
self.nar_audio_embeddings[1:],
|
| 686 |
+
)
|
| 687 |
+
):
|
| 688 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
| 689 |
+
y_pos = self.nar_audio_position(y_pos)
|
| 690 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
| 691 |
+
|
| 692 |
+
xy_dec, _ = self.nar_decoder(
|
| 693 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
| 694 |
+
)
|
| 695 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
| 696 |
+
|
| 697 |
+
samples = torch.argmax(logits, dim=-1)
|
| 698 |
+
codes.append(samples)
|
| 699 |
+
|
| 700 |
+
if i < 6:
|
| 701 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
| 702 |
+
|
| 703 |
+
assert len(codes) == 8
|
| 704 |
+
return torch.stack(codes, dim=-1)
|
| 705 |
+
|
| 706 |
+
def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
|
| 707 |
+
assert stage > 0
|
| 708 |
+
if stage == 1:
|
| 709 |
+
for name, param in self.named_parameters():
|
| 710 |
+
if name.startswith("ar_"):
|
| 711 |
+
yield param
|
| 712 |
+
|
| 713 |
+
if stage == 2:
|
| 714 |
+
for name, param in self.named_parameters():
|
| 715 |
+
if name.startswith("nar_"):
|
| 716 |
+
yield param
|
| 717 |
+
|
| 718 |
+
def stage_named_parameters(
|
| 719 |
+
self, stage: int = 1
|
| 720 |
+
) -> Iterator[Tuple[str, nn.Parameter]]:
|
| 721 |
+
assert stage > 0
|
| 722 |
+
if stage == 1:
|
| 723 |
+
for pair in self.named_parameters():
|
| 724 |
+
if pair[0].startswith("ar_"):
|
| 725 |
+
yield pair
|
| 726 |
+
|
| 727 |
+
if stage == 2:
|
| 728 |
+
for pair in self.named_parameters():
|
| 729 |
+
if pair[0].startswith("nar_"):
|
| 730 |
+
yield pair
|
| 731 |
+
|
| 732 |
+
def pad_y_eos(self, y, y_mask_int, eos_id):
|
| 733 |
+
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
|
| 734 |
+
y_mask_int, (0, 1), value=1
|
| 735 |
+
)
|
| 736 |
+
if self.ar_audio_prepend_bos:
|
| 737 |
+
return (
|
| 738 |
+
F.pad(targets[:, :-1], (1, 0), value=self.audio_token_num + 1),
|
| 739 |
+
targets,
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
return targets[:, :-1], targets[:, 1:]
|
| 743 |
+
|
| 744 |
+
def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes):
|
| 745 |
+
# 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
|
| 746 |
+
# from the same utterance.
|
| 747 |
+
# We implement this differently.
|
| 748 |
+
if self.prefix_mode == 0:
|
| 749 |
+
# no prefix
|
| 750 |
+
prefix_len = 0
|
| 751 |
+
y_emb = self.nar_audio_embeddings[0](y)
|
| 752 |
+
for j in range(1, nar_stage):
|
| 753 |
+
# Formula (4) (5)
|
| 754 |
+
y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
|
| 755 |
+
elif self.prefix_mode == 1:
|
| 756 |
+
# prefix at begining
|
| 757 |
+
int_low = (0.25 * y_lens.min()).type(torch.int64).item()
|
| 758 |
+
prefix_len = torch.randint(int_low, int_low * 2, size=()).item()
|
| 759 |
+
prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames
|
| 760 |
+
|
| 761 |
+
y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
|
| 762 |
+
y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
|
| 763 |
+
for j in range(1, self.num_quantizers):
|
| 764 |
+
y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j])
|
| 765 |
+
if j < nar_stage:
|
| 766 |
+
y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j])
|
| 767 |
+
y_emb = torch.concat([y_prompts, y_emb], axis=1)
|
| 768 |
+
elif self.prefix_mode in [2, 4]:
|
| 769 |
+
if self.prefix_mode == 2:
|
| 770 |
+
# random prefix
|
| 771 |
+
prefix_len = min(225, int(0.25 * y_lens.min().item()))
|
| 772 |
+
|
| 773 |
+
y_prompts_codes = []
|
| 774 |
+
for b in range(codes.shape[0]):
|
| 775 |
+
start = self.rng.randint(0, y_lens[b].item() - prefix_len)
|
| 776 |
+
y_prompts_codes.append(
|
| 777 |
+
torch.clone(codes[b, start : start + prefix_len])
|
| 778 |
+
)
|
| 779 |
+
codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS
|
| 780 |
+
y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
|
| 781 |
+
else:
|
| 782 |
+
prefix_len = y_prompts_codes.shape[1]
|
| 783 |
+
|
| 784 |
+
y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
|
| 785 |
+
y_emb = self.nar_audio_embeddings[0](y)
|
| 786 |
+
for j in range(1, self.num_quantizers):
|
| 787 |
+
y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j])
|
| 788 |
+
if j < nar_stage:
|
| 789 |
+
y_emb += self.nar_audio_embeddings[j](codes[..., j])
|
| 790 |
+
y_emb = torch.concat([y_prompts, y_emb], axis=1)
|
| 791 |
+
else:
|
| 792 |
+
raise ValueError
|
| 793 |
+
|
| 794 |
+
return y_emb, prefix_len
|
Amphion/models/tts/valle/valle_inference.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torchaudio
|
| 10 |
+
import argparse
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
from text.g2p_module import G2PModule
|
| 14 |
+
from utils.tokenizer import AudioTokenizer, tokenize_audio
|
| 15 |
+
from models.tts.valle.valle import VALLE
|
| 16 |
+
from models.tts.base.tts_inferece import TTSInference
|
| 17 |
+
from models.tts.valle.valle_dataset import VALLETestDataset, VALLETestCollator
|
| 18 |
+
from processors.phone_extractor import phoneExtractor
|
| 19 |
+
from text.text_token_collation import phoneIDCollation
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class VALLEInference(TTSInference):
|
| 23 |
+
def __init__(self, args=None, cfg=None):
|
| 24 |
+
TTSInference.__init__(self, args, cfg)
|
| 25 |
+
|
| 26 |
+
self.g2p_module = G2PModule(backend=self.cfg.preprocess.phone_extractor)
|
| 27 |
+
text_token_path = os.path.join(
|
| 28 |
+
cfg.preprocess.processed_dir, cfg.dataset[0], cfg.preprocess.symbols_dict
|
| 29 |
+
)
|
| 30 |
+
self.audio_tokenizer = AudioTokenizer()
|
| 31 |
+
|
| 32 |
+
def _build_model(self):
|
| 33 |
+
model = VALLE(self.cfg.model)
|
| 34 |
+
return model
|
| 35 |
+
|
| 36 |
+
def _build_test_dataset(self):
|
| 37 |
+
return VALLETestDataset, VALLETestCollator
|
| 38 |
+
|
| 39 |
+
def inference_one_clip(self, text, text_prompt, audio_file, save_name="pred"):
|
| 40 |
+
# get phone symbol file
|
| 41 |
+
phone_symbol_file = None
|
| 42 |
+
if self.cfg.preprocess.phone_extractor != "lexicon":
|
| 43 |
+
phone_symbol_file = os.path.join(
|
| 44 |
+
self.exp_dir, self.cfg.preprocess.symbols_dict
|
| 45 |
+
)
|
| 46 |
+
assert os.path.exists(phone_symbol_file)
|
| 47 |
+
# convert text to phone sequence
|
| 48 |
+
phone_extractor = phoneExtractor(self.cfg)
|
| 49 |
+
# convert phone sequence to phone id sequence
|
| 50 |
+
phon_id_collator = phoneIDCollation(
|
| 51 |
+
self.cfg, symbols_dict_file=phone_symbol_file
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
text = f"{text_prompt} {text}".strip()
|
| 55 |
+
phone_seq = phone_extractor.extract_phone(text) # phone_seq: list
|
| 56 |
+
phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq)
|
| 57 |
+
phone_id_seq_len = torch.IntTensor([len(phone_id_seq)]).to(self.device)
|
| 58 |
+
|
| 59 |
+
# convert phone sequence to phone id sequence
|
| 60 |
+
phone_id_seq = np.array([phone_id_seq])
|
| 61 |
+
phone_id_seq = torch.from_numpy(phone_id_seq).to(self.device)
|
| 62 |
+
|
| 63 |
+
# extract acoustic token
|
| 64 |
+
encoded_frames = tokenize_audio(self.audio_tokenizer, audio_file)
|
| 65 |
+
audio_prompt_token = encoded_frames[0][0].transpose(2, 1).to(self.device)
|
| 66 |
+
|
| 67 |
+
# copysyn
|
| 68 |
+
if self.args.copysyn:
|
| 69 |
+
samples = self.audio_tokenizer.decode(encoded_frames)
|
| 70 |
+
audio_copysyn = samples[0].cpu().detach()
|
| 71 |
+
|
| 72 |
+
out_path = os.path.join(
|
| 73 |
+
self.args.output_dir, self.infer_type, f"{save_name}_copysyn.wav"
|
| 74 |
+
)
|
| 75 |
+
torchaudio.save(out_path, audio_copysyn, self.cfg.preprocess.sampling_rate)
|
| 76 |
+
|
| 77 |
+
if self.args.continual:
|
| 78 |
+
encoded_frames = self.model.continual(
|
| 79 |
+
phone_id_seq,
|
| 80 |
+
phone_id_seq_len,
|
| 81 |
+
audio_prompt_token,
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
enroll_x_lens = None
|
| 85 |
+
if text_prompt:
|
| 86 |
+
# prompt_phone_seq = tokenize_text(self.g2p_module, text=f"{text_prompt}".strip())
|
| 87 |
+
# _, enroll_x_lens = self.text_tokenizer.get_token_id_seq(prompt_phone_seq)
|
| 88 |
+
|
| 89 |
+
text = f"{text_prompt}".strip()
|
| 90 |
+
prompt_phone_seq = phone_extractor.extract_phone(
|
| 91 |
+
text
|
| 92 |
+
) # phone_seq: list
|
| 93 |
+
prompt_phone_id_seq = phon_id_collator.get_phone_id_sequence(
|
| 94 |
+
self.cfg, prompt_phone_seq
|
| 95 |
+
)
|
| 96 |
+
prompt_phone_id_seq_len = torch.IntTensor(
|
| 97 |
+
[len(prompt_phone_id_seq)]
|
| 98 |
+
).to(self.device)
|
| 99 |
+
|
| 100 |
+
encoded_frames = self.model.inference(
|
| 101 |
+
phone_id_seq,
|
| 102 |
+
phone_id_seq_len,
|
| 103 |
+
audio_prompt_token,
|
| 104 |
+
enroll_x_lens=prompt_phone_id_seq_len,
|
| 105 |
+
top_k=self.args.top_k,
|
| 106 |
+
temperature=self.args.temperature,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
samples = self.audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)])
|
| 110 |
+
|
| 111 |
+
audio = samples[0].squeeze(0).cpu().detach()
|
| 112 |
+
|
| 113 |
+
return audio
|
| 114 |
+
|
| 115 |
+
def inference_for_single_utterance(self):
|
| 116 |
+
text = self.args.text
|
| 117 |
+
text_prompt = self.args.text_prompt
|
| 118 |
+
audio_file = self.args.audio_prompt
|
| 119 |
+
|
| 120 |
+
if not self.args.continual:
|
| 121 |
+
assert text != ""
|
| 122 |
+
else:
|
| 123 |
+
text = ""
|
| 124 |
+
assert text_prompt != ""
|
| 125 |
+
assert audio_file != ""
|
| 126 |
+
|
| 127 |
+
audio = self.inference_one_clip(text, text_prompt, audio_file)
|
| 128 |
+
|
| 129 |
+
return audio
|
| 130 |
+
|
| 131 |
+
def inference_for_batches(self):
|
| 132 |
+
test_list_file = self.args.test_list_file
|
| 133 |
+
assert test_list_file is not None
|
| 134 |
+
|
| 135 |
+
pred_res = []
|
| 136 |
+
with open(test_list_file, "r") as fin:
|
| 137 |
+
for idx, line in enumerate(fin.readlines()):
|
| 138 |
+
fields = line.strip().split("|")
|
| 139 |
+
if self.args.continual:
|
| 140 |
+
assert len(fields) == 2
|
| 141 |
+
text_prompt, audio_prompt_path = fields
|
| 142 |
+
text = ""
|
| 143 |
+
else:
|
| 144 |
+
assert len(fields) == 3
|
| 145 |
+
text_prompt, audio_prompt_path, text = fields
|
| 146 |
+
|
| 147 |
+
audio = self.inference_one_clip(
|
| 148 |
+
text, text_prompt, audio_prompt_path, str(idx)
|
| 149 |
+
)
|
| 150 |
+
pred_res.append(audio)
|
| 151 |
+
|
| 152 |
+
return pred_res
|
| 153 |
+
|
| 154 |
+
"""
|
| 155 |
+
TODO: batch inference
|
| 156 |
+
###### Construct test_batch ######
|
| 157 |
+
n_batch = len(self.test_dataloader)
|
| 158 |
+
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
|
| 159 |
+
print(
|
| 160 |
+
"Model eval time: {}, batch_size = {}, n_batch = {}".format(
|
| 161 |
+
now, self.test_batch_size, n_batch
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
###### Inference for each batch ######
|
| 166 |
+
pred_res = []
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
for i, batch_data in enumerate(
|
| 169 |
+
self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader)
|
| 170 |
+
):
|
| 171 |
+
if self.args.continual:
|
| 172 |
+
encoded_frames = self.model.continual(
|
| 173 |
+
batch_data["phone_seq"],
|
| 174 |
+
batch_data["phone_len"],
|
| 175 |
+
batch_data["acoustic_token"],
|
| 176 |
+
)
|
| 177 |
+
else:
|
| 178 |
+
encoded_frames = self.model.inference(
|
| 179 |
+
batch_data["phone_seq"],
|
| 180 |
+
batch_data["phone_len"],
|
| 181 |
+
batch_data["acoustic_token"],
|
| 182 |
+
enroll_x_lens=batch_data["pmt_phone_len"],
|
| 183 |
+
top_k=self.args.top_k,
|
| 184 |
+
temperature=self.args.temperature
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
samples = self.audio_tokenizer.decode(
|
| 188 |
+
[(encoded_frames.transpose(2, 1), None)]
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
for idx in range(samples.size(0)):
|
| 193 |
+
audio = samples[idx].cpu()
|
| 194 |
+
pred_res.append(audio)
|
| 195 |
+
|
| 196 |
+
return pred_res
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def add_arguments(parser: argparse.ArgumentParser):
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--text_prompt",
|
| 202 |
+
type=str,
|
| 203 |
+
default="",
|
| 204 |
+
help="Text prompt that should be aligned with --audio_prompt.",
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
parser.add_argument(
|
| 208 |
+
"--audio_prompt",
|
| 209 |
+
type=str,
|
| 210 |
+
default="",
|
| 211 |
+
help="Audio prompt that should be aligned with --text_prompt.",
|
| 212 |
+
)
|
| 213 |
+
parser.add_argument(
|
| 214 |
+
"--top-k",
|
| 215 |
+
type=int,
|
| 216 |
+
default=-100,
|
| 217 |
+
help="Whether AR Decoder do top_k(if > 0) sampling.",
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
parser.add_argument(
|
| 221 |
+
"--temperature",
|
| 222 |
+
type=float,
|
| 223 |
+
default=1.0,
|
| 224 |
+
help="The temperature of AR Decoder top_k sampling.",
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
parser.add_argument(
|
| 228 |
+
"--continual",
|
| 229 |
+
action="store_true",
|
| 230 |
+
help="Inference for continual task.",
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
parser.add_argument(
|
| 234 |
+
"--copysyn",
|
| 235 |
+
action="store_true",
|
| 236 |
+
help="Copysyn: generate audio by decoder of the original audio tokenizer.",
|
| 237 |
+
)
|
Amphion/models/tts/valle/valle_trainer.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 12 |
+
from optimizer.optimizers import Eve, ScaledAdam
|
| 13 |
+
from schedulers.scheduler import NoamScheduler, Eden
|
| 14 |
+
from models.tts.valle.valle_dataset import (
|
| 15 |
+
VALLEDataset,
|
| 16 |
+
VALLECollator,
|
| 17 |
+
batch_by_size,
|
| 18 |
+
)
|
| 19 |
+
from models.base.base_sampler import VariableSampler
|
| 20 |
+
from models.tts.base import TTSTrainer
|
| 21 |
+
from models.tts.valle.valle import VALLE
|
| 22 |
+
import diffusers
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class VALLETrainer(TTSTrainer):
|
| 26 |
+
def __init__(self, args, cfg):
|
| 27 |
+
TTSTrainer.__init__(self, args, cfg)
|
| 28 |
+
|
| 29 |
+
def _build_model(self):
|
| 30 |
+
model = VALLE(self.cfg.model)
|
| 31 |
+
|
| 32 |
+
return model
|
| 33 |
+
|
| 34 |
+
def _build_dataset(self):
|
| 35 |
+
return VALLEDataset, VALLECollator
|
| 36 |
+
|
| 37 |
+
def _build_optimizer(self):
|
| 38 |
+
if self.args.train_stage:
|
| 39 |
+
if isinstance(self.model, DistributedDataParallel):
|
| 40 |
+
model = self.model.module
|
| 41 |
+
else:
|
| 42 |
+
model = self.model
|
| 43 |
+
model_parameters = model.stage_parameters(self.args.train_stage)
|
| 44 |
+
else:
|
| 45 |
+
model_parameters = self.model.parameters()
|
| 46 |
+
|
| 47 |
+
if self.cfg.train.optimizer == "ScaledAdam":
|
| 48 |
+
parameters_names = []
|
| 49 |
+
if self.args.train_stage != 0:
|
| 50 |
+
parameters_names.append(
|
| 51 |
+
[
|
| 52 |
+
name_param_pair[0]
|
| 53 |
+
for name_param_pair in model.stage_named_parameters(
|
| 54 |
+
self.args.train_stage
|
| 55 |
+
)
|
| 56 |
+
]
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
parameters_names.append(
|
| 60 |
+
[name_param_pair[0] for name_param_pair in model.named_parameters()]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
optimizer = ScaledAdam(
|
| 64 |
+
model_parameters,
|
| 65 |
+
lr=self.cfg.train.base_lr,
|
| 66 |
+
betas=(0.9, 0.95),
|
| 67 |
+
clipping_scale=2.0,
|
| 68 |
+
parameters_names=parameters_names,
|
| 69 |
+
show_dominant_parameters=False,
|
| 70 |
+
clipping_update_period=1000,
|
| 71 |
+
)
|
| 72 |
+
elif self.cfg.train.optimizer == "Eve":
|
| 73 |
+
optimizer = Eve(
|
| 74 |
+
model_parameters,
|
| 75 |
+
lr=self.cfg.train.base_lr,
|
| 76 |
+
betas=(0.9, 0.98),
|
| 77 |
+
target_rms=0.1,
|
| 78 |
+
)
|
| 79 |
+
elif self.cfg.train.optimizer == "AdamW":
|
| 80 |
+
optimizer = torch.optim.AdamW(
|
| 81 |
+
model_parameters,
|
| 82 |
+
lr=self.cfg.train.base_lr,
|
| 83 |
+
betas=(0.9, 0.95),
|
| 84 |
+
weight_decay=1e-2,
|
| 85 |
+
eps=1e-8,
|
| 86 |
+
)
|
| 87 |
+
elif self.cfg.train.optimizer == "Adam":
|
| 88 |
+
optimizer = torch.optim.Adam(
|
| 89 |
+
model_parameters,
|
| 90 |
+
lr=self.cfg.train.base_lr,
|
| 91 |
+
betas=(0.9, 0.95),
|
| 92 |
+
eps=1e-8,
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
raise NotImplementedError()
|
| 96 |
+
|
| 97 |
+
return optimizer
|
| 98 |
+
|
| 99 |
+
def _build_scheduler(self):
|
| 100 |
+
if self.cfg.train.scheduler.lower() == "eden":
|
| 101 |
+
scheduler = Eden(
|
| 102 |
+
self.optimizer, 5000, 4, warmup_batches=self.cfg.train.warmup_steps
|
| 103 |
+
)
|
| 104 |
+
elif self.cfg.train.scheduler.lower() == "noam":
|
| 105 |
+
scheduler = NoamScheduler(
|
| 106 |
+
self.cfg.train.base_lr,
|
| 107 |
+
self.optimizer,
|
| 108 |
+
self.cfg.model.decoder_dim,
|
| 109 |
+
warmup_steps=self.cfg.train.warmup_steps,
|
| 110 |
+
)
|
| 111 |
+
elif self.cfg.train.scheduler.lower() == "cosine":
|
| 112 |
+
from diffusers.optimization import get_cosine_schedule_with_warmup
|
| 113 |
+
|
| 114 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 115 |
+
self.optimizer,
|
| 116 |
+
num_warmup_steps=self.cfg.train.warmup_steps
|
| 117 |
+
* self.accelerator.num_processes,
|
| 118 |
+
num_training_steps=self.cfg.train.total_training_steps
|
| 119 |
+
* self.accelerator.num_processes,
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
raise NotImplementedError(f"{self.cfg.train.scheduler}")
|
| 123 |
+
|
| 124 |
+
return scheduler
|
| 125 |
+
|
| 126 |
+
def _train_epoch(self):
|
| 127 |
+
r"""Training epoch. Should return average loss of a batch (sample) over
|
| 128 |
+
one epoch. See ``train_loop`` for usage.
|
| 129 |
+
"""
|
| 130 |
+
if isinstance(self.model, dict):
|
| 131 |
+
for key in self.model.keys():
|
| 132 |
+
self.model[key].train()
|
| 133 |
+
else:
|
| 134 |
+
self.model.train()
|
| 135 |
+
|
| 136 |
+
epoch_sum_loss: float = 0.0
|
| 137 |
+
epoch_losses: dict = {}
|
| 138 |
+
epoch_step: int = 0
|
| 139 |
+
for batch in tqdm(
|
| 140 |
+
self.train_dataloader,
|
| 141 |
+
desc=f"Training Epoch {self.epoch}",
|
| 142 |
+
unit="batch",
|
| 143 |
+
colour="GREEN",
|
| 144 |
+
leave=False,
|
| 145 |
+
dynamic_ncols=True,
|
| 146 |
+
smoothing=0.04,
|
| 147 |
+
disable=not self.accelerator.is_main_process,
|
| 148 |
+
):
|
| 149 |
+
# Do training step and BP
|
| 150 |
+
with self.accelerator.accumulate(self.model):
|
| 151 |
+
total_loss, train_losses = self._train_step(batch)
|
| 152 |
+
self.accelerator.backward(total_loss)
|
| 153 |
+
self.optimizer.step()
|
| 154 |
+
self.optimizer.zero_grad()
|
| 155 |
+
self.batch_count += 1
|
| 156 |
+
|
| 157 |
+
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
|
| 158 |
+
if self.cfg.train.optimizer not in ["ScaledAdam", "Eve"]:
|
| 159 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
| 160 |
+
|
| 161 |
+
for k in range(self.cfg.train.gradient_accumulation_step):
|
| 162 |
+
if isinstance(self.scheduler, Eden):
|
| 163 |
+
self.scheduler.step_batch(self.step)
|
| 164 |
+
else:
|
| 165 |
+
self.scheduler.step()
|
| 166 |
+
|
| 167 |
+
epoch_sum_loss += total_loss.detach().cpu().item()
|
| 168 |
+
|
| 169 |
+
if isinstance(train_losses, dict):
|
| 170 |
+
for key, value in train_losses.items():
|
| 171 |
+
if key not in epoch_losses.keys():
|
| 172 |
+
epoch_losses[key] = value
|
| 173 |
+
else:
|
| 174 |
+
epoch_losses[key] += value
|
| 175 |
+
|
| 176 |
+
if isinstance(train_losses, dict):
|
| 177 |
+
for key, loss in train_losses.items():
|
| 178 |
+
self.accelerator.log(
|
| 179 |
+
{"Step/Train {}".format(key): "{:.6f}".format(loss)},
|
| 180 |
+
step=self.step,
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
self.accelerator.log(
|
| 184 |
+
{"Step/Train Loss": loss},
|
| 185 |
+
step=self.step,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.accelerator.log(
|
| 189 |
+
{"Step/lr": self.scheduler.get_last_lr()[0]},
|
| 190 |
+
step=self.step,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# print loss every log_epoch_step steps
|
| 194 |
+
# if epoch_step % self.cfg.train.log_epoch_step == 0:
|
| 195 |
+
# for key, loss in train_losses.items():
|
| 196 |
+
# self.logger.info("Step/Train {}: {:.6f}".format(key, loss))
|
| 197 |
+
# print("Step/Train {}: {:.6f}".format(key, loss))
|
| 198 |
+
|
| 199 |
+
self.step += 1
|
| 200 |
+
epoch_step += 1
|
| 201 |
+
|
| 202 |
+
self.accelerator.wait_for_everyone()
|
| 203 |
+
|
| 204 |
+
epoch_sum_loss = (
|
| 205 |
+
epoch_sum_loss
|
| 206 |
+
/ len(self.train_dataloader)
|
| 207 |
+
* self.cfg.train.gradient_accumulation_step
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
for key in epoch_losses.keys():
|
| 211 |
+
epoch_losses[key] = (
|
| 212 |
+
epoch_losses[key]
|
| 213 |
+
/ len(self.train_dataloader)
|
| 214 |
+
* self.cfg.train.gradient_accumulation_step
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
return epoch_sum_loss, epoch_losses
|
| 218 |
+
|
| 219 |
+
def _train_step(self, batch, is_training=True):
|
| 220 |
+
text_tokens = batch["phone_seq"].to(self.device)
|
| 221 |
+
text_tokens_lens = batch["phone_len"].to(self.device)
|
| 222 |
+
assert text_tokens.ndim == 2
|
| 223 |
+
|
| 224 |
+
audio_features = batch["acoustic_token"].to(self.device)
|
| 225 |
+
audio_features_lens = batch["target_len"].to(self.device)
|
| 226 |
+
assert audio_features.ndim == 3
|
| 227 |
+
|
| 228 |
+
with torch.set_grad_enabled(is_training):
|
| 229 |
+
loss, losses = self.model(
|
| 230 |
+
x=text_tokens,
|
| 231 |
+
x_lens=text_tokens_lens,
|
| 232 |
+
y=audio_features,
|
| 233 |
+
y_lens=audio_features_lens,
|
| 234 |
+
train_stage=self.args.train_stage,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
assert loss.requires_grad == is_training
|
| 238 |
+
|
| 239 |
+
loss_dict = {}
|
| 240 |
+
frames_sum = (audio_features_lens).sum()
|
| 241 |
+
|
| 242 |
+
avg_loss = loss / frames_sum
|
| 243 |
+
|
| 244 |
+
loss_dict["loss"] = avg_loss.detach().cpu().item()
|
| 245 |
+
for l in losses:
|
| 246 |
+
loss_dict[l] = losses[l].detach().cpu().item() / frames_sum.item()
|
| 247 |
+
|
| 248 |
+
return avg_loss, loss_dict
|
| 249 |
+
|
| 250 |
+
def _valid_step(self, batch):
|
| 251 |
+
valid_losses = {}
|
| 252 |
+
total_loss = 0
|
| 253 |
+
valid_stats = {}
|
| 254 |
+
|
| 255 |
+
total_loss, valid_losses = self._train_step(
|
| 256 |
+
batch=batch,
|
| 257 |
+
is_training=False,
|
| 258 |
+
)
|
| 259 |
+
assert total_loss.requires_grad is False
|
| 260 |
+
|
| 261 |
+
total_loss = total_loss.detach().cpu().item()
|
| 262 |
+
|
| 263 |
+
return total_loss, valid_losses, valid_stats
|
| 264 |
+
|
| 265 |
+
def _build_dataloader(self):
|
| 266 |
+
if not self.cfg.train.use_dynamic_batchsize:
|
| 267 |
+
return super()._build_dataloader()
|
| 268 |
+
if len(self.cfg.dataset) > 1:
|
| 269 |
+
raise Exception("use_dynamic_batchsize only supports single dataset now.")
|
| 270 |
+
Dataset, Collator = self._build_dataset()
|
| 271 |
+
train_dataset = Dataset(
|
| 272 |
+
self.cfg, self.cfg.dataset[0], is_valid=False
|
| 273 |
+
) # TODO: support use_dynamic_batchsize for more than one datasets.
|
| 274 |
+
train_collate = Collator(self.cfg)
|
| 275 |
+
batch_sampler = batch_by_size(
|
| 276 |
+
train_dataset.num_frame_indices,
|
| 277 |
+
train_dataset.get_num_frames,
|
| 278 |
+
max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
|
| 279 |
+
max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes,
|
| 280 |
+
required_batch_size_multiple=self.accelerator.num_processes,
|
| 281 |
+
)
|
| 282 |
+
np.random.seed(1234)
|
| 283 |
+
np.random.shuffle(batch_sampler)
|
| 284 |
+
print(batch_sampler[:1])
|
| 285 |
+
batches = [
|
| 286 |
+
x[self.accelerator.local_process_index :: self.accelerator.num_processes]
|
| 287 |
+
for x in batch_sampler
|
| 288 |
+
if len(x) % self.accelerator.num_processes == 0
|
| 289 |
+
]
|
| 290 |
+
|
| 291 |
+
train_loader = DataLoader(
|
| 292 |
+
train_dataset,
|
| 293 |
+
collate_fn=train_collate,
|
| 294 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
| 295 |
+
batch_sampler=VariableSampler(
|
| 296 |
+
batches, drop_last=False, use_random_sampler=True
|
| 297 |
+
),
|
| 298 |
+
pin_memory=False,
|
| 299 |
+
)
|
| 300 |
+
self.accelerator.wait_for_everyone()
|
| 301 |
+
|
| 302 |
+
valid_dataset = Dataset(self.cfg, self.cfg.dataset[0], is_valid=True)
|
| 303 |
+
valid_collate = Collator(self.cfg)
|
| 304 |
+
batch_sampler = batch_by_size(
|
| 305 |
+
valid_dataset.num_frame_indices,
|
| 306 |
+
valid_dataset.get_num_frames,
|
| 307 |
+
max_tokens=self.cfg.train.max_tokens * self.accelerator.num_processes,
|
| 308 |
+
max_sentences=self.cfg.train.max_sentences * self.accelerator.num_processes,
|
| 309 |
+
required_batch_size_multiple=self.accelerator.num_processes,
|
| 310 |
+
)
|
| 311 |
+
batches = [
|
| 312 |
+
x[self.accelerator.local_process_index :: self.accelerator.num_processes]
|
| 313 |
+
for x in batch_sampler
|
| 314 |
+
if len(x) % self.accelerator.num_processes == 0
|
| 315 |
+
]
|
| 316 |
+
valid_loader = DataLoader(
|
| 317 |
+
valid_dataset,
|
| 318 |
+
collate_fn=valid_collate,
|
| 319 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
| 320 |
+
batch_sampler=VariableSampler(batches, drop_last=False),
|
| 321 |
+
pin_memory=False,
|
| 322 |
+
)
|
| 323 |
+
self.accelerator.wait_for_everyone()
|
| 324 |
+
|
| 325 |
+
return train_loader, valid_loader
|
| 326 |
+
|
| 327 |
+
def _accelerator_prepare(self):
|
| 328 |
+
if not self.cfg.train.use_dynamic_batchsize:
|
| 329 |
+
(
|
| 330 |
+
self.train_dataloader,
|
| 331 |
+
self.valid_dataloader,
|
| 332 |
+
) = self.accelerator.prepare(
|
| 333 |
+
self.train_dataloader,
|
| 334 |
+
self.valid_dataloader,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if isinstance(self.model, dict):
|
| 338 |
+
for key in self.model.keys():
|
| 339 |
+
self.model[key] = self.accelerator.prepare(self.model[key])
|
| 340 |
+
else:
|
| 341 |
+
self.model = self.accelerator.prepare(self.model)
|
| 342 |
+
|
| 343 |
+
if isinstance(self.optimizer, dict):
|
| 344 |
+
for key in self.optimizer.keys():
|
| 345 |
+
self.optimizer[key] = self.accelerator.prepare(self.optimizer[key])
|
| 346 |
+
else:
|
| 347 |
+
self.optimizer = self.accelerator.prepare(self.optimizer)
|
| 348 |
+
|
| 349 |
+
if isinstance(self.scheduler, dict):
|
| 350 |
+
for key in self.scheduler.keys():
|
| 351 |
+
self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
|
| 352 |
+
else:
|
| 353 |
+
self.scheduler = self.accelerator.prepare(self.scheduler)
|
| 354 |
+
|
| 355 |
+
def add_arguments(parser: argparse.ArgumentParser):
|
| 356 |
+
parser.add_argument(
|
| 357 |
+
"--train_stage",
|
| 358 |
+
type=int,
|
| 359 |
+
default="1",
|
| 360 |
+
help="0: train all modules, 1: AR Decoder, 2: NAR Decoder",
|
| 361 |
+
)
|
| 362 |
+
parser.add_argument(
|
| 363 |
+
"--ar_model_ckpt_dir",
|
| 364 |
+
type=str,
|
| 365 |
+
default=None,
|
| 366 |
+
help="Checkpoint for ar model ckeckpoint in the first training stage.",
|
| 367 |
+
)
|
Amphion/models/tts/vits/__init__.py
ADDED
|
File without changes
|
Amphion/models/tts/vits/vits.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# This code is modified from https://github.com/jaywalnut310/vits/blob/main/models.py
|
| 7 |
+
import math
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
|
| 12 |
+
from utils.util import *
|
| 13 |
+
from modules.flow.modules import *
|
| 14 |
+
from modules.base.base_module import *
|
| 15 |
+
from modules.transformer.attentions import Encoder
|
| 16 |
+
from modules.duration_predictor.standard_duration_predictor import DurationPredictor
|
| 17 |
+
from modules.duration_predictor.stochastic_duration_predictor import (
|
| 18 |
+
StochasticDurationPredictor,
|
| 19 |
+
)
|
| 20 |
+
from models.vocoders.gan.generator.hifigan import HiFiGAN_vits as Generator
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from modules import monotonic_align
|
| 24 |
+
except ImportError:
|
| 25 |
+
print("Monotonic align not found. Please make sure you have compiled it.")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TextEncoder(nn.Module):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
n_vocab,
|
| 32 |
+
out_channels,
|
| 33 |
+
hidden_channels,
|
| 34 |
+
filter_channels,
|
| 35 |
+
n_heads,
|
| 36 |
+
n_layers,
|
| 37 |
+
kernel_size,
|
| 38 |
+
p_dropout,
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.n_vocab = n_vocab
|
| 42 |
+
self.out_channels = out_channels
|
| 43 |
+
self.hidden_channels = hidden_channels
|
| 44 |
+
self.filter_channels = filter_channels
|
| 45 |
+
self.n_heads = n_heads
|
| 46 |
+
self.n_layers = n_layers
|
| 47 |
+
self.kernel_size = kernel_size
|
| 48 |
+
self.p_dropout = p_dropout
|
| 49 |
+
|
| 50 |
+
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
| 51 |
+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
| 52 |
+
|
| 53 |
+
self.encoder = Encoder(
|
| 54 |
+
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
| 55 |
+
)
|
| 56 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 57 |
+
|
| 58 |
+
def forward(self, x, x_lengths):
|
| 59 |
+
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
| 60 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
| 61 |
+
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
| 62 |
+
|
| 63 |
+
x = self.encoder(x * x_mask, x_mask)
|
| 64 |
+
stats = self.proj(x) * x_mask
|
| 65 |
+
|
| 66 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 67 |
+
return x, m, logs, x_mask
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class ResidualCouplingBlock(nn.Module):
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
channels,
|
| 74 |
+
hidden_channels,
|
| 75 |
+
kernel_size,
|
| 76 |
+
dilation_rate,
|
| 77 |
+
n_layers,
|
| 78 |
+
n_flows=4,
|
| 79 |
+
gin_channels=0,
|
| 80 |
+
):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.channels = channels
|
| 83 |
+
self.hidden_channels = hidden_channels
|
| 84 |
+
self.kernel_size = kernel_size
|
| 85 |
+
self.dilation_rate = dilation_rate
|
| 86 |
+
self.n_layers = n_layers
|
| 87 |
+
self.n_flows = n_flows
|
| 88 |
+
self.gin_channels = gin_channels
|
| 89 |
+
|
| 90 |
+
self.flows = nn.ModuleList()
|
| 91 |
+
for i in range(n_flows):
|
| 92 |
+
self.flows.append(
|
| 93 |
+
ResidualCouplingLayer(
|
| 94 |
+
channels,
|
| 95 |
+
hidden_channels,
|
| 96 |
+
kernel_size,
|
| 97 |
+
dilation_rate,
|
| 98 |
+
n_layers,
|
| 99 |
+
gin_channels=gin_channels,
|
| 100 |
+
mean_only=True,
|
| 101 |
+
)
|
| 102 |
+
)
|
| 103 |
+
self.flows.append(Flip())
|
| 104 |
+
|
| 105 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
| 106 |
+
if not reverse:
|
| 107 |
+
for flow in self.flows:
|
| 108 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
| 109 |
+
else:
|
| 110 |
+
for flow in reversed(self.flows):
|
| 111 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class PosteriorEncoder(nn.Module):
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
in_channels,
|
| 119 |
+
out_channels,
|
| 120 |
+
hidden_channels,
|
| 121 |
+
kernel_size,
|
| 122 |
+
dilation_rate,
|
| 123 |
+
n_layers,
|
| 124 |
+
gin_channels=0,
|
| 125 |
+
):
|
| 126 |
+
super().__init__()
|
| 127 |
+
self.in_channels = in_channels
|
| 128 |
+
self.out_channels = out_channels
|
| 129 |
+
self.hidden_channels = hidden_channels
|
| 130 |
+
self.kernel_size = kernel_size
|
| 131 |
+
self.dilation_rate = dilation_rate
|
| 132 |
+
self.n_layers = n_layers
|
| 133 |
+
self.gin_channels = gin_channels
|
| 134 |
+
|
| 135 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
| 136 |
+
self.enc = WN(
|
| 137 |
+
hidden_channels,
|
| 138 |
+
kernel_size,
|
| 139 |
+
dilation_rate,
|
| 140 |
+
n_layers,
|
| 141 |
+
gin_channels=gin_channels,
|
| 142 |
+
)
|
| 143 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 144 |
+
|
| 145 |
+
def forward(self, x, x_lengths, g=None):
|
| 146 |
+
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
| 147 |
+
x = self.pre(x) * x_mask
|
| 148 |
+
x = self.enc(x, x_mask, g=g)
|
| 149 |
+
stats = self.proj(x) * x_mask
|
| 150 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 151 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
| 152 |
+
return z, m, logs, x_mask
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class SynthesizerTrn(nn.Module):
|
| 156 |
+
"""
|
| 157 |
+
Synthesizer for Training
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
n_vocab,
|
| 163 |
+
spec_channels,
|
| 164 |
+
segment_size,
|
| 165 |
+
inter_channels,
|
| 166 |
+
hidden_channels,
|
| 167 |
+
filter_channels,
|
| 168 |
+
n_heads,
|
| 169 |
+
n_layers,
|
| 170 |
+
kernel_size,
|
| 171 |
+
p_dropout,
|
| 172 |
+
resblock,
|
| 173 |
+
resblock_kernel_sizes,
|
| 174 |
+
resblock_dilation_sizes,
|
| 175 |
+
upsample_rates,
|
| 176 |
+
upsample_initial_channel,
|
| 177 |
+
upsample_kernel_sizes,
|
| 178 |
+
n_speakers=0,
|
| 179 |
+
gin_channels=0,
|
| 180 |
+
use_sdp=True,
|
| 181 |
+
**kwargs,
|
| 182 |
+
):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.n_vocab = n_vocab
|
| 185 |
+
self.spec_channels = spec_channels
|
| 186 |
+
self.inter_channels = inter_channels
|
| 187 |
+
self.hidden_channels = hidden_channels
|
| 188 |
+
self.filter_channels = filter_channels
|
| 189 |
+
self.n_heads = n_heads
|
| 190 |
+
self.n_layers = n_layers
|
| 191 |
+
self.kernel_size = kernel_size
|
| 192 |
+
self.p_dropout = p_dropout
|
| 193 |
+
self.resblock = resblock
|
| 194 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 195 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 196 |
+
self.upsample_rates = upsample_rates
|
| 197 |
+
self.upsample_initial_channel = upsample_initial_channel
|
| 198 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 199 |
+
self.segment_size = segment_size
|
| 200 |
+
self.n_speakers = n_speakers
|
| 201 |
+
self.gin_channels = gin_channels
|
| 202 |
+
|
| 203 |
+
self.use_sdp = use_sdp
|
| 204 |
+
|
| 205 |
+
self.enc_p = TextEncoder(
|
| 206 |
+
n_vocab,
|
| 207 |
+
inter_channels,
|
| 208 |
+
hidden_channels,
|
| 209 |
+
filter_channels,
|
| 210 |
+
n_heads,
|
| 211 |
+
n_layers,
|
| 212 |
+
kernel_size,
|
| 213 |
+
p_dropout,
|
| 214 |
+
)
|
| 215 |
+
self.dec = Generator(
|
| 216 |
+
inter_channels,
|
| 217 |
+
resblock,
|
| 218 |
+
resblock_kernel_sizes,
|
| 219 |
+
resblock_dilation_sizes,
|
| 220 |
+
upsample_rates,
|
| 221 |
+
upsample_initial_channel,
|
| 222 |
+
upsample_kernel_sizes,
|
| 223 |
+
gin_channels=gin_channels,
|
| 224 |
+
)
|
| 225 |
+
self.enc_q = PosteriorEncoder(
|
| 226 |
+
spec_channels,
|
| 227 |
+
inter_channels,
|
| 228 |
+
hidden_channels,
|
| 229 |
+
5,
|
| 230 |
+
1,
|
| 231 |
+
16,
|
| 232 |
+
gin_channels=gin_channels,
|
| 233 |
+
)
|
| 234 |
+
self.flow = ResidualCouplingBlock(
|
| 235 |
+
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
if use_sdp:
|
| 239 |
+
self.dp = StochasticDurationPredictor(
|
| 240 |
+
hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
|
| 241 |
+
)
|
| 242 |
+
else:
|
| 243 |
+
self.dp = DurationPredictor(
|
| 244 |
+
hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
if n_speakers >= 1:
|
| 248 |
+
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
| 249 |
+
|
| 250 |
+
def forward(self, data):
|
| 251 |
+
x = data["phone_seq"]
|
| 252 |
+
x_lengths = data["phone_len"]
|
| 253 |
+
y = data["linear"]
|
| 254 |
+
y_lengths = data["target_len"]
|
| 255 |
+
|
| 256 |
+
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
|
| 257 |
+
if self.n_speakers > 0:
|
| 258 |
+
g = self.emb_g(data["spk_id"].squeeze(-1)).unsqueeze(-1) # [b, h, 1]
|
| 259 |
+
else:
|
| 260 |
+
g = None
|
| 261 |
+
|
| 262 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
| 263 |
+
z_p = self.flow(z, y_mask, g=g)
|
| 264 |
+
|
| 265 |
+
with torch.no_grad():
|
| 266 |
+
# negative cross-entropy
|
| 267 |
+
s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
|
| 268 |
+
neg_cent1 = torch.sum(
|
| 269 |
+
-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
|
| 270 |
+
) # [b, 1, t_s]
|
| 271 |
+
neg_cent2 = torch.matmul(
|
| 272 |
+
-0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
|
| 273 |
+
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
| 274 |
+
neg_cent3 = torch.matmul(
|
| 275 |
+
z_p.transpose(1, 2), (m_p * s_p_sq_r)
|
| 276 |
+
) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
|
| 277 |
+
neg_cent4 = torch.sum(
|
| 278 |
+
-0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
|
| 279 |
+
) # [b, 1, t_s]
|
| 280 |
+
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
|
| 281 |
+
|
| 282 |
+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
| 283 |
+
attn = (
|
| 284 |
+
monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
|
| 285 |
+
.unsqueeze(1)
|
| 286 |
+
.detach()
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
w = attn.sum(2)
|
| 290 |
+
if self.use_sdp:
|
| 291 |
+
l_length = self.dp(x, x_mask, w, g=g)
|
| 292 |
+
l_length = l_length / torch.sum(x_mask)
|
| 293 |
+
else:
|
| 294 |
+
logw_ = torch.log(w + 1e-6) * x_mask
|
| 295 |
+
logw = self.dp(x, x_mask, g=g)
|
| 296 |
+
l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
|
| 297 |
+
|
| 298 |
+
# expand prior
|
| 299 |
+
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
|
| 300 |
+
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
|
| 301 |
+
|
| 302 |
+
z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
|
| 303 |
+
o = self.dec(z_slice, g=g)
|
| 304 |
+
outputs = {
|
| 305 |
+
"y_hat": o,
|
| 306 |
+
"l_length": l_length,
|
| 307 |
+
"attn": attn,
|
| 308 |
+
"ids_slice": ids_slice,
|
| 309 |
+
"x_mask": x_mask,
|
| 310 |
+
"z_mask": y_mask,
|
| 311 |
+
"z": z,
|
| 312 |
+
"z_p": z_p,
|
| 313 |
+
"m_p": m_p,
|
| 314 |
+
"logs_p": logs_p,
|
| 315 |
+
"m_q": m_q,
|
| 316 |
+
"logs_q": logs_q,
|
| 317 |
+
}
|
| 318 |
+
return outputs
|
| 319 |
+
|
| 320 |
+
def infer(
|
| 321 |
+
self,
|
| 322 |
+
x,
|
| 323 |
+
x_lengths,
|
| 324 |
+
sid=None,
|
| 325 |
+
noise_scale=1,
|
| 326 |
+
length_scale=1,
|
| 327 |
+
noise_scale_w=1.0,
|
| 328 |
+
max_len=None,
|
| 329 |
+
):
|
| 330 |
+
x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
|
| 331 |
+
if self.n_speakers > 0:
|
| 332 |
+
sid = sid.squeeze(-1)
|
| 333 |
+
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
|
| 334 |
+
else:
|
| 335 |
+
g = None
|
| 336 |
+
|
| 337 |
+
if self.use_sdp:
|
| 338 |
+
logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
|
| 339 |
+
else:
|
| 340 |
+
logw = self.dp(x, x_mask, g=g)
|
| 341 |
+
w = torch.exp(logw) * x_mask * length_scale
|
| 342 |
+
w_ceil = torch.ceil(w)
|
| 343 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
| 344 |
+
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
|
| 345 |
+
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
|
| 346 |
+
attn = generate_path(w_ceil, attn_mask)
|
| 347 |
+
|
| 348 |
+
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
|
| 349 |
+
1, 2
|
| 350 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 351 |
+
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
|
| 352 |
+
1, 2
|
| 353 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 354 |
+
|
| 355 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
| 356 |
+
z = self.flow(z_p, y_mask, g=g, reverse=True)
|
| 357 |
+
o = self.dec((z * y_mask)[:, :, :max_len], g=g)
|
| 358 |
+
|
| 359 |
+
outputs = {
|
| 360 |
+
"y_hat": o,
|
| 361 |
+
"attn": attn,
|
| 362 |
+
"mask": y_mask,
|
| 363 |
+
"z": z,
|
| 364 |
+
"z_p": z_p,
|
| 365 |
+
"m_p": m_p,
|
| 366 |
+
"logs_p": logs_p,
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
return outputs
|
| 370 |
+
|
| 371 |
+
def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
|
| 372 |
+
assert self.n_speakers > 0, "n_speakers have to be larger than 0."
|
| 373 |
+
g_src = self.emb_g(sid_src).unsqueeze(-1)
|
| 374 |
+
g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
|
| 375 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
|
| 376 |
+
z_p = self.flow(z, y_mask, g=g_src)
|
| 377 |
+
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
|
| 378 |
+
o_hat = self.dec(z_hat * y_mask, g=g_tgt)
|
| 379 |
+
return o_hat, y_mask, (z, z_p, z_hat)
|
Amphion/models/tts/vits/vits_dataset.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import numpy as np
|
| 9 |
+
from text import text_to_sequence
|
| 10 |
+
from text.text_token_collation import phoneIDCollation
|
| 11 |
+
from models.tts.base.tts_dataset import (
|
| 12 |
+
TTSDataset,
|
| 13 |
+
TTSCollator,
|
| 14 |
+
TTSTestDataset,
|
| 15 |
+
TTSTestCollator,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class VITSDataset(TTSDataset):
|
| 20 |
+
def __init__(self, cfg, dataset, is_valid):
|
| 21 |
+
super().__init__(cfg, dataset, is_valid=is_valid)
|
| 22 |
+
|
| 23 |
+
def __getitem__(self, index):
|
| 24 |
+
single_feature = super().__getitem__(index)
|
| 25 |
+
return single_feature
|
| 26 |
+
|
| 27 |
+
def __len__(self):
|
| 28 |
+
return super().__len__()
|
| 29 |
+
|
| 30 |
+
def get_metadata(self):
|
| 31 |
+
metadata_filter = []
|
| 32 |
+
with open(self.metafile_path, "r", encoding="utf-8") as f:
|
| 33 |
+
metadata = json.load(f)
|
| 34 |
+
for utt_info in metadata:
|
| 35 |
+
duration = utt_info["Duration"]
|
| 36 |
+
frame_len = (
|
| 37 |
+
duration
|
| 38 |
+
* self.cfg.preprocess.sample_rate
|
| 39 |
+
// self.cfg.preprocess.hop_size
|
| 40 |
+
)
|
| 41 |
+
if (
|
| 42 |
+
frame_len
|
| 43 |
+
< self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size
|
| 44 |
+
):
|
| 45 |
+
continue
|
| 46 |
+
metadata_filter.append(utt_info)
|
| 47 |
+
|
| 48 |
+
return metadata_filter
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class VITSCollator(TTSCollator):
|
| 52 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, cfg):
|
| 55 |
+
super().__init__(cfg)
|
| 56 |
+
|
| 57 |
+
def __call__(self, batch):
|
| 58 |
+
parsed_batch_features = super().__call__(batch)
|
| 59 |
+
return parsed_batch_features
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class VITSTestDataset(TTSTestDataset):
|
| 63 |
+
def __init__(self, args, cfg):
|
| 64 |
+
super().__init__(args, cfg)
|
| 65 |
+
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, args.dataset)
|
| 66 |
+
if cfg.preprocess.use_spkid:
|
| 67 |
+
spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
|
| 68 |
+
with open(spk2id_path, "r") as f:
|
| 69 |
+
self.spk2id = json.load(f)
|
| 70 |
+
|
| 71 |
+
utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
|
| 72 |
+
self.utt2spk = dict()
|
| 73 |
+
with open(utt2spk_path, "r") as f:
|
| 74 |
+
for line in f.readlines():
|
| 75 |
+
utt, spk = line.strip().split("\t")
|
| 76 |
+
self.utt2spk[utt] = spk
|
| 77 |
+
|
| 78 |
+
if cfg.preprocess.use_text or cfg.preprocess.use_phone:
|
| 79 |
+
self.utt2seq = {}
|
| 80 |
+
for utt_info in self.metadata:
|
| 81 |
+
dataset = utt_info["Dataset"]
|
| 82 |
+
uid = utt_info["Uid"]
|
| 83 |
+
utt = "{}_{}".format(dataset, uid)
|
| 84 |
+
|
| 85 |
+
if cfg.preprocess.use_text:
|
| 86 |
+
text = utt_info["Text"]
|
| 87 |
+
sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
|
| 88 |
+
elif cfg.preprocess.use_phone:
|
| 89 |
+
# load phoneme squence from phone file
|
| 90 |
+
phone_path = os.path.join(
|
| 91 |
+
processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
|
| 92 |
+
)
|
| 93 |
+
with open(phone_path, "r") as fin:
|
| 94 |
+
phones = fin.readlines()
|
| 95 |
+
assert len(phones) == 1
|
| 96 |
+
phones = phones[0].strip()
|
| 97 |
+
phones_seq = phones.split(" ")
|
| 98 |
+
|
| 99 |
+
phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
|
| 100 |
+
sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
|
| 101 |
+
|
| 102 |
+
self.utt2seq[utt] = sequence
|
| 103 |
+
|
| 104 |
+
def __getitem__(self, index):
|
| 105 |
+
utt_info = self.metadata[index]
|
| 106 |
+
|
| 107 |
+
dataset = utt_info["Dataset"]
|
| 108 |
+
uid = utt_info["Uid"]
|
| 109 |
+
utt = "{}_{}".format(dataset, uid)
|
| 110 |
+
|
| 111 |
+
single_feature = dict()
|
| 112 |
+
|
| 113 |
+
if self.cfg.preprocess.use_spkid:
|
| 114 |
+
single_feature["spk_id"] = np.array(
|
| 115 |
+
[self.spk2id[self.utt2spk[utt]]], dtype=np.int32
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
|
| 119 |
+
single_feature["phone_seq"] = np.array(self.utt2seq[utt])
|
| 120 |
+
single_feature["phone_len"] = len(self.utt2seq[utt])
|
| 121 |
+
|
| 122 |
+
return single_feature
|
| 123 |
+
|
| 124 |
+
def get_metadata(self):
|
| 125 |
+
with open(self.metafile_path, "r", encoding="utf-8") as f:
|
| 126 |
+
metadata = json.load(f)
|
| 127 |
+
return metadata
|
| 128 |
+
|
| 129 |
+
def __len__(self):
|
| 130 |
+
return len(self.metadata)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class VITSTestCollator(TTSTestCollator):
|
| 134 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
| 135 |
+
|
| 136 |
+
def __init__(self, cfg):
|
| 137 |
+
self.cfg = cfg
|
| 138 |
+
|
| 139 |
+
def __call__(self, batch):
|
| 140 |
+
return super().__call__(batch)
|
Amphion/models/tts/vits/vits_trainer.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.optim.lr_scheduler import ExponentialLR
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from utils.util import *
|
| 13 |
+
from utils.mel import mel_spectrogram_torch
|
| 14 |
+
from models.tts.base import TTSTrainer
|
| 15 |
+
from models.tts.vits.vits import SynthesizerTrn
|
| 16 |
+
from models.tts.vits.vits_dataset import VITSDataset, VITSCollator
|
| 17 |
+
from models.vocoders.gan.discriminator.mpd import (
|
| 18 |
+
MultiPeriodDiscriminator_vits as MultiPeriodDiscriminator,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class VITSTrainer(TTSTrainer):
|
| 23 |
+
def __init__(self, args, cfg):
|
| 24 |
+
TTSTrainer.__init__(self, args, cfg)
|
| 25 |
+
|
| 26 |
+
if cfg.preprocess.use_spkid and cfg.train.multi_speaker_training:
|
| 27 |
+
if cfg.model.n_speakers == 0:
|
| 28 |
+
cfg.model.n_speaker = len(self.speakers)
|
| 29 |
+
|
| 30 |
+
def _build_model(self):
|
| 31 |
+
net_g = SynthesizerTrn(
|
| 32 |
+
self.cfg.model.text_token_num,
|
| 33 |
+
self.cfg.preprocess.n_fft // 2 + 1,
|
| 34 |
+
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
|
| 35 |
+
**self.cfg.model,
|
| 36 |
+
)
|
| 37 |
+
net_d = MultiPeriodDiscriminator(self.cfg.model.use_spectral_norm)
|
| 38 |
+
model = {"generator": net_g, "discriminator": net_d}
|
| 39 |
+
|
| 40 |
+
return model
|
| 41 |
+
|
| 42 |
+
def _build_dataset(self):
|
| 43 |
+
return VITSDataset, VITSCollator
|
| 44 |
+
|
| 45 |
+
def _build_optimizer(self):
|
| 46 |
+
optimizer_g = torch.optim.AdamW(
|
| 47 |
+
self.model["generator"].parameters(),
|
| 48 |
+
self.cfg.train.learning_rate,
|
| 49 |
+
betas=self.cfg.train.AdamW.betas,
|
| 50 |
+
eps=self.cfg.train.AdamW.eps,
|
| 51 |
+
)
|
| 52 |
+
optimizer_d = torch.optim.AdamW(
|
| 53 |
+
self.model["discriminator"].parameters(),
|
| 54 |
+
self.cfg.train.learning_rate,
|
| 55 |
+
betas=self.cfg.train.AdamW.betas,
|
| 56 |
+
eps=self.cfg.train.AdamW.eps,
|
| 57 |
+
)
|
| 58 |
+
optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d}
|
| 59 |
+
|
| 60 |
+
return optimizer
|
| 61 |
+
|
| 62 |
+
def _build_scheduler(self):
|
| 63 |
+
scheduler_g = ExponentialLR(
|
| 64 |
+
self.optimizer["optimizer_g"],
|
| 65 |
+
gamma=self.cfg.train.lr_decay,
|
| 66 |
+
last_epoch=self.epoch - 1,
|
| 67 |
+
)
|
| 68 |
+
scheduler_d = ExponentialLR(
|
| 69 |
+
self.optimizer["optimizer_d"],
|
| 70 |
+
gamma=self.cfg.train.lr_decay,
|
| 71 |
+
last_epoch=self.epoch - 1,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d}
|
| 75 |
+
return scheduler
|
| 76 |
+
|
| 77 |
+
def _build_criterion(self):
|
| 78 |
+
class GeneratorLoss(nn.Module):
|
| 79 |
+
def __init__(self, cfg):
|
| 80 |
+
super(GeneratorLoss, self).__init__()
|
| 81 |
+
self.cfg = cfg
|
| 82 |
+
self.l1_loss = nn.L1Loss()
|
| 83 |
+
|
| 84 |
+
def generator_loss(self, disc_outputs):
|
| 85 |
+
loss = 0
|
| 86 |
+
gen_losses = []
|
| 87 |
+
for dg in disc_outputs:
|
| 88 |
+
dg = dg.float()
|
| 89 |
+
l = torch.mean((1 - dg) ** 2)
|
| 90 |
+
gen_losses.append(l)
|
| 91 |
+
loss += l
|
| 92 |
+
|
| 93 |
+
return loss, gen_losses
|
| 94 |
+
|
| 95 |
+
def feature_loss(self, fmap_r, fmap_g):
|
| 96 |
+
loss = 0
|
| 97 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
| 98 |
+
for rl, gl in zip(dr, dg):
|
| 99 |
+
rl = rl.float().detach()
|
| 100 |
+
gl = gl.float()
|
| 101 |
+
loss += torch.mean(torch.abs(rl - gl))
|
| 102 |
+
|
| 103 |
+
return loss * 2
|
| 104 |
+
|
| 105 |
+
def kl_loss(self, z_p, logs_q, m_p, logs_p, z_mask):
|
| 106 |
+
"""
|
| 107 |
+
z_p, logs_q: [b, h, t_t]
|
| 108 |
+
m_p, logs_p: [b, h, t_t]
|
| 109 |
+
"""
|
| 110 |
+
z_p = z_p.float()
|
| 111 |
+
logs_q = logs_q.float()
|
| 112 |
+
m_p = m_p.float()
|
| 113 |
+
logs_p = logs_p.float()
|
| 114 |
+
z_mask = z_mask.float()
|
| 115 |
+
|
| 116 |
+
kl = logs_p - logs_q - 0.5
|
| 117 |
+
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
| 118 |
+
kl = torch.sum(kl * z_mask)
|
| 119 |
+
l = kl / torch.sum(z_mask)
|
| 120 |
+
return l
|
| 121 |
+
|
| 122 |
+
def forward(
|
| 123 |
+
self,
|
| 124 |
+
outputs_g,
|
| 125 |
+
outputs_d,
|
| 126 |
+
y_mel,
|
| 127 |
+
y_hat_mel,
|
| 128 |
+
):
|
| 129 |
+
loss_g = {}
|
| 130 |
+
|
| 131 |
+
# duration loss
|
| 132 |
+
loss_dur = torch.sum(outputs_g["l_length"].float())
|
| 133 |
+
loss_g["loss_dur"] = loss_dur
|
| 134 |
+
|
| 135 |
+
# mel loss
|
| 136 |
+
loss_mel = self.l1_loss(y_mel, y_hat_mel) * self.cfg.train.c_mel
|
| 137 |
+
loss_g["loss_mel"] = loss_mel
|
| 138 |
+
|
| 139 |
+
# kl loss
|
| 140 |
+
loss_kl = (
|
| 141 |
+
self.kl_loss(
|
| 142 |
+
outputs_g["z_p"],
|
| 143 |
+
outputs_g["logs_q"],
|
| 144 |
+
outputs_g["m_p"],
|
| 145 |
+
outputs_g["logs_p"],
|
| 146 |
+
outputs_g["z_mask"],
|
| 147 |
+
)
|
| 148 |
+
* self.cfg.train.c_kl
|
| 149 |
+
)
|
| 150 |
+
loss_g["loss_kl"] = loss_kl
|
| 151 |
+
|
| 152 |
+
# feature loss
|
| 153 |
+
loss_fm = self.feature_loss(outputs_d["fmap_rs"], outputs_d["fmap_gs"])
|
| 154 |
+
loss_g["loss_fm"] = loss_fm
|
| 155 |
+
|
| 156 |
+
# gan loss
|
| 157 |
+
loss_gen, losses_gen = self.generator_loss(outputs_d["y_d_hat_g"])
|
| 158 |
+
loss_g["loss_gen"] = loss_gen
|
| 159 |
+
loss_g["loss_gen_all"] = (
|
| 160 |
+
loss_dur + loss_mel + loss_kl + loss_fm + loss_gen
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
return loss_g
|
| 164 |
+
|
| 165 |
+
class DiscriminatorLoss(nn.Module):
|
| 166 |
+
def __init__(self, cfg):
|
| 167 |
+
super(DiscriminatorLoss, self).__init__()
|
| 168 |
+
self.cfg = cfg
|
| 169 |
+
self.l1Loss = torch.nn.L1Loss(reduction="mean")
|
| 170 |
+
|
| 171 |
+
def __call__(self, disc_real_outputs, disc_generated_outputs):
|
| 172 |
+
loss_d = {}
|
| 173 |
+
|
| 174 |
+
loss = 0
|
| 175 |
+
r_losses = []
|
| 176 |
+
g_losses = []
|
| 177 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 178 |
+
dr = dr.float()
|
| 179 |
+
dg = dg.float()
|
| 180 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
| 181 |
+
g_loss = torch.mean(dg**2)
|
| 182 |
+
loss += r_loss + g_loss
|
| 183 |
+
r_losses.append(r_loss.item())
|
| 184 |
+
g_losses.append(g_loss.item())
|
| 185 |
+
|
| 186 |
+
loss_d["loss_disc_all"] = loss
|
| 187 |
+
|
| 188 |
+
return loss_d
|
| 189 |
+
|
| 190 |
+
criterion = {
|
| 191 |
+
"generator": GeneratorLoss(self.cfg),
|
| 192 |
+
"discriminator": DiscriminatorLoss(self.cfg),
|
| 193 |
+
}
|
| 194 |
+
return criterion
|
| 195 |
+
|
| 196 |
+
def write_summary(
|
| 197 |
+
self,
|
| 198 |
+
losses,
|
| 199 |
+
stats,
|
| 200 |
+
images={},
|
| 201 |
+
audios={},
|
| 202 |
+
audio_sampling_rate=24000,
|
| 203 |
+
tag="train",
|
| 204 |
+
):
|
| 205 |
+
for key, value in losses.items():
|
| 206 |
+
self.sw.add_scalar(tag + "/" + key, value, self.step)
|
| 207 |
+
self.sw.add_scalar(
|
| 208 |
+
"learning_rate",
|
| 209 |
+
self.optimizer["optimizer_g"].param_groups[0]["lr"],
|
| 210 |
+
self.step,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
if len(images) != 0:
|
| 214 |
+
for key, value in images.items():
|
| 215 |
+
self.sw.add_image(key, value, self.global_step, batchformats="HWC")
|
| 216 |
+
if len(audios) != 0:
|
| 217 |
+
for key, value in audios.items():
|
| 218 |
+
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
|
| 219 |
+
|
| 220 |
+
def write_valid_summary(
|
| 221 |
+
self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val"
|
| 222 |
+
):
|
| 223 |
+
for key, value in losses.items():
|
| 224 |
+
self.sw.add_scalar(tag + "/" + key, value, self.step)
|
| 225 |
+
|
| 226 |
+
if len(images) != 0:
|
| 227 |
+
for key, value in images.items():
|
| 228 |
+
self.sw.add_image(key, value, self.global_step, batchformats="HWC")
|
| 229 |
+
if len(audios) != 0:
|
| 230 |
+
for key, value in audios.items():
|
| 231 |
+
self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
|
| 232 |
+
|
| 233 |
+
def get_state_dict(self):
|
| 234 |
+
state_dict = {
|
| 235 |
+
"generator": self.model["generator"].state_dict(),
|
| 236 |
+
"discriminator": self.model["discriminator"].state_dict(),
|
| 237 |
+
"optimizer_g": self.optimizer["optimizer_g"].state_dict(),
|
| 238 |
+
"optimizer_d": self.optimizer["optimizer_d"].state_dict(),
|
| 239 |
+
"scheduler_g": self.scheduler["scheduler_g"].state_dict(),
|
| 240 |
+
"scheduler_d": self.scheduler["scheduler_d"].state_dict(),
|
| 241 |
+
"step": self.step,
|
| 242 |
+
"epoch": self.epoch,
|
| 243 |
+
"batch_size": self.cfg.train.batch_size,
|
| 244 |
+
}
|
| 245 |
+
return state_dict
|
| 246 |
+
|
| 247 |
+
def load_model(self, checkpoint):
|
| 248 |
+
self.step = checkpoint["step"]
|
| 249 |
+
self.epoch = checkpoint["epoch"]
|
| 250 |
+
self.model["generator"].load_state_dict(checkpoint["generator"])
|
| 251 |
+
self.model["discriminator"].load_state_dict(checkpoint["discriminator"])
|
| 252 |
+
self.optimizer["optimizer_g"].load_state_dict(checkpoint["optimizer_g"])
|
| 253 |
+
self.optimizer["optimizer_d"].load_state_dict(checkpoint["optimizer_d"])
|
| 254 |
+
self.scheduler["scheduler_g"].load_state_dict(checkpoint["scheduler_g"])
|
| 255 |
+
self.scheduler["scheduler_d"].load_state_dict(checkpoint["scheduler_d"])
|
| 256 |
+
|
| 257 |
+
@torch.inference_mode()
|
| 258 |
+
def _valid_step(self, batch):
|
| 259 |
+
r"""Testing forward step. Should return average loss of a sample over
|
| 260 |
+
one batch. Provoke ``_forward_step`` is recommended except for special case.
|
| 261 |
+
See ``_test_epoch`` for usage.
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
valid_losses = {}
|
| 265 |
+
total_loss = 0
|
| 266 |
+
valid_stats = {}
|
| 267 |
+
|
| 268 |
+
batch["linear"] = batch["linear"].transpose(2, 1) # [b, d, t]
|
| 269 |
+
batch["mel"] = batch["mel"].transpose(2, 1) # [b, d, t]
|
| 270 |
+
batch["audio"] = batch["audio"].unsqueeze(1) # [b, d, t]
|
| 271 |
+
|
| 272 |
+
# Discriminator
|
| 273 |
+
# Generator output
|
| 274 |
+
outputs_g = self.model["generator"](batch)
|
| 275 |
+
|
| 276 |
+
y_mel = slice_segments(
|
| 277 |
+
batch["mel"],
|
| 278 |
+
outputs_g["ids_slice"],
|
| 279 |
+
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
|
| 280 |
+
)
|
| 281 |
+
y_hat_mel = mel_spectrogram_torch(
|
| 282 |
+
outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
|
| 283 |
+
)
|
| 284 |
+
y = slice_segments(
|
| 285 |
+
batch["audio"],
|
| 286 |
+
outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
|
| 287 |
+
self.cfg.preprocess.segment_size,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Discriminator output
|
| 291 |
+
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
|
| 292 |
+
## Discriminator loss
|
| 293 |
+
loss_d = self.criterion["discriminator"](
|
| 294 |
+
outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
|
| 295 |
+
)
|
| 296 |
+
valid_losses.update(loss_d)
|
| 297 |
+
|
| 298 |
+
## Generator
|
| 299 |
+
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
|
| 300 |
+
loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
|
| 301 |
+
valid_losses.update(loss_g)
|
| 302 |
+
|
| 303 |
+
for item in valid_losses:
|
| 304 |
+
valid_losses[item] = valid_losses[item].item()
|
| 305 |
+
|
| 306 |
+
total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
|
| 307 |
+
|
| 308 |
+
return (
|
| 309 |
+
total_loss.item(),
|
| 310 |
+
valid_losses,
|
| 311 |
+
valid_stats,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
def _train_step(self, batch):
|
| 315 |
+
r"""Forward step for training and inference. This function is called
|
| 316 |
+
in ``_train_step`` & ``_test_step`` function.
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
train_losses = {}
|
| 320 |
+
total_loss = 0
|
| 321 |
+
training_stats = {}
|
| 322 |
+
|
| 323 |
+
batch["linear"] = batch["linear"].transpose(2, 1) # [b, d, t]
|
| 324 |
+
batch["mel"] = batch["mel"].transpose(2, 1) # [b, d, t]
|
| 325 |
+
batch["audio"] = batch["audio"].unsqueeze(1) # [b, d, t]
|
| 326 |
+
|
| 327 |
+
# Train Discriminator
|
| 328 |
+
# Generator output
|
| 329 |
+
outputs_g = self.model["generator"](batch)
|
| 330 |
+
|
| 331 |
+
y_mel = slice_segments(
|
| 332 |
+
batch["mel"],
|
| 333 |
+
outputs_g["ids_slice"],
|
| 334 |
+
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
|
| 335 |
+
)
|
| 336 |
+
y_hat_mel = mel_spectrogram_torch(
|
| 337 |
+
outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
|
| 338 |
+
)
|
| 339 |
+
y = slice_segments(
|
| 340 |
+
batch["audio"],
|
| 341 |
+
outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
|
| 342 |
+
self.cfg.preprocess.segment_size,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
# Discriminator output
|
| 346 |
+
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
|
| 347 |
+
## Discriminator loss
|
| 348 |
+
loss_d = self.criterion["discriminator"](
|
| 349 |
+
outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
|
| 350 |
+
)
|
| 351 |
+
train_losses.update(loss_d)
|
| 352 |
+
|
| 353 |
+
# BP and Grad Updated
|
| 354 |
+
self.optimizer["optimizer_d"].zero_grad()
|
| 355 |
+
self.accelerator.backward(loss_d["loss_disc_all"])
|
| 356 |
+
self.optimizer["optimizer_d"].step()
|
| 357 |
+
|
| 358 |
+
## Train Generator
|
| 359 |
+
outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
|
| 360 |
+
loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
|
| 361 |
+
train_losses.update(loss_g)
|
| 362 |
+
|
| 363 |
+
# BP and Grad Updated
|
| 364 |
+
self.optimizer["optimizer_g"].zero_grad()
|
| 365 |
+
self.accelerator.backward(loss_g["loss_gen_all"])
|
| 366 |
+
self.optimizer["optimizer_g"].step()
|
| 367 |
+
|
| 368 |
+
for item in train_losses:
|
| 369 |
+
train_losses[item] = train_losses[item].item()
|
| 370 |
+
|
| 371 |
+
total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
|
| 372 |
+
|
| 373 |
+
return (
|
| 374 |
+
total_loss.item(),
|
| 375 |
+
train_losses,
|
| 376 |
+
training_stats,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
def _train_epoch(self):
|
| 380 |
+
r"""Training epoch. Should return average loss of a batch (sample) over
|
| 381 |
+
one epoch. See ``train_loop`` for usage.
|
| 382 |
+
"""
|
| 383 |
+
epoch_sum_loss: float = 0.0
|
| 384 |
+
epoch_losses: dict = {}
|
| 385 |
+
epoch_step: int = 0
|
| 386 |
+
for batch in tqdm(
|
| 387 |
+
self.train_dataloader,
|
| 388 |
+
desc=f"Training Epoch {self.epoch}",
|
| 389 |
+
unit="batch",
|
| 390 |
+
colour="GREEN",
|
| 391 |
+
leave=False,
|
| 392 |
+
dynamic_ncols=True,
|
| 393 |
+
smoothing=0.04,
|
| 394 |
+
disable=not self.accelerator.is_main_process,
|
| 395 |
+
):
|
| 396 |
+
with self.accelerator.accumulate(self.model):
|
| 397 |
+
total_loss, train_losses, training_stats = self._train_step(batch)
|
| 398 |
+
self.batch_count += 1
|
| 399 |
+
|
| 400 |
+
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
|
| 401 |
+
epoch_sum_loss += total_loss
|
| 402 |
+
for key, value in train_losses.items():
|
| 403 |
+
if key not in epoch_losses.keys():
|
| 404 |
+
epoch_losses[key] = value
|
| 405 |
+
else:
|
| 406 |
+
epoch_losses[key] += value
|
| 407 |
+
|
| 408 |
+
self.accelerator.log(
|
| 409 |
+
{
|
| 410 |
+
"Step/Generator Loss": train_losses["loss_gen_all"],
|
| 411 |
+
"Step/Discriminator Loss": train_losses["loss_disc_all"],
|
| 412 |
+
"Step/Generator Learning Rate": self.optimizer[
|
| 413 |
+
"optimizer_d"
|
| 414 |
+
].param_groups[0]["lr"],
|
| 415 |
+
"Step/Discriminator Learning Rate": self.optimizer[
|
| 416 |
+
"optimizer_g"
|
| 417 |
+
].param_groups[0]["lr"],
|
| 418 |
+
},
|
| 419 |
+
step=self.step,
|
| 420 |
+
)
|
| 421 |
+
self.step += 1
|
| 422 |
+
epoch_step += 1
|
| 423 |
+
|
| 424 |
+
self.accelerator.wait_for_everyone()
|
| 425 |
+
|
| 426 |
+
epoch_sum_loss = (
|
| 427 |
+
epoch_sum_loss
|
| 428 |
+
/ len(self.train_dataloader)
|
| 429 |
+
* self.cfg.train.gradient_accumulation_step
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
for key in epoch_losses.keys():
|
| 433 |
+
epoch_losses[key] = (
|
| 434 |
+
epoch_losses[key]
|
| 435 |
+
/ len(self.train_dataloader)
|
| 436 |
+
* self.cfg.train.gradient_accumulation_step
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
return epoch_sum_loss, epoch_losses
|
Amphion/models/vocoders/autoregressive/autoregressive_vocoder_trainer.py
ADDED
|
File without changes
|
Amphion/modules/activation_functions/gated_activation_unit.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from modules.general.utils import Conv1d
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class GaU(nn.Module):
|
| 13 |
+
r"""Gated Activation Unit (GaU) proposed in `Gated Activation Units for Neural
|
| 14 |
+
Networks <https://arxiv.org/pdf/1606.05328.pdf>`_.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
channels: number of input channels.
|
| 18 |
+
kernel_size: kernel size of the convolution.
|
| 19 |
+
dilation: dilation rate of the convolution.
|
| 20 |
+
d_context: dimension of context tensor, None if don't use context.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
channels: int,
|
| 26 |
+
kernel_size: int = 3,
|
| 27 |
+
dilation: int = 1,
|
| 28 |
+
d_context: int = None,
|
| 29 |
+
):
|
| 30 |
+
super().__init__()
|
| 31 |
+
|
| 32 |
+
self.context = d_context
|
| 33 |
+
|
| 34 |
+
self.conv = Conv1d(
|
| 35 |
+
channels,
|
| 36 |
+
channels * 2,
|
| 37 |
+
kernel_size,
|
| 38 |
+
dilation=dilation,
|
| 39 |
+
padding=dilation * (kernel_size - 1) // 2,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
if self.context:
|
| 43 |
+
self.context_proj = Conv1d(d_context, channels * 2, 1)
|
| 44 |
+
|
| 45 |
+
def forward(self, x: torch.Tensor, context: torch.Tensor = None):
|
| 46 |
+
r"""Calculate forward propagation.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
x: input tensor with shape [B, C, T].
|
| 50 |
+
context: context tensor with shape [B, ``d_context``, T], default to None.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
h = self.conv(x)
|
| 54 |
+
|
| 55 |
+
if self.context:
|
| 56 |
+
h = h + self.context_proj(context)
|
| 57 |
+
|
| 58 |
+
h1, h2 = h.chunk(2, 1)
|
| 59 |
+
h = torch.tanh(h1) * torch.sigmoid(h2)
|
| 60 |
+
|
| 61 |
+
return h
|
Amphion/modules/base/base_module.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LayerNorm(nn.Module):
|
| 12 |
+
def __init__(self, channels, eps=1e-5):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.channels = channels
|
| 15 |
+
self.eps = eps
|
| 16 |
+
|
| 17 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
| 18 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
x = x.transpose(1, -1)
|
| 22 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
| 23 |
+
return x.transpose(1, -1)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ConvReluNorm(nn.Module):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
in_channels,
|
| 30 |
+
hidden_channels,
|
| 31 |
+
out_channels,
|
| 32 |
+
kernel_size,
|
| 33 |
+
n_layers,
|
| 34 |
+
p_dropout,
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.in_channels = in_channels
|
| 38 |
+
self.hidden_channels = hidden_channels
|
| 39 |
+
self.out_channels = out_channels
|
| 40 |
+
self.kernel_size = kernel_size
|
| 41 |
+
self.n_layers = n_layers
|
| 42 |
+
self.p_dropout = p_dropout
|
| 43 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
| 44 |
+
|
| 45 |
+
self.conv_layers = nn.ModuleList()
|
| 46 |
+
self.norm_layers = nn.ModuleList()
|
| 47 |
+
self.conv_layers.append(
|
| 48 |
+
nn.Conv1d(
|
| 49 |
+
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
| 50 |
+
)
|
| 51 |
+
)
|
| 52 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
| 53 |
+
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
| 54 |
+
for _ in range(n_layers - 1):
|
| 55 |
+
self.conv_layers.append(
|
| 56 |
+
nn.Conv1d(
|
| 57 |
+
hidden_channels,
|
| 58 |
+
hidden_channels,
|
| 59 |
+
kernel_size,
|
| 60 |
+
padding=kernel_size // 2,
|
| 61 |
+
)
|
| 62 |
+
)
|
| 63 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
| 64 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
| 65 |
+
self.proj.weight.data.zero_()
|
| 66 |
+
self.proj.bias.data.zero_()
|
| 67 |
+
|
| 68 |
+
def forward(self, x, x_mask):
|
| 69 |
+
x_org = x
|
| 70 |
+
for i in range(self.n_layers):
|
| 71 |
+
x = self.conv_layers[i](x * x_mask)
|
| 72 |
+
x = self.norm_layers[i](x)
|
| 73 |
+
x = self.relu_drop(x)
|
| 74 |
+
x = x_org + self.proj(x)
|
| 75 |
+
return x * x_mask
|
Amphion/modules/diffusion/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .bidilconv.bidilated_conv import BiDilConv
|
| 7 |
+
from .unet.unet import UNet
|
Amphion/modules/duration_predictor/__init__.py
ADDED
|
File without changes
|
Amphion/modules/duration_predictor/standard_duration_predictor.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# This code is modified from https://github.com/jaywalnut310/vits/blob/main/models.py
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from modules.base.base_module import LayerNorm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DurationPredictor(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
| 16 |
+
):
|
| 17 |
+
super().__init__()
|
| 18 |
+
|
| 19 |
+
self.in_channels = in_channels
|
| 20 |
+
self.filter_channels = filter_channels
|
| 21 |
+
self.kernel_size = kernel_size
|
| 22 |
+
self.p_dropout = p_dropout
|
| 23 |
+
self.gin_channels = gin_channels
|
| 24 |
+
|
| 25 |
+
self.drop = nn.Dropout(p_dropout)
|
| 26 |
+
self.conv_1 = nn.Conv1d(
|
| 27 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 28 |
+
)
|
| 29 |
+
self.norm_1 = LayerNorm(filter_channels)
|
| 30 |
+
self.conv_2 = nn.Conv1d(
|
| 31 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 32 |
+
)
|
| 33 |
+
self.norm_2 = LayerNorm(filter_channels)
|
| 34 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
| 35 |
+
|
| 36 |
+
if gin_channels != 0:
|
| 37 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
| 38 |
+
|
| 39 |
+
def forward(self, x, x_mask, g=None):
|
| 40 |
+
x = torch.detach(x)
|
| 41 |
+
if g is not None:
|
| 42 |
+
g = torch.detach(g)
|
| 43 |
+
x = x + self.cond(g)
|
| 44 |
+
x = self.conv_1(x * x_mask)
|
| 45 |
+
x = torch.relu(x)
|
| 46 |
+
x = self.norm_1(x)
|
| 47 |
+
x = self.drop(x)
|
| 48 |
+
x = self.conv_2(x * x_mask)
|
| 49 |
+
x = torch.relu(x)
|
| 50 |
+
x = self.norm_2(x)
|
| 51 |
+
x = self.drop(x)
|
| 52 |
+
x = self.proj(x * x_mask)
|
| 53 |
+
return x * x_mask
|
Amphion/modules/duration_predictor/stochastic_duration_predictor.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# This code is modified from https://github.com/jaywalnut310/vits/blob/main/models.pyimport torch
|
| 7 |
+
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
import math
|
| 11 |
+
from modules.flow.modules import *
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class StochasticDurationPredictor(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_channels,
|
| 18 |
+
filter_channels,
|
| 19 |
+
kernel_size,
|
| 20 |
+
p_dropout,
|
| 21 |
+
n_flows=4,
|
| 22 |
+
gin_channels=0,
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
filter_channels = in_channels
|
| 26 |
+
self.in_channels = in_channels
|
| 27 |
+
self.filter_channels = filter_channels
|
| 28 |
+
self.kernel_size = kernel_size
|
| 29 |
+
self.p_dropout = p_dropout
|
| 30 |
+
self.n_flows = n_flows
|
| 31 |
+
self.gin_channels = gin_channels
|
| 32 |
+
|
| 33 |
+
self.log_flow = Log()
|
| 34 |
+
self.flows = nn.ModuleList()
|
| 35 |
+
self.flows.append(ElementwiseAffine(2))
|
| 36 |
+
for i in range(n_flows):
|
| 37 |
+
self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
| 38 |
+
self.flows.append(Flip())
|
| 39 |
+
|
| 40 |
+
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
| 41 |
+
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 42 |
+
self.post_convs = DDSConv(
|
| 43 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
| 44 |
+
)
|
| 45 |
+
self.post_flows = nn.ModuleList()
|
| 46 |
+
self.post_flows.append(ElementwiseAffine(2))
|
| 47 |
+
for i in range(4):
|
| 48 |
+
self.post_flows.append(
|
| 49 |
+
ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
| 50 |
+
)
|
| 51 |
+
self.post_flows.append(Flip())
|
| 52 |
+
|
| 53 |
+
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
| 54 |
+
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 55 |
+
self.convs = DDSConv(
|
| 56 |
+
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
| 57 |
+
)
|
| 58 |
+
if gin_channels != 0:
|
| 59 |
+
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
| 60 |
+
|
| 61 |
+
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
|
| 62 |
+
x = torch.detach(x)
|
| 63 |
+
x = self.pre(x)
|
| 64 |
+
if g is not None:
|
| 65 |
+
g = torch.detach(g)
|
| 66 |
+
x = x + self.cond(g)
|
| 67 |
+
x = self.convs(x, x_mask)
|
| 68 |
+
x = self.proj(x) * x_mask
|
| 69 |
+
|
| 70 |
+
if not reverse:
|
| 71 |
+
flows = self.flows
|
| 72 |
+
assert w is not None
|
| 73 |
+
|
| 74 |
+
logdet_tot_q = 0
|
| 75 |
+
h_w = self.post_pre(w)
|
| 76 |
+
h_w = self.post_convs(h_w, x_mask)
|
| 77 |
+
h_w = self.post_proj(h_w) * x_mask
|
| 78 |
+
e_q = (
|
| 79 |
+
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
|
| 80 |
+
* x_mask
|
| 81 |
+
)
|
| 82 |
+
z_q = e_q
|
| 83 |
+
for flow in self.post_flows:
|
| 84 |
+
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
| 85 |
+
logdet_tot_q += logdet_q
|
| 86 |
+
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
| 87 |
+
u = torch.sigmoid(z_u) * x_mask
|
| 88 |
+
z0 = (w - u) * x_mask
|
| 89 |
+
logdet_tot_q += torch.sum(
|
| 90 |
+
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
| 91 |
+
)
|
| 92 |
+
logq = (
|
| 93 |
+
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
| 94 |
+
- logdet_tot_q
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
logdet_tot = 0
|
| 98 |
+
z0, logdet = self.log_flow(z0, x_mask)
|
| 99 |
+
logdet_tot += logdet
|
| 100 |
+
z = torch.cat([z0, z1], 1)
|
| 101 |
+
for flow in flows:
|
| 102 |
+
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
| 103 |
+
logdet_tot = logdet_tot + logdet
|
| 104 |
+
nll = (
|
| 105 |
+
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
| 106 |
+
- logdet_tot
|
| 107 |
+
)
|
| 108 |
+
return nll + logq
|
| 109 |
+
else:
|
| 110 |
+
flows = list(reversed(self.flows))
|
| 111 |
+
flows = flows[:-2] + [flows[-1]]
|
| 112 |
+
z = (
|
| 113 |
+
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
| 114 |
+
* noise_scale
|
| 115 |
+
)
|
| 116 |
+
for flow in flows:
|
| 117 |
+
z = flow(z, x_mask, g=x, reverse=reverse)
|
| 118 |
+
z0, z1 = torch.split(z, [1, 1], 1)
|
| 119 |
+
logw = z0
|
| 120 |
+
return logw
|
Amphion/modules/general/scaling.py
ADDED
|
@@ -0,0 +1,1349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This module is modified from https://github.com/Plachtaa/VALL-E-X/blob/3faaf8ccadb154d63b38070caf518ce9309ea0f4/modules/scaling.py
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
import random
|
| 6 |
+
import math
|
| 7 |
+
from typing import Optional, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Transpose(nn.Identity):
|
| 15 |
+
"""(N, T, D) -> (N, D, T)"""
|
| 16 |
+
|
| 17 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
return input.transpose(1, 2)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ActivationBalancerFunction(torch.autograd.Function):
|
| 22 |
+
@staticmethod
|
| 23 |
+
def forward(
|
| 24 |
+
ctx,
|
| 25 |
+
x: Tensor,
|
| 26 |
+
scale_factor: Tensor,
|
| 27 |
+
sign_factor: Optional[Tensor],
|
| 28 |
+
channel_dim: int,
|
| 29 |
+
) -> Tensor:
|
| 30 |
+
if channel_dim < 0:
|
| 31 |
+
channel_dim += x.ndim
|
| 32 |
+
ctx.channel_dim = channel_dim
|
| 33 |
+
xgt0 = x > 0
|
| 34 |
+
if sign_factor is None:
|
| 35 |
+
ctx.save_for_backward(xgt0, scale_factor)
|
| 36 |
+
else:
|
| 37 |
+
ctx.save_for_backward(xgt0, scale_factor, sign_factor)
|
| 38 |
+
return x
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
| 42 |
+
if len(ctx.saved_tensors) == 3:
|
| 43 |
+
xgt0, scale_factor, sign_factor = ctx.saved_tensors
|
| 44 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
| 45 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
| 46 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
| 47 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
| 48 |
+
else:
|
| 49 |
+
xgt0, scale_factor = ctx.saved_tensors
|
| 50 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
| 51 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
| 52 |
+
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
| 53 |
+
neg_delta_grad = x_grad.abs() * factor
|
| 54 |
+
return (
|
| 55 |
+
x_grad - neg_delta_grad,
|
| 56 |
+
None,
|
| 57 |
+
None,
|
| 58 |
+
None,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _compute_scale_factor(
|
| 63 |
+
x: Tensor,
|
| 64 |
+
channel_dim: int,
|
| 65 |
+
min_abs: float,
|
| 66 |
+
max_abs: float,
|
| 67 |
+
gain_factor: float,
|
| 68 |
+
max_factor: float,
|
| 69 |
+
) -> Tensor:
|
| 70 |
+
if channel_dim < 0:
|
| 71 |
+
channel_dim += x.ndim
|
| 72 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
| 73 |
+
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
| 74 |
+
|
| 75 |
+
if min_abs == 0.0:
|
| 76 |
+
below_threshold = 0.0
|
| 77 |
+
else:
|
| 78 |
+
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
| 79 |
+
# x_abs)_mean , min_abs.
|
| 80 |
+
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
|
| 81 |
+
min=0, max=max_factor
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
|
| 85 |
+
min=0, max=max_factor
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return below_threshold - above_threshold
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _compute_sign_factor(
|
| 92 |
+
x: Tensor,
|
| 93 |
+
channel_dim: int,
|
| 94 |
+
min_positive: float,
|
| 95 |
+
max_positive: float,
|
| 96 |
+
gain_factor: float,
|
| 97 |
+
max_factor: float,
|
| 98 |
+
) -> Tensor:
|
| 99 |
+
if channel_dim < 0:
|
| 100 |
+
channel_dim += x.ndim
|
| 101 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
| 102 |
+
proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
|
| 103 |
+
if min_positive == 0.0:
|
| 104 |
+
factor1 = 0.0
|
| 105 |
+
else:
|
| 106 |
+
# 0 if proportion_positive >= min_positive, else can be
|
| 107 |
+
# as large as max_factor.
|
| 108 |
+
factor1 = (
|
| 109 |
+
(min_positive - proportion_positive) * (gain_factor / min_positive)
|
| 110 |
+
).clamp_(min=0, max=max_factor)
|
| 111 |
+
|
| 112 |
+
if max_positive == 1.0:
|
| 113 |
+
factor2 = 0.0
|
| 114 |
+
else:
|
| 115 |
+
# 0 if self.proportion_positive <= max_positive, else can be
|
| 116 |
+
# as large as -max_factor.
|
| 117 |
+
factor2 = (
|
| 118 |
+
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
|
| 119 |
+
).clamp_(min=0, max=max_factor)
|
| 120 |
+
sign_factor = factor1 - factor2
|
| 121 |
+
# require min_positive != 0 or max_positive != 1:
|
| 122 |
+
assert not isinstance(sign_factor, float)
|
| 123 |
+
return sign_factor
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class ActivationScaleBalancerFunction(torch.autograd.Function):
|
| 127 |
+
"""
|
| 128 |
+
This object is used in class ActivationBalancer when the user specified
|
| 129 |
+
min_positive=0, max_positive=1, so there are no constraints on the signs
|
| 130 |
+
of the activations and only the absolute value has a constraint.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
@staticmethod
|
| 134 |
+
def forward(
|
| 135 |
+
ctx,
|
| 136 |
+
x: Tensor,
|
| 137 |
+
sign_factor: Tensor,
|
| 138 |
+
scale_factor: Tensor,
|
| 139 |
+
channel_dim: int,
|
| 140 |
+
) -> Tensor:
|
| 141 |
+
if channel_dim < 0:
|
| 142 |
+
channel_dim += x.ndim
|
| 143 |
+
ctx.channel_dim = channel_dim
|
| 144 |
+
xgt0 = x > 0
|
| 145 |
+
ctx.save_for_backward(xgt0, sign_factor, scale_factor)
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
@staticmethod
|
| 149 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
| 150 |
+
xgt0, sign_factor, scale_factor = ctx.saved_tensors
|
| 151 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
| 152 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
| 153 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
| 154 |
+
|
| 155 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
| 156 |
+
neg_delta_grad = x_grad.abs() * factor
|
| 157 |
+
return (
|
| 158 |
+
x_grad - neg_delta_grad,
|
| 159 |
+
None,
|
| 160 |
+
None,
|
| 161 |
+
None,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class RandomClampFunction(torch.autograd.Function):
|
| 166 |
+
@staticmethod
|
| 167 |
+
def forward(
|
| 168 |
+
ctx,
|
| 169 |
+
x: Tensor,
|
| 170 |
+
min: Optional[float],
|
| 171 |
+
max: Optional[float],
|
| 172 |
+
prob: float,
|
| 173 |
+
reflect: float,
|
| 174 |
+
) -> Tensor:
|
| 175 |
+
x_clamped = torch.clamp(x, min=min, max=max)
|
| 176 |
+
mask = torch.rand_like(x) < prob
|
| 177 |
+
ans = torch.where(mask, x_clamped, x)
|
| 178 |
+
if x.requires_grad:
|
| 179 |
+
ctx.save_for_backward(ans == x)
|
| 180 |
+
ctx.reflect = reflect
|
| 181 |
+
if reflect != 0.0:
|
| 182 |
+
ans = ans * (1.0 + reflect) - (x * reflect)
|
| 183 |
+
return ans
|
| 184 |
+
|
| 185 |
+
@staticmethod
|
| 186 |
+
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
|
| 187 |
+
(is_same,) = ctx.saved_tensors
|
| 188 |
+
x_grad = ans_grad * is_same.to(ans_grad.dtype)
|
| 189 |
+
reflect = ctx.reflect
|
| 190 |
+
if reflect != 0.0:
|
| 191 |
+
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
|
| 192 |
+
return x_grad, None, None, None, None
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def random_clamp(
|
| 196 |
+
x: Tensor,
|
| 197 |
+
min: Optional[float] = None,
|
| 198 |
+
max: Optional[float] = None,
|
| 199 |
+
prob: float = 0.5,
|
| 200 |
+
reflect: float = 0.0,
|
| 201 |
+
):
|
| 202 |
+
return RandomClampFunction.apply(x, min, max, prob, reflect)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
|
| 206 |
+
"""
|
| 207 |
+
A randomized way of casting a floating point value to half precision.
|
| 208 |
+
"""
|
| 209 |
+
if x.dtype == torch.float16:
|
| 210 |
+
return x
|
| 211 |
+
x_abs = x.abs()
|
| 212 |
+
is_too_small = x_abs < min_abs
|
| 213 |
+
# for elements where is_too_small is true, random_val will contain +-min_abs with
|
| 214 |
+
# probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
|
| 215 |
+
# for those elements].
|
| 216 |
+
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
|
| 217 |
+
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class RandomGradFunction(torch.autograd.Function):
|
| 221 |
+
"""
|
| 222 |
+
Does nothing in forward pass; in backward pass, gets rid of very small grads using
|
| 223 |
+
randomized approach that preserves expectations (intended to reduce roundoff).
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
@staticmethod
|
| 227 |
+
def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
|
| 228 |
+
ctx.min_abs = min_abs
|
| 229 |
+
return x
|
| 230 |
+
|
| 231 |
+
@staticmethod
|
| 232 |
+
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
|
| 233 |
+
if ans_grad.dtype == torch.float16:
|
| 234 |
+
return (
|
| 235 |
+
random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs),
|
| 236 |
+
None,
|
| 237 |
+
)
|
| 238 |
+
else:
|
| 239 |
+
return ans_grad, None
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class RandomGrad(torch.nn.Module):
|
| 243 |
+
"""
|
| 244 |
+
Gets rid of very small gradients using an expectation-preserving method, intended to increase
|
| 245 |
+
accuracy of training when using amp (automatic mixed precision)
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
def __init__(self, min_abs: float = 5.0e-06):
|
| 249 |
+
super(RandomGrad, self).__init__()
|
| 250 |
+
self.min_abs = min_abs
|
| 251 |
+
|
| 252 |
+
def forward(self, x: Tensor):
|
| 253 |
+
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
|
| 254 |
+
return x
|
| 255 |
+
else:
|
| 256 |
+
return RandomGradFunction.apply(x, self.min_abs)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class SoftmaxFunction(torch.autograd.Function):
|
| 260 |
+
"""
|
| 261 |
+
Tries to handle half-precision derivatives in a randomized way that should
|
| 262 |
+
be more accurate for training than the default behavior.
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
@staticmethod
|
| 266 |
+
def forward(ctx, x: Tensor, dim: int):
|
| 267 |
+
ans = x.softmax(dim=dim)
|
| 268 |
+
# if x dtype is float16, x.softmax() returns a float32 because
|
| 269 |
+
# (presumably) that op does not support float16, and autocast
|
| 270 |
+
# is enabled.
|
| 271 |
+
if torch.is_autocast_enabled():
|
| 272 |
+
ans = ans.to(torch.float16)
|
| 273 |
+
ctx.save_for_backward(ans)
|
| 274 |
+
ctx.x_dtype = x.dtype
|
| 275 |
+
ctx.dim = dim
|
| 276 |
+
return ans
|
| 277 |
+
|
| 278 |
+
@staticmethod
|
| 279 |
+
def backward(ctx, ans_grad: Tensor):
|
| 280 |
+
(ans,) = ctx.saved_tensors
|
| 281 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 282 |
+
ans_grad = ans_grad.to(torch.float32)
|
| 283 |
+
ans = ans.to(torch.float32)
|
| 284 |
+
x_grad = ans_grad * ans
|
| 285 |
+
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
| 286 |
+
return x_grad, None
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def softmax(x: Tensor, dim: int):
|
| 290 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 291 |
+
return x.softmax(dim)
|
| 292 |
+
|
| 293 |
+
return SoftmaxFunction.apply(x, dim)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class MaxEigLimiterFunction(torch.autograd.Function):
|
| 297 |
+
@staticmethod
|
| 298 |
+
def forward(
|
| 299 |
+
ctx,
|
| 300 |
+
x: Tensor,
|
| 301 |
+
coeffs: Tensor,
|
| 302 |
+
direction: Tensor,
|
| 303 |
+
channel_dim: int,
|
| 304 |
+
grad_scale: float,
|
| 305 |
+
) -> Tensor:
|
| 306 |
+
ctx.channel_dim = channel_dim
|
| 307 |
+
ctx.grad_scale = grad_scale
|
| 308 |
+
ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
|
| 309 |
+
return x
|
| 310 |
+
|
| 311 |
+
@staticmethod
|
| 312 |
+
def backward(ctx, x_grad, *args):
|
| 313 |
+
with torch.enable_grad():
|
| 314 |
+
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
| 315 |
+
x_orig.requires_grad = True
|
| 316 |
+
num_channels = x_orig.shape[ctx.channel_dim]
|
| 317 |
+
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
|
| 318 |
+
new_direction.requires_grad = False
|
| 319 |
+
x = x - x.mean(dim=0)
|
| 320 |
+
x_var = (x**2).mean()
|
| 321 |
+
x_residual = x - coeffs * new_direction
|
| 322 |
+
x_residual_var = (x_residual**2).mean()
|
| 323 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
| 324 |
+
# by the top eigen-direction. This is to be minimized.
|
| 325 |
+
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
| 326 |
+
variance_proportion.backward()
|
| 327 |
+
x_orig_grad = x_orig.grad
|
| 328 |
+
x_extra_grad = (
|
| 329 |
+
x_orig.grad
|
| 330 |
+
* ctx.grad_scale
|
| 331 |
+
* x_grad.norm()
|
| 332 |
+
/ (x_orig_grad.norm() + 1.0e-20)
|
| 333 |
+
)
|
| 334 |
+
return x_grad + x_extra_grad.detach(), None, None, None, None
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class BasicNorm(torch.nn.Module):
|
| 338 |
+
"""
|
| 339 |
+
This is intended to be a simpler, and hopefully cheaper, replacement for
|
| 340 |
+
LayerNorm. The observation this is based on, is that Transformer-type
|
| 341 |
+
networks, especially with pre-norm, sometimes seem to set one of the
|
| 342 |
+
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
| 343 |
+
the LayerNorm because the output magnitude is then not strongly dependent
|
| 344 |
+
on the other (useful) features. Presumably the weight and bias of the
|
| 345 |
+
LayerNorm are required to allow it to do this.
|
| 346 |
+
|
| 347 |
+
So the idea is to introduce this large constant value as an explicit
|
| 348 |
+
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
| 349 |
+
doesn't have to do this trick. We make the "eps" learnable.
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
num_channels: the number of channels, e.g. 512.
|
| 353 |
+
channel_dim: the axis/dimension corresponding to the channel,
|
| 354 |
+
interprted as an offset from the input's ndim if negative.
|
| 355 |
+
shis is NOT the num_channels; it should typically be one of
|
| 356 |
+
{-2, -1, 0, 1, 2, 3}.
|
| 357 |
+
eps: the initial "epsilon" that we add as ballast in:
|
| 358 |
+
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
| 359 |
+
Note: our epsilon is actually large, but we keep the name
|
| 360 |
+
to indicate the connection with conventional LayerNorm.
|
| 361 |
+
learn_eps: if true, we learn epsilon; if false, we keep it
|
| 362 |
+
at the initial value.
|
| 363 |
+
eps_min: float
|
| 364 |
+
eps_max: float
|
| 365 |
+
"""
|
| 366 |
+
|
| 367 |
+
def __init__(
|
| 368 |
+
self,
|
| 369 |
+
num_channels: int,
|
| 370 |
+
channel_dim: int = -1, # CAUTION: see documentation.
|
| 371 |
+
eps: float = 0.25,
|
| 372 |
+
learn_eps: bool = True,
|
| 373 |
+
eps_min: float = -3.0,
|
| 374 |
+
eps_max: float = 3.0,
|
| 375 |
+
) -> None:
|
| 376 |
+
super(BasicNorm, self).__init__()
|
| 377 |
+
self.num_channels = num_channels
|
| 378 |
+
self.channel_dim = channel_dim
|
| 379 |
+
if learn_eps:
|
| 380 |
+
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
| 381 |
+
else:
|
| 382 |
+
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
| 383 |
+
self.eps_min = eps_min
|
| 384 |
+
self.eps_max = eps_max
|
| 385 |
+
|
| 386 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 387 |
+
assert x.shape[self.channel_dim] == self.num_channels
|
| 388 |
+
eps = self.eps
|
| 389 |
+
if self.training and random.random() < 0.25:
|
| 390 |
+
# with probability 0.25, in training mode, clamp eps between the min
|
| 391 |
+
# and max; this will encourage it to learn parameters within the
|
| 392 |
+
# allowed range by making parameters that are outside the allowed
|
| 393 |
+
# range noisy.
|
| 394 |
+
|
| 395 |
+
# gradients to allow the parameter to get back into the allowed
|
| 396 |
+
# region if it happens to exit it.
|
| 397 |
+
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
| 398 |
+
scales = (
|
| 399 |
+
torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp()
|
| 400 |
+
) ** -0.5
|
| 401 |
+
return x * scales
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
|
| 405 |
+
"""
|
| 406 |
+
Behaves like a constructor of a modified version of nn.Linear
|
| 407 |
+
that gives an easy way to set the default initial parameter scale.
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
| 411 |
+
e.g. in_features, out_features, bias=False.
|
| 412 |
+
|
| 413 |
+
initial_scale: you can override this if you want to increase
|
| 414 |
+
or decrease the initial magnitude of the module's output
|
| 415 |
+
(affects the initialization of weight_scale and bias_scale).
|
| 416 |
+
Another option, if you want to do something like this, is
|
| 417 |
+
to re-initialize the parameters.
|
| 418 |
+
"""
|
| 419 |
+
ans = nn.Linear(*args, **kwargs)
|
| 420 |
+
with torch.no_grad():
|
| 421 |
+
ans.weight[:] *= initial_scale
|
| 422 |
+
if ans.bias is not None:
|
| 423 |
+
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
| 424 |
+
return ans
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def ScaledConv1d(
|
| 428 |
+
*args,
|
| 429 |
+
initial_scale: float = 1.0,
|
| 430 |
+
kernel_size: int = 3,
|
| 431 |
+
padding: str = "same",
|
| 432 |
+
**kwargs,
|
| 433 |
+
) -> nn.Conv1d:
|
| 434 |
+
"""
|
| 435 |
+
Behaves like a constructor of a modified version of nn.Conv1d
|
| 436 |
+
that gives an easy way to set the default initial parameter scale.
|
| 437 |
+
|
| 438 |
+
Args:
|
| 439 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
| 440 |
+
e.g. in_features, out_features, bias=False.
|
| 441 |
+
|
| 442 |
+
initial_scale: you can override this if you want to increase
|
| 443 |
+
or decrease the initial magnitude of the module's output
|
| 444 |
+
(affects the initialization of weight_scale and bias_scale).
|
| 445 |
+
Another option, if you want to do something like this, is
|
| 446 |
+
to re-initialize the parameters.
|
| 447 |
+
"""
|
| 448 |
+
ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
|
| 449 |
+
with torch.no_grad():
|
| 450 |
+
ans.weight[:] *= initial_scale
|
| 451 |
+
if ans.bias is not None:
|
| 452 |
+
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
| 453 |
+
return ans
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def TransposeScaledConv1d(
|
| 457 |
+
*args,
|
| 458 |
+
initial_scale: float = 1.0,
|
| 459 |
+
kernel_size: int = 3,
|
| 460 |
+
padding: str = "same",
|
| 461 |
+
**kwargs,
|
| 462 |
+
) -> nn.Sequential:
|
| 463 |
+
"""
|
| 464 |
+
Transpose -> ScaledConv1d
|
| 465 |
+
"""
|
| 466 |
+
return nn.Sequential(
|
| 467 |
+
Transpose(),
|
| 468 |
+
ScaledConv1d(
|
| 469 |
+
*args,
|
| 470 |
+
initial_scale=initial_scale,
|
| 471 |
+
kernel_size=kernel_size,
|
| 472 |
+
padding=padding,
|
| 473 |
+
**kwargs,
|
| 474 |
+
),
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def ScaledConv1dTranspose(
|
| 479 |
+
*args,
|
| 480 |
+
initial_scale: float = 1.0,
|
| 481 |
+
kernel_size: int = 3,
|
| 482 |
+
padding: str = "same",
|
| 483 |
+
**kwargs,
|
| 484 |
+
) -> nn.Sequential:
|
| 485 |
+
"""
|
| 486 |
+
Transpose -> ScaledConv1d
|
| 487 |
+
"""
|
| 488 |
+
return nn.Sequential(
|
| 489 |
+
ScaledConv1d(
|
| 490 |
+
*args,
|
| 491 |
+
initial_scale=initial_scale,
|
| 492 |
+
kernel_size=kernel_size,
|
| 493 |
+
padding=padding,
|
| 494 |
+
**kwargs,
|
| 495 |
+
),
|
| 496 |
+
Transpose(),
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def TransposeConv1d(
|
| 501 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
| 502 |
+
) -> nn.Sequential:
|
| 503 |
+
"""
|
| 504 |
+
Transpose -> Conv1d
|
| 505 |
+
"""
|
| 506 |
+
return nn.Sequential(
|
| 507 |
+
Transpose(),
|
| 508 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def Conv1dTranspose(
|
| 513 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
| 514 |
+
) -> nn.Sequential:
|
| 515 |
+
"""
|
| 516 |
+
ScaledConv1d -> Transpose
|
| 517 |
+
"""
|
| 518 |
+
return nn.Sequential(
|
| 519 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
| 520 |
+
Transpose(),
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
class SRLinear(nn.Linear):
|
| 525 |
+
"""https://arxiv.org/abs/2303.06296
|
| 526 |
+
Stabilizing Transformer Training by Preventing Attention Entropy Collapse
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
def __init__(self, in_features, out_features, bias=True, **kwargs):
|
| 530 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
| 531 |
+
self.register_buffer(
|
| 532 |
+
"u", nn.functional.normalize(torch.randn(in_features), dim=0)
|
| 533 |
+
)
|
| 534 |
+
with torch.no_grad():
|
| 535 |
+
sigma = self.get_sigma()
|
| 536 |
+
self.register_buffer("spectral_norm", sigma)
|
| 537 |
+
self.sigma = nn.Parameter(torch.ones(1))
|
| 538 |
+
|
| 539 |
+
def get_sigma(self):
|
| 540 |
+
with torch.no_grad():
|
| 541 |
+
u = self.u
|
| 542 |
+
v = self.weight.mv(u)
|
| 543 |
+
v = nn.functional.normalize(v, dim=0)
|
| 544 |
+
u = self.weight.T.mv(v)
|
| 545 |
+
u = nn.functional.normalize(u, dim=0)
|
| 546 |
+
self.u.data.copy_(u)
|
| 547 |
+
return torch.einsum("c,cd,d->", v, self.weight, u)
|
| 548 |
+
|
| 549 |
+
def get_weight(self):
|
| 550 |
+
sigma = self.get_sigma()
|
| 551 |
+
if self.training:
|
| 552 |
+
self.spectral_norm.data.copy_(sigma)
|
| 553 |
+
weight = (self.sigma / sigma) * self.weight
|
| 554 |
+
return weight
|
| 555 |
+
|
| 556 |
+
def forward(self, x):
|
| 557 |
+
return nn.functional.linear(x, self.get_weight(), self.bias)
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
class SRConv1d(SRLinear):
|
| 561 |
+
def __init__(
|
| 562 |
+
self,
|
| 563 |
+
in_features,
|
| 564 |
+
out_features,
|
| 565 |
+
kernel_size,
|
| 566 |
+
stride: int = 1,
|
| 567 |
+
padding: str = "same",
|
| 568 |
+
bias: bool = True,
|
| 569 |
+
**kwargs,
|
| 570 |
+
):
|
| 571 |
+
in_features = in_features * kernel_size
|
| 572 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
| 573 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 574 |
+
self.kernel_size = kernel_size
|
| 575 |
+
self.stride = stride
|
| 576 |
+
self.padding = padding
|
| 577 |
+
|
| 578 |
+
def forward(self, x):
|
| 579 |
+
in_features = self.in_features // self.kernel_size
|
| 580 |
+
weight = self.get_weight().view(
|
| 581 |
+
self.out_features, in_features, self.kernel_size
|
| 582 |
+
)
|
| 583 |
+
return nn.functional.conv1d(
|
| 584 |
+
x, weight, bias=self.bias, stride=self.stride, padding=self.padding
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def TransposeSRConv1d(
|
| 589 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
| 590 |
+
) -> nn.Sequential:
|
| 591 |
+
"""
|
| 592 |
+
Transpose -> SRConv1d
|
| 593 |
+
"""
|
| 594 |
+
return nn.Sequential(
|
| 595 |
+
Transpose(),
|
| 596 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
def SRConv1dTranspose(
|
| 601 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
| 602 |
+
) -> nn.Sequential:
|
| 603 |
+
"""
|
| 604 |
+
SRConv1d -> Transpose
|
| 605 |
+
"""
|
| 606 |
+
return nn.Sequential(
|
| 607 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
| 608 |
+
Transpose(),
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class ActivationBalancer(torch.nn.Module):
|
| 613 |
+
"""
|
| 614 |
+
Modifies the backpropped derivatives of a function to try to encourage, for
|
| 615 |
+
each channel, that it is positive at least a proportion `threshold` of the
|
| 616 |
+
time. It does this by multiplying negative derivative values by up to
|
| 617 |
+
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
| 618 |
+
interpolated from 1 at the threshold to those extremal values when none
|
| 619 |
+
of the inputs are positive.
|
| 620 |
+
|
| 621 |
+
Args:
|
| 622 |
+
num_channels: the number of channels
|
| 623 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
| 624 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
| 625 |
+
min_positive: the minimum, per channel, of the proportion of the time
|
| 626 |
+
that (x > 0), below which we start to modify the derivatives.
|
| 627 |
+
max_positive: the maximum, per channel, of the proportion of the time
|
| 628 |
+
that (x > 0), above which we start to modify the derivatives.
|
| 629 |
+
max_factor: the maximum factor by which we modify the derivatives for
|
| 630 |
+
either the sign constraint or the magnitude constraint;
|
| 631 |
+
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
| 632 |
+
values in the range [0.98..1.02].
|
| 633 |
+
sign_gain_factor: determines the 'gain' with which we increase the
|
| 634 |
+
change in gradient once the constraints on min_positive and max_positive
|
| 635 |
+
are violated.
|
| 636 |
+
scale_gain_factor: determines the 'gain' with which we increase the
|
| 637 |
+
change in gradient once the constraints on min_abs and max_abs
|
| 638 |
+
are violated.
|
| 639 |
+
min_abs: the minimum average-absolute-value difference from the mean
|
| 640 |
+
value per channel, which we allow, before we start to modify
|
| 641 |
+
the derivatives to prevent this.
|
| 642 |
+
max_abs: the maximum average-absolute-value difference from the mean
|
| 643 |
+
value per channel, which we allow, before we start to modify
|
| 644 |
+
the derivatives to prevent this.
|
| 645 |
+
min_prob: determines the minimum probability with which we modify the
|
| 646 |
+
gradients for the {min,max}_positive and {min,max}_abs constraints,
|
| 647 |
+
on each forward(). This is done randomly to prevent all layers
|
| 648 |
+
from doing it at the same time. Early in training we may use
|
| 649 |
+
higher probabilities than this; it will decay to this value.
|
| 650 |
+
"""
|
| 651 |
+
|
| 652 |
+
def __init__(
|
| 653 |
+
self,
|
| 654 |
+
num_channels: int,
|
| 655 |
+
channel_dim: int,
|
| 656 |
+
min_positive: float = 0.05,
|
| 657 |
+
max_positive: float = 0.95,
|
| 658 |
+
max_factor: float = 0.04,
|
| 659 |
+
sign_gain_factor: float = 0.01,
|
| 660 |
+
scale_gain_factor: float = 0.02,
|
| 661 |
+
min_abs: float = 0.2,
|
| 662 |
+
max_abs: float = 100.0,
|
| 663 |
+
min_prob: float = 0.1,
|
| 664 |
+
):
|
| 665 |
+
super(ActivationBalancer, self).__init__()
|
| 666 |
+
self.num_channels = num_channels
|
| 667 |
+
self.channel_dim = channel_dim
|
| 668 |
+
self.min_positive = min_positive
|
| 669 |
+
self.max_positive = max_positive
|
| 670 |
+
self.max_factor = max_factor
|
| 671 |
+
self.min_abs = min_abs
|
| 672 |
+
self.max_abs = max_abs
|
| 673 |
+
self.min_prob = min_prob
|
| 674 |
+
self.sign_gain_factor = sign_gain_factor
|
| 675 |
+
self.scale_gain_factor = scale_gain_factor
|
| 676 |
+
|
| 677 |
+
# count measures how many times the forward() function has been called.
|
| 678 |
+
# We occasionally sync this to a tensor called `count`, that exists to
|
| 679 |
+
# make sure it is synced to disk when we load and save the model.
|
| 680 |
+
self.cpu_count = 0
|
| 681 |
+
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
|
| 682 |
+
|
| 683 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 684 |
+
if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
|
| 685 |
+
return _no_op(x)
|
| 686 |
+
|
| 687 |
+
count = self.cpu_count
|
| 688 |
+
self.cpu_count += 1
|
| 689 |
+
|
| 690 |
+
if random.random() < 0.01:
|
| 691 |
+
# Occasionally sync self.cpu_count with self.count.
|
| 692 |
+
# count affects the decay of 'prob'. don't do this on every iter,
|
| 693 |
+
# because syncing with the GPU is slow.
|
| 694 |
+
self.cpu_count = max(self.cpu_count, self.count.item())
|
| 695 |
+
self.count.fill_(self.cpu_count)
|
| 696 |
+
|
| 697 |
+
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
| 698 |
+
# a floor at min_prob (==0.1, by default)
|
| 699 |
+
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
|
| 700 |
+
|
| 701 |
+
if random.random() < prob:
|
| 702 |
+
sign_gain_factor = 0.5
|
| 703 |
+
if self.min_positive != 0.0 or self.max_positive != 1.0:
|
| 704 |
+
sign_factor = _compute_sign_factor(
|
| 705 |
+
x,
|
| 706 |
+
self.channel_dim,
|
| 707 |
+
self.min_positive,
|
| 708 |
+
self.max_positive,
|
| 709 |
+
gain_factor=self.sign_gain_factor / prob,
|
| 710 |
+
max_factor=self.max_factor,
|
| 711 |
+
)
|
| 712 |
+
else:
|
| 713 |
+
sign_factor = None
|
| 714 |
+
|
| 715 |
+
scale_factor = _compute_scale_factor(
|
| 716 |
+
x.detach(),
|
| 717 |
+
self.channel_dim,
|
| 718 |
+
min_abs=self.min_abs,
|
| 719 |
+
max_abs=self.max_abs,
|
| 720 |
+
gain_factor=self.scale_gain_factor / prob,
|
| 721 |
+
max_factor=self.max_factor,
|
| 722 |
+
)
|
| 723 |
+
return ActivationBalancerFunction.apply(
|
| 724 |
+
x,
|
| 725 |
+
scale_factor,
|
| 726 |
+
sign_factor,
|
| 727 |
+
self.channel_dim,
|
| 728 |
+
)
|
| 729 |
+
else:
|
| 730 |
+
return _no_op(x)
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
| 734 |
+
"""
|
| 735 |
+
Returns x unmodified, but in backprop will put a penalty for the excess of
|
| 736 |
+
the absolute values of elements of x over the limit "limit". E.g. if
|
| 737 |
+
limit == 10.0, then if x has any values over 10 it will get a penalty.
|
| 738 |
+
|
| 739 |
+
Caution: the value of this penalty will be affected by grad scaling used
|
| 740 |
+
in automatic mixed precision training. For this reasons we use this,
|
| 741 |
+
it shouldn't really matter, or may even be helpful; we just use this
|
| 742 |
+
to disallow really implausible values of scores to be given to softmax.
|
| 743 |
+
"""
|
| 744 |
+
x_sign = x.sign()
|
| 745 |
+
over_limit = (x.abs() - limit) > 0
|
| 746 |
+
# The following is a memory efficient way to penalize the absolute values of
|
| 747 |
+
# x that's over the limit. (The memory efficiency comes when you think
|
| 748 |
+
# about which items torch needs to cache for the autograd, and which ones it
|
| 749 |
+
# can throw away). The numerical value of aux_loss as computed here will
|
| 750 |
+
# actually be larger than it should be, by limit * over_limit.sum(), but it
|
| 751 |
+
# has the same derivative as the real aux_loss which is penalty * (x.abs() -
|
| 752 |
+
# limit).relu().
|
| 753 |
+
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
|
| 754 |
+
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
| 755 |
+
# sum() due to how with_loss() works.
|
| 756 |
+
x = with_loss(x, aux_loss)
|
| 757 |
+
# you must use x for something, or this will be ineffective.
|
| 758 |
+
return x
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
| 762 |
+
if x.ndim == 2:
|
| 763 |
+
return x.diag()
|
| 764 |
+
else:
|
| 765 |
+
(batch, dim, dim) = x.shape
|
| 766 |
+
x = x.reshape(batch, dim * dim)
|
| 767 |
+
x = x[:, :: dim + 1]
|
| 768 |
+
assert x.shape == (batch, dim)
|
| 769 |
+
return x
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
def _whitening_metric(x: Tensor, num_groups: int):
|
| 773 |
+
"""
|
| 774 |
+
Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
|
| 775 |
+
of the centered feature covariance are the same within each group's covariance matrix
|
| 776 |
+
and also between groups.
|
| 777 |
+
Args:
|
| 778 |
+
x: a Tensor of shape (*, num_channels)
|
| 779 |
+
num_groups: the number of groups of channels, a number >=1 that divides num_channels
|
| 780 |
+
Returns:
|
| 781 |
+
Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
|
| 782 |
+
greater than 1.0 otherwise.
|
| 783 |
+
"""
|
| 784 |
+
assert x.dtype != torch.float16
|
| 785 |
+
x = x.reshape(-1, x.shape[-1])
|
| 786 |
+
(num_frames, num_channels) = x.shape
|
| 787 |
+
assert num_channels % num_groups == 0
|
| 788 |
+
channels_per_group = num_channels // num_groups
|
| 789 |
+
x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
|
| 790 |
+
# x now has shape (num_groups, num_frames, channels_per_group)
|
| 791 |
+
# subtract the mean so we use the centered, not uncentered, covariance.
|
| 792 |
+
# My experience has been that when we "mess with the gradients" like this,
|
| 793 |
+
# it's better not do anything that tries to move the mean around, because
|
| 794 |
+
# that can easily cause instability.
|
| 795 |
+
x = x - x.mean(dim=1, keepdim=True)
|
| 796 |
+
# x_covar: (num_groups, channels_per_group, channels_per_group)
|
| 797 |
+
x_covar = torch.matmul(x.transpose(1, 2), x)
|
| 798 |
+
x_covar_mean_diag = _diag(x_covar).mean()
|
| 799 |
+
# the following expression is what we'd get if we took the matrix product
|
| 800 |
+
# of each covariance and measured the mean of its trace, i.e.
|
| 801 |
+
# the same as _diag(torch.matmul(x_covar, x_covar)).mean().
|
| 802 |
+
x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
|
| 803 |
+
# this metric will be >= 1.0; the larger it is, the less 'white' the data was.
|
| 804 |
+
metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
|
| 805 |
+
return metric
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
class WhiteningPenaltyFunction(torch.autograd.Function):
|
| 809 |
+
@staticmethod
|
| 810 |
+
def forward(
|
| 811 |
+
ctx,
|
| 812 |
+
x: Tensor,
|
| 813 |
+
num_groups: int,
|
| 814 |
+
whitening_limit: float,
|
| 815 |
+
grad_scale: float,
|
| 816 |
+
) -> Tensor:
|
| 817 |
+
ctx.save_for_backward(x)
|
| 818 |
+
ctx.num_groups = num_groups
|
| 819 |
+
ctx.whitening_limit = whitening_limit
|
| 820 |
+
ctx.grad_scale = grad_scale
|
| 821 |
+
return x
|
| 822 |
+
|
| 823 |
+
@staticmethod
|
| 824 |
+
def backward(ctx, x_grad: Tensor):
|
| 825 |
+
(x_orig,) = ctx.saved_tensors
|
| 826 |
+
with torch.enable_grad():
|
| 827 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 828 |
+
x_detached = x_orig.to(torch.float32).detach()
|
| 829 |
+
x_detached.requires_grad = True
|
| 830 |
+
|
| 831 |
+
metric = _whitening_metric(x_detached, ctx.num_groups)
|
| 832 |
+
|
| 833 |
+
if random.random() < 0.005 or __name__ == "__main__":
|
| 834 |
+
logging.info(
|
| 835 |
+
f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
|
| 836 |
+
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
(metric - ctx.whitening_limit).relu().backward()
|
| 840 |
+
penalty_grad = x_detached.grad
|
| 841 |
+
scale = ctx.grad_scale * (
|
| 842 |
+
x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20)
|
| 843 |
+
)
|
| 844 |
+
penalty_grad = penalty_grad * scale
|
| 845 |
+
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
class Whiten(nn.Module):
|
| 849 |
+
def __init__(
|
| 850 |
+
self,
|
| 851 |
+
num_groups: int,
|
| 852 |
+
whitening_limit: float,
|
| 853 |
+
prob: Union[float, Tuple[float, float]],
|
| 854 |
+
grad_scale: float,
|
| 855 |
+
):
|
| 856 |
+
"""
|
| 857 |
+
Args:
|
| 858 |
+
num_groups: the number of groups to divide the channel dim into before
|
| 859 |
+
whitening. We will attempt to make the feature covariance
|
| 860 |
+
within each group, after mean subtraction, as "white" as possible,
|
| 861 |
+
while having the same trace across all groups.
|
| 862 |
+
whitening_limit: a value greater than 1.0, that dictates how much
|
| 863 |
+
freedom we have to violate the constraints. 1.0 would mean perfectly
|
| 864 |
+
white, with exactly the same trace across groups; larger values
|
| 865 |
+
give more freedom. E.g. 2.0.
|
| 866 |
+
prob: the probability with which we apply the gradient modification
|
| 867 |
+
(also affects the grad scale). May be supplied as a float,
|
| 868 |
+
or as a pair (min_prob, max_prob)
|
| 869 |
+
|
| 870 |
+
grad_scale: determines the scale on the gradient term from this object,
|
| 871 |
+
relative to the rest of the gradient on the attention weights.
|
| 872 |
+
E.g. 0.02 (you may want to use smaller values than this if prob is large)
|
| 873 |
+
"""
|
| 874 |
+
super(Whiten, self).__init__()
|
| 875 |
+
assert num_groups >= 1
|
| 876 |
+
assert whitening_limit >= 1
|
| 877 |
+
assert grad_scale >= 0
|
| 878 |
+
self.num_groups = num_groups
|
| 879 |
+
self.whitening_limit = whitening_limit
|
| 880 |
+
if isinstance(prob, float):
|
| 881 |
+
assert 0 < prob <= 1
|
| 882 |
+
self.prob = prob
|
| 883 |
+
else:
|
| 884 |
+
(self.min_prob, self.max_prob) = prob
|
| 885 |
+
assert 0 < self.min_prob < self.max_prob <= 1
|
| 886 |
+
self.prob = self.max_prob
|
| 887 |
+
|
| 888 |
+
self.grad_scale = grad_scale
|
| 889 |
+
|
| 890 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 891 |
+
"""
|
| 892 |
+
In the forward pass, this function just returns the input unmodified.
|
| 893 |
+
In the backward pass, it will modify the gradients to ensure that the
|
| 894 |
+
distribution in each group has close to (lambda times I) as the covariance
|
| 895 |
+
after mean subtraction, with the same lambda across groups.
|
| 896 |
+
For whitening_limit > 1, there will be more freedom to violate this
|
| 897 |
+
constraint.
|
| 898 |
+
|
| 899 |
+
Args:
|
| 900 |
+
x: the input of shape (*, num_channels)
|
| 901 |
+
|
| 902 |
+
Returns:
|
| 903 |
+
x, unmodified. You should make sure
|
| 904 |
+
you use the returned value, or the graph will be freed
|
| 905 |
+
and nothing will happen in backprop.
|
| 906 |
+
"""
|
| 907 |
+
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
|
| 908 |
+
return _no_op(x)
|
| 909 |
+
else:
|
| 910 |
+
if hasattr(self, "min_prob") and random.random() < 0.25:
|
| 911 |
+
# occasionally switch between min_prob and max_prob, based on whether
|
| 912 |
+
# we are above or below the threshold.
|
| 913 |
+
if (
|
| 914 |
+
_whitening_metric(x.to(torch.float32), self.num_groups)
|
| 915 |
+
> self.whitening_limit
|
| 916 |
+
):
|
| 917 |
+
# there would be a change to the grad.
|
| 918 |
+
self.prob = self.max_prob
|
| 919 |
+
else:
|
| 920 |
+
self.prob = self.min_prob
|
| 921 |
+
|
| 922 |
+
return WhiteningPenaltyFunction.apply(
|
| 923 |
+
x, self.num_groups, self.whitening_limit, self.grad_scale
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
class WithLoss(torch.autograd.Function):
|
| 928 |
+
@staticmethod
|
| 929 |
+
def forward(ctx, x: Tensor, y: Tensor):
|
| 930 |
+
ctx.y_shape = y.shape
|
| 931 |
+
return x
|
| 932 |
+
|
| 933 |
+
@staticmethod
|
| 934 |
+
def backward(ctx, ans_grad: Tensor):
|
| 935 |
+
return ans_grad, torch.ones(
|
| 936 |
+
ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
|
| 937 |
+
)
|
| 938 |
+
|
| 939 |
+
|
| 940 |
+
def with_loss(x, y):
|
| 941 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 942 |
+
return x
|
| 943 |
+
# returns x but adds y.sum() to the loss function.
|
| 944 |
+
return WithLoss.apply(x, y)
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
def _no_op(x: Tensor) -> Tensor:
|
| 948 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 949 |
+
return x
|
| 950 |
+
else:
|
| 951 |
+
# a no-op function that will have a node in the autograd graph,
|
| 952 |
+
# to avoid certain bugs relating to backward hooks
|
| 953 |
+
return x.chunk(1, dim=-1)[0]
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
class Identity(torch.nn.Module):
|
| 957 |
+
def __init__(self):
|
| 958 |
+
super(Identity, self).__init__()
|
| 959 |
+
|
| 960 |
+
def forward(self, x):
|
| 961 |
+
return _no_op(x)
|
| 962 |
+
|
| 963 |
+
|
| 964 |
+
class MaxEig(torch.nn.Module):
|
| 965 |
+
"""
|
| 966 |
+
Modifies the backpropped derivatives of a function to try to discourage
|
| 967 |
+
that any given direction in activation space accounts for more than
|
| 968 |
+
a specified proportion of the covariance (e.g. 0.2).
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
Args:
|
| 972 |
+
num_channels: the number of channels
|
| 973 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
| 974 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
| 975 |
+
max_var_per_eig: the maximum proportion of the variance of the
|
| 976 |
+
features/channels, after mean subtraction, that can come from
|
| 977 |
+
any given eigenvalue.
|
| 978 |
+
min_prob: the minimum probability with which we apply this during any invocation
|
| 979 |
+
of forward(), assuming last time we applied the constraint it was
|
| 980 |
+
not active; supplied for speed.
|
| 981 |
+
scale: determines the scale with which we modify the gradients, relative
|
| 982 |
+
to the existing / unmodified gradients
|
| 983 |
+
"""
|
| 984 |
+
|
| 985 |
+
def __init__(
|
| 986 |
+
self,
|
| 987 |
+
num_channels: int,
|
| 988 |
+
channel_dim: int,
|
| 989 |
+
max_var_per_eig: float = 0.2,
|
| 990 |
+
min_prob: float = 0.01,
|
| 991 |
+
scale: float = 0.01,
|
| 992 |
+
):
|
| 993 |
+
super(MaxEig, self).__init__()
|
| 994 |
+
self.num_channels = num_channels
|
| 995 |
+
self.channel_dim = channel_dim
|
| 996 |
+
self.scale = scale
|
| 997 |
+
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
|
| 998 |
+
self.max_var_per_eig = max_var_per_eig
|
| 999 |
+
|
| 1000 |
+
# we figure out the dominant direction using the power method: starting with
|
| 1001 |
+
# a random vector, keep multiplying by the covariance and renormalizing.
|
| 1002 |
+
with torch.no_grad():
|
| 1003 |
+
# arbitrary.. would use randn() but want to leave the rest of the model's
|
| 1004 |
+
# random parameters unchanged for comparison
|
| 1005 |
+
direction = torch.arange(num_channels).to(torch.float)
|
| 1006 |
+
direction = direction / direction.norm()
|
| 1007 |
+
self.register_buffer("max_eig_direction", direction)
|
| 1008 |
+
|
| 1009 |
+
self.min_prob = min_prob
|
| 1010 |
+
# cur_prob is the current probability we'll use to apply the ActivationBalancer.
|
| 1011 |
+
# We'll regress this towards prob, each tiem we try to apply it and it is not
|
| 1012 |
+
# active.
|
| 1013 |
+
self.cur_prob = 1.0
|
| 1014 |
+
|
| 1015 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1016 |
+
if (
|
| 1017 |
+
torch.jit.is_scripting()
|
| 1018 |
+
or self.max_var_per_eig <= 0
|
| 1019 |
+
or random.random() > self.cur_prob
|
| 1020 |
+
or torch.jit.is_tracing()
|
| 1021 |
+
):
|
| 1022 |
+
return _no_op(x)
|
| 1023 |
+
|
| 1024 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 1025 |
+
eps = 1.0e-20
|
| 1026 |
+
orig_x = x
|
| 1027 |
+
x = x.to(torch.float32)
|
| 1028 |
+
with torch.no_grad():
|
| 1029 |
+
x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels)
|
| 1030 |
+
x = x - x.mean(dim=0)
|
| 1031 |
+
new_direction, coeffs = self._find_direction_coeffs(
|
| 1032 |
+
x, self.max_eig_direction
|
| 1033 |
+
)
|
| 1034 |
+
x_var = (x**2).mean()
|
| 1035 |
+
x_residual = x - coeffs * new_direction
|
| 1036 |
+
x_residual_var = (x_residual**2).mean()
|
| 1037 |
+
|
| 1038 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
| 1039 |
+
# by the top eigen-direction.
|
| 1040 |
+
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
| 1041 |
+
|
| 1042 |
+
# ensure new direction is nonzero even if x == 0, by including `direction`.
|
| 1043 |
+
self._set_direction(0.1 * self.max_eig_direction + new_direction)
|
| 1044 |
+
|
| 1045 |
+
if random.random() < 0.01 or __name__ == "__main__":
|
| 1046 |
+
logging.info(
|
| 1047 |
+
f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
|
| 1048 |
+
)
|
| 1049 |
+
|
| 1050 |
+
if variance_proportion >= self.max_var_per_eig:
|
| 1051 |
+
# The constraint is active. Note, we should quite rarely
|
| 1052 |
+
# reach here, only near the beginning of training if we are
|
| 1053 |
+
# starting to diverge, should this constraint be active.
|
| 1054 |
+
cur_prob = self.cur_prob
|
| 1055 |
+
self.cur_prob = 1.0 # next time, do the update with probability 1.0.
|
| 1056 |
+
return MaxEigLimiterFunction.apply(
|
| 1057 |
+
orig_x, coeffs, new_direction, self.channel_dim, self.scale
|
| 1058 |
+
)
|
| 1059 |
+
else:
|
| 1060 |
+
# let self.cur_prob exponentially approach self.min_prob, as
|
| 1061 |
+
# long as the constraint is inactive.
|
| 1062 |
+
self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
|
| 1063 |
+
return orig_x
|
| 1064 |
+
|
| 1065 |
+
def _set_direction(self, direction: Tensor):
|
| 1066 |
+
"""
|
| 1067 |
+
Sets self.max_eig_direction to a normalized version of `direction`
|
| 1068 |
+
"""
|
| 1069 |
+
direction = direction.detach()
|
| 1070 |
+
direction = direction / direction.norm()
|
| 1071 |
+
direction_sum = direction.sum().item()
|
| 1072 |
+
if direction_sum - direction_sum == 0: # no inf/nan
|
| 1073 |
+
self.max_eig_direction[:] = direction
|
| 1074 |
+
else:
|
| 1075 |
+
logging.info(
|
| 1076 |
+
f"Warning: sum of direction in MaxEig is {direction_sum}, "
|
| 1077 |
+
"num_channels={self.num_channels}, channel_dim={self.channel_dim}"
|
| 1078 |
+
)
|
| 1079 |
+
|
| 1080 |
+
def _find_direction_coeffs(
|
| 1081 |
+
self, x: Tensor, prev_direction: Tensor
|
| 1082 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
| 1083 |
+
"""
|
| 1084 |
+
Figure out (an approximation to) the proportion of the variance of a set of
|
| 1085 |
+
feature vectors that can be attributed to the top eigen-direction.
|
| 1086 |
+
Args:
|
| 1087 |
+
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
|
| 1088 |
+
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
|
| 1089 |
+
of the top eigen-direction, or a random direction if this is the first
|
| 1090 |
+
iteration. Does not have to be normalized, but should be nonzero.
|
| 1091 |
+
|
| 1092 |
+
Returns: (cur_direction, coeffs), where:
|
| 1093 |
+
cur_direction: a Tensor of shape (num_channels,) that is the current
|
| 1094 |
+
estimate of the top eigen-direction.
|
| 1095 |
+
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
|
| 1096 |
+
approximately minimizes, (x - coeffs * cur_direction).norm()
|
| 1097 |
+
"""
|
| 1098 |
+
(num_frames, num_channels) = x.shape
|
| 1099 |
+
assert num_channels > 1 and num_frames > 1
|
| 1100 |
+
assert prev_direction.shape == (num_channels,)
|
| 1101 |
+
# `coeffs` are the coefficients of `prev_direction` in x.
|
| 1102 |
+
# actually represent the coeffs up to a constant positive factor.
|
| 1103 |
+
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
|
| 1104 |
+
cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20)
|
| 1105 |
+
return cur_direction, coeffs
|
| 1106 |
+
|
| 1107 |
+
|
| 1108 |
+
class DoubleSwishFunction(torch.autograd.Function):
|
| 1109 |
+
"""
|
| 1110 |
+
double_swish(x) = x * torch.sigmoid(x-1)
|
| 1111 |
+
This is a definition, originally motivated by its close numerical
|
| 1112 |
+
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
| 1113 |
+
|
| 1114 |
+
Memory-efficient derivative computation:
|
| 1115 |
+
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
| 1116 |
+
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
| 1117 |
+
Now, s'(x) = s(x) * (1-s(x)).
|
| 1118 |
+
double_swish'(x) = x * s'(x) + s(x).
|
| 1119 |
+
= x * s(x) * (1-s(x)) + s(x).
|
| 1120 |
+
= double_swish(x) * (1-s(x)) + s(x)
|
| 1121 |
+
... so we just need to remember s(x) but not x itself.
|
| 1122 |
+
"""
|
| 1123 |
+
|
| 1124 |
+
@staticmethod
|
| 1125 |
+
def forward(ctx, x: Tensor) -> Tensor:
|
| 1126 |
+
requires_grad = x.requires_grad
|
| 1127 |
+
x_dtype = x.dtype
|
| 1128 |
+
if x.dtype == torch.float16:
|
| 1129 |
+
x = x.to(torch.float32)
|
| 1130 |
+
|
| 1131 |
+
s = torch.sigmoid(x - 1.0)
|
| 1132 |
+
y = x * s
|
| 1133 |
+
|
| 1134 |
+
if requires_grad:
|
| 1135 |
+
deriv = y * (1 - s) + s
|
| 1136 |
+
# notes on derivative of x * sigmoid(x - 1):
|
| 1137 |
+
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
|
| 1138 |
+
# min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
|
| 1139 |
+
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
|
| 1140 |
+
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
|
| 1141 |
+
# floors), should be expectation-preserving.
|
| 1142 |
+
floor = -0.043637
|
| 1143 |
+
ceil = 1.2
|
| 1144 |
+
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
| 1145 |
+
deriv
|
| 1146 |
+
)
|
| 1147 |
+
if __name__ == "__main__":
|
| 1148 |
+
# for self-testing only.
|
| 1149 |
+
assert d_scaled.min() >= 0.0
|
| 1150 |
+
assert d_scaled.max() < 256.0
|
| 1151 |
+
d_int = d_scaled.to(torch.uint8)
|
| 1152 |
+
ctx.save_for_backward(d_int)
|
| 1153 |
+
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
| 1154 |
+
y = y.to(torch.float16)
|
| 1155 |
+
return y
|
| 1156 |
+
|
| 1157 |
+
@staticmethod
|
| 1158 |
+
def backward(ctx, y_grad: Tensor) -> Tensor:
|
| 1159 |
+
(d,) = ctx.saved_tensors
|
| 1160 |
+
# the same constants as used in forward pass.
|
| 1161 |
+
floor = -0.043637
|
| 1162 |
+
ceil = 1.2
|
| 1163 |
+
d = d * ((ceil - floor) / 255.0) + floor
|
| 1164 |
+
return y_grad * d
|
| 1165 |
+
|
| 1166 |
+
|
| 1167 |
+
class DoubleSwish(torch.nn.Module):
|
| 1168 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1169 |
+
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
| 1170 |
+
that we approximate closely with x * sigmoid(x-1).
|
| 1171 |
+
"""
|
| 1172 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 1173 |
+
return x * torch.sigmoid(x - 1.0)
|
| 1174 |
+
return DoubleSwishFunction.apply(x)
|
| 1175 |
+
|
| 1176 |
+
|
| 1177 |
+
def BalancedDoubleSwish(
|
| 1178 |
+
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
|
| 1179 |
+
) -> nn.Sequential:
|
| 1180 |
+
"""
|
| 1181 |
+
ActivationBalancer -> DoubleSwish
|
| 1182 |
+
"""
|
| 1183 |
+
balancer = ActivationBalancer(
|
| 1184 |
+
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
|
| 1185 |
+
)
|
| 1186 |
+
return nn.Sequential(
|
| 1187 |
+
balancer,
|
| 1188 |
+
DoubleSwish(),
|
| 1189 |
+
)
|
| 1190 |
+
|
| 1191 |
+
|
| 1192 |
+
def _test_max_eig():
|
| 1193 |
+
for proportion in [0.1, 0.5, 10.0]:
|
| 1194 |
+
logging.info(f"proportion = {proportion}")
|
| 1195 |
+
x = torch.randn(100, 128)
|
| 1196 |
+
direction = torch.randn(128)
|
| 1197 |
+
coeffs = torch.randn(100, 1)
|
| 1198 |
+
x += proportion * direction * coeffs
|
| 1199 |
+
|
| 1200 |
+
x.requires_grad = True
|
| 1201 |
+
|
| 1202 |
+
num_channels = 128
|
| 1203 |
+
m = MaxEig(
|
| 1204 |
+
num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig
|
| 1205 |
+
) # grad_scale
|
| 1206 |
+
|
| 1207 |
+
for _ in range(4):
|
| 1208 |
+
y = m(x)
|
| 1209 |
+
|
| 1210 |
+
y_grad = torch.randn_like(x)
|
| 1211 |
+
y.backward(gradient=y_grad)
|
| 1212 |
+
|
| 1213 |
+
if proportion < 0.2:
|
| 1214 |
+
assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
|
| 1215 |
+
elif proportion > 1.0:
|
| 1216 |
+
assert not torch.allclose(x.grad, y_grad)
|
| 1217 |
+
|
| 1218 |
+
|
| 1219 |
+
def _test_whiten():
|
| 1220 |
+
for proportion in [0.1, 0.5, 10.0]:
|
| 1221 |
+
logging.info(f"_test_whiten(): proportion = {proportion}")
|
| 1222 |
+
x = torch.randn(100, 128)
|
| 1223 |
+
direction = torch.randn(128)
|
| 1224 |
+
coeffs = torch.randn(100, 1)
|
| 1225 |
+
x += proportion * direction * coeffs
|
| 1226 |
+
|
| 1227 |
+
x.requires_grad = True
|
| 1228 |
+
|
| 1229 |
+
num_channels = 128
|
| 1230 |
+
m = Whiten(
|
| 1231 |
+
1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit,
|
| 1232 |
+
) # grad_scale
|
| 1233 |
+
|
| 1234 |
+
for _ in range(4):
|
| 1235 |
+
y = m(x)
|
| 1236 |
+
|
| 1237 |
+
y_grad = torch.randn_like(x)
|
| 1238 |
+
y.backward(gradient=y_grad)
|
| 1239 |
+
|
| 1240 |
+
if proportion < 0.2:
|
| 1241 |
+
assert torch.allclose(x.grad, y_grad)
|
| 1242 |
+
elif proportion > 1.0:
|
| 1243 |
+
assert not torch.allclose(x.grad, y_grad)
|
| 1244 |
+
|
| 1245 |
+
|
| 1246 |
+
def _test_activation_balancer_sign():
|
| 1247 |
+
probs = torch.arange(0, 1, 0.01)
|
| 1248 |
+
N = 1000
|
| 1249 |
+
x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
|
| 1250 |
+
x = x.detach()
|
| 1251 |
+
x.requires_grad = True
|
| 1252 |
+
m = ActivationBalancer(
|
| 1253 |
+
probs.numel(),
|
| 1254 |
+
channel_dim=0,
|
| 1255 |
+
min_positive=0.05,
|
| 1256 |
+
max_positive=0.95,
|
| 1257 |
+
max_factor=0.2,
|
| 1258 |
+
min_abs=0.0,
|
| 1259 |
+
)
|
| 1260 |
+
|
| 1261 |
+
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
| 1262 |
+
|
| 1263 |
+
y = m(x)
|
| 1264 |
+
y.backward(gradient=y_grad)
|
| 1265 |
+
print("_test_activation_balancer_sign: x = ", x)
|
| 1266 |
+
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
| 1267 |
+
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
| 1268 |
+
|
| 1269 |
+
|
| 1270 |
+
def _test_activation_balancer_magnitude():
|
| 1271 |
+
magnitudes = torch.arange(0, 1, 0.01)
|
| 1272 |
+
N = 1000
|
| 1273 |
+
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
| 1274 |
+
x = x.detach()
|
| 1275 |
+
x.requires_grad = True
|
| 1276 |
+
m = ActivationBalancer(
|
| 1277 |
+
magnitudes.numel(),
|
| 1278 |
+
channel_dim=0,
|
| 1279 |
+
min_positive=0.0,
|
| 1280 |
+
max_positive=1.0,
|
| 1281 |
+
max_factor=0.2,
|
| 1282 |
+
min_abs=0.2,
|
| 1283 |
+
max_abs=0.8,
|
| 1284 |
+
min_prob=1.0,
|
| 1285 |
+
)
|
| 1286 |
+
|
| 1287 |
+
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
| 1288 |
+
|
| 1289 |
+
y = m(x)
|
| 1290 |
+
y.backward(gradient=y_grad)
|
| 1291 |
+
print("_test_activation_balancer_magnitude: x = ", x)
|
| 1292 |
+
print("_test_activation_balancer_magnitude: y grad = ", y_grad)
|
| 1293 |
+
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
| 1294 |
+
|
| 1295 |
+
|
| 1296 |
+
def _test_basic_norm():
|
| 1297 |
+
num_channels = 128
|
| 1298 |
+
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
| 1299 |
+
|
| 1300 |
+
x = torch.randn(500, num_channels)
|
| 1301 |
+
|
| 1302 |
+
y = m(x)
|
| 1303 |
+
|
| 1304 |
+
assert y.shape == x.shape
|
| 1305 |
+
x_rms = (x**2).mean().sqrt()
|
| 1306 |
+
y_rms = (y**2).mean().sqrt()
|
| 1307 |
+
print("x rms = ", x_rms)
|
| 1308 |
+
print("y rms = ", y_rms)
|
| 1309 |
+
assert y_rms < x_rms
|
| 1310 |
+
assert y_rms > 0.5 * x_rms
|
| 1311 |
+
|
| 1312 |
+
|
| 1313 |
+
def _test_double_swish_deriv():
|
| 1314 |
+
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
| 1315 |
+
x.requires_grad = True
|
| 1316 |
+
m = DoubleSwish()
|
| 1317 |
+
|
| 1318 |
+
tol = (1.2 - (-0.043637)) / 255.0
|
| 1319 |
+
torch.autograd.gradcheck(m, x, atol=tol)
|
| 1320 |
+
|
| 1321 |
+
# for self-test.
|
| 1322 |
+
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
| 1323 |
+
x.requires_grad = True
|
| 1324 |
+
y = m(x)
|
| 1325 |
+
|
| 1326 |
+
|
| 1327 |
+
def _test_softmax():
|
| 1328 |
+
a = torch.randn(2, 10, dtype=torch.float64)
|
| 1329 |
+
b = a.clone()
|
| 1330 |
+
a.requires_grad = True
|
| 1331 |
+
b.requires_grad = True
|
| 1332 |
+
a.softmax(dim=1)[:, 0].sum().backward()
|
| 1333 |
+
print("a grad = ", a.grad)
|
| 1334 |
+
softmax(b, dim=1)[:, 0].sum().backward()
|
| 1335 |
+
print("b grad = ", b.grad)
|
| 1336 |
+
assert torch.allclose(a.grad, b.grad)
|
| 1337 |
+
|
| 1338 |
+
|
| 1339 |
+
if __name__ == "__main__":
|
| 1340 |
+
logging.getLogger().setLevel(logging.INFO)
|
| 1341 |
+
torch.set_num_threads(1)
|
| 1342 |
+
torch.set_num_interop_threads(1)
|
| 1343 |
+
_test_softmax()
|
| 1344 |
+
_test_whiten()
|
| 1345 |
+
_test_max_eig()
|
| 1346 |
+
_test_activation_balancer_sign()
|
| 1347 |
+
_test_activation_balancer_magnitude()
|
| 1348 |
+
_test_basic_norm()
|
| 1349 |
+
_test_double_swish_deriv()
|
Amphion/modules/norms/norm.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 Amphion.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import copy
|
| 8 |
+
import numbers
|
| 9 |
+
from typing import Any, List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import Tensor, nn
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
|
| 15 |
+
from modules.general.scaling import ActivationBalancer
|
| 16 |
+
from modules.general.scaling import BasicNorm as _BasicNorm
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
_shape_t = Union[int, List[int], torch.Size]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class LayerNorm(nn.Module):
|
| 23 |
+
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
| 24 |
+
normalized_shape: Tuple[int, ...]
|
| 25 |
+
eps: float
|
| 26 |
+
elementwise_affine: bool
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
normalized_shape: _shape_t,
|
| 31 |
+
eps: float = 1e-5,
|
| 32 |
+
elementwise_affine: bool = True,
|
| 33 |
+
device=None,
|
| 34 |
+
dtype=None,
|
| 35 |
+
) -> None:
|
| 36 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 37 |
+
super(LayerNorm, self).__init__()
|
| 38 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 39 |
+
normalized_shape = (normalized_shape,)
|
| 40 |
+
self.normalized_shape = tuple(normalized_shape)
|
| 41 |
+
self.eps = eps
|
| 42 |
+
self.elementwise_affine = elementwise_affine
|
| 43 |
+
if self.elementwise_affine:
|
| 44 |
+
self.weight = nn.Parameter(
|
| 45 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
| 46 |
+
)
|
| 47 |
+
self.bias = nn.Parameter(
|
| 48 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
| 49 |
+
)
|
| 50 |
+
else:
|
| 51 |
+
self.register_parameter("weight", None)
|
| 52 |
+
self.register_parameter("bias", None)
|
| 53 |
+
|
| 54 |
+
self.reset_parameters()
|
| 55 |
+
|
| 56 |
+
def reset_parameters(self) -> None:
|
| 57 |
+
if self.elementwise_affine:
|
| 58 |
+
nn.init.ones_(self.weight)
|
| 59 |
+
nn.init.zeros_(self.bias)
|
| 60 |
+
|
| 61 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
| 62 |
+
if isinstance(input, tuple):
|
| 63 |
+
input, embedding = input
|
| 64 |
+
output = F.layer_norm(
|
| 65 |
+
input, self.normalized_shape, self.weight, self.bias, self.eps
|
| 66 |
+
)
|
| 67 |
+
return output, embedding
|
| 68 |
+
|
| 69 |
+
assert embedding is None
|
| 70 |
+
return F.layer_norm(
|
| 71 |
+
input, self.normalized_shape, self.weight, self.bias, self.eps
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def extra_repr(self) -> str:
|
| 75 |
+
return (
|
| 76 |
+
"{normalized_shape}, eps={eps}, "
|
| 77 |
+
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class AdaptiveLayerNorm(nn.Module):
|
| 82 |
+
r"""Adaptive Layer Normalization"""
|
| 83 |
+
|
| 84 |
+
def __init__(self, d_model, norm) -> None:
|
| 85 |
+
super(AdaptiveLayerNorm, self).__init__()
|
| 86 |
+
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
| 87 |
+
self.norm = norm
|
| 88 |
+
self.d_model = d_model
|
| 89 |
+
self.eps = self.norm.eps
|
| 90 |
+
|
| 91 |
+
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
| 92 |
+
if isinstance(input, tuple):
|
| 93 |
+
input, embedding = input
|
| 94 |
+
weight, bias = torch.split(
|
| 95 |
+
self.project_layer(embedding),
|
| 96 |
+
split_size_or_sections=self.d_model,
|
| 97 |
+
dim=-1,
|
| 98 |
+
)
|
| 99 |
+
return (weight * self.norm(input) + bias, embedding)
|
| 100 |
+
|
| 101 |
+
weight, bias = torch.split(
|
| 102 |
+
self.project_layer(embedding),
|
| 103 |
+
split_size_or_sections=self.d_model,
|
| 104 |
+
dim=-1,
|
| 105 |
+
)
|
| 106 |
+
return weight * self.norm(input) + bias
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class BasicNorm(_BasicNorm):
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
d_model: int,
|
| 113 |
+
eps: float = 1e-5,
|
| 114 |
+
device=None,
|
| 115 |
+
dtype=None,
|
| 116 |
+
):
|
| 117 |
+
super(BasicNorm, self).__init__(d_model, eps=eps)
|
| 118 |
+
|
| 119 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
| 120 |
+
if isinstance(input, tuple):
|
| 121 |
+
input, embedding = input
|
| 122 |
+
return (
|
| 123 |
+
super(BasicNorm, self).forward(input),
|
| 124 |
+
embedding,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
assert embedding is None
|
| 128 |
+
return super(BasicNorm, self).forward(input)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class BalancedBasicNorm(nn.Module):
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
d_model: int,
|
| 135 |
+
eps: float = 1e-5,
|
| 136 |
+
device=None,
|
| 137 |
+
dtype=None,
|
| 138 |
+
):
|
| 139 |
+
super(BalancedBasicNorm, self).__init__()
|
| 140 |
+
self.balancer = ActivationBalancer(
|
| 141 |
+
d_model,
|
| 142 |
+
channel_dim=-1,
|
| 143 |
+
min_positive=0.45,
|
| 144 |
+
max_positive=0.55,
|
| 145 |
+
max_abs=6.0,
|
| 146 |
+
)
|
| 147 |
+
self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
|
| 148 |
+
|
| 149 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
| 150 |
+
if isinstance(input, tuple):
|
| 151 |
+
input, embedding = input
|
| 152 |
+
return self.norm((self.balancer(input), embedding))
|
| 153 |
+
|
| 154 |
+
assert embedding is None
|
| 155 |
+
return self.norm(self.balancer(input))
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class IdentityNorm(nn.Module):
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
d_model: int,
|
| 162 |
+
eps: float = 1e-5,
|
| 163 |
+
device=None,
|
| 164 |
+
dtype=None,
|
| 165 |
+
) -> None:
|
| 166 |
+
super(IdentityNorm, self).__init__()
|
| 167 |
+
|
| 168 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
| 169 |
+
if isinstance(input, tuple):
|
| 170 |
+
return input
|
| 171 |
+
|
| 172 |
+
assert embedding is None
|
| 173 |
+
return input
|