Spaces:
Paused
Paused
Upload 221 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- archive/README_INDEXTTS_1_5.md +247 -0
- examples/cases.jsonl +12 -0
- examples/emo_hate.wav +3 -0
- examples/emo_sad.wav +3 -0
- examples/voice_01.wav +3 -0
- examples/voice_02.wav +3 -0
- examples/voice_03.wav +3 -0
- examples/voice_04.wav +3 -0
- examples/voice_05.wav +3 -0
- examples/voice_06.wav +3 -0
- examples/voice_07.wav +3 -0
- examples/voice_08.wav +3 -0
- examples/voice_09.wav +3 -0
- examples/voice_10.wav +3 -0
- examples/voice_11.wav +3 -0
- examples/voice_12.wav +3 -0
- indextts/.DS_Store +0 -0
- indextts/BigVGAN/.DS_Store +0 -0
- indextts/BigVGAN/ECAPA_TDNN.py +656 -0
- indextts/BigVGAN/__init__.py +0 -0
- indextts/BigVGAN/activations.py +122 -0
- indextts/BigVGAN/alias_free_activation/.DS_Store +0 -0
- indextts/BigVGAN/alias_free_activation/__init__.py +0 -0
- indextts/BigVGAN/alias_free_activation/cuda/.gitignore +1 -0
- indextts/BigVGAN/alias_free_activation/cuda/__init__.py +0 -0
- indextts/BigVGAN/alias_free_activation/cuda/activation1d.py +76 -0
- indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
- indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu +256 -0
- indextts/BigVGAN/alias_free_activation/cuda/compat.h +29 -0
- indextts/BigVGAN/alias_free_activation/cuda/load.py +121 -0
- indextts/BigVGAN/alias_free_activation/cuda/type_shim.h +92 -0
- indextts/BigVGAN/alias_free_activation/torch/__init__.py +6 -0
- indextts/BigVGAN/alias_free_activation/torch/act.py +31 -0
- indextts/BigVGAN/alias_free_activation/torch/filter.py +102 -0
- indextts/BigVGAN/alias_free_activation/torch/resample.py +58 -0
- indextts/BigVGAN/alias_free_torch/__init__.py +6 -0
- indextts/BigVGAN/alias_free_torch/act.py +29 -0
- indextts/BigVGAN/alias_free_torch/filter.py +96 -0
- indextts/BigVGAN/alias_free_torch/resample.py +49 -0
- indextts/BigVGAN/bigvgan.py +534 -0
- indextts/BigVGAN/models.py +451 -0
- indextts/BigVGAN/nnet/CNN.py +546 -0
- indextts/BigVGAN/nnet/__init__.py +0 -0
- indextts/BigVGAN/nnet/linear.py +89 -0
- indextts/BigVGAN/nnet/normalization.py +670 -0
- indextts/BigVGAN/utils.py +101 -0
- indextts/__init__.py +0 -0
- indextts/cli.py +65 -0
- indextts/gpt/__init__.py +0 -0
- indextts/gpt/conformer/__init__.py +0 -0
archive/README_INDEXTTS_1_5.md
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
<div align="center">
|
| 3 |
+
<img src='assets/index_icon.png' width="250"/>
|
| 4 |
+
</div>
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
<h2><center>IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System</h2>
|
| 8 |
+
|
| 9 |
+
<p align="center">
|
| 10 |
+
<a href='https://arxiv.org/abs/2502.05512'><img src='https://img.shields.io/badge/ArXiv-2502.05512-red'></a>
|
| 11 |
+
|
| 12 |
+
## 👉🏻 IndexTTS 👈🏻
|
| 13 |
+
|
| 14 |
+
[[HuggingFace Demo]](https://huggingface.co/spaces/IndexTeam/IndexTTS) [[ModelScope Demo]](https://modelscope.cn/studios/IndexTeam/IndexTTS-Demo) \
|
| 15 |
+
[[Paper]](https://arxiv.org/abs/2502.05512) [[Demos]](https://index-tts.github.io)
|
| 16 |
+
|
| 17 |
+
**IndexTTS** is a GPT-style text-to-speech (TTS) model mainly based on XTTS and Tortoise. It is capable of correcting the pronunciation of Chinese characters using pinyin and controlling pauses at any position through punctuation marks. We enhanced multiple modules of the system, including the improvement of speaker condition feature representation, and the integration of BigVGAN2 to optimize audio quality. Trained on tens of thousands of hours of data, our system achieves state-of-the-art performance, outperforming current popular TTS systems such as XTTS, CosyVoice2, Fish-Speech, and F5-TTS.
|
| 18 |
+
<span style="font-size:16px;">
|
| 19 |
+
Experience **IndexTTS**: Please contact <u>xuanwu@bilibili.com</u> for more detailed information. </span>
|
| 20 |
+
### Contact
|
| 21 |
+
QQ群(二群):1048202584 \
|
| 22 |
+
Discord:https://discord.gg/uT32E7KDmy \
|
| 23 |
+
简历:indexspeech@bilibili.com \
|
| 24 |
+
欢迎大家来交流讨论!
|
| 25 |
+
## 📣 Updates
|
| 26 |
+
|
| 27 |
+
- `2025/05/14` 🔥🔥 We release the **IndexTTS-1.5**, Significantly improve the model's stability and its performance in the English language.
|
| 28 |
+
- `2025/03/25` 🔥 We release IndexTTS-1.0 model parameters and inference code.
|
| 29 |
+
- `2025/02/12` 🔥 We submitted our paper on arXiv, and released our demos and test sets.
|
| 30 |
+
|
| 31 |
+
## 🖥️ Method
|
| 32 |
+
|
| 33 |
+
The overview of IndexTTS is shown as follows.
|
| 34 |
+
|
| 35 |
+
<picture>
|
| 36 |
+
<img src="assets/IndexTTS.png" width="800"/>
|
| 37 |
+
</picture>
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
The main improvements and contributions are summarized as follows:
|
| 41 |
+
- In Chinese scenarios, we have introduced a character-pinyin hybrid modeling approach. This allows for quick correction of mispronounced characters.
|
| 42 |
+
- **IndexTTS** incorporate a conformer conditioning encoder and a BigVGAN2-based speechcode decoder. This improves training stability, voice timbre similarity, and sound quality.
|
| 43 |
+
- We release all test sets here, including those for polysyllabic words, subjective and objective test sets.
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
## Model Download
|
| 48 |
+
| 🤗**HuggingFace** | **ModelScope** |
|
| 49 |
+
|----------------------------------------------------------|----------------------------------------------------------|
|
| 50 |
+
| [IndexTTS](https://huggingface.co/IndexTeam/Index-TTS) | [IndexTTS](https://modelscope.cn/models/IndexTeam/Index-TTS) |
|
| 51 |
+
| [😁IndexTTS-1.5](https://huggingface.co/IndexTeam/IndexTTS-1.5) | [IndexTTS-1.5](https://modelscope.cn/models/IndexTeam/IndexTTS-1.5) |
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
## 📑 Evaluation
|
| 55 |
+
|
| 56 |
+
**Word Error Rate (WER) Results for IndexTTS and Baseline Models on the** [**seed-test**](https://github.com/BytedanceSpeech/seed-tts-eval)
|
| 57 |
+
|
| 58 |
+
| **WER** | **test_zh** | **test_en** | **test_hard** |
|
| 59 |
+
|:----------------------:|:-----------:|:-----------:|:-------------:|
|
| 60 |
+
| **Human** | 1.26 | 2.14 | - |
|
| 61 |
+
| **SeedTTS** | 1.002 | 1.945 | **6.243** |
|
| 62 |
+
| **CosyVoice 2** | 1.45 | 2.57 | 6.83 |
|
| 63 |
+
| **F5TTS** | 1.56 | 1.83 | 8.67 |
|
| 64 |
+
| **FireRedTTS** | 1.51 | 3.82 | 17.45 |
|
| 65 |
+
| **MaskGCT** | 2.27 | 2.62 | 10.27 |
|
| 66 |
+
| **Spark-TTS** | 1.2 | 1.98 | - |
|
| 67 |
+
| **MegaTTS 3** | 1.36 | 1.82 | - |
|
| 68 |
+
| **IndexTTS** | 0.937 | 1.936 | 6.831 |
|
| 69 |
+
| **IndexTTS-1.5** | **0.821** | **1.606** | 6.565 |
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
**Word Error Rate (WER) Results for IndexTTS and Baseline Models on the other opensource test**
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
| **Model** | **aishell1_test** | **commonvoice_20_test_zh** | **commonvoice_20_test_en** | **librispeech_test_clean** | **avg** |
|
| 76 |
+
|:---------------:|:-----------------:|:--------------------------:|:--------------------------:|:--------------------------:|:--------:|
|
| 77 |
+
| **Human** | 2.0 | 9.5 | 10.0 | 2.4 | 5.1 |
|
| 78 |
+
| **CosyVoice 2** | 1.8 | 9.1 | 7.3 | 4.9 | 5.9 |
|
| 79 |
+
| **F5TTS** | 3.9 | 11.7 | 5.4 | 7.8 | 8.2 |
|
| 80 |
+
| **Fishspeech** | 2.4 | 11.4 | 8.8 | 8.0 | 8.3 |
|
| 81 |
+
| **FireRedTTS** | 2.2 | 11.0 | 16.3 | 5.7 | 7.7 |
|
| 82 |
+
| **XTTS** | 3.0 | 11.4 | 7.1 | 3.5 | 6.0 |
|
| 83 |
+
| **IndexTTS** | 1.3 | 7.0 | 5.3 | 2.1 | 3.7 |
|
| 84 |
+
| **IndexTTS-1.5** | **1.2** | **6.8** | **3.9** | **1.7** | **3.1** |
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
**Speaker Similarity (SS) Results for IndexTTS and Baseline Models**
|
| 88 |
+
|
| 89 |
+
| **Model** | **aishell1_test** | **commonvoice_20_test_zh** | **commonvoice_20_test_en** | **librispeech_test_clean** | **avg** |
|
| 90 |
+
|:---------------:|:-----------------:|:--------------------------:|:--------------------------:|:--------------------------:|:---------:|
|
| 91 |
+
| **Human** | 0.846 | 0.809 | 0.820 | 0.858 | 0.836 |
|
| 92 |
+
| **CosyVoice 2** | **0.796** | 0.743 | 0.742 | **0.837** | **0.788** |
|
| 93 |
+
| **F5TTS** | 0.743 | **0.747** | 0.746 | 0.828 | 0.779 |
|
| 94 |
+
| **Fishspeech** | 0.488 | 0.552 | 0.622 | 0.701 | 0.612 |
|
| 95 |
+
| **FireRedTTS** | 0.579 | 0.593 | 0.587 | 0.698 | 0.631 |
|
| 96 |
+
| **XTTS** | 0.573 | 0.586 | 0.648 | 0.761 | 0.663 |
|
| 97 |
+
| **IndexTTS** | 0.744 | 0.742 | **0.758** | 0.823 | 0.776 |
|
| 98 |
+
| **IndexTTS-1.5** | 0.741 | 0.722 | 0.753 | 0.819 | 0.771 |
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
**MOS Scores for Zero-Shot Cloned Voice**
|
| 103 |
+
|
| 104 |
+
| **Model** | **Prosody** | **Timbre** | **Quality** | **AVG** |
|
| 105 |
+
|-----------------|:-----------:|:----------:|:-----------:|:---------:|
|
| 106 |
+
| **CosyVoice 2** | 3.67 | 4.05 | 3.73 | 3.81 |
|
| 107 |
+
| **F5TTS** | 3.56 | 3.88 | 3.56 | 3.66 |
|
| 108 |
+
| **Fishspeech** | 3.40 | 3.63 | 3.69 | 3.57 |
|
| 109 |
+
| **FireRedTTS** | 3.79 | 3.72 | 3.60 | 3.70 |
|
| 110 |
+
| **XTTS** | 3.23 | 2.99 | 3.10 | 3.11 |
|
| 111 |
+
| **IndexTTS** | **3.79** | **4.20** | **4.05** | **4.01** |
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
## Usage Instructions
|
| 115 |
+
### Environment Setup
|
| 116 |
+
1. Download this repository:
|
| 117 |
+
```bash
|
| 118 |
+
git clone https://github.com/index-tts/index-tts.git
|
| 119 |
+
```
|
| 120 |
+
2. Install dependencies:
|
| 121 |
+
|
| 122 |
+
Create a new conda environment and install dependencies:
|
| 123 |
+
|
| 124 |
+
```bash
|
| 125 |
+
conda create -n index-tts python=3.10
|
| 126 |
+
conda activate index-tts
|
| 127 |
+
apt-get install ffmpeg
|
| 128 |
+
# or use conda to install ffmpeg
|
| 129 |
+
conda install -c conda-forge ffmpeg
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
Install [PyTorch](https://pytorch.org/get-started/locally/), e.g.:
|
| 133 |
+
```bash
|
| 134 |
+
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
> [!NOTE]
|
| 138 |
+
> If you are using Windows you may encounter [an error](https://github.com/index-tts/index-tts/issues/61) when installing `pynini`:
|
| 139 |
+
`ERROR: Failed building wheel for pynini`
|
| 140 |
+
> In this case, please install `pynini` via `conda`:
|
| 141 |
+
> ```bash
|
| 142 |
+
> # after conda activate index-tts
|
| 143 |
+
> conda install -c conda-forge pynini==2.1.6
|
| 144 |
+
> pip install WeTextProcessing --no-deps
|
| 145 |
+
> ```
|
| 146 |
+
|
| 147 |
+
Install `IndexTTS` as a package:
|
| 148 |
+
```bash
|
| 149 |
+
cd index-tts
|
| 150 |
+
pip install -e .
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
3. Download models:
|
| 154 |
+
|
| 155 |
+
Download by `huggingface-cli`:
|
| 156 |
+
|
| 157 |
+
```bash
|
| 158 |
+
huggingface-cli download IndexTeam/IndexTTS-1.5 \
|
| 159 |
+
config.yaml bigvgan_discriminator.pth bigvgan_generator.pth bpe.model dvae.pth gpt.pth unigram_12000.vocab \
|
| 160 |
+
--local-dir checkpoints
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
Recommended for China users. 如果下载速度慢,可以使用镜像:
|
| 164 |
+
```bash
|
| 165 |
+
export HF_ENDPOINT="https://hf-mirror.com"
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
Or by `wget`:
|
| 169 |
+
|
| 170 |
+
```bash
|
| 171 |
+
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/bigvgan_discriminator.pth -P checkpoints
|
| 172 |
+
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/bigvgan_generator.pth -P checkpoints
|
| 173 |
+
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/bpe.model -P checkpoints
|
| 174 |
+
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/dvae.pth -P checkpoints
|
| 175 |
+
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/gpt.pth -P checkpoints
|
| 176 |
+
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/unigram_12000.vocab -P checkpoints
|
| 177 |
+
wget https://huggingface.co/IndexTeam/IndexTTS-1.5/resolve/main/config.yaml -P checkpoints
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
> [!NOTE]
|
| 181 |
+
> If you prefer to use the `IndexTTS-1.0` model, please replace `IndexTeam/IndexTTS-1.5` with `IndexTeam/IndexTTS` in the above commands.
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
4. Run test script:
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
```bash
|
| 188 |
+
# Please put your prompt audio in 'test_data' and rename it to 'input.wav'
|
| 189 |
+
python indextts/infer.py
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
5. Use as command line tool:
|
| 193 |
+
|
| 194 |
+
```bash
|
| 195 |
+
# Make sure pytorch has been installed before running this command
|
| 196 |
+
indextts "大��好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!" \
|
| 197 |
+
--voice reference_voice.wav \
|
| 198 |
+
--model_dir checkpoints \
|
| 199 |
+
--config checkpoints/config.yaml \
|
| 200 |
+
--output output.wav
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
Use `--help` to see more options.
|
| 204 |
+
```bash
|
| 205 |
+
indextts --help
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
#### Web Demo
|
| 209 |
+
```bash
|
| 210 |
+
pip install -e ".[webui]" --no-build-isolation
|
| 211 |
+
python webui.py
|
| 212 |
+
|
| 213 |
+
# use another model version:
|
| 214 |
+
python webui.py --model_dir IndexTTS-1.5
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
Open your browser and visit `http://127.0.0.1:7860` to see the demo.
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
#### Sample Code
|
| 221 |
+
```python
|
| 222 |
+
from indextts.infer import IndexTTS
|
| 223 |
+
tts = IndexTTS(model_dir="checkpoints",cfg_path="checkpoints/config.yaml")
|
| 224 |
+
voice="reference_voice.wav"
|
| 225 |
+
text="大家好,我现在正在bilibili 体验 ai 科技,说实话,来之前我绝对想不到!AI技术已经发展到这样匪夷所思的地步了!比如说,现在正在说话的其实是B站为我现场复刻的数字分身,简直就是平行宇宙的另一个我了。如果大家也想体验更多深入的AIGC功能,可以访问 bilibili studio,相信我,你们也会吃惊的。"
|
| 226 |
+
tts.infer(voice, text, output_path)
|
| 227 |
+
```
|
| 228 |
+
|
| 229 |
+
## Acknowledge
|
| 230 |
+
1. [tortoise-tts](https://github.com/neonbjb/tortoise-tts)
|
| 231 |
+
2. [XTTSv2](https://github.com/coqui-ai/TTS)
|
| 232 |
+
3. [BigVGAN](https://github.com/NVIDIA/BigVGAN)
|
| 233 |
+
4. [wenet](https://github.com/wenet-e2e/wenet/tree/main)
|
| 234 |
+
5. [icefall](https://github.com/k2-fsa/icefall)
|
| 235 |
+
|
| 236 |
+
## 📚 Citation
|
| 237 |
+
|
| 238 |
+
🌟 If you find our work helpful, please leave us a star and cite our paper.
|
| 239 |
+
|
| 240 |
+
```
|
| 241 |
+
@article{deng2025indextts,
|
| 242 |
+
title={IndexTTS: An Industrial-Level Controllable and Efficient Zero-Shot Text-To-Speech System},
|
| 243 |
+
author={Wei Deng, Siyi Zhou, Jingchen Shu, Jinchao Wang, Lu Wang},
|
| 244 |
+
journal={arXiv preprint arXiv:2502.05512},
|
| 245 |
+
year={2025}
|
| 246 |
+
}
|
| 247 |
+
```
|
examples/cases.jsonl
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"prompt_audio":"voice_01.wav","text":"Translate for me, what is a surprise!","emo_mode":0}
|
| 2 |
+
{"prompt_audio":"voice_02.wav","text":"The palace is strict, no false rumors, Lady Qi!","emo_mode":0}
|
| 3 |
+
{"prompt_audio":"voice_03.wav","text":"这个呀,就是我们精心制作准备的纪念品,大家可以看到这个色泽和这个材质啊,哎呀多么的光彩照人。","emo_mode":0}
|
| 4 |
+
{"prompt_audio":"voice_04.wav","text":"你就需要我这种专业人士的帮助,就像手无缚鸡之力的人进入雪山狩猎,一定需要最老练的猎人指导。","emo_mode":0}
|
| 5 |
+
{"prompt_audio":"voice_05.wav","text":"在真正的日本剑道中,格斗过程极其短暂,常常短至半秒,最长也不超过两秒,利剑相击的转瞬间,已有一方倒在血泊中。但在这电光石火的对决之前,双方都要以一个石雕般凝固的姿势站定,长时间的逼视对方,这一过程可能长达十分钟!","emo_mode":0}
|
| 6 |
+
{"prompt_audio":"voice_06.wav","text":"今天呢,咱们开一部新书,叫《赛博朋克二零七七》。这词儿我听着都新鲜。这赛博朋克啊,简单理解就是“高科技,低生活”。这一听,我就明白了,于老师就爱用那高科技的东西,手机都得拿脚纹开,大冬天为了解锁脱得一丝不挂,冻得跟王八蛋似的。","emo_mode":0}
|
| 7 |
+
{"prompt_audio":"voice_07.wav","emo_audio":"emo_sad.wav","emo_weight": 1.0, "emo_mode":1,"text":"酒楼丧尽天良,开始借机竞拍房间,哎,一群蠢货。"}
|
| 8 |
+
{"prompt_audio":"voice_08.wav","emo_audio":"emo_hate.wav","emo_weight": 1.0, "emo_mode":1,"text":"你看看你,对我还有没有一点父子之间的信任了。"}
|
| 9 |
+
{"prompt_audio":"voice_09.wav","emo_vec_3":0.8,"emo_mode":2,"text":"对不起嘛!我的记性真的不太好,但是和你在一起的事情,我都会努力记住的~"}
|
| 10 |
+
{"prompt_audio":"voice_10.wav","emo_vec_7":1.0,"emo_mode":2,"text":"哇塞!这个爆率也太高了!欧皇附体了!"}
|
| 11 |
+
{"prompt_audio":"voice_11.wav","emo_mode":3,"emo_text":"极度悲伤","text":"这些年的时光终究是错付了... "}
|
| 12 |
+
{"prompt_audio":"voice_12.wav","emo_mode":3,"emo_text":"You scared me to death! What are you, a ghost?","text":"快躲起来!是他要来了!他要来抓我们了!"}
|
examples/emo_hate.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:89e6e7eee1a28303776e9cf43971e9505529bd0e669f5fcf47f4d1370f9187c4
|
| 3 |
+
size 145368
|
examples/emo_sad.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f7d3e5bf2b7bca6458f9e6d7a5ce073c41eb4418895e7df2f994e5a0c96c064a
|
| 3 |
+
size 842016
|
examples/voice_01.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e33e6ee0107a1dd58e1d66dd90c13df3d55a8683047cc3d7ea206dad84ed3fc8
|
| 3 |
+
size 478050
|
examples/voice_02.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8fe2dd1dbd54ef85a073fbc4c8fc0198f8d4523cc3320a600de0e347a3d8b491
|
| 3 |
+
size 574074
|
examples/voice_03.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:50e8b632efd794418919e2d33c8c2aab9189a57f4d21ef55020413be9f2b292a
|
| 3 |
+
size 616814
|
examples/voice_04.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2a3d2536245f45fd5e1eef046dd768ae7b72a0dba3ec3f370f145862fe64b3b2
|
| 3 |
+
size 681084
|
examples/voice_05.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eefb7f4a29a8b36f08d5cc1014ea947dbe9f7bef348f07c40263058e604a98eb
|
| 3 |
+
size 1482796
|
examples/voice_06.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2d85800fe261d106c3274fa792cbb952458c4b0b2e1b908340a8cd0d63c73a30
|
| 3 |
+
size 299052
|
examples/voice_07.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bcb10f84e63c3fdbfe99ac4184ca403b46a6d20b50540732713d48c4c95375ce
|
| 3 |
+
size 591894
|
examples/voice_08.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2e2c5f4859999b1ada95ee801d50c3c72879147269a4ed99e385fd917dae5c6f
|
| 3 |
+
size 426812
|
examples/voice_09.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8702467b9b3c83a16bead578e131c4388b3ef82aeff861bd336e622a9ae8a511
|
| 3 |
+
size 1798188
|
examples/voice_10.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:39c2db8b395e4c6ea1122ec7463b5f7bd7dd7d7302f3255780e4c529a9ae9985
|
| 3 |
+
size 1942242
|
examples/voice_11.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:82730e38498413d4371a76e841cd91fa2f74843b79ad3b606d45ad8a7b7a736c
|
| 3 |
+
size 1520734
|
examples/voice_12.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d67bd4f51773677d5902409813b9bb4c1d59b8243c74fc104553b80b49edd22b
|
| 3 |
+
size 778626
|
indextts/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
indextts/BigVGAN/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
indextts/BigVGAN/ECAPA_TDNN.py
ADDED
|
@@ -0,0 +1,656 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A popular speaker recognition and diarization model.
|
| 2 |
+
|
| 3 |
+
Authors
|
| 4 |
+
* Hwidong Na 2020
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch # noqa: F401
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
from indextts.BigVGAN.nnet.CNN import Conv1d as _Conv1d
|
| 12 |
+
from indextts.BigVGAN.nnet.linear import Linear
|
| 13 |
+
from indextts.BigVGAN.nnet.normalization import BatchNorm1d as _BatchNorm1d
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def length_to_mask(length, max_len=None, dtype=None, device=None):
|
| 17 |
+
"""Creates a binary mask for each sequence.
|
| 18 |
+
|
| 19 |
+
Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
|
| 20 |
+
|
| 21 |
+
Arguments
|
| 22 |
+
---------
|
| 23 |
+
length : torch.LongTensor
|
| 24 |
+
Containing the length of each sequence in the batch. Must be 1D.
|
| 25 |
+
max_len : int
|
| 26 |
+
Max length for the mask, also the size of the second dimension.
|
| 27 |
+
dtype : torch.dtype, default: None
|
| 28 |
+
The dtype of the generated mask.
|
| 29 |
+
device: torch.device, default: None
|
| 30 |
+
The device to put the mask variable.
|
| 31 |
+
|
| 32 |
+
Returns
|
| 33 |
+
-------
|
| 34 |
+
mask : tensor
|
| 35 |
+
The binary mask.
|
| 36 |
+
|
| 37 |
+
Example
|
| 38 |
+
-------
|
| 39 |
+
>>> length=torch.Tensor([1,2,3])
|
| 40 |
+
>>> mask=length_to_mask(length)
|
| 41 |
+
>>> mask
|
| 42 |
+
tensor([[1., 0., 0.],
|
| 43 |
+
[1., 1., 0.],
|
| 44 |
+
[1., 1., 1.]])
|
| 45 |
+
"""
|
| 46 |
+
assert len(length.shape) == 1
|
| 47 |
+
|
| 48 |
+
if max_len is None:
|
| 49 |
+
max_len = length.max().long().item() # using arange to generate mask
|
| 50 |
+
mask = torch.arange(
|
| 51 |
+
max_len, device=length.device, dtype=length.dtype
|
| 52 |
+
).expand(len(length), max_len) < length.unsqueeze(1)
|
| 53 |
+
|
| 54 |
+
if dtype is None:
|
| 55 |
+
dtype = length.dtype
|
| 56 |
+
|
| 57 |
+
if device is None:
|
| 58 |
+
device = length.device
|
| 59 |
+
|
| 60 |
+
mask = torch.as_tensor(mask, dtype=dtype, device=device)
|
| 61 |
+
return mask
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Skip transpose as much as possible for efficiency
|
| 65 |
+
class Conv1d(_Conv1d):
|
| 66 |
+
"""1D convolution. Skip transpose is used to improve efficiency."""
|
| 67 |
+
|
| 68 |
+
def __init__(self, *args, **kwargs):
|
| 69 |
+
super().__init__(skip_transpose=True, *args, **kwargs)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class BatchNorm1d(_BatchNorm1d):
|
| 73 |
+
"""1D batch normalization. Skip transpose is used to improve efficiency."""
|
| 74 |
+
|
| 75 |
+
def __init__(self, *args, **kwargs):
|
| 76 |
+
super().__init__(skip_transpose=True, *args, **kwargs)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class TDNNBlock(nn.Module):
|
| 80 |
+
"""An implementation of TDNN.
|
| 81 |
+
|
| 82 |
+
Arguments
|
| 83 |
+
---------
|
| 84 |
+
in_channels : int
|
| 85 |
+
Number of input channels.
|
| 86 |
+
out_channels : int
|
| 87 |
+
The number of output channels.
|
| 88 |
+
kernel_size : int
|
| 89 |
+
The kernel size of the TDNN blocks.
|
| 90 |
+
dilation : int
|
| 91 |
+
The dilation of the TDNN block.
|
| 92 |
+
activation : torch class
|
| 93 |
+
A class for constructing the activation layers.
|
| 94 |
+
groups : int
|
| 95 |
+
The groups size of the TDNN blocks.
|
| 96 |
+
|
| 97 |
+
Example
|
| 98 |
+
-------
|
| 99 |
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
| 100 |
+
>>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1)
|
| 101 |
+
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
|
| 102 |
+
>>> out_tensor.shape
|
| 103 |
+
torch.Size([8, 120, 64])
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
in_channels,
|
| 109 |
+
out_channels,
|
| 110 |
+
kernel_size,
|
| 111 |
+
dilation,
|
| 112 |
+
activation=nn.ReLU,
|
| 113 |
+
groups=1,
|
| 114 |
+
):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.conv = Conv1d(
|
| 117 |
+
in_channels=in_channels,
|
| 118 |
+
out_channels=out_channels,
|
| 119 |
+
kernel_size=kernel_size,
|
| 120 |
+
dilation=dilation,
|
| 121 |
+
groups=groups,
|
| 122 |
+
)
|
| 123 |
+
self.activation = activation()
|
| 124 |
+
self.norm = BatchNorm1d(input_size=out_channels)
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
"""Processes the input tensor x and returns an output tensor."""
|
| 128 |
+
return self.norm(self.activation(self.conv(x)))
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class Res2NetBlock(torch.nn.Module):
|
| 132 |
+
"""An implementation of Res2NetBlock w/ dilation.
|
| 133 |
+
|
| 134 |
+
Arguments
|
| 135 |
+
---------
|
| 136 |
+
in_channels : int
|
| 137 |
+
The number of channels expected in the input.
|
| 138 |
+
out_channels : int
|
| 139 |
+
The number of output channels.
|
| 140 |
+
scale : int
|
| 141 |
+
The scale of the Res2Net block.
|
| 142 |
+
kernel_size: int
|
| 143 |
+
The kernel size of the Res2Net block.
|
| 144 |
+
dilation : int
|
| 145 |
+
The dilation of the Res2Net block.
|
| 146 |
+
|
| 147 |
+
Example
|
| 148 |
+
-------
|
| 149 |
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
| 150 |
+
>>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
|
| 151 |
+
>>> out_tensor = layer(inp_tensor).transpose(1, 2)
|
| 152 |
+
>>> out_tensor.shape
|
| 153 |
+
torch.Size([8, 120, 64])
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def __init__(
|
| 157 |
+
self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1
|
| 158 |
+
):
|
| 159 |
+
super().__init__()
|
| 160 |
+
assert in_channels % scale == 0
|
| 161 |
+
assert out_channels % scale == 0
|
| 162 |
+
|
| 163 |
+
in_channel = in_channels // scale
|
| 164 |
+
hidden_channel = out_channels // scale
|
| 165 |
+
|
| 166 |
+
self.blocks = nn.ModuleList(
|
| 167 |
+
[
|
| 168 |
+
TDNNBlock(
|
| 169 |
+
in_channel,
|
| 170 |
+
hidden_channel,
|
| 171 |
+
kernel_size=kernel_size,
|
| 172 |
+
dilation=dilation,
|
| 173 |
+
)
|
| 174 |
+
for i in range(scale - 1)
|
| 175 |
+
]
|
| 176 |
+
)
|
| 177 |
+
self.scale = scale
|
| 178 |
+
|
| 179 |
+
def forward(self, x):
|
| 180 |
+
"""Processes the input tensor x and returns an output tensor."""
|
| 181 |
+
y = []
|
| 182 |
+
for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
|
| 183 |
+
if i == 0:
|
| 184 |
+
y_i = x_i
|
| 185 |
+
elif i == 1:
|
| 186 |
+
y_i = self.blocks[i - 1](x_i)
|
| 187 |
+
else:
|
| 188 |
+
y_i = self.blocks[i - 1](x_i + y_i)
|
| 189 |
+
y.append(y_i)
|
| 190 |
+
y = torch.cat(y, dim=1)
|
| 191 |
+
return y
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class SEBlock(nn.Module):
|
| 195 |
+
"""An implementation of squeeze-and-excitation block.
|
| 196 |
+
|
| 197 |
+
Arguments
|
| 198 |
+
---------
|
| 199 |
+
in_channels : int
|
| 200 |
+
The number of input channels.
|
| 201 |
+
se_channels : int
|
| 202 |
+
The number of output channels after squeeze.
|
| 203 |
+
out_channels : int
|
| 204 |
+
The number of output channels.
|
| 205 |
+
|
| 206 |
+
Example
|
| 207 |
+
-------
|
| 208 |
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
| 209 |
+
>>> se_layer = SEBlock(64, 16, 64)
|
| 210 |
+
>>> lengths = torch.rand((8,))
|
| 211 |
+
>>> out_tensor = se_layer(inp_tensor, lengths).transpose(1, 2)
|
| 212 |
+
>>> out_tensor.shape
|
| 213 |
+
torch.Size([8, 120, 64])
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
def __init__(self, in_channels, se_channels, out_channels):
|
| 217 |
+
super().__init__()
|
| 218 |
+
|
| 219 |
+
self.conv1 = Conv1d(
|
| 220 |
+
in_channels=in_channels, out_channels=se_channels, kernel_size=1
|
| 221 |
+
)
|
| 222 |
+
self.relu = torch.nn.ReLU(inplace=True)
|
| 223 |
+
self.conv2 = Conv1d(
|
| 224 |
+
in_channels=se_channels, out_channels=out_channels, kernel_size=1
|
| 225 |
+
)
|
| 226 |
+
self.sigmoid = torch.nn.Sigmoid()
|
| 227 |
+
|
| 228 |
+
def forward(self, x, lengths=None):
|
| 229 |
+
"""Processes the input tensor x and returns an output tensor."""
|
| 230 |
+
L = x.shape[-1]
|
| 231 |
+
if lengths is not None:
|
| 232 |
+
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
| 233 |
+
mask = mask.unsqueeze(1)
|
| 234 |
+
total = mask.sum(dim=2, keepdim=True)
|
| 235 |
+
s = (x * mask).sum(dim=2, keepdim=True) / total
|
| 236 |
+
else:
|
| 237 |
+
s = x.mean(dim=2, keepdim=True)
|
| 238 |
+
|
| 239 |
+
s = self.relu(self.conv1(s))
|
| 240 |
+
s = self.sigmoid(self.conv2(s))
|
| 241 |
+
|
| 242 |
+
return s * x
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class AttentiveStatisticsPooling(nn.Module):
|
| 246 |
+
"""This class implements an attentive statistic pooling layer for each channel.
|
| 247 |
+
It returns the concatenated mean and std of the input tensor.
|
| 248 |
+
|
| 249 |
+
Arguments
|
| 250 |
+
---------
|
| 251 |
+
channels: int
|
| 252 |
+
The number of input channels.
|
| 253 |
+
attention_channels: int
|
| 254 |
+
The number of attention channels.
|
| 255 |
+
global_context: bool
|
| 256 |
+
Whether to use global context.
|
| 257 |
+
|
| 258 |
+
Example
|
| 259 |
+
-------
|
| 260 |
+
>>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
|
| 261 |
+
>>> asp_layer = AttentiveStatisticsPooling(64)
|
| 262 |
+
>>> lengths = torch.rand((8,))
|
| 263 |
+
>>> out_tensor = asp_layer(inp_tensor, lengths).transpose(1, 2)
|
| 264 |
+
>>> out_tensor.shape
|
| 265 |
+
torch.Size([8, 1, 128])
|
| 266 |
+
"""
|
| 267 |
+
|
| 268 |
+
def __init__(self, channels, attention_channels=128, global_context=True):
|
| 269 |
+
super().__init__()
|
| 270 |
+
|
| 271 |
+
self.eps = 1e-12
|
| 272 |
+
self.global_context = global_context
|
| 273 |
+
if global_context:
|
| 274 |
+
self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
|
| 275 |
+
else:
|
| 276 |
+
self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
|
| 277 |
+
self.tanh = nn.Tanh()
|
| 278 |
+
self.conv = Conv1d(
|
| 279 |
+
in_channels=attention_channels, out_channels=channels, kernel_size=1
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
def forward(self, x, lengths=None):
|
| 283 |
+
"""Calculates mean and std for a batch (input tensor).
|
| 284 |
+
|
| 285 |
+
Arguments
|
| 286 |
+
---------
|
| 287 |
+
x : torch.Tensor
|
| 288 |
+
Tensor of shape [N, C, L].
|
| 289 |
+
lengths : torch.Tensor
|
| 290 |
+
The corresponding relative lengths of the inputs.
|
| 291 |
+
|
| 292 |
+
Returns
|
| 293 |
+
-------
|
| 294 |
+
pooled_stats : torch.Tensor
|
| 295 |
+
mean and std of batch
|
| 296 |
+
"""
|
| 297 |
+
L = x.shape[-1]
|
| 298 |
+
|
| 299 |
+
def _compute_statistics(x, m, dim=2, eps=self.eps):
|
| 300 |
+
mean = (m * x).sum(dim)
|
| 301 |
+
std = torch.sqrt(
|
| 302 |
+
(m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps)
|
| 303 |
+
)
|
| 304 |
+
return mean, std
|
| 305 |
+
|
| 306 |
+
if lengths is None:
|
| 307 |
+
lengths = torch.ones(x.shape[0], device=x.device)
|
| 308 |
+
|
| 309 |
+
# Make binary mask of shape [N, 1, L]
|
| 310 |
+
mask = length_to_mask(lengths * L, max_len=L, device=x.device)
|
| 311 |
+
mask = mask.unsqueeze(1)
|
| 312 |
+
|
| 313 |
+
# Expand the temporal context of the pooling layer by allowing the
|
| 314 |
+
# self-attention to look at global properties of the utterance.
|
| 315 |
+
if self.global_context:
|
| 316 |
+
# torch.std is unstable for backward computation
|
| 317 |
+
# https://github.com/pytorch/pytorch/issues/4320
|
| 318 |
+
total = mask.sum(dim=2, keepdim=True).float()
|
| 319 |
+
mean, std = _compute_statistics(x, mask / total)
|
| 320 |
+
mean = mean.unsqueeze(2).repeat(1, 1, L)
|
| 321 |
+
std = std.unsqueeze(2).repeat(1, 1, L)
|
| 322 |
+
attn = torch.cat([x, mean, std], dim=1)
|
| 323 |
+
else:
|
| 324 |
+
attn = x
|
| 325 |
+
|
| 326 |
+
# Apply layers
|
| 327 |
+
attn = self.conv(self.tanh(self.tdnn(attn)))
|
| 328 |
+
|
| 329 |
+
# Filter out zero-paddings
|
| 330 |
+
attn = attn.masked_fill(mask == 0, float("-inf"))
|
| 331 |
+
|
| 332 |
+
attn = F.softmax(attn, dim=2)
|
| 333 |
+
mean, std = _compute_statistics(x, attn)
|
| 334 |
+
# Append mean and std of the batch
|
| 335 |
+
pooled_stats = torch.cat((mean, std), dim=1)
|
| 336 |
+
pooled_stats = pooled_stats.unsqueeze(2)
|
| 337 |
+
|
| 338 |
+
return pooled_stats
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
class SERes2NetBlock(nn.Module):
|
| 342 |
+
"""An implementation of building block in ECAPA-TDNN, i.e.,
|
| 343 |
+
TDNN-Res2Net-TDNN-SEBlock.
|
| 344 |
+
|
| 345 |
+
Arguments
|
| 346 |
+
---------
|
| 347 |
+
in_channels: int
|
| 348 |
+
Expected size of input channels.
|
| 349 |
+
out_channels: int
|
| 350 |
+
The number of output channels.
|
| 351 |
+
res2net_scale: int
|
| 352 |
+
The scale of the Res2Net block.
|
| 353 |
+
se_channels : int
|
| 354 |
+
The number of output channels after squeeze.
|
| 355 |
+
kernel_size: int
|
| 356 |
+
The kernel size of the TDNN blocks.
|
| 357 |
+
dilation: int
|
| 358 |
+
The dilation of the Res2Net block.
|
| 359 |
+
activation : torch class
|
| 360 |
+
A class for constructing the activation layers.
|
| 361 |
+
groups: int
|
| 362 |
+
Number of blocked connections from input channels to output channels.
|
| 363 |
+
|
| 364 |
+
Example
|
| 365 |
+
-------
|
| 366 |
+
>>> x = torch.rand(8, 120, 64).transpose(1, 2)
|
| 367 |
+
>>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
|
| 368 |
+
>>> out = conv(x).transpose(1, 2)
|
| 369 |
+
>>> out.shape
|
| 370 |
+
torch.Size([8, 120, 64])
|
| 371 |
+
"""
|
| 372 |
+
|
| 373 |
+
def __init__(
|
| 374 |
+
self,
|
| 375 |
+
in_channels,
|
| 376 |
+
out_channels,
|
| 377 |
+
res2net_scale=8,
|
| 378 |
+
se_channels=128,
|
| 379 |
+
kernel_size=1,
|
| 380 |
+
dilation=1,
|
| 381 |
+
activation=torch.nn.ReLU,
|
| 382 |
+
groups=1,
|
| 383 |
+
):
|
| 384 |
+
super().__init__()
|
| 385 |
+
self.out_channels = out_channels
|
| 386 |
+
self.tdnn1 = TDNNBlock(
|
| 387 |
+
in_channels,
|
| 388 |
+
out_channels,
|
| 389 |
+
kernel_size=1,
|
| 390 |
+
dilation=1,
|
| 391 |
+
activation=activation,
|
| 392 |
+
groups=groups,
|
| 393 |
+
)
|
| 394 |
+
self.res2net_block = Res2NetBlock(
|
| 395 |
+
out_channels, out_channels, res2net_scale, kernel_size, dilation
|
| 396 |
+
)
|
| 397 |
+
self.tdnn2 = TDNNBlock(
|
| 398 |
+
out_channels,
|
| 399 |
+
out_channels,
|
| 400 |
+
kernel_size=1,
|
| 401 |
+
dilation=1,
|
| 402 |
+
activation=activation,
|
| 403 |
+
groups=groups,
|
| 404 |
+
)
|
| 405 |
+
self.se_block = SEBlock(out_channels, se_channels, out_channels)
|
| 406 |
+
|
| 407 |
+
self.shortcut = None
|
| 408 |
+
if in_channels != out_channels:
|
| 409 |
+
self.shortcut = Conv1d(
|
| 410 |
+
in_channels=in_channels,
|
| 411 |
+
out_channels=out_channels,
|
| 412 |
+
kernel_size=1,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
def forward(self, x, lengths=None):
|
| 416 |
+
"""Processes the input tensor x and returns an output tensor."""
|
| 417 |
+
residual = x
|
| 418 |
+
if self.shortcut:
|
| 419 |
+
residual = self.shortcut(x)
|
| 420 |
+
|
| 421 |
+
x = self.tdnn1(x)
|
| 422 |
+
x = self.res2net_block(x)
|
| 423 |
+
x = self.tdnn2(x)
|
| 424 |
+
x = self.se_block(x, lengths)
|
| 425 |
+
|
| 426 |
+
return x + residual
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
class ECAPA_TDNN(torch.nn.Module):
|
| 430 |
+
"""An implementation of the speaker embedding model in a paper.
|
| 431 |
+
"ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
|
| 432 |
+
TDNN Based Speaker Verification" (https://arxiv.org/abs/2005.07143).
|
| 433 |
+
|
| 434 |
+
Arguments
|
| 435 |
+
---------
|
| 436 |
+
input_size : int
|
| 437 |
+
Expected size of the input dimension.
|
| 438 |
+
device : str
|
| 439 |
+
Device used, e.g., "cpu" or "cuda".
|
| 440 |
+
lin_neurons : int
|
| 441 |
+
Number of neurons in linear layers.
|
| 442 |
+
activation : torch class
|
| 443 |
+
A class for constructing the activation layers.
|
| 444 |
+
channels : list of ints
|
| 445 |
+
Output channels for TDNN/SERes2Net layer.
|
| 446 |
+
kernel_sizes : list of ints
|
| 447 |
+
List of kernel sizes for each layer.
|
| 448 |
+
dilations : list of ints
|
| 449 |
+
List of dilations for kernels in each layer.
|
| 450 |
+
attention_channels: int
|
| 451 |
+
The number of attention channels.
|
| 452 |
+
res2net_scale : int
|
| 453 |
+
The scale of the Res2Net block.
|
| 454 |
+
se_channels : int
|
| 455 |
+
The number of output channels after squeeze.
|
| 456 |
+
global_context: bool
|
| 457 |
+
Whether to use global context.
|
| 458 |
+
groups : list of ints
|
| 459 |
+
List of groups for kernels in each layer.
|
| 460 |
+
|
| 461 |
+
Example
|
| 462 |
+
-------
|
| 463 |
+
>>> input_feats = torch.rand([5, 120, 80])
|
| 464 |
+
>>> compute_embedding = ECAPA_TDNN(80, lin_neurons=192)
|
| 465 |
+
>>> outputs = compute_embedding(input_feats)
|
| 466 |
+
>>> outputs.shape
|
| 467 |
+
torch.Size([5, 1, 192])
|
| 468 |
+
"""
|
| 469 |
+
|
| 470 |
+
def __init__(
|
| 471 |
+
self,
|
| 472 |
+
input_size,
|
| 473 |
+
device="cpu",
|
| 474 |
+
lin_neurons=192,
|
| 475 |
+
activation=torch.nn.ReLU,
|
| 476 |
+
channels=[512, 512, 512, 512, 1536],
|
| 477 |
+
kernel_sizes=[5, 3, 3, 3, 1],
|
| 478 |
+
dilations=[1, 2, 3, 4, 1],
|
| 479 |
+
attention_channels=128,
|
| 480 |
+
res2net_scale=8,
|
| 481 |
+
se_channels=128,
|
| 482 |
+
global_context=True,
|
| 483 |
+
groups=[1, 1, 1, 1, 1],
|
| 484 |
+
):
|
| 485 |
+
super().__init__()
|
| 486 |
+
assert len(channels) == len(kernel_sizes)
|
| 487 |
+
assert len(channels) == len(dilations)
|
| 488 |
+
self.channels = channels
|
| 489 |
+
self.blocks = nn.ModuleList()
|
| 490 |
+
|
| 491 |
+
# The initial TDNN layer
|
| 492 |
+
self.blocks.append(
|
| 493 |
+
TDNNBlock(
|
| 494 |
+
input_size,
|
| 495 |
+
channels[0],
|
| 496 |
+
kernel_sizes[0],
|
| 497 |
+
dilations[0],
|
| 498 |
+
activation,
|
| 499 |
+
groups[0],
|
| 500 |
+
)
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
# SE-Res2Net layers
|
| 504 |
+
for i in range(1, len(channels) - 1):
|
| 505 |
+
self.blocks.append(
|
| 506 |
+
SERes2NetBlock(
|
| 507 |
+
channels[i - 1],
|
| 508 |
+
channels[i],
|
| 509 |
+
res2net_scale=res2net_scale,
|
| 510 |
+
se_channels=se_channels,
|
| 511 |
+
kernel_size=kernel_sizes[i],
|
| 512 |
+
dilation=dilations[i],
|
| 513 |
+
activation=activation,
|
| 514 |
+
groups=groups[i],
|
| 515 |
+
)
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# Multi-layer feature aggregation
|
| 519 |
+
self.mfa = TDNNBlock(
|
| 520 |
+
channels[-2] * (len(channels) - 2),
|
| 521 |
+
channels[-1],
|
| 522 |
+
kernel_sizes[-1],
|
| 523 |
+
dilations[-1],
|
| 524 |
+
activation,
|
| 525 |
+
groups=groups[-1],
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
# Attentive Statistical Pooling
|
| 529 |
+
self.asp = AttentiveStatisticsPooling(
|
| 530 |
+
channels[-1],
|
| 531 |
+
attention_channels=attention_channels,
|
| 532 |
+
global_context=global_context,
|
| 533 |
+
)
|
| 534 |
+
self.asp_bn = BatchNorm1d(input_size=channels[-1] * 2)
|
| 535 |
+
|
| 536 |
+
# Final linear transformation
|
| 537 |
+
self.fc = Conv1d(
|
| 538 |
+
in_channels=channels[-1] * 2,
|
| 539 |
+
out_channels=lin_neurons,
|
| 540 |
+
kernel_size=1,
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
def forward(self, x, lengths=None):
|
| 544 |
+
"""Returns the embedding vector.
|
| 545 |
+
|
| 546 |
+
Arguments
|
| 547 |
+
---------
|
| 548 |
+
x : torch.Tensor
|
| 549 |
+
Tensor of shape (batch, time, channel).
|
| 550 |
+
lengths : torch.Tensor
|
| 551 |
+
Corresponding relative lengths of inputs.
|
| 552 |
+
|
| 553 |
+
Returns
|
| 554 |
+
-------
|
| 555 |
+
x : torch.Tensor
|
| 556 |
+
Embedding vector.
|
| 557 |
+
"""
|
| 558 |
+
# Minimize transpose for efficiency
|
| 559 |
+
x = x.transpose(1, 2)
|
| 560 |
+
|
| 561 |
+
xl = []
|
| 562 |
+
for layer in self.blocks:
|
| 563 |
+
try:
|
| 564 |
+
x = layer(x, lengths=lengths)
|
| 565 |
+
except TypeError:
|
| 566 |
+
x = layer(x)
|
| 567 |
+
xl.append(x)
|
| 568 |
+
|
| 569 |
+
# Multi-layer feature aggregation
|
| 570 |
+
x = torch.cat(xl[1:], dim=1)
|
| 571 |
+
x = self.mfa(x)
|
| 572 |
+
|
| 573 |
+
# Attentive Statistical Pooling
|
| 574 |
+
x = self.asp(x, lengths=lengths)
|
| 575 |
+
x = self.asp_bn(x)
|
| 576 |
+
|
| 577 |
+
# Final linear transformation
|
| 578 |
+
x = self.fc(x)
|
| 579 |
+
|
| 580 |
+
x = x.transpose(1, 2)
|
| 581 |
+
return x
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
class Classifier(torch.nn.Module):
|
| 585 |
+
"""This class implements the cosine similarity on the top of features.
|
| 586 |
+
|
| 587 |
+
Arguments
|
| 588 |
+
---------
|
| 589 |
+
input_size : int
|
| 590 |
+
Expected size of input dimension.
|
| 591 |
+
device : str
|
| 592 |
+
Device used, e.g., "cpu" or "cuda".
|
| 593 |
+
lin_blocks : int
|
| 594 |
+
Number of linear layers.
|
| 595 |
+
lin_neurons : int
|
| 596 |
+
Number of neurons in linear layers.
|
| 597 |
+
out_neurons : int
|
| 598 |
+
Number of classes.
|
| 599 |
+
|
| 600 |
+
Example
|
| 601 |
+
-------
|
| 602 |
+
>>> classify = Classifier(input_size=2, lin_neurons=2, out_neurons=2)
|
| 603 |
+
>>> outputs = torch.tensor([ [1., -1.], [-9., 1.], [0.9, 0.1], [0.1, 0.9] ])
|
| 604 |
+
>>> outputs = outputs.unsqueeze(1)
|
| 605 |
+
>>> cos = classify(outputs)
|
| 606 |
+
>>> (cos < -1.0).long().sum()
|
| 607 |
+
tensor(0)
|
| 608 |
+
>>> (cos > 1.0).long().sum()
|
| 609 |
+
tensor(0)
|
| 610 |
+
"""
|
| 611 |
+
|
| 612 |
+
def __init__(
|
| 613 |
+
self,
|
| 614 |
+
input_size,
|
| 615 |
+
device="cpu",
|
| 616 |
+
lin_blocks=0,
|
| 617 |
+
lin_neurons=192,
|
| 618 |
+
out_neurons=1211,
|
| 619 |
+
):
|
| 620 |
+
super().__init__()
|
| 621 |
+
self.blocks = nn.ModuleList()
|
| 622 |
+
|
| 623 |
+
for block_index in range(lin_blocks):
|
| 624 |
+
self.blocks.extend(
|
| 625 |
+
[
|
| 626 |
+
_BatchNorm1d(input_size=input_size),
|
| 627 |
+
Linear(input_size=input_size, n_neurons=lin_neurons),
|
| 628 |
+
]
|
| 629 |
+
)
|
| 630 |
+
input_size = lin_neurons
|
| 631 |
+
|
| 632 |
+
# Final Layer
|
| 633 |
+
self.weight = nn.Parameter(
|
| 634 |
+
torch.FloatTensor(out_neurons, input_size, device=device)
|
| 635 |
+
)
|
| 636 |
+
nn.init.xavier_uniform_(self.weight)
|
| 637 |
+
|
| 638 |
+
def forward(self, x):
|
| 639 |
+
"""Returns the output probabilities over speakers.
|
| 640 |
+
|
| 641 |
+
Arguments
|
| 642 |
+
---------
|
| 643 |
+
x : torch.Tensor
|
| 644 |
+
Torch tensor.
|
| 645 |
+
|
| 646 |
+
Returns
|
| 647 |
+
-------
|
| 648 |
+
out : torch.Tensor
|
| 649 |
+
Output probabilities over speakers.
|
| 650 |
+
"""
|
| 651 |
+
for layer in self.blocks:
|
| 652 |
+
x = layer(x)
|
| 653 |
+
|
| 654 |
+
# Need to be normalized
|
| 655 |
+
x = F.linear(F.normalize(x.squeeze(1)), F.normalize(self.weight))
|
| 656 |
+
return x.unsqueeze(1)
|
indextts/BigVGAN/__init__.py
ADDED
|
File without changes
|
indextts/BigVGAN/activations.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn, pow, sin
|
| 6 |
+
from torch.nn import Parameter
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Snake(nn.Module):
|
| 10 |
+
'''
|
| 11 |
+
Implementation of a sine-based periodic activation function
|
| 12 |
+
Shape:
|
| 13 |
+
- Input: (B, C, T)
|
| 14 |
+
- Output: (B, C, T), same shape as the input
|
| 15 |
+
Parameters:
|
| 16 |
+
- alpha - trainable parameter
|
| 17 |
+
References:
|
| 18 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 19 |
+
https://arxiv.org/abs/2006.08195
|
| 20 |
+
Examples:
|
| 21 |
+
>>> a1 = snake(256)
|
| 22 |
+
>>> x = torch.randn(256)
|
| 23 |
+
>>> x = a1(x)
|
| 24 |
+
'''
|
| 25 |
+
|
| 26 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
| 27 |
+
'''
|
| 28 |
+
Initialization.
|
| 29 |
+
INPUT:
|
| 30 |
+
- in_features: shape of the input
|
| 31 |
+
- alpha: trainable parameter
|
| 32 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 33 |
+
alpha will be trained along with the rest of your model.
|
| 34 |
+
'''
|
| 35 |
+
super(Snake, self).__init__()
|
| 36 |
+
self.in_features = in_features
|
| 37 |
+
|
| 38 |
+
# initialize alpha
|
| 39 |
+
self.alpha_logscale = alpha_logscale
|
| 40 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 41 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 42 |
+
else: # linear scale alphas initialized to ones
|
| 43 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 44 |
+
|
| 45 |
+
self.alpha.requires_grad = alpha_trainable
|
| 46 |
+
|
| 47 |
+
self.no_div_by_zero = 0.000000001
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
'''
|
| 51 |
+
Forward pass of the function.
|
| 52 |
+
Applies the function to the input elementwise.
|
| 53 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
| 54 |
+
'''
|
| 55 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 56 |
+
if self.alpha_logscale:
|
| 57 |
+
alpha = torch.exp(alpha)
|
| 58 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 59 |
+
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class SnakeBeta(nn.Module):
|
| 64 |
+
'''
|
| 65 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
| 66 |
+
Shape:
|
| 67 |
+
- Input: (B, C, T)
|
| 68 |
+
- Output: (B, C, T), same shape as the input
|
| 69 |
+
Parameters:
|
| 70 |
+
- alpha - trainable parameter that controls frequency
|
| 71 |
+
- beta - trainable parameter that controls magnitude
|
| 72 |
+
References:
|
| 73 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 74 |
+
https://arxiv.org/abs/2006.08195
|
| 75 |
+
Examples:
|
| 76 |
+
>>> a1 = snakebeta(256)
|
| 77 |
+
>>> x = torch.randn(256)
|
| 78 |
+
>>> x = a1(x)
|
| 79 |
+
'''
|
| 80 |
+
|
| 81 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
| 82 |
+
'''
|
| 83 |
+
Initialization.
|
| 84 |
+
INPUT:
|
| 85 |
+
- in_features: shape of the input
|
| 86 |
+
- alpha - trainable parameter that controls frequency
|
| 87 |
+
- beta - trainable parameter that controls magnitude
|
| 88 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
| 89 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
| 90 |
+
alpha will be trained along with the rest of your model.
|
| 91 |
+
'''
|
| 92 |
+
super(SnakeBeta, self).__init__()
|
| 93 |
+
self.in_features = in_features
|
| 94 |
+
|
| 95 |
+
# initialize alpha
|
| 96 |
+
self.alpha_logscale = alpha_logscale
|
| 97 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
| 98 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 99 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
| 100 |
+
else: # linear scale alphas initialized to ones
|
| 101 |
+
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
| 102 |
+
self.beta = Parameter(torch.ones(in_features) * alpha)
|
| 103 |
+
|
| 104 |
+
self.alpha.requires_grad = alpha_trainable
|
| 105 |
+
self.beta.requires_grad = alpha_trainable
|
| 106 |
+
|
| 107 |
+
self.no_div_by_zero = 0.000000001
|
| 108 |
+
|
| 109 |
+
def forward(self, x):
|
| 110 |
+
'''
|
| 111 |
+
Forward pass of the function.
|
| 112 |
+
Applies the function to the input elementwise.
|
| 113 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
| 114 |
+
'''
|
| 115 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 116 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 117 |
+
if self.alpha_logscale:
|
| 118 |
+
alpha = torch.exp(alpha)
|
| 119 |
+
beta = torch.exp(beta)
|
| 120 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
| 121 |
+
|
| 122 |
+
return x
|
indextts/BigVGAN/alias_free_activation/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
indextts/BigVGAN/alias_free_activation/__init__.py
ADDED
|
File without changes
|
indextts/BigVGAN/alias_free_activation/cuda/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
/build
|
indextts/BigVGAN/alias_free_activation/cuda/__init__.py
ADDED
|
File without changes
|
indextts/BigVGAN/alias_free_activation/cuda/activation1d.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
| 7 |
+
from indextts.BigVGAN.alias_free_activation.cuda import load
|
| 8 |
+
from indextts.BigVGAN.alias_free_activation.torch.resample import DownSample1d, UpSample1d
|
| 9 |
+
|
| 10 |
+
anti_alias_activation_cuda = load.load()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FusedAntiAliasActivation(torch.autograd.Function):
|
| 14 |
+
"""
|
| 15 |
+
Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
|
| 16 |
+
The hyperparameters are hard-coded in the kernel to maximize speed.
|
| 17 |
+
NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
@staticmethod
|
| 21 |
+
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
|
| 22 |
+
activation_results = anti_alias_activation_cuda.forward(
|
| 23 |
+
inputs, up_ftr, down_ftr, alpha, beta
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
return activation_results
|
| 27 |
+
|
| 28 |
+
@staticmethod
|
| 29 |
+
def backward(ctx, output_grads):
|
| 30 |
+
raise NotImplementedError
|
| 31 |
+
return output_grads, None, None
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Activation1d(nn.Module):
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
activation,
|
| 38 |
+
up_ratio: int = 2,
|
| 39 |
+
down_ratio: int = 2,
|
| 40 |
+
up_kernel_size: int = 12,
|
| 41 |
+
down_kernel_size: int = 12,
|
| 42 |
+
fused: bool = True,
|
| 43 |
+
):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.up_ratio = up_ratio
|
| 46 |
+
self.down_ratio = down_ratio
|
| 47 |
+
self.act = activation
|
| 48 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
| 49 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
| 50 |
+
|
| 51 |
+
self.fused = fused # Whether to use fused CUDA kernel or not
|
| 52 |
+
|
| 53 |
+
def forward(self, x):
|
| 54 |
+
if not self.fused:
|
| 55 |
+
x = self.upsample(x)
|
| 56 |
+
x = self.act(x)
|
| 57 |
+
x = self.downsample(x)
|
| 58 |
+
return x
|
| 59 |
+
else:
|
| 60 |
+
if self.act.__class__.__name__ == "Snake":
|
| 61 |
+
beta = self.act.alpha.data # Snake uses same params for alpha and beta
|
| 62 |
+
else:
|
| 63 |
+
beta = (
|
| 64 |
+
self.act.beta.data
|
| 65 |
+
) # Snakebeta uses different params for alpha and beta
|
| 66 |
+
alpha = self.act.alpha.data
|
| 67 |
+
if (
|
| 68 |
+
not self.act.alpha_logscale
|
| 69 |
+
): # Exp baked into cuda kernel, cancel it out with a log
|
| 70 |
+
alpha = torch.log(alpha)
|
| 71 |
+
beta = torch.log(beta)
|
| 72 |
+
|
| 73 |
+
x = FusedAntiAliasActivation.apply(
|
| 74 |
+
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
|
| 75 |
+
)
|
| 76 |
+
return x
|
indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation.cpp
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* coding=utf-8
|
| 2 |
+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#include <torch/extension.h>
|
| 18 |
+
|
| 19 |
+
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
|
| 20 |
+
|
| 21 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 22 |
+
m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
|
| 23 |
+
}
|
indextts/BigVGAN/alias_free_activation/cuda/anti_alias_activation_cuda.cu
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* coding=utf-8
|
| 2 |
+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#include <ATen/ATen.h>
|
| 18 |
+
#include <cuda.h>
|
| 19 |
+
#include <cuda_runtime.h>
|
| 20 |
+
#include <cuda_fp16.h>
|
| 21 |
+
#include <cuda_profiler_api.h>
|
| 22 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 23 |
+
#include <torch/extension.h>
|
| 24 |
+
#include "type_shim.h"
|
| 25 |
+
#include <assert.h>
|
| 26 |
+
#include <cfloat>
|
| 27 |
+
#include <limits>
|
| 28 |
+
#include <stdint.h>
|
| 29 |
+
#include <c10/macros/Macros.h>
|
| 30 |
+
|
| 31 |
+
namespace
|
| 32 |
+
{
|
| 33 |
+
// Hard-coded hyperparameters
|
| 34 |
+
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
| 35 |
+
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
|
| 36 |
+
constexpr int BUFFER_SIZE = 32;
|
| 37 |
+
constexpr int FILTER_SIZE = 12;
|
| 38 |
+
constexpr int HALF_FILTER_SIZE = 6;
|
| 39 |
+
constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
|
| 40 |
+
constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
|
| 41 |
+
constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
|
| 42 |
+
|
| 43 |
+
template <typename input_t, typename output_t, typename acc_t>
|
| 44 |
+
__global__ void anti_alias_activation_forward(
|
| 45 |
+
output_t *dst,
|
| 46 |
+
const input_t *src,
|
| 47 |
+
const acc_t *up_ftr,
|
| 48 |
+
const acc_t *down_ftr,
|
| 49 |
+
const acc_t *alpha,
|
| 50 |
+
const acc_t *beta,
|
| 51 |
+
int batch_size,
|
| 52 |
+
int channels,
|
| 53 |
+
int seq_len)
|
| 54 |
+
{
|
| 55 |
+
// Up and downsample filters
|
| 56 |
+
input_t up_filter[FILTER_SIZE];
|
| 57 |
+
input_t down_filter[FILTER_SIZE];
|
| 58 |
+
|
| 59 |
+
// Load data from global memory including extra indices reserved for replication paddings
|
| 60 |
+
input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
|
| 61 |
+
input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
|
| 62 |
+
|
| 63 |
+
// Output stores downsampled output before writing to dst
|
| 64 |
+
output_t output[BUFFER_SIZE];
|
| 65 |
+
|
| 66 |
+
// blockDim/threadIdx = (128, 1, 1)
|
| 67 |
+
// gridDim/blockIdx = (seq_blocks, channels, batches)
|
| 68 |
+
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
| 69 |
+
int local_offset = threadIdx.x * BUFFER_SIZE;
|
| 70 |
+
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
|
| 71 |
+
|
| 72 |
+
// intermediate have double the seq_len
|
| 73 |
+
int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
|
| 74 |
+
int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
|
| 75 |
+
|
| 76 |
+
// Get values needed for replication padding before moving pointer
|
| 77 |
+
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
| 78 |
+
input_t seq_left_most_value = right_most_pntr[0];
|
| 79 |
+
input_t seq_right_most_value = right_most_pntr[seq_len - 1];
|
| 80 |
+
|
| 81 |
+
// Move src and dst pointers
|
| 82 |
+
src += block_offset + local_offset;
|
| 83 |
+
dst += block_offset + local_offset;
|
| 84 |
+
|
| 85 |
+
// Alpha and beta values for snake activatons. Applies exp by default
|
| 86 |
+
alpha = alpha + blockIdx.y;
|
| 87 |
+
beta = beta + blockIdx.y;
|
| 88 |
+
|
| 89 |
+
acc_t alpha_val = expf(alpha[0]);
|
| 90 |
+
acc_t beta_val = expf(beta[0]);
|
| 91 |
+
|
| 92 |
+
#pragma unroll
|
| 93 |
+
for (int it = 0; it < FILTER_SIZE; it += 1)
|
| 94 |
+
{
|
| 95 |
+
up_filter[it] = up_ftr[it];
|
| 96 |
+
down_filter[it] = down_ftr[it];
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
// Apply replication padding for upsampling, matching torch impl
|
| 100 |
+
#pragma unroll
|
| 101 |
+
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
|
| 102 |
+
{
|
| 103 |
+
int element_index = seq_offset + it; // index for element
|
| 104 |
+
if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
|
| 105 |
+
{
|
| 106 |
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
|
| 107 |
+
}
|
| 108 |
+
if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
|
| 109 |
+
{
|
| 110 |
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
|
| 111 |
+
}
|
| 112 |
+
if ((element_index >= 0) && (element_index < seq_len))
|
| 113 |
+
{
|
| 114 |
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
// Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
|
| 119 |
+
#pragma unroll
|
| 120 |
+
for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
|
| 121 |
+
{
|
| 122 |
+
acc_t acc = 0.0;
|
| 123 |
+
int element_index = intermediate_seq_offset + it; // index for intermediate
|
| 124 |
+
#pragma unroll
|
| 125 |
+
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
| 126 |
+
{
|
| 127 |
+
if ((element_index + f_idx) >= 0)
|
| 128 |
+
{
|
| 129 |
+
acc += up_filter[f_idx] * elements[it + f_idx];
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
// Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
|
| 136 |
+
double no_div_by_zero = 0.000000001;
|
| 137 |
+
#pragma unroll
|
| 138 |
+
for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
|
| 139 |
+
{
|
| 140 |
+
acc_t a = sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
|
| 141 |
+
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * a * a;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
// Apply replication padding before downsampling conv from intermediates
|
| 145 |
+
#pragma unroll
|
| 146 |
+
for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
|
| 147 |
+
{
|
| 148 |
+
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
|
| 149 |
+
}
|
| 150 |
+
#pragma unroll
|
| 151 |
+
for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
|
| 152 |
+
{
|
| 153 |
+
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
// Apply downsample strided convolution (assuming stride=2) from intermediates
|
| 157 |
+
#pragma unroll
|
| 158 |
+
for (int it = 0; it < BUFFER_SIZE; it += 1)
|
| 159 |
+
{
|
| 160 |
+
acc_t acc = 0.0;
|
| 161 |
+
#pragma unroll
|
| 162 |
+
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
| 163 |
+
{
|
| 164 |
+
// Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
|
| 165 |
+
acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
|
| 166 |
+
}
|
| 167 |
+
output[it] = acc;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
// Write output to dst
|
| 171 |
+
#pragma unroll
|
| 172 |
+
for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
|
| 173 |
+
{
|
| 174 |
+
int element_index = seq_offset + it;
|
| 175 |
+
if (element_index < seq_len)
|
| 176 |
+
{
|
| 177 |
+
dst[it] = output[it];
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
template <typename input_t, typename output_t, typename acc_t>
|
| 184 |
+
void dispatch_anti_alias_activation_forward(
|
| 185 |
+
output_t *dst,
|
| 186 |
+
const input_t *src,
|
| 187 |
+
const acc_t *up_ftr,
|
| 188 |
+
const acc_t *down_ftr,
|
| 189 |
+
const acc_t *alpha,
|
| 190 |
+
const acc_t *beta,
|
| 191 |
+
int batch_size,
|
| 192 |
+
int channels,
|
| 193 |
+
int seq_len)
|
| 194 |
+
{
|
| 195 |
+
if (seq_len == 0)
|
| 196 |
+
{
|
| 197 |
+
return;
|
| 198 |
+
}
|
| 199 |
+
else
|
| 200 |
+
{
|
| 201 |
+
// Use 128 threads per block to maximimize gpu utilization
|
| 202 |
+
constexpr int threads_per_block = 128;
|
| 203 |
+
constexpr int seq_len_per_block = 4096;
|
| 204 |
+
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
|
| 205 |
+
dim3 blocks(blocks_per_seq_len, channels, batch_size);
|
| 206 |
+
dim3 threads(threads_per_block, 1, 1);
|
| 207 |
+
|
| 208 |
+
anti_alias_activation_forward<input_t, output_t, acc_t>
|
| 209 |
+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
|
| 215 |
+
{
|
| 216 |
+
// Input is a 3d tensor with dimensions [batches, channels, seq_len]
|
| 217 |
+
const int batches = input.size(0);
|
| 218 |
+
const int channels = input.size(1);
|
| 219 |
+
const int seq_len = input.size(2);
|
| 220 |
+
|
| 221 |
+
// Output
|
| 222 |
+
auto act_options = input.options().requires_grad(false);
|
| 223 |
+
|
| 224 |
+
torch::Tensor anti_alias_activation_results =
|
| 225 |
+
torch::empty({batches, channels, seq_len}, act_options);
|
| 226 |
+
|
| 227 |
+
using float32 = float;
|
| 228 |
+
// The dtype of input is float16, bfloat16, or float32
|
| 229 |
+
// The dtype of up_filter, down_filter, alpha, and beta is float32
|
| 230 |
+
// printf("input scalar type: %d\n", input.scalar_type());
|
| 231 |
+
// printf("up_filter scalar type: %d\n", up_filter.scalar_type());
|
| 232 |
+
// printf("down_filter scalar type: %d\n", down_filter.scalar_type());
|
| 233 |
+
// printf("alpha scalar type: %d\n", alpha.scalar_type());
|
| 234 |
+
// printf("beta scalar type: %d\n", beta.scalar_type());
|
| 235 |
+
void *input_ptr = static_cast<void *>(input.data_ptr());
|
| 236 |
+
float32 *up_filter_ptr = static_cast<float32 *>(up_filter.data_ptr());
|
| 237 |
+
float32 *down_filter_ptr = static_cast<float32 *>(down_filter.data_ptr());
|
| 238 |
+
float32 *alpha_ptr = static_cast<float32 *>(alpha.data_ptr());
|
| 239 |
+
float32 *beta_ptr = static_cast<float32 *>(beta.data_ptr());
|
| 240 |
+
void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
|
| 241 |
+
|
| 242 |
+
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
| 243 |
+
input.scalar_type(),
|
| 244 |
+
"dispatch anti alias activation_forward",
|
| 245 |
+
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float32>(
|
| 246 |
+
reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
|
| 247 |
+
reinterpret_cast<const scalar_t *>(input_ptr),
|
| 248 |
+
reinterpret_cast<const float32 *>(up_filter_ptr),
|
| 249 |
+
reinterpret_cast<const float32 *>(down_filter_ptr),
|
| 250 |
+
reinterpret_cast<const float32 *>(alpha_ptr),
|
| 251 |
+
reinterpret_cast<const float32 *>(beta_ptr),
|
| 252 |
+
batches,
|
| 253 |
+
channels,
|
| 254 |
+
seq_len););
|
| 255 |
+
return anti_alias_activation_results;
|
| 256 |
+
}
|
indextts/BigVGAN/alias_free_activation/cuda/compat.h
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* coding=utf-8
|
| 2 |
+
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
/*This code is copied fron NVIDIA apex:
|
| 18 |
+
* https://github.com/NVIDIA/apex
|
| 19 |
+
* with minor changes. */
|
| 20 |
+
|
| 21 |
+
#ifndef TORCH_CHECK
|
| 22 |
+
#define TORCH_CHECK AT_CHECK
|
| 23 |
+
#endif
|
| 24 |
+
|
| 25 |
+
#ifdef VERSION_GE_1_3
|
| 26 |
+
#define DATA_PTR data_ptr
|
| 27 |
+
#else
|
| 28 |
+
#define DATA_PTR data
|
| 29 |
+
#endif
|
indextts/BigVGAN/alias_free_activation/cuda/load.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import pathlib
|
| 6 |
+
import subprocess
|
| 7 |
+
|
| 8 |
+
from torch.utils import cpp_extension
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
|
| 12 |
+
Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
|
| 13 |
+
"""
|
| 14 |
+
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import re
|
| 18 |
+
import shutil
|
| 19 |
+
import tempfile
|
| 20 |
+
|
| 21 |
+
# 补丁修复:sources 路径含中文字符时,生成 build.ninja 乱码导致编译失败
|
| 22 |
+
# 使用临时目录来规避 ninja 编译失败(比如中文路径)
|
| 23 |
+
def chinese_path_compile_support(sources, buildpath):
|
| 24 |
+
pattern = re.compile(r'[\u4e00-\u9fff]')
|
| 25 |
+
if not bool(pattern.search(str(sources[0].resolve()))):
|
| 26 |
+
return buildpath # 检测非中文路径跳过
|
| 27 |
+
# Create build directory
|
| 28 |
+
resolves = [ item.name for item in sources]
|
| 29 |
+
ninja_compile_dir = os.path.join(tempfile.gettempdir(), "BigVGAN", "cuda")
|
| 30 |
+
os.makedirs(ninja_compile_dir, exist_ok=True)
|
| 31 |
+
new_buildpath = os.path.join(ninja_compile_dir, "build")
|
| 32 |
+
os.makedirs(new_buildpath, exist_ok=True)
|
| 33 |
+
print(f"ninja_buildpath: {new_buildpath}")
|
| 34 |
+
# Copy files to directory
|
| 35 |
+
sources.clear()
|
| 36 |
+
current_dir = os.path.dirname(__file__)
|
| 37 |
+
ALLOWED_EXTENSIONS = {'.py', '.cu', '.cpp', '.h'}
|
| 38 |
+
for filename in os.listdir(current_dir):
|
| 39 |
+
item = pathlib.Path(current_dir).joinpath(filename)
|
| 40 |
+
tar_path = pathlib.Path(ninja_compile_dir).joinpath(item.name)
|
| 41 |
+
if not item.suffix.lower() in ALLOWED_EXTENSIONS:continue
|
| 42 |
+
pathlib.Path(shutil.copy2(item, tar_path))
|
| 43 |
+
if tar_path.name in resolves:sources.append(tar_path)
|
| 44 |
+
return new_buildpath
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def load():
|
| 49 |
+
# Check if cuda 11 is installed for compute capability 8.0
|
| 50 |
+
cc_flag = []
|
| 51 |
+
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
|
| 52 |
+
if int(bare_metal_major) >= 11:
|
| 53 |
+
cc_flag.append("-gencode")
|
| 54 |
+
cc_flag.append("arch=compute_80,code=sm_80")
|
| 55 |
+
|
| 56 |
+
# Build path
|
| 57 |
+
srcpath = pathlib.Path(__file__).parent.absolute()
|
| 58 |
+
buildpath = srcpath / "build"
|
| 59 |
+
_create_build_dir(buildpath)
|
| 60 |
+
|
| 61 |
+
# Helper function to build the kernels.
|
| 62 |
+
def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
|
| 63 |
+
return cpp_extension.load(
|
| 64 |
+
name=name,
|
| 65 |
+
sources=sources,
|
| 66 |
+
build_directory=buildpath,
|
| 67 |
+
extra_cflags=[
|
| 68 |
+
"-O3",
|
| 69 |
+
],
|
| 70 |
+
extra_cuda_cflags=[
|
| 71 |
+
"-O3",
|
| 72 |
+
"-gencode",
|
| 73 |
+
"arch=compute_70,code=sm_70",
|
| 74 |
+
"--use_fast_math",
|
| 75 |
+
]
|
| 76 |
+
+ extra_cuda_flags
|
| 77 |
+
+ cc_flag,
|
| 78 |
+
verbose=True,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
extra_cuda_flags = [
|
| 82 |
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
| 83 |
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
| 84 |
+
"--expt-relaxed-constexpr",
|
| 85 |
+
"--expt-extended-lambda",
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
sources = [
|
| 89 |
+
srcpath / "anti_alias_activation.cpp",
|
| 90 |
+
srcpath / "anti_alias_activation_cuda.cu",
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
# 兼容方案:ninja 特殊字符路径编译支持处理(比如中文路径)
|
| 94 |
+
buildpath = chinese_path_compile_support(sources, buildpath)
|
| 95 |
+
|
| 96 |
+
anti_alias_activation_cuda = _cpp_extention_load_helper(
|
| 97 |
+
"anti_alias_activation_cuda", sources, extra_cuda_flags
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
return anti_alias_activation_cuda
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _get_cuda_bare_metal_version(cuda_dir):
|
| 104 |
+
raw_output = subprocess.check_output(
|
| 105 |
+
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
| 106 |
+
)
|
| 107 |
+
output = raw_output.split()
|
| 108 |
+
release_idx = output.index("release") + 1
|
| 109 |
+
release = output[release_idx].split(".")
|
| 110 |
+
bare_metal_major = release[0]
|
| 111 |
+
bare_metal_minor = release[1][0]
|
| 112 |
+
|
| 113 |
+
return raw_output, bare_metal_major, bare_metal_minor
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _create_build_dir(buildpath):
|
| 117 |
+
try:
|
| 118 |
+
os.mkdir(buildpath)
|
| 119 |
+
except OSError:
|
| 120 |
+
if not os.path.isdir(buildpath):
|
| 121 |
+
print(f"Creation of the build directory {buildpath} failed")
|
indextts/BigVGAN/alias_free_activation/cuda/type_shim.h
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* coding=utf-8
|
| 2 |
+
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
* you may not use this file except in compliance with the License.
|
| 6 |
+
* You may obtain a copy of the License at
|
| 7 |
+
*
|
| 8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
*
|
| 10 |
+
* Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
* See the License for the specific language governing permissions and
|
| 14 |
+
* limitations under the License.
|
| 15 |
+
*/
|
| 16 |
+
|
| 17 |
+
#include <ATen/ATen.h>
|
| 18 |
+
#include "compat.h"
|
| 19 |
+
|
| 20 |
+
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
| 21 |
+
switch (TYPE) \
|
| 22 |
+
{ \
|
| 23 |
+
case at::ScalarType::Float: \
|
| 24 |
+
{ \
|
| 25 |
+
using scalar_t = float; \
|
| 26 |
+
__VA_ARGS__; \
|
| 27 |
+
break; \
|
| 28 |
+
} \
|
| 29 |
+
case at::ScalarType::Half: \
|
| 30 |
+
{ \
|
| 31 |
+
using scalar_t = at::Half; \
|
| 32 |
+
__VA_ARGS__; \
|
| 33 |
+
break; \
|
| 34 |
+
} \
|
| 35 |
+
case at::ScalarType::BFloat16: \
|
| 36 |
+
{ \
|
| 37 |
+
using scalar_t = at::BFloat16; \
|
| 38 |
+
__VA_ARGS__; \
|
| 39 |
+
break; \
|
| 40 |
+
} \
|
| 41 |
+
default: \
|
| 42 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
| 46 |
+
switch (TYPEIN) \
|
| 47 |
+
{ \
|
| 48 |
+
case at::ScalarType::Float: \
|
| 49 |
+
{ \
|
| 50 |
+
using scalar_t_in = float; \
|
| 51 |
+
switch (TYPEOUT) \
|
| 52 |
+
{ \
|
| 53 |
+
case at::ScalarType::Float: \
|
| 54 |
+
{ \
|
| 55 |
+
using scalar_t_out = float; \
|
| 56 |
+
__VA_ARGS__; \
|
| 57 |
+
break; \
|
| 58 |
+
} \
|
| 59 |
+
case at::ScalarType::Half: \
|
| 60 |
+
{ \
|
| 61 |
+
using scalar_t_out = at::Half; \
|
| 62 |
+
__VA_ARGS__; \
|
| 63 |
+
break; \
|
| 64 |
+
} \
|
| 65 |
+
case at::ScalarType::BFloat16: \
|
| 66 |
+
{ \
|
| 67 |
+
using scalar_t_out = at::BFloat16; \
|
| 68 |
+
__VA_ARGS__; \
|
| 69 |
+
break; \
|
| 70 |
+
} \
|
| 71 |
+
default: \
|
| 72 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
| 73 |
+
} \
|
| 74 |
+
break; \
|
| 75 |
+
} \
|
| 76 |
+
case at::ScalarType::Half: \
|
| 77 |
+
{ \
|
| 78 |
+
using scalar_t_in = at::Half; \
|
| 79 |
+
using scalar_t_out = at::Half; \
|
| 80 |
+
__VA_ARGS__; \
|
| 81 |
+
break; \
|
| 82 |
+
} \
|
| 83 |
+
case at::ScalarType::BFloat16: \
|
| 84 |
+
{ \
|
| 85 |
+
using scalar_t_in = at::BFloat16; \
|
| 86 |
+
using scalar_t_out = at::BFloat16; \
|
| 87 |
+
__VA_ARGS__; \
|
| 88 |
+
break; \
|
| 89 |
+
} \
|
| 90 |
+
default: \
|
| 91 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
| 92 |
+
}
|
indextts/BigVGAN/alias_free_activation/torch/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
from .act import *
|
| 5 |
+
from .filter import *
|
| 6 |
+
from .resample import *
|
indextts/BigVGAN/alias_free_activation/torch/act.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from .resample import DownSample1d, UpSample1d
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Activation1d(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
activation,
|
| 13 |
+
up_ratio: int = 2,
|
| 14 |
+
down_ratio: int = 2,
|
| 15 |
+
up_kernel_size: int = 12,
|
| 16 |
+
down_kernel_size: int = 12,
|
| 17 |
+
):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.up_ratio = up_ratio
|
| 20 |
+
self.down_ratio = down_ratio
|
| 21 |
+
self.act = activation
|
| 22 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
| 23 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
| 24 |
+
|
| 25 |
+
# x: [B,C,T]
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
x = self.upsample(x)
|
| 28 |
+
x = self.act(x)
|
| 29 |
+
x = self.downsample(x)
|
| 30 |
+
|
| 31 |
+
return x
|
indextts/BigVGAN/alias_free_activation/torch/filter.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
if "sinc" in dir(torch):
|
| 11 |
+
sinc = torch.sinc
|
| 12 |
+
else:
|
| 13 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
| 14 |
+
# https://adefossez.github.io/julius/julius/core.html
|
| 15 |
+
# LICENSE is in incl_licenses directory.
|
| 16 |
+
def sinc(x: torch.Tensor):
|
| 17 |
+
"""
|
| 18 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
| 19 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
| 20 |
+
"""
|
| 21 |
+
return torch.where(
|
| 22 |
+
x == 0,
|
| 23 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
| 24 |
+
torch.sin(math.pi * x) / math.pi / x,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
| 29 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
| 30 |
+
# LICENSE is in incl_licenses directory.
|
| 31 |
+
def kaiser_sinc_filter1d(
|
| 32 |
+
cutoff, half_width, kernel_size
|
| 33 |
+
): # return filter [1,1,kernel_size]
|
| 34 |
+
even = kernel_size % 2 == 0
|
| 35 |
+
half_size = kernel_size // 2
|
| 36 |
+
|
| 37 |
+
# For kaiser window
|
| 38 |
+
delta_f = 4 * half_width
|
| 39 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
| 40 |
+
if A > 50.0:
|
| 41 |
+
beta = 0.1102 * (A - 8.7)
|
| 42 |
+
elif A >= 21.0:
|
| 43 |
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
| 44 |
+
else:
|
| 45 |
+
beta = 0.0
|
| 46 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
| 47 |
+
|
| 48 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
| 49 |
+
if even:
|
| 50 |
+
time = torch.arange(-half_size, half_size) + 0.5
|
| 51 |
+
else:
|
| 52 |
+
time = torch.arange(kernel_size) - half_size
|
| 53 |
+
if cutoff == 0:
|
| 54 |
+
filter_ = torch.zeros_like(time)
|
| 55 |
+
else:
|
| 56 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
| 57 |
+
"""
|
| 58 |
+
Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
|
| 59 |
+
"""
|
| 60 |
+
filter_ /= filter_.sum()
|
| 61 |
+
filter = filter_.view(1, 1, kernel_size)
|
| 62 |
+
|
| 63 |
+
return filter
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class LowPassFilter1d(nn.Module):
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
cutoff=0.5,
|
| 70 |
+
half_width=0.6,
|
| 71 |
+
stride: int = 1,
|
| 72 |
+
padding: bool = True,
|
| 73 |
+
padding_mode: str = "replicate",
|
| 74 |
+
kernel_size: int = 12,
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
|
| 78 |
+
"""
|
| 79 |
+
super().__init__()
|
| 80 |
+
if cutoff < -0.0:
|
| 81 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
| 82 |
+
if cutoff > 0.5:
|
| 83 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
| 84 |
+
self.kernel_size = kernel_size
|
| 85 |
+
self.even = kernel_size % 2 == 0
|
| 86 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
| 87 |
+
self.pad_right = kernel_size // 2
|
| 88 |
+
self.stride = stride
|
| 89 |
+
self.padding = padding
|
| 90 |
+
self.padding_mode = padding_mode
|
| 91 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
| 92 |
+
self.register_buffer("filter", filter)
|
| 93 |
+
|
| 94 |
+
# Input [B, C, T]
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
_, C, _ = x.shape
|
| 97 |
+
|
| 98 |
+
if self.padding:
|
| 99 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
| 100 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
| 101 |
+
|
| 102 |
+
return out
|
indextts/BigVGAN/alias_free_activation/torch/resample.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from .filter import LowPassFilter1d, kaiser_sinc_filter1d
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class UpSample1d(nn.Module):
|
| 11 |
+
def __init__(self, ratio=2, kernel_size=None):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.ratio = ratio
|
| 14 |
+
self.kernel_size = (
|
| 15 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 16 |
+
)
|
| 17 |
+
self.stride = ratio
|
| 18 |
+
self.pad = self.kernel_size // ratio - 1
|
| 19 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
| 20 |
+
self.pad_right = (
|
| 21 |
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
| 22 |
+
)
|
| 23 |
+
filter = kaiser_sinc_filter1d(
|
| 24 |
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
| 25 |
+
)
|
| 26 |
+
self.register_buffer("filter", filter)
|
| 27 |
+
|
| 28 |
+
# x: [B, C, T]
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
_, C, _ = x.shape
|
| 31 |
+
|
| 32 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
| 33 |
+
x = self.ratio * F.conv_transpose1d(
|
| 34 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
| 35 |
+
)
|
| 36 |
+
x = x[..., self.pad_left : -self.pad_right]
|
| 37 |
+
|
| 38 |
+
return x
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class DownSample1d(nn.Module):
|
| 42 |
+
def __init__(self, ratio=2, kernel_size=None):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.ratio = ratio
|
| 45 |
+
self.kernel_size = (
|
| 46 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 47 |
+
)
|
| 48 |
+
self.lowpass = LowPassFilter1d(
|
| 49 |
+
cutoff=0.5 / ratio,
|
| 50 |
+
half_width=0.6 / ratio,
|
| 51 |
+
stride=ratio,
|
| 52 |
+
kernel_size=self.kernel_size,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
xx = self.lowpass(x)
|
| 57 |
+
|
| 58 |
+
return xx
|
indextts/BigVGAN/alias_free_torch/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
from .act import *
|
| 5 |
+
from .filter import *
|
| 6 |
+
from .resample import *
|
indextts/BigVGAN/alias_free_torch/act.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from .resample import DownSample1d, UpSample1d
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Activation1d(nn.Module):
|
| 10 |
+
def __init__(self,
|
| 11 |
+
activation,
|
| 12 |
+
up_ratio: int = 2,
|
| 13 |
+
down_ratio: int = 2,
|
| 14 |
+
up_kernel_size: int = 12,
|
| 15 |
+
down_kernel_size: int = 12):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.up_ratio = up_ratio
|
| 18 |
+
self.down_ratio = down_ratio
|
| 19 |
+
self.act = activation
|
| 20 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
| 21 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
| 22 |
+
|
| 23 |
+
# x: [B,C,T]
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
x = self.upsample(x)
|
| 26 |
+
x = self.act(x)
|
| 27 |
+
x = self.downsample(x)
|
| 28 |
+
|
| 29 |
+
return x
|
indextts/BigVGAN/alias_free_torch/filter.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
if 'sinc' in dir(torch):
|
| 11 |
+
sinc = torch.sinc
|
| 12 |
+
else:
|
| 13 |
+
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
| 14 |
+
# https://adefossez.github.io/julius/julius/core.html
|
| 15 |
+
# LICENSE is in incl_licenses directory.
|
| 16 |
+
def sinc(x: torch.Tensor):
|
| 17 |
+
"""
|
| 18 |
+
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
| 19 |
+
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
| 20 |
+
"""
|
| 21 |
+
return torch.where(x == 0,
|
| 22 |
+
torch.tensor(1., device=x.device, dtype=x.dtype),
|
| 23 |
+
torch.sin(math.pi * x) / math.pi / x)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
| 27 |
+
# https://adefossez.github.io/julius/julius/lowpass.html
|
| 28 |
+
# LICENSE is in incl_licenses directory.
|
| 29 |
+
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
| 30 |
+
even = (kernel_size % 2 == 0)
|
| 31 |
+
half_size = kernel_size // 2
|
| 32 |
+
|
| 33 |
+
#For kaiser window
|
| 34 |
+
delta_f = 4 * half_width
|
| 35 |
+
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
| 36 |
+
if A > 50.:
|
| 37 |
+
beta = 0.1102 * (A - 8.7)
|
| 38 |
+
elif A >= 21.:
|
| 39 |
+
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
| 40 |
+
else:
|
| 41 |
+
beta = 0.
|
| 42 |
+
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
| 43 |
+
|
| 44 |
+
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
| 45 |
+
if even:
|
| 46 |
+
time = (torch.arange(-half_size, half_size) + 0.5)
|
| 47 |
+
else:
|
| 48 |
+
time = torch.arange(kernel_size) - half_size
|
| 49 |
+
if cutoff == 0:
|
| 50 |
+
filter_ = torch.zeros_like(time)
|
| 51 |
+
else:
|
| 52 |
+
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
| 53 |
+
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
| 54 |
+
# of the constant component in the input signal.
|
| 55 |
+
filter_ /= filter_.sum()
|
| 56 |
+
filter = filter_.view(1, 1, kernel_size)
|
| 57 |
+
|
| 58 |
+
return filter
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class LowPassFilter1d(nn.Module):
|
| 62 |
+
def __init__(self,
|
| 63 |
+
cutoff=0.5,
|
| 64 |
+
half_width=0.6,
|
| 65 |
+
stride: int = 1,
|
| 66 |
+
padding: bool = True,
|
| 67 |
+
padding_mode: str = 'replicate',
|
| 68 |
+
kernel_size: int = 12):
|
| 69 |
+
# kernel_size should be even number for stylegan3 setup,
|
| 70 |
+
# in this implementation, odd number is also possible.
|
| 71 |
+
super().__init__()
|
| 72 |
+
if cutoff < -0.:
|
| 73 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
| 74 |
+
if cutoff > 0.5:
|
| 75 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
| 76 |
+
self.kernel_size = kernel_size
|
| 77 |
+
self.even = (kernel_size % 2 == 0)
|
| 78 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
| 79 |
+
self.pad_right = kernel_size // 2
|
| 80 |
+
self.stride = stride
|
| 81 |
+
self.padding = padding
|
| 82 |
+
self.padding_mode = padding_mode
|
| 83 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
| 84 |
+
self.register_buffer("filter", filter)
|
| 85 |
+
|
| 86 |
+
#input [B, C, T]
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
_, C, _ = x.shape
|
| 89 |
+
|
| 90 |
+
if self.padding:
|
| 91 |
+
x = F.pad(x, (self.pad_left, self.pad_right),
|
| 92 |
+
mode=self.padding_mode)
|
| 93 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1),
|
| 94 |
+
stride=self.stride, groups=C)
|
| 95 |
+
|
| 96 |
+
return out
|
indextts/BigVGAN/alias_free_torch/resample.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from .filter import LowPassFilter1d, kaiser_sinc_filter1d
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class UpSample1d(nn.Module):
|
| 11 |
+
def __init__(self, ratio=2, kernel_size=None):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.ratio = ratio
|
| 14 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 15 |
+
self.stride = ratio
|
| 16 |
+
self.pad = self.kernel_size // ratio - 1
|
| 17 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
| 18 |
+
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
| 19 |
+
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
| 20 |
+
half_width=0.6 / ratio,
|
| 21 |
+
kernel_size=self.kernel_size)
|
| 22 |
+
self.register_buffer("filter", filter)
|
| 23 |
+
|
| 24 |
+
# x: [B, C, T]
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
_, C, _ = x.shape
|
| 27 |
+
|
| 28 |
+
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
| 29 |
+
x = self.ratio * F.conv_transpose1d(
|
| 30 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
| 31 |
+
x = x[..., self.pad_left:-self.pad_right]
|
| 32 |
+
|
| 33 |
+
return x
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class DownSample1d(nn.Module):
|
| 37 |
+
def __init__(self, ratio=2, kernel_size=None):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.ratio = ratio
|
| 40 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 41 |
+
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
|
| 42 |
+
half_width=0.6 / ratio,
|
| 43 |
+
stride=ratio,
|
| 44 |
+
kernel_size=self.kernel_size)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
xx = self.lowpass(x)
|
| 48 |
+
|
| 49 |
+
return xx
|
indextts/BigVGAN/bigvgan.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
| 5 |
+
# LICENSE is in incl_licenses directory.
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Dict, Optional, Union
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
| 15 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
| 16 |
+
from torch.nn.utils import remove_weight_norm, weight_norm
|
| 17 |
+
|
| 18 |
+
import indextts.BigVGAN.activations as activations
|
| 19 |
+
from indextts.BigVGAN.alias_free_activation.torch.act import \
|
| 20 |
+
Activation1d as TorchActivation1d
|
| 21 |
+
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
|
| 22 |
+
from indextts.BigVGAN.env import AttrDict
|
| 23 |
+
from indextts.BigVGAN.utils import get_padding, init_weights
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_hparams_from_json(path) -> AttrDict:
|
| 27 |
+
with open(path) as f:
|
| 28 |
+
data = f.read()
|
| 29 |
+
return AttrDict(json.loads(data))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class AMPBlock1(torch.nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
| 35 |
+
AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
h (AttrDict): Hyperparameters.
|
| 39 |
+
channels (int): Number of convolution channels.
|
| 40 |
+
kernel_size (int): Size of the convolution kernel. Default is 3.
|
| 41 |
+
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
| 42 |
+
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
h: AttrDict,
|
| 48 |
+
channels: int,
|
| 49 |
+
kernel_size: int = 3,
|
| 50 |
+
dilation: tuple = (1, 3, 5),
|
| 51 |
+
activation: str = None,
|
| 52 |
+
):
|
| 53 |
+
super().__init__()
|
| 54 |
+
|
| 55 |
+
self.h = h
|
| 56 |
+
|
| 57 |
+
self.convs1 = nn.ModuleList(
|
| 58 |
+
[
|
| 59 |
+
weight_norm(
|
| 60 |
+
Conv1d(
|
| 61 |
+
channels,
|
| 62 |
+
channels,
|
| 63 |
+
kernel_size,
|
| 64 |
+
stride=1,
|
| 65 |
+
dilation=d,
|
| 66 |
+
padding=get_padding(kernel_size, d),
|
| 67 |
+
)
|
| 68 |
+
)
|
| 69 |
+
for d in dilation
|
| 70 |
+
]
|
| 71 |
+
)
|
| 72 |
+
self.convs1.apply(init_weights)
|
| 73 |
+
|
| 74 |
+
self.convs2 = nn.ModuleList(
|
| 75 |
+
[
|
| 76 |
+
weight_norm(
|
| 77 |
+
Conv1d(
|
| 78 |
+
channels,
|
| 79 |
+
channels,
|
| 80 |
+
kernel_size,
|
| 81 |
+
stride=1,
|
| 82 |
+
dilation=1,
|
| 83 |
+
padding=get_padding(kernel_size, 1),
|
| 84 |
+
)
|
| 85 |
+
)
|
| 86 |
+
for _ in range(len(dilation))
|
| 87 |
+
]
|
| 88 |
+
)
|
| 89 |
+
self.convs2.apply(init_weights)
|
| 90 |
+
|
| 91 |
+
self.num_layers = len(self.convs1) + len(
|
| 92 |
+
self.convs2
|
| 93 |
+
) # Total number of conv layers
|
| 94 |
+
|
| 95 |
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
| 96 |
+
if self.h.get("use_cuda_kernel", False):
|
| 97 |
+
from alias_free_activation.cuda.activation1d import \
|
| 98 |
+
Activation1d as CudaActivation1d
|
| 99 |
+
|
| 100 |
+
Activation1d = CudaActivation1d
|
| 101 |
+
else:
|
| 102 |
+
Activation1d = TorchActivation1d
|
| 103 |
+
|
| 104 |
+
# Activation functions
|
| 105 |
+
if activation == "snake":
|
| 106 |
+
self.activations = nn.ModuleList(
|
| 107 |
+
[
|
| 108 |
+
Activation1d(
|
| 109 |
+
activation=activations.Snake(
|
| 110 |
+
channels, alpha_logscale=h.snake_logscale
|
| 111 |
+
)
|
| 112 |
+
)
|
| 113 |
+
for _ in range(self.num_layers)
|
| 114 |
+
]
|
| 115 |
+
)
|
| 116 |
+
elif activation == "snakebeta":
|
| 117 |
+
self.activations = nn.ModuleList(
|
| 118 |
+
[
|
| 119 |
+
Activation1d(
|
| 120 |
+
activation=activations.SnakeBeta(
|
| 121 |
+
channels, alpha_logscale=h.snake_logscale
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
for _ in range(self.num_layers)
|
| 125 |
+
]
|
| 126 |
+
)
|
| 127 |
+
else:
|
| 128 |
+
raise NotImplementedError(
|
| 129 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def forward(self, x):
|
| 133 |
+
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
| 134 |
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
| 135 |
+
xt = a1(x)
|
| 136 |
+
xt = c1(xt)
|
| 137 |
+
xt = a2(xt)
|
| 138 |
+
xt = c2(xt)
|
| 139 |
+
x = xt + x
|
| 140 |
+
|
| 141 |
+
return x
|
| 142 |
+
|
| 143 |
+
def remove_weight_norm(self):
|
| 144 |
+
for l in self.convs1:
|
| 145 |
+
remove_weight_norm(l)
|
| 146 |
+
for l in self.convs2:
|
| 147 |
+
remove_weight_norm(l)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class AMPBlock2(torch.nn.Module):
|
| 151 |
+
"""
|
| 152 |
+
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
| 153 |
+
Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
h (AttrDict): Hyperparameters.
|
| 157 |
+
channels (int): Number of convolution channels.
|
| 158 |
+
kernel_size (int): Size of the convolution kernel. Default is 3.
|
| 159 |
+
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
| 160 |
+
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def __init__(
|
| 164 |
+
self,
|
| 165 |
+
h: AttrDict,
|
| 166 |
+
channels: int,
|
| 167 |
+
kernel_size: int = 3,
|
| 168 |
+
dilation: tuple = (1, 3, 5),
|
| 169 |
+
activation: str = None,
|
| 170 |
+
):
|
| 171 |
+
super().__init__()
|
| 172 |
+
|
| 173 |
+
self.h = h
|
| 174 |
+
|
| 175 |
+
self.convs = nn.ModuleList(
|
| 176 |
+
[
|
| 177 |
+
weight_norm(
|
| 178 |
+
Conv1d(
|
| 179 |
+
channels,
|
| 180 |
+
channels,
|
| 181 |
+
kernel_size,
|
| 182 |
+
stride=1,
|
| 183 |
+
dilation=d,
|
| 184 |
+
padding=get_padding(kernel_size, d),
|
| 185 |
+
)
|
| 186 |
+
)
|
| 187 |
+
for d in dilation
|
| 188 |
+
]
|
| 189 |
+
)
|
| 190 |
+
self.convs.apply(init_weights)
|
| 191 |
+
|
| 192 |
+
self.num_layers = len(self.convs) # Total number of conv layers
|
| 193 |
+
|
| 194 |
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
| 195 |
+
if self.h.get("use_cuda_kernel", False):
|
| 196 |
+
from alias_free_activation.cuda.activation1d import \
|
| 197 |
+
Activation1d as CudaActivation1d
|
| 198 |
+
|
| 199 |
+
Activation1d = CudaActivation1d
|
| 200 |
+
else:
|
| 201 |
+
Activation1d = TorchActivation1d
|
| 202 |
+
|
| 203 |
+
# Activation functions
|
| 204 |
+
if activation == "snake":
|
| 205 |
+
self.activations = nn.ModuleList(
|
| 206 |
+
[
|
| 207 |
+
Activation1d(
|
| 208 |
+
activation=activations.Snake(
|
| 209 |
+
channels, alpha_logscale=h.snake_logscale
|
| 210 |
+
)
|
| 211 |
+
)
|
| 212 |
+
for _ in range(self.num_layers)
|
| 213 |
+
]
|
| 214 |
+
)
|
| 215 |
+
elif activation == "snakebeta":
|
| 216 |
+
self.activations = nn.ModuleList(
|
| 217 |
+
[
|
| 218 |
+
Activation1d(
|
| 219 |
+
activation=activations.SnakeBeta(
|
| 220 |
+
channels, alpha_logscale=h.snake_logscale
|
| 221 |
+
)
|
| 222 |
+
)
|
| 223 |
+
for _ in range(self.num_layers)
|
| 224 |
+
]
|
| 225 |
+
)
|
| 226 |
+
else:
|
| 227 |
+
raise NotImplementedError(
|
| 228 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
def forward(self, x):
|
| 232 |
+
for c, a in zip(self.convs, self.activations):
|
| 233 |
+
xt = a(x)
|
| 234 |
+
xt = c(xt)
|
| 235 |
+
x = xt + x
|
| 236 |
+
return x
|
| 237 |
+
|
| 238 |
+
def remove_weight_norm(self):
|
| 239 |
+
for l in self.convs:
|
| 240 |
+
remove_weight_norm(l)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
'''
|
| 244 |
+
PyTorchModelHubMixin,
|
| 245 |
+
library_name="bigvgan",
|
| 246 |
+
repo_url="https://github.com/NVIDIA/BigVGAN",
|
| 247 |
+
docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
|
| 248 |
+
pipeline_tag="audio-to-audio",
|
| 249 |
+
license="mit",
|
| 250 |
+
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
|
| 251 |
+
'''
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class BigVGAN(
|
| 255 |
+
torch.nn.Module,
|
| 256 |
+
):
|
| 257 |
+
"""
|
| 258 |
+
BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
|
| 259 |
+
New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
h (AttrDict): Hyperparameters.
|
| 263 |
+
use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
|
| 264 |
+
|
| 265 |
+
Note:
|
| 266 |
+
- The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
|
| 267 |
+
- Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
|
| 271 |
+
super().__init__()
|
| 272 |
+
self.h = h
|
| 273 |
+
self.h["use_cuda_kernel"] = use_cuda_kernel
|
| 274 |
+
|
| 275 |
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
| 276 |
+
if self.h.get("use_cuda_kernel", False):
|
| 277 |
+
from alias_free_activation.cuda.activation1d import \
|
| 278 |
+
Activation1d as CudaActivation1d
|
| 279 |
+
|
| 280 |
+
Activation1d = CudaActivation1d
|
| 281 |
+
else:
|
| 282 |
+
Activation1d = TorchActivation1d
|
| 283 |
+
|
| 284 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
| 285 |
+
self.num_upsamples = len(h.upsample_rates)
|
| 286 |
+
|
| 287 |
+
self.feat_upsample = h.feat_upsample
|
| 288 |
+
self.cond_in_each_up_layer = h.cond_d_vector_in_each_upsampling_layer
|
| 289 |
+
|
| 290 |
+
# Pre-conv
|
| 291 |
+
self.conv_pre = weight_norm(
|
| 292 |
+
Conv1d(h.gpt_dim, h.upsample_initial_channel, 7, 1, padding=3)
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
| 296 |
+
if h.resblock == "1":
|
| 297 |
+
resblock_class = AMPBlock1
|
| 298 |
+
elif h.resblock == "2":
|
| 299 |
+
resblock_class = AMPBlock2
|
| 300 |
+
else:
|
| 301 |
+
raise ValueError(
|
| 302 |
+
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
| 306 |
+
self.ups = nn.ModuleList()
|
| 307 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
| 308 |
+
self.ups.append(
|
| 309 |
+
nn.ModuleList(
|
| 310 |
+
[
|
| 311 |
+
weight_norm(
|
| 312 |
+
ConvTranspose1d(
|
| 313 |
+
h.upsample_initial_channel // (2**i),
|
| 314 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
| 315 |
+
k,
|
| 316 |
+
u,
|
| 317 |
+
padding=(k - u) // 2,
|
| 318 |
+
)
|
| 319 |
+
)
|
| 320 |
+
]
|
| 321 |
+
)
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
| 325 |
+
self.resblocks = nn.ModuleList()
|
| 326 |
+
for i in range(len(self.ups)):
|
| 327 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
| 328 |
+
for j, (k, d) in enumerate(
|
| 329 |
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
| 330 |
+
):
|
| 331 |
+
self.resblocks.append(
|
| 332 |
+
resblock_class(h, ch, k, d, activation=h.activation)
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
# Post-conv
|
| 336 |
+
activation_post = (
|
| 337 |
+
activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
| 338 |
+
if h.activation == "snake"
|
| 339 |
+
else (
|
| 340 |
+
activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
| 341 |
+
if h.activation == "snakebeta"
|
| 342 |
+
else None
|
| 343 |
+
)
|
| 344 |
+
)
|
| 345 |
+
if activation_post is None:
|
| 346 |
+
raise NotImplementedError(
|
| 347 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
self.activation_post = Activation1d(activation=activation_post)
|
| 351 |
+
|
| 352 |
+
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
| 353 |
+
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
| 354 |
+
self.conv_post = weight_norm(
|
| 355 |
+
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# Weight initialization
|
| 359 |
+
for i in range(len(self.ups)):
|
| 360 |
+
self.ups[i].apply(init_weights)
|
| 361 |
+
self.conv_post.apply(init_weights)
|
| 362 |
+
|
| 363 |
+
# Final tanh activation. Defaults to True for backward compatibility
|
| 364 |
+
self.use_tanh_at_final = h.get("use_tanh_at_final", True)
|
| 365 |
+
|
| 366 |
+
self.speaker_encoder = ECAPA_TDNN(h.num_mels, lin_neurons=h.speaker_embedding_dim)
|
| 367 |
+
self.cond_layer = nn.Conv1d(h.speaker_embedding_dim, h.upsample_initial_channel, 1)
|
| 368 |
+
if self.cond_in_each_up_layer:
|
| 369 |
+
self.conds = nn.ModuleList()
|
| 370 |
+
for i in range(len(self.ups)):
|
| 371 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
| 372 |
+
self.conds.append(nn.Conv1d(h.speaker_embedding_dim, ch, 1))
|
| 373 |
+
|
| 374 |
+
def forward(self, x, mel_refer, lens=None):
|
| 375 |
+
# Speaker reference
|
| 376 |
+
speaker_embedding = self.speaker_encoder(mel_refer, lens)
|
| 377 |
+
n_batch = x.size(0)
|
| 378 |
+
contrastive_loss = None
|
| 379 |
+
if n_batch * 2 == speaker_embedding.size(0):
|
| 380 |
+
spe_emb_chunk1, spe_emb_chunk2 = speaker_embedding[:n_batch, :, :], speaker_embedding[n_batch:, :, :]
|
| 381 |
+
contrastive_loss = self.cal_clip_loss(spe_emb_chunk1.squeeze(1), spe_emb_chunk2.squeeze(1),
|
| 382 |
+
self.logit_scale.exp())
|
| 383 |
+
|
| 384 |
+
speaker_embedding = speaker_embedding[:n_batch, :, :]
|
| 385 |
+
speaker_embedding = speaker_embedding.transpose(1, 2)
|
| 386 |
+
|
| 387 |
+
# upsample feat
|
| 388 |
+
if self.feat_upsample:
|
| 389 |
+
x = torch.nn.functional.interpolate(
|
| 390 |
+
x.transpose(1, 2),
|
| 391 |
+
scale_factor=[4],
|
| 392 |
+
mode="linear",
|
| 393 |
+
).squeeze(1)
|
| 394 |
+
else:
|
| 395 |
+
x = x.transpose(1, 2)
|
| 396 |
+
|
| 397 |
+
# BigVGAN
|
| 398 |
+
# Pre-conv
|
| 399 |
+
x = self.conv_pre(x)
|
| 400 |
+
x = x + self.cond_layer(speaker_embedding)
|
| 401 |
+
|
| 402 |
+
for i in range(self.num_upsamples):
|
| 403 |
+
# Upsampling
|
| 404 |
+
for i_up in range(len(self.ups[i])):
|
| 405 |
+
x = self.ups[i][i_up](x)
|
| 406 |
+
|
| 407 |
+
if self.cond_in_each_up_layer:
|
| 408 |
+
x = x + self.conds[i](speaker_embedding)
|
| 409 |
+
|
| 410 |
+
# AMP blocks
|
| 411 |
+
xs = None
|
| 412 |
+
for j in range(self.num_kernels):
|
| 413 |
+
if xs is None:
|
| 414 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 415 |
+
else:
|
| 416 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 417 |
+
x = xs / self.num_kernels
|
| 418 |
+
|
| 419 |
+
# Post-conv
|
| 420 |
+
x = self.activation_post(x)
|
| 421 |
+
x = self.conv_post(x)
|
| 422 |
+
# Final tanh activation
|
| 423 |
+
if self.use_tanh_at_final:
|
| 424 |
+
x = torch.tanh(x)
|
| 425 |
+
else:
|
| 426 |
+
x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
|
| 427 |
+
|
| 428 |
+
return x, contrastive_loss
|
| 429 |
+
|
| 430 |
+
def remove_weight_norm(self):
|
| 431 |
+
try:
|
| 432 |
+
print("Removing weight norm...")
|
| 433 |
+
for l in self.ups:
|
| 434 |
+
for l_i in l:
|
| 435 |
+
remove_weight_norm(l_i)
|
| 436 |
+
for l in self.resblocks:
|
| 437 |
+
l.remove_weight_norm()
|
| 438 |
+
remove_weight_norm(self.conv_pre)
|
| 439 |
+
remove_weight_norm(self.conv_post)
|
| 440 |
+
except ValueError:
|
| 441 |
+
print("[INFO] Model already removed weight norm. Skipping!")
|
| 442 |
+
pass
|
| 443 |
+
|
| 444 |
+
# Additional methods for huggingface_hub support
|
| 445 |
+
def _save_pretrained(self, save_directory: Path) -> None:
|
| 446 |
+
"""Save weights and config.json from a Pytorch model to a local directory."""
|
| 447 |
+
|
| 448 |
+
model_path = save_directory / "bigvgan_generator.pt"
|
| 449 |
+
torch.save({"generator": self.state_dict()}, model_path)
|
| 450 |
+
|
| 451 |
+
config_path = save_directory / "config.json"
|
| 452 |
+
with open(config_path, "w") as config_file:
|
| 453 |
+
json.dump(self.h, config_file, indent=4)
|
| 454 |
+
|
| 455 |
+
@classmethod
|
| 456 |
+
def _from_pretrained(
|
| 457 |
+
cls,
|
| 458 |
+
*,
|
| 459 |
+
model_id: str,
|
| 460 |
+
revision: str,
|
| 461 |
+
cache_dir: str,
|
| 462 |
+
force_download: bool,
|
| 463 |
+
proxies: Optional[Dict],
|
| 464 |
+
resume_download: bool,
|
| 465 |
+
local_files_only: bool,
|
| 466 |
+
token: Union[str, bool, None],
|
| 467 |
+
map_location: str = "cpu", # Additional argument
|
| 468 |
+
strict: bool = False, # Additional argument
|
| 469 |
+
use_cuda_kernel: bool = False,
|
| 470 |
+
**model_kwargs,
|
| 471 |
+
):
|
| 472 |
+
"""Load Pytorch pretrained weights and return the loaded model."""
|
| 473 |
+
|
| 474 |
+
# Download and load hyperparameters (h) used by BigVGAN
|
| 475 |
+
if os.path.isdir(model_id):
|
| 476 |
+
print("Loading config.json from local directory")
|
| 477 |
+
config_file = os.path.join(model_id, "config.json")
|
| 478 |
+
else:
|
| 479 |
+
config_file = hf_hub_download(
|
| 480 |
+
repo_id=model_id,
|
| 481 |
+
filename="config.json",
|
| 482 |
+
revision=revision,
|
| 483 |
+
cache_dir=cache_dir,
|
| 484 |
+
force_download=force_download,
|
| 485 |
+
proxies=proxies,
|
| 486 |
+
resume_download=resume_download,
|
| 487 |
+
token=token,
|
| 488 |
+
local_files_only=local_files_only,
|
| 489 |
+
)
|
| 490 |
+
h = load_hparams_from_json(config_file)
|
| 491 |
+
|
| 492 |
+
# instantiate BigVGAN using h
|
| 493 |
+
if use_cuda_kernel:
|
| 494 |
+
print(
|
| 495 |
+
f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
| 496 |
+
)
|
| 497 |
+
print(
|
| 498 |
+
f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
| 499 |
+
)
|
| 500 |
+
print(
|
| 501 |
+
f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
| 502 |
+
)
|
| 503 |
+
model = cls(h, use_cuda_kernel=use_cuda_kernel)
|
| 504 |
+
|
| 505 |
+
# Download and load pretrained generator weight
|
| 506 |
+
if os.path.isdir(model_id):
|
| 507 |
+
print("Loading weights from local directory")
|
| 508 |
+
model_file = os.path.join(model_id, "bigvgan_generator.pt")
|
| 509 |
+
else:
|
| 510 |
+
print(f"Loading weights from {model_id}")
|
| 511 |
+
model_file = hf_hub_download(
|
| 512 |
+
repo_id=model_id,
|
| 513 |
+
filename="bigvgan_generator.pt",
|
| 514 |
+
revision=revision,
|
| 515 |
+
cache_dir=cache_dir,
|
| 516 |
+
force_download=force_download,
|
| 517 |
+
proxies=proxies,
|
| 518 |
+
resume_download=resume_download,
|
| 519 |
+
token=token,
|
| 520 |
+
local_files_only=local_files_only,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
checkpoint_dict = torch.load(model_file, map_location=map_location)
|
| 524 |
+
|
| 525 |
+
try:
|
| 526 |
+
model.load_state_dict(checkpoint_dict["generator"])
|
| 527 |
+
except RuntimeError:
|
| 528 |
+
print(
|
| 529 |
+
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
| 530 |
+
)
|
| 531 |
+
model.remove_weight_norm()
|
| 532 |
+
model.load_state_dict(checkpoint_dict["generator"])
|
| 533 |
+
|
| 534 |
+
return model
|
indextts/BigVGAN/models.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022 NVIDIA CORPORATION.
|
| 2 |
+
# Licensed under the MIT license.
|
| 3 |
+
|
| 4 |
+
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
| 5 |
+
# LICENSE is in incl_licenses directory.
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
| 10 |
+
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
| 11 |
+
|
| 12 |
+
import indextts.BigVGAN.activations as activations
|
| 13 |
+
|
| 14 |
+
from indextts.BigVGAN.ECAPA_TDNN import ECAPA_TDNN
|
| 15 |
+
from indextts.BigVGAN.utils import get_padding, init_weights
|
| 16 |
+
|
| 17 |
+
LRELU_SLOPE = 0.1
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AMPBlock1(torch.nn.Module):
|
| 21 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
|
| 22 |
+
super(AMPBlock1, self).__init__()
|
| 23 |
+
self.h = h
|
| 24 |
+
|
| 25 |
+
self.convs1 = nn.ModuleList([
|
| 26 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
| 27 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
| 28 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
| 29 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
| 30 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
| 31 |
+
padding=get_padding(kernel_size, dilation[2])))
|
| 32 |
+
])
|
| 33 |
+
self.convs1.apply(init_weights)
|
| 34 |
+
|
| 35 |
+
self.convs2 = nn.ModuleList([
|
| 36 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 37 |
+
padding=get_padding(kernel_size, 1))),
|
| 38 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 39 |
+
padding=get_padding(kernel_size, 1))),
|
| 40 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
| 41 |
+
padding=get_padding(kernel_size, 1)))
|
| 42 |
+
])
|
| 43 |
+
self.convs2.apply(init_weights)
|
| 44 |
+
|
| 45 |
+
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
| 46 |
+
if self.h.get("use_cuda_kernel", False):
|
| 47 |
+
from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d
|
| 48 |
+
else:
|
| 49 |
+
from indextts.BigVGAN.alias_free_torch import Activation1d
|
| 50 |
+
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
| 51 |
+
self.activations = nn.ModuleList([
|
| 52 |
+
Activation1d(
|
| 53 |
+
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
| 54 |
+
for _ in range(self.num_layers)
|
| 55 |
+
])
|
| 56 |
+
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
| 57 |
+
self.activations = nn.ModuleList([
|
| 58 |
+
Activation1d(
|
| 59 |
+
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
| 60 |
+
for _ in range(self.num_layers)
|
| 61 |
+
])
|
| 62 |
+
else:
|
| 63 |
+
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
| 67 |
+
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
| 68 |
+
xt = a1(x)
|
| 69 |
+
xt = c1(xt)
|
| 70 |
+
xt = a2(xt)
|
| 71 |
+
xt = c2(xt)
|
| 72 |
+
x = xt + x
|
| 73 |
+
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
def remove_weight_norm(self):
|
| 77 |
+
for l in self.convs1:
|
| 78 |
+
remove_weight_norm(l)
|
| 79 |
+
for l in self.convs2:
|
| 80 |
+
remove_weight_norm(l)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class AMPBlock2(torch.nn.Module):
|
| 84 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
|
| 85 |
+
super(AMPBlock2, self).__init__()
|
| 86 |
+
self.h = h
|
| 87 |
+
|
| 88 |
+
self.convs = nn.ModuleList([
|
| 89 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
| 90 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
| 91 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
| 92 |
+
padding=get_padding(kernel_size, dilation[1])))
|
| 93 |
+
])
|
| 94 |
+
self.convs.apply(init_weights)
|
| 95 |
+
|
| 96 |
+
self.num_layers = len(self.convs) # total number of conv layers
|
| 97 |
+
if self.h.get("use_cuda_kernel", False):
|
| 98 |
+
from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d
|
| 99 |
+
else:
|
| 100 |
+
from indextts.BigVGAN.alias_free_torch import Activation1d
|
| 101 |
+
|
| 102 |
+
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
| 103 |
+
self.activations = nn.ModuleList([
|
| 104 |
+
Activation1d(
|
| 105 |
+
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
| 106 |
+
for _ in range(self.num_layers)
|
| 107 |
+
])
|
| 108 |
+
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
| 109 |
+
self.activations = nn.ModuleList([
|
| 110 |
+
Activation1d(
|
| 111 |
+
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
| 112 |
+
for _ in range(self.num_layers)
|
| 113 |
+
])
|
| 114 |
+
else:
|
| 115 |
+
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
|
| 116 |
+
|
| 117 |
+
def forward(self, x):
|
| 118 |
+
for c, a in zip(self.convs, self.activations):
|
| 119 |
+
xt = a(x)
|
| 120 |
+
xt = c(xt)
|
| 121 |
+
x = xt + x
|
| 122 |
+
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
def remove_weight_norm(self):
|
| 126 |
+
for l in self.convs:
|
| 127 |
+
remove_weight_norm(l)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class BigVGAN(torch.nn.Module):
|
| 131 |
+
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
|
| 132 |
+
def __init__(self, h, use_cuda_kernel=False):
|
| 133 |
+
"""
|
| 134 |
+
Args:
|
| 135 |
+
h (dict)
|
| 136 |
+
use_cuda_kernel (bool): whether to use custom cuda kernel for anti-aliased activation
|
| 137 |
+
"""
|
| 138 |
+
super(BigVGAN, self).__init__()
|
| 139 |
+
self.h = h
|
| 140 |
+
self.h["use_cuda_kernel"] = use_cuda_kernel
|
| 141 |
+
|
| 142 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
| 143 |
+
self.num_upsamples = len(h.upsample_rates)
|
| 144 |
+
|
| 145 |
+
self.feat_upsample = h.feat_upsample
|
| 146 |
+
self.cond_in_each_up_layer = h.cond_d_vector_in_each_upsampling_layer
|
| 147 |
+
|
| 148 |
+
# pre conv
|
| 149 |
+
self.conv_pre = weight_norm(Conv1d(h.gpt_dim, h.upsample_initial_channel, 7, 1, padding=3))
|
| 150 |
+
|
| 151 |
+
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
| 152 |
+
resblock = AMPBlock1 if h.resblock == "1" else AMPBlock2
|
| 153 |
+
|
| 154 |
+
# transposed conv-based upsamplers. does not apply anti-aliasing
|
| 155 |
+
self.ups = nn.ModuleList()
|
| 156 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
| 157 |
+
self.ups.append(nn.ModuleList([
|
| 158 |
+
weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i),
|
| 159 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
| 160 |
+
k, u, padding=(k - u) // 2))
|
| 161 |
+
]))
|
| 162 |
+
|
| 163 |
+
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
| 164 |
+
self.resblocks = nn.ModuleList()
|
| 165 |
+
for i in range(len(self.ups)):
|
| 166 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
| 167 |
+
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
| 168 |
+
self.resblocks.append(resblock(self.h, ch, k, d, activation=h.activation))
|
| 169 |
+
if use_cuda_kernel:
|
| 170 |
+
from indextts.BigVGAN.alias_free_activation.cuda.activation1d import Activation1d
|
| 171 |
+
else:
|
| 172 |
+
from indextts.BigVGAN.alias_free_torch import Activation1d
|
| 173 |
+
|
| 174 |
+
# post conv
|
| 175 |
+
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
|
| 176 |
+
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
| 177 |
+
self.activation_post = Activation1d(activation=activation_post)
|
| 178 |
+
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
|
| 179 |
+
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
| 180 |
+
self.activation_post = Activation1d(activation=activation_post)
|
| 181 |
+
else:
|
| 182 |
+
raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.")
|
| 183 |
+
|
| 184 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
| 185 |
+
|
| 186 |
+
# weight initialization
|
| 187 |
+
for i in range(len(self.ups)):
|
| 188 |
+
self.ups[i].apply(init_weights)
|
| 189 |
+
self.conv_post.apply(init_weights)
|
| 190 |
+
|
| 191 |
+
self.speaker_encoder = ECAPA_TDNN(h.num_mels, lin_neurons=h.speaker_embedding_dim)
|
| 192 |
+
self.cond_layer = nn.Conv1d(h.speaker_embedding_dim, h.upsample_initial_channel, 1)
|
| 193 |
+
if self.cond_in_each_up_layer:
|
| 194 |
+
self.conds = nn.ModuleList()
|
| 195 |
+
for i in range(len(self.ups)):
|
| 196 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
| 197 |
+
self.conds.append(nn.Conv1d(h.speaker_embedding_dim, ch, 1))
|
| 198 |
+
|
| 199 |
+
# self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 200 |
+
|
| 201 |
+
def forward(self, x, mel_ref, lens=None):
|
| 202 |
+
speaker_embedding = self.speaker_encoder(mel_ref, lens)
|
| 203 |
+
n_batch = x.size(0)
|
| 204 |
+
contrastive_loss = None
|
| 205 |
+
if n_batch * 2 == speaker_embedding.size(0):
|
| 206 |
+
spe_emb_chunk1, spe_emb_chunk2 = speaker_embedding[:n_batch, :, :], speaker_embedding[n_batch:, :, :]
|
| 207 |
+
contrastive_loss = self.cal_clip_loss(spe_emb_chunk1.squeeze(1), spe_emb_chunk2.squeeze(1), self.logit_scale.exp())
|
| 208 |
+
|
| 209 |
+
speaker_embedding = speaker_embedding[:n_batch, :, :]
|
| 210 |
+
speaker_embedding = speaker_embedding.transpose(1, 2)
|
| 211 |
+
|
| 212 |
+
# upsample feat
|
| 213 |
+
if self.feat_upsample:
|
| 214 |
+
x = torch.nn.functional.interpolate(
|
| 215 |
+
x.transpose(1, 2),
|
| 216 |
+
scale_factor=[4],
|
| 217 |
+
mode="linear",
|
| 218 |
+
).squeeze(1)
|
| 219 |
+
else:
|
| 220 |
+
x = x.transpose(1, 2)
|
| 221 |
+
|
| 222 |
+
### bigVGAN ###
|
| 223 |
+
# pre conv
|
| 224 |
+
x = self.conv_pre(x)
|
| 225 |
+
|
| 226 |
+
x = x + self.cond_layer(speaker_embedding)
|
| 227 |
+
|
| 228 |
+
for i in range(self.num_upsamples):
|
| 229 |
+
# upsampling
|
| 230 |
+
for i_up in range(len(self.ups[i])):
|
| 231 |
+
x = self.ups[i][i_up](x)
|
| 232 |
+
|
| 233 |
+
if self.cond_in_each_up_layer:
|
| 234 |
+
x = x + self.conds[i](speaker_embedding)
|
| 235 |
+
|
| 236 |
+
# AMP blocks
|
| 237 |
+
xs = None
|
| 238 |
+
for j in range(self.num_kernels):
|
| 239 |
+
if xs is None:
|
| 240 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 241 |
+
else:
|
| 242 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 243 |
+
x = xs / self.num_kernels
|
| 244 |
+
|
| 245 |
+
# post conv
|
| 246 |
+
x = self.activation_post(x)
|
| 247 |
+
x = self.conv_post(x)
|
| 248 |
+
x = torch.tanh(x)
|
| 249 |
+
|
| 250 |
+
return x, contrastive_loss
|
| 251 |
+
|
| 252 |
+
def remove_weight_norm(self):
|
| 253 |
+
print('Removing weight norm...')
|
| 254 |
+
for l in self.ups:
|
| 255 |
+
for l_i in l:
|
| 256 |
+
remove_weight_norm(l_i)
|
| 257 |
+
for l in self.resblocks:
|
| 258 |
+
l.remove_weight_norm()
|
| 259 |
+
remove_weight_norm(self.conv_pre)
|
| 260 |
+
remove_weight_norm(self.conv_post)
|
| 261 |
+
|
| 262 |
+
def cal_clip_loss(self, image_features, text_features, logit_scale):
|
| 263 |
+
device = image_features.device
|
| 264 |
+
logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
|
| 265 |
+
labels = torch.arange(logits_per_image.shape[0], device=device, dtype=torch.long)
|
| 266 |
+
total_loss = (
|
| 267 |
+
F.cross_entropy(logits_per_image, labels) +
|
| 268 |
+
F.cross_entropy(logits_per_text, labels)
|
| 269 |
+
) / 2
|
| 270 |
+
return total_loss
|
| 271 |
+
|
| 272 |
+
def get_logits(self, image_features, text_features, logit_scale):
|
| 273 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
| 274 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
| 275 |
+
return logits_per_image, logits_per_text
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class DiscriminatorP(torch.nn.Module):
|
| 279 |
+
def __init__(self, h, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
| 280 |
+
super(DiscriminatorP, self).__init__()
|
| 281 |
+
self.period = period
|
| 282 |
+
self.d_mult = h.discriminator_channel_mult
|
| 283 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 284 |
+
self.convs = nn.ModuleList([
|
| 285 |
+
norm_f(Conv2d(1, int(32 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 286 |
+
norm_f(Conv2d(int(32 * self.d_mult), int(128 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 287 |
+
norm_f(Conv2d(int(128 * self.d_mult), int(512 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 288 |
+
norm_f(Conv2d(int(512 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
| 289 |
+
norm_f(Conv2d(int(1024 * self.d_mult), int(1024 * self.d_mult), (kernel_size, 1), 1, padding=(2, 0))),
|
| 290 |
+
])
|
| 291 |
+
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
|
| 292 |
+
|
| 293 |
+
def forward(self, x):
|
| 294 |
+
fmap = []
|
| 295 |
+
|
| 296 |
+
# 1d to 2d
|
| 297 |
+
b, c, t = x.shape
|
| 298 |
+
if t % self.period != 0: # pad first
|
| 299 |
+
n_pad = self.period - (t % self.period)
|
| 300 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
| 301 |
+
t = t + n_pad
|
| 302 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 303 |
+
|
| 304 |
+
for l in self.convs:
|
| 305 |
+
x = l(x)
|
| 306 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 307 |
+
fmap.append(x)
|
| 308 |
+
x = self.conv_post(x)
|
| 309 |
+
fmap.append(x)
|
| 310 |
+
x = torch.flatten(x, 1, -1)
|
| 311 |
+
|
| 312 |
+
return x, fmap
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 316 |
+
def __init__(self, h):
|
| 317 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
| 318 |
+
self.mpd_reshapes = h.mpd_reshapes
|
| 319 |
+
print("mpd_reshapes: {}".format(self.mpd_reshapes))
|
| 320 |
+
discriminators = [DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
|
| 321 |
+
self.discriminators = nn.ModuleList(discriminators)
|
| 322 |
+
|
| 323 |
+
def forward(self, y, y_hat):
|
| 324 |
+
y_d_rs = []
|
| 325 |
+
y_d_gs = []
|
| 326 |
+
fmap_rs = []
|
| 327 |
+
fmap_gs = []
|
| 328 |
+
for i, d in enumerate(self.discriminators):
|
| 329 |
+
y_d_r, fmap_r = d(y)
|
| 330 |
+
y_d_g, fmap_g = d(y_hat)
|
| 331 |
+
y_d_rs.append(y_d_r)
|
| 332 |
+
fmap_rs.append(fmap_r)
|
| 333 |
+
y_d_gs.append(y_d_g)
|
| 334 |
+
fmap_gs.append(fmap_g)
|
| 335 |
+
|
| 336 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class DiscriminatorR(nn.Module):
|
| 340 |
+
def __init__(self, cfg, resolution):
|
| 341 |
+
super().__init__()
|
| 342 |
+
|
| 343 |
+
self.resolution = resolution
|
| 344 |
+
assert len(self.resolution) == 3, \
|
| 345 |
+
"MRD layer requires list with len=3, got {}".format(self.resolution)
|
| 346 |
+
self.lrelu_slope = LRELU_SLOPE
|
| 347 |
+
|
| 348 |
+
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
|
| 349 |
+
if hasattr(cfg, "mrd_use_spectral_norm"):
|
| 350 |
+
print("INFO: overriding MRD use_spectral_norm as {}".format(cfg.mrd_use_spectral_norm))
|
| 351 |
+
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
|
| 352 |
+
self.d_mult = cfg.discriminator_channel_mult
|
| 353 |
+
if hasattr(cfg, "mrd_channel_mult"):
|
| 354 |
+
print("INFO: overriding mrd channel multiplier as {}".format(cfg.mrd_channel_mult))
|
| 355 |
+
self.d_mult = cfg.mrd_channel_mult
|
| 356 |
+
|
| 357 |
+
self.convs = nn.ModuleList([
|
| 358 |
+
norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
|
| 359 |
+
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
| 360 |
+
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
| 361 |
+
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 9), stride=(1, 2), padding=(1, 4))),
|
| 362 |
+
norm_f(nn.Conv2d(int(32 * self.d_mult), int(32 * self.d_mult), (3, 3), padding=(1, 1))),
|
| 363 |
+
])
|
| 364 |
+
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
|
| 365 |
+
|
| 366 |
+
def forward(self, x):
|
| 367 |
+
fmap = []
|
| 368 |
+
|
| 369 |
+
x = self.spectrogram(x)
|
| 370 |
+
x = x.unsqueeze(1)
|
| 371 |
+
for l in self.convs:
|
| 372 |
+
x = l(x)
|
| 373 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
| 374 |
+
fmap.append(x)
|
| 375 |
+
x = self.conv_post(x)
|
| 376 |
+
fmap.append(x)
|
| 377 |
+
x = torch.flatten(x, 1, -1)
|
| 378 |
+
|
| 379 |
+
return x, fmap
|
| 380 |
+
|
| 381 |
+
def spectrogram(self, x):
|
| 382 |
+
n_fft, hop_length, win_length = self.resolution
|
| 383 |
+
x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
|
| 384 |
+
x = x.squeeze(1)
|
| 385 |
+
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True)
|
| 386 |
+
x = torch.view_as_real(x) # [B, F, TT, 2]
|
| 387 |
+
mag = torch.norm(x, p=2, dim=-1) # [B, F, TT]
|
| 388 |
+
|
| 389 |
+
return mag
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class MultiResolutionDiscriminator(nn.Module):
|
| 393 |
+
def __init__(self, cfg, debug=False):
|
| 394 |
+
super().__init__()
|
| 395 |
+
self.resolutions = cfg.resolutions
|
| 396 |
+
assert len(self.resolutions) == 3, \
|
| 397 |
+
"MRD requires list of list with len=3, each element having a list with len=3. got {}".\
|
| 398 |
+
format(self.resolutions)
|
| 399 |
+
self.discriminators = nn.ModuleList(
|
| 400 |
+
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
def forward(self, y, y_hat):
|
| 404 |
+
y_d_rs = []
|
| 405 |
+
y_d_gs = []
|
| 406 |
+
fmap_rs = []
|
| 407 |
+
fmap_gs = []
|
| 408 |
+
|
| 409 |
+
for i, d in enumerate(self.discriminators):
|
| 410 |
+
y_d_r, fmap_r = d(x=y)
|
| 411 |
+
y_d_g, fmap_g = d(x=y_hat)
|
| 412 |
+
y_d_rs.append(y_d_r)
|
| 413 |
+
fmap_rs.append(fmap_r)
|
| 414 |
+
y_d_gs.append(y_d_g)
|
| 415 |
+
fmap_gs.append(fmap_g)
|
| 416 |
+
|
| 417 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def feature_loss(fmap_r, fmap_g):
|
| 421 |
+
loss = 0
|
| 422 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
| 423 |
+
for rl, gl in zip(dr, dg):
|
| 424 |
+
loss += torch.mean(torch.abs(rl - gl))
|
| 425 |
+
|
| 426 |
+
return loss * 2
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
| 430 |
+
loss = 0
|
| 431 |
+
r_losses = []
|
| 432 |
+
g_losses = []
|
| 433 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 434 |
+
r_loss = torch.mean((1 - dr)**2)
|
| 435 |
+
g_loss = torch.mean(dg**2)
|
| 436 |
+
loss += (r_loss + g_loss)
|
| 437 |
+
r_losses.append(r_loss.item())
|
| 438 |
+
g_losses.append(g_loss.item())
|
| 439 |
+
|
| 440 |
+
return loss, r_losses, g_losses
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def generator_loss(disc_outputs):
|
| 444 |
+
loss = 0
|
| 445 |
+
gen_losses = []
|
| 446 |
+
for dg in disc_outputs:
|
| 447 |
+
l = torch.mean((1 - dg)**2)
|
| 448 |
+
gen_losses.append(l)
|
| 449 |
+
loss += l
|
| 450 |
+
|
| 451 |
+
return loss, gen_losses
|
indextts/BigVGAN/nnet/CNN.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Library implementing convolutional neural networks.
|
| 2 |
+
|
| 3 |
+
Authors
|
| 4 |
+
* Mirco Ravanelli 2020
|
| 5 |
+
* Jianyuan Zhong 2020
|
| 6 |
+
* Cem Subakan 2021
|
| 7 |
+
* Davide Borra 2021
|
| 8 |
+
* Andreas Nautsch 2022
|
| 9 |
+
* Sarthak Yadav 2022
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
import math
|
| 14 |
+
from typing import Tuple
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import torchaudio
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SincConv(nn.Module):
|
| 24 |
+
"""This function implements SincConv (SincNet).
|
| 25 |
+
|
| 26 |
+
M. Ravanelli, Y. Bengio, "Speaker Recognition from raw waveform with
|
| 27 |
+
SincNet", in Proc. of SLT 2018 (https://arxiv.org/abs/1808.00158)
|
| 28 |
+
|
| 29 |
+
Arguments
|
| 30 |
+
---------
|
| 31 |
+
out_channels : int
|
| 32 |
+
It is the number of output channels.
|
| 33 |
+
kernel_size: int
|
| 34 |
+
Kernel size of the convolutional filters.
|
| 35 |
+
input_shape : tuple
|
| 36 |
+
The shape of the input. Alternatively use ``in_channels``.
|
| 37 |
+
in_channels : int
|
| 38 |
+
The number of input channels. Alternatively use ``input_shape``.
|
| 39 |
+
stride : int
|
| 40 |
+
Stride factor of the convolutional filters. When the stride factor > 1,
|
| 41 |
+
a decimation in time is performed.
|
| 42 |
+
dilation : int
|
| 43 |
+
Dilation factor of the convolutional filters.
|
| 44 |
+
padding : str
|
| 45 |
+
(same, valid, causal). If "valid", no padding is performed.
|
| 46 |
+
If "same" and stride is 1, output shape is the same as the input shape.
|
| 47 |
+
"causal" results in causal (dilated) convolutions.
|
| 48 |
+
padding_mode : str
|
| 49 |
+
This flag specifies the type of padding. See torch.nn documentation
|
| 50 |
+
for more information.
|
| 51 |
+
sample_rate : int
|
| 52 |
+
Sampling rate of the input signals. It is only used for sinc_conv.
|
| 53 |
+
min_low_hz : float
|
| 54 |
+
Lowest possible frequency (in Hz) for a filter. It is only used for
|
| 55 |
+
sinc_conv.
|
| 56 |
+
min_band_hz : float
|
| 57 |
+
Lowest possible value (in Hz) for a filter bandwidth.
|
| 58 |
+
|
| 59 |
+
Example
|
| 60 |
+
-------
|
| 61 |
+
>>> inp_tensor = torch.rand([10, 16000])
|
| 62 |
+
>>> conv = SincConv(input_shape=inp_tensor.shape, out_channels=25, kernel_size=11)
|
| 63 |
+
>>> out_tensor = conv(inp_tensor)
|
| 64 |
+
>>> out_tensor.shape
|
| 65 |
+
torch.Size([10, 16000, 25])
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
out_channels,
|
| 71 |
+
kernel_size,
|
| 72 |
+
input_shape=None,
|
| 73 |
+
in_channels=None,
|
| 74 |
+
stride=1,
|
| 75 |
+
dilation=1,
|
| 76 |
+
padding="same",
|
| 77 |
+
padding_mode="reflect",
|
| 78 |
+
sample_rate=16000,
|
| 79 |
+
min_low_hz=50,
|
| 80 |
+
min_band_hz=50,
|
| 81 |
+
):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.in_channels = in_channels
|
| 84 |
+
self.out_channels = out_channels
|
| 85 |
+
self.kernel_size = kernel_size
|
| 86 |
+
self.stride = stride
|
| 87 |
+
self.dilation = dilation
|
| 88 |
+
self.padding = padding
|
| 89 |
+
self.padding_mode = padding_mode
|
| 90 |
+
self.sample_rate = sample_rate
|
| 91 |
+
self.min_low_hz = min_low_hz
|
| 92 |
+
self.min_band_hz = min_band_hz
|
| 93 |
+
|
| 94 |
+
# input shape inference
|
| 95 |
+
if input_shape is None and self.in_channels is None:
|
| 96 |
+
raise ValueError("Must provide one of input_shape or in_channels")
|
| 97 |
+
|
| 98 |
+
if self.in_channels is None:
|
| 99 |
+
self.in_channels = self._check_input_shape(input_shape)
|
| 100 |
+
|
| 101 |
+
if self.out_channels % self.in_channels != 0:
|
| 102 |
+
raise ValueError(
|
| 103 |
+
"Number of output channels must be divisible by in_channels"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Initialize Sinc filters
|
| 107 |
+
self._init_sinc_conv()
|
| 108 |
+
|
| 109 |
+
def forward(self, x):
|
| 110 |
+
"""Returns the output of the convolution.
|
| 111 |
+
|
| 112 |
+
Arguments
|
| 113 |
+
---------
|
| 114 |
+
x : torch.Tensor (batch, time, channel)
|
| 115 |
+
input to convolve. 2d or 4d tensors are expected.
|
| 116 |
+
|
| 117 |
+
Returns
|
| 118 |
+
-------
|
| 119 |
+
wx : torch.Tensor
|
| 120 |
+
The convolved outputs.
|
| 121 |
+
"""
|
| 122 |
+
x = x.transpose(1, -1)
|
| 123 |
+
self.device = x.device
|
| 124 |
+
|
| 125 |
+
unsqueeze = x.ndim == 2
|
| 126 |
+
if unsqueeze:
|
| 127 |
+
x = x.unsqueeze(1)
|
| 128 |
+
|
| 129 |
+
if self.padding == "same":
|
| 130 |
+
x = self._manage_padding(
|
| 131 |
+
x, self.kernel_size, self.dilation, self.stride
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
elif self.padding == "causal":
|
| 135 |
+
num_pad = (self.kernel_size - 1) * self.dilation
|
| 136 |
+
x = F.pad(x, (num_pad, 0))
|
| 137 |
+
|
| 138 |
+
elif self.padding == "valid":
|
| 139 |
+
pass
|
| 140 |
+
|
| 141 |
+
else:
|
| 142 |
+
raise ValueError(
|
| 143 |
+
"Padding must be 'same', 'valid' or 'causal'. Got %s."
|
| 144 |
+
% (self.padding)
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
sinc_filters = self._get_sinc_filters()
|
| 148 |
+
|
| 149 |
+
wx = F.conv1d(
|
| 150 |
+
x,
|
| 151 |
+
sinc_filters,
|
| 152 |
+
stride=self.stride,
|
| 153 |
+
padding=0,
|
| 154 |
+
dilation=self.dilation,
|
| 155 |
+
groups=self.in_channels,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if unsqueeze:
|
| 159 |
+
wx = wx.squeeze(1)
|
| 160 |
+
|
| 161 |
+
wx = wx.transpose(1, -1)
|
| 162 |
+
|
| 163 |
+
return wx
|
| 164 |
+
|
| 165 |
+
def _check_input_shape(self, shape):
|
| 166 |
+
"""Checks the input shape and returns the number of input channels."""
|
| 167 |
+
|
| 168 |
+
if len(shape) == 2:
|
| 169 |
+
in_channels = 1
|
| 170 |
+
elif len(shape) == 3:
|
| 171 |
+
in_channels = shape[-1]
|
| 172 |
+
else:
|
| 173 |
+
raise ValueError(
|
| 174 |
+
"sincconv expects 2d or 3d inputs. Got " + str(len(shape))
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Kernel size must be odd
|
| 178 |
+
if self.kernel_size % 2 == 0:
|
| 179 |
+
raise ValueError(
|
| 180 |
+
"The field kernel size must be an odd number. Got %s."
|
| 181 |
+
% (self.kernel_size)
|
| 182 |
+
)
|
| 183 |
+
return in_channels
|
| 184 |
+
|
| 185 |
+
def _get_sinc_filters(self):
|
| 186 |
+
"""This functions creates the sinc-filters to used for sinc-conv."""
|
| 187 |
+
# Computing the low frequencies of the filters
|
| 188 |
+
low = self.min_low_hz + torch.abs(self.low_hz_)
|
| 189 |
+
|
| 190 |
+
# Setting minimum band and minimum freq
|
| 191 |
+
high = torch.clamp(
|
| 192 |
+
low + self.min_band_hz + torch.abs(self.band_hz_),
|
| 193 |
+
self.min_low_hz,
|
| 194 |
+
self.sample_rate / 2,
|
| 195 |
+
)
|
| 196 |
+
band = (high - low)[:, 0]
|
| 197 |
+
|
| 198 |
+
# Passing from n_ to the corresponding f_times_t domain
|
| 199 |
+
self.n_ = self.n_.to(self.device)
|
| 200 |
+
self.window_ = self.window_.to(self.device)
|
| 201 |
+
f_times_t_low = torch.matmul(low, self.n_)
|
| 202 |
+
f_times_t_high = torch.matmul(high, self.n_)
|
| 203 |
+
|
| 204 |
+
# Left part of the filters.
|
| 205 |
+
band_pass_left = (
|
| 206 |
+
(torch.sin(f_times_t_high) - torch.sin(f_times_t_low))
|
| 207 |
+
/ (self.n_ / 2)
|
| 208 |
+
) * self.window_
|
| 209 |
+
|
| 210 |
+
# Central element of the filter
|
| 211 |
+
band_pass_center = 2 * band.view(-1, 1)
|
| 212 |
+
|
| 213 |
+
# Right part of the filter (sinc filters are symmetric)
|
| 214 |
+
band_pass_right = torch.flip(band_pass_left, dims=[1])
|
| 215 |
+
|
| 216 |
+
# Combining left, central, and right part of the filter
|
| 217 |
+
band_pass = torch.cat(
|
| 218 |
+
[band_pass_left, band_pass_center, band_pass_right], dim=1
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Amplitude normalization
|
| 222 |
+
band_pass = band_pass / (2 * band[:, None])
|
| 223 |
+
|
| 224 |
+
# Setting up the filter coefficients
|
| 225 |
+
filters = band_pass.view(self.out_channels, 1, self.kernel_size)
|
| 226 |
+
|
| 227 |
+
return filters
|
| 228 |
+
|
| 229 |
+
def _init_sinc_conv(self):
|
| 230 |
+
"""Initializes the parameters of the sinc_conv layer."""
|
| 231 |
+
|
| 232 |
+
# Initialize filterbanks such that they are equally spaced in Mel scale
|
| 233 |
+
high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
|
| 234 |
+
|
| 235 |
+
mel = torch.linspace(
|
| 236 |
+
self._to_mel(self.min_low_hz),
|
| 237 |
+
self._to_mel(high_hz),
|
| 238 |
+
self.out_channels + 1,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
hz = self._to_hz(mel)
|
| 242 |
+
|
| 243 |
+
# Filter lower frequency and bands
|
| 244 |
+
self.low_hz_ = hz[:-1].unsqueeze(1)
|
| 245 |
+
self.band_hz_ = (hz[1:] - hz[:-1]).unsqueeze(1)
|
| 246 |
+
|
| 247 |
+
# Maiking freq and bands learnable
|
| 248 |
+
self.low_hz_ = nn.Parameter(self.low_hz_)
|
| 249 |
+
self.band_hz_ = nn.Parameter(self.band_hz_)
|
| 250 |
+
|
| 251 |
+
# Hamming window
|
| 252 |
+
n_lin = torch.linspace(
|
| 253 |
+
0, (self.kernel_size / 2) - 1, steps=int((self.kernel_size / 2))
|
| 254 |
+
)
|
| 255 |
+
self.window_ = 0.54 - 0.46 * torch.cos(
|
| 256 |
+
2 * math.pi * n_lin / self.kernel_size
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# Time axis (only half is needed due to symmetry)
|
| 260 |
+
n = (self.kernel_size - 1) / 2.0
|
| 261 |
+
self.n_ = (
|
| 262 |
+
2 * math.pi * torch.arange(-n, 0).view(1, -1) / self.sample_rate
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
def _to_mel(self, hz):
|
| 266 |
+
"""Converts frequency in Hz to the mel scale."""
|
| 267 |
+
return 2595 * np.log10(1 + hz / 700)
|
| 268 |
+
|
| 269 |
+
def _to_hz(self, mel):
|
| 270 |
+
"""Converts frequency in the mel scale to Hz."""
|
| 271 |
+
return 700 * (10 ** (mel / 2595) - 1)
|
| 272 |
+
|
| 273 |
+
def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
|
| 274 |
+
"""This function performs zero-padding on the time axis
|
| 275 |
+
such that their lengths is unchanged after the convolution.
|
| 276 |
+
|
| 277 |
+
Arguments
|
| 278 |
+
---------
|
| 279 |
+
x : torch.Tensor
|
| 280 |
+
Input tensor.
|
| 281 |
+
kernel_size : int
|
| 282 |
+
Size of kernel.
|
| 283 |
+
dilation : int
|
| 284 |
+
Dilation used.
|
| 285 |
+
stride : int
|
| 286 |
+
Stride.
|
| 287 |
+
|
| 288 |
+
Returns
|
| 289 |
+
-------
|
| 290 |
+
x : torch.Tensor
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
# Detecting input shape
|
| 294 |
+
L_in = self.in_channels
|
| 295 |
+
|
| 296 |
+
# Time padding
|
| 297 |
+
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
|
| 298 |
+
|
| 299 |
+
# Applying padding
|
| 300 |
+
x = F.pad(x, padding, mode=self.padding_mode)
|
| 301 |
+
|
| 302 |
+
return x
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class Conv1d(nn.Module):
|
| 306 |
+
"""This function implements 1d convolution.
|
| 307 |
+
|
| 308 |
+
Arguments
|
| 309 |
+
---------
|
| 310 |
+
out_channels : int
|
| 311 |
+
It is the number of output channels.
|
| 312 |
+
kernel_size : int
|
| 313 |
+
Kernel size of the convolutional filters.
|
| 314 |
+
input_shape : tuple
|
| 315 |
+
The shape of the input. Alternatively use ``in_channels``.
|
| 316 |
+
in_channels : int
|
| 317 |
+
The number of input channels. Alternatively use ``input_shape``.
|
| 318 |
+
stride : int
|
| 319 |
+
Stride factor of the convolutional filters. When the stride factor > 1,
|
| 320 |
+
a decimation in time is performed.
|
| 321 |
+
dilation : int
|
| 322 |
+
Dilation factor of the convolutional filters.
|
| 323 |
+
padding : str
|
| 324 |
+
(same, valid, causal). If "valid", no padding is performed.
|
| 325 |
+
If "same" and stride is 1, output shape is the same as the input shape.
|
| 326 |
+
"causal" results in causal (dilated) convolutions.
|
| 327 |
+
groups : int
|
| 328 |
+
Number of blocked connections from input channels to output channels.
|
| 329 |
+
bias : bool
|
| 330 |
+
Whether to add a bias term to convolution operation.
|
| 331 |
+
padding_mode : str
|
| 332 |
+
This flag specifies the type of padding. See torch.nn documentation
|
| 333 |
+
for more information.
|
| 334 |
+
skip_transpose : bool
|
| 335 |
+
If False, uses batch x time x channel convention of speechbrain.
|
| 336 |
+
If True, uses batch x channel x time convention.
|
| 337 |
+
weight_norm : bool
|
| 338 |
+
If True, use weight normalization,
|
| 339 |
+
to be removed with self.remove_weight_norm() at inference
|
| 340 |
+
conv_init : str
|
| 341 |
+
Weight initialization for the convolution network
|
| 342 |
+
default_padding: str or int
|
| 343 |
+
This sets the default padding mode that will be used by the pytorch Conv1d backend.
|
| 344 |
+
|
| 345 |
+
Example
|
| 346 |
+
-------
|
| 347 |
+
>>> inp_tensor = torch.rand([10, 40, 16])
|
| 348 |
+
>>> cnn_1d = Conv1d(
|
| 349 |
+
... input_shape=inp_tensor.shape, out_channels=8, kernel_size=5
|
| 350 |
+
... )
|
| 351 |
+
>>> out_tensor = cnn_1d(inp_tensor)
|
| 352 |
+
>>> out_tensor.shape
|
| 353 |
+
torch.Size([10, 40, 8])
|
| 354 |
+
"""
|
| 355 |
+
|
| 356 |
+
def __init__(
|
| 357 |
+
self,
|
| 358 |
+
out_channels,
|
| 359 |
+
kernel_size,
|
| 360 |
+
input_shape=None,
|
| 361 |
+
in_channels=None,
|
| 362 |
+
stride=1,
|
| 363 |
+
dilation=1,
|
| 364 |
+
padding="same",
|
| 365 |
+
groups=1,
|
| 366 |
+
bias=True,
|
| 367 |
+
padding_mode="reflect",
|
| 368 |
+
skip_transpose=False,
|
| 369 |
+
weight_norm=False,
|
| 370 |
+
conv_init=None,
|
| 371 |
+
default_padding=0,
|
| 372 |
+
):
|
| 373 |
+
super().__init__()
|
| 374 |
+
self.kernel_size = kernel_size
|
| 375 |
+
self.stride = stride
|
| 376 |
+
self.dilation = dilation
|
| 377 |
+
self.padding = padding
|
| 378 |
+
self.padding_mode = padding_mode
|
| 379 |
+
self.unsqueeze = False
|
| 380 |
+
self.skip_transpose = skip_transpose
|
| 381 |
+
|
| 382 |
+
if input_shape is None and in_channels is None:
|
| 383 |
+
raise ValueError("Must provide one of input_shape or in_channels")
|
| 384 |
+
|
| 385 |
+
if in_channels is None:
|
| 386 |
+
in_channels = self._check_input_shape(input_shape)
|
| 387 |
+
|
| 388 |
+
self.in_channels = in_channels
|
| 389 |
+
|
| 390 |
+
self.conv = nn.Conv1d(
|
| 391 |
+
in_channels,
|
| 392 |
+
out_channels,
|
| 393 |
+
self.kernel_size,
|
| 394 |
+
stride=self.stride,
|
| 395 |
+
dilation=self.dilation,
|
| 396 |
+
padding=default_padding,
|
| 397 |
+
groups=groups,
|
| 398 |
+
bias=bias,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
if conv_init == "kaiming":
|
| 402 |
+
nn.init.kaiming_normal_(self.conv.weight)
|
| 403 |
+
elif conv_init == "zero":
|
| 404 |
+
nn.init.zeros_(self.conv.weight)
|
| 405 |
+
elif conv_init == "normal":
|
| 406 |
+
nn.init.normal_(self.conv.weight, std=1e-6)
|
| 407 |
+
|
| 408 |
+
if weight_norm:
|
| 409 |
+
self.conv = nn.utils.weight_norm(self.conv)
|
| 410 |
+
|
| 411 |
+
def forward(self, x):
|
| 412 |
+
"""Returns the output of the convolution.
|
| 413 |
+
|
| 414 |
+
Arguments
|
| 415 |
+
---------
|
| 416 |
+
x : torch.Tensor (batch, time, channel)
|
| 417 |
+
input to convolve. 2d or 4d tensors are expected.
|
| 418 |
+
|
| 419 |
+
Returns
|
| 420 |
+
-------
|
| 421 |
+
wx : torch.Tensor
|
| 422 |
+
The convolved outputs.
|
| 423 |
+
"""
|
| 424 |
+
if not self.skip_transpose:
|
| 425 |
+
x = x.transpose(1, -1)
|
| 426 |
+
|
| 427 |
+
if self.unsqueeze:
|
| 428 |
+
x = x.unsqueeze(1)
|
| 429 |
+
|
| 430 |
+
if self.padding == "same":
|
| 431 |
+
x = self._manage_padding(
|
| 432 |
+
x, self.kernel_size, self.dilation, self.stride
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
elif self.padding == "causal":
|
| 436 |
+
num_pad = (self.kernel_size - 1) * self.dilation
|
| 437 |
+
x = F.pad(x, (num_pad, 0))
|
| 438 |
+
|
| 439 |
+
elif self.padding == "valid":
|
| 440 |
+
pass
|
| 441 |
+
|
| 442 |
+
else:
|
| 443 |
+
raise ValueError(
|
| 444 |
+
"Padding must be 'same', 'valid' or 'causal'. Got "
|
| 445 |
+
+ self.padding
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
wx = self.conv(x)
|
| 449 |
+
|
| 450 |
+
if self.unsqueeze:
|
| 451 |
+
wx = wx.squeeze(1)
|
| 452 |
+
|
| 453 |
+
if not self.skip_transpose:
|
| 454 |
+
wx = wx.transpose(1, -1)
|
| 455 |
+
|
| 456 |
+
return wx
|
| 457 |
+
|
| 458 |
+
def _manage_padding(self, x, kernel_size: int, dilation: int, stride: int):
|
| 459 |
+
"""This function performs zero-padding on the time axis
|
| 460 |
+
such that their lengths is unchanged after the convolution.
|
| 461 |
+
|
| 462 |
+
Arguments
|
| 463 |
+
---------
|
| 464 |
+
x : torch.Tensor
|
| 465 |
+
Input tensor.
|
| 466 |
+
kernel_size : int
|
| 467 |
+
Size of kernel.
|
| 468 |
+
dilation : int
|
| 469 |
+
Dilation used.
|
| 470 |
+
stride : int
|
| 471 |
+
Stride.
|
| 472 |
+
|
| 473 |
+
Returns
|
| 474 |
+
-------
|
| 475 |
+
x : torch.Tensor
|
| 476 |
+
The padded outputs.
|
| 477 |
+
"""
|
| 478 |
+
|
| 479 |
+
# Detecting input shape
|
| 480 |
+
L_in = self.in_channels
|
| 481 |
+
|
| 482 |
+
# Time padding
|
| 483 |
+
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
|
| 484 |
+
|
| 485 |
+
# Applying padding
|
| 486 |
+
x = F.pad(x, padding, mode=self.padding_mode)
|
| 487 |
+
|
| 488 |
+
return x
|
| 489 |
+
|
| 490 |
+
def _check_input_shape(self, shape):
|
| 491 |
+
"""Checks the input shape and returns the number of input channels."""
|
| 492 |
+
|
| 493 |
+
if len(shape) == 2:
|
| 494 |
+
self.unsqueeze = True
|
| 495 |
+
in_channels = 1
|
| 496 |
+
elif self.skip_transpose:
|
| 497 |
+
in_channels = shape[1]
|
| 498 |
+
elif len(shape) == 3:
|
| 499 |
+
in_channels = shape[2]
|
| 500 |
+
else:
|
| 501 |
+
raise ValueError(
|
| 502 |
+
"conv1d expects 2d, 3d inputs. Got " + str(len(shape))
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
# Kernel size must be odd
|
| 506 |
+
if not self.padding == "valid" and self.kernel_size % 2 == 0:
|
| 507 |
+
raise ValueError(
|
| 508 |
+
"The field kernel size must be an odd number. Got %s."
|
| 509 |
+
% (self.kernel_size)
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
return in_channels
|
| 513 |
+
|
| 514 |
+
def remove_weight_norm(self):
|
| 515 |
+
"""Removes weight normalization at inference if used during training."""
|
| 516 |
+
self.conv = nn.utils.remove_weight_norm(self.conv)
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
|
| 520 |
+
"""This function computes the number of elements to add for zero-padding.
|
| 521 |
+
|
| 522 |
+
Arguments
|
| 523 |
+
---------
|
| 524 |
+
L_in : int
|
| 525 |
+
stride: int
|
| 526 |
+
kernel_size : int
|
| 527 |
+
dilation : int
|
| 528 |
+
|
| 529 |
+
Returns
|
| 530 |
+
-------
|
| 531 |
+
padding : int
|
| 532 |
+
The size of the padding to be added
|
| 533 |
+
"""
|
| 534 |
+
if stride > 1:
|
| 535 |
+
padding = [math.floor(kernel_size / 2), math.floor(kernel_size / 2)]
|
| 536 |
+
|
| 537 |
+
else:
|
| 538 |
+
L_out = (
|
| 539 |
+
math.floor((L_in - dilation * (kernel_size - 1) - 1) / stride) + 1
|
| 540 |
+
)
|
| 541 |
+
padding = [
|
| 542 |
+
math.floor((L_in - L_out) / 2),
|
| 543 |
+
math.floor((L_in - L_out) / 2),
|
| 544 |
+
]
|
| 545 |
+
return padding
|
| 546 |
+
|
indextts/BigVGAN/nnet/__init__.py
ADDED
|
File without changes
|
indextts/BigVGAN/nnet/linear.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Library implementing linear transformation.
|
| 2 |
+
|
| 3 |
+
Authors
|
| 4 |
+
* Mirco Ravanelli 2020
|
| 5 |
+
* Davide Borra 2021
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Linear(torch.nn.Module):
|
| 15 |
+
"""Computes a linear transformation y = wx + b.
|
| 16 |
+
|
| 17 |
+
Arguments
|
| 18 |
+
---------
|
| 19 |
+
n_neurons : int
|
| 20 |
+
It is the number of output neurons (i.e, the dimensionality of the
|
| 21 |
+
output).
|
| 22 |
+
input_shape : tuple
|
| 23 |
+
It is the shape of the input tensor.
|
| 24 |
+
input_size : int
|
| 25 |
+
Size of the input tensor.
|
| 26 |
+
bias : bool
|
| 27 |
+
If True, the additive bias b is adopted.
|
| 28 |
+
max_norm : float
|
| 29 |
+
weight max-norm.
|
| 30 |
+
combine_dims : bool
|
| 31 |
+
If True and the input is 4D, combine 3rd and 4th dimensions of input.
|
| 32 |
+
|
| 33 |
+
Example
|
| 34 |
+
-------
|
| 35 |
+
>>> inputs = torch.rand(10, 50, 40)
|
| 36 |
+
>>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100)
|
| 37 |
+
>>> output = lin_t(inputs)
|
| 38 |
+
>>> output.shape
|
| 39 |
+
torch.Size([10, 50, 100])
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
n_neurons,
|
| 45 |
+
input_shape=None,
|
| 46 |
+
input_size=None,
|
| 47 |
+
bias=True,
|
| 48 |
+
max_norm=None,
|
| 49 |
+
combine_dims=False,
|
| 50 |
+
):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.max_norm = max_norm
|
| 53 |
+
self.combine_dims = combine_dims
|
| 54 |
+
|
| 55 |
+
if input_shape is None and input_size is None:
|
| 56 |
+
raise ValueError("Expected one of input_shape or input_size")
|
| 57 |
+
|
| 58 |
+
if input_size is None:
|
| 59 |
+
input_size = input_shape[-1]
|
| 60 |
+
if len(input_shape) == 4 and self.combine_dims:
|
| 61 |
+
input_size = input_shape[2] * input_shape[3]
|
| 62 |
+
|
| 63 |
+
# Weights are initialized following pytorch approach
|
| 64 |
+
self.w = nn.Linear(input_size, n_neurons, bias=bias)
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
"""Returns the linear transformation of input tensor.
|
| 68 |
+
|
| 69 |
+
Arguments
|
| 70 |
+
---------
|
| 71 |
+
x : torch.Tensor
|
| 72 |
+
Input to transform linearly.
|
| 73 |
+
|
| 74 |
+
Returns
|
| 75 |
+
-------
|
| 76 |
+
wx : torch.Tensor
|
| 77 |
+
The linearly transformed outputs.
|
| 78 |
+
"""
|
| 79 |
+
if x.ndim == 4 and self.combine_dims:
|
| 80 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
|
| 81 |
+
|
| 82 |
+
if self.max_norm is not None:
|
| 83 |
+
self.w.weight.data = torch.renorm(
|
| 84 |
+
self.w.weight.data, p=2, dim=0, maxnorm=self.max_norm
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
wx = self.w(x)
|
| 88 |
+
|
| 89 |
+
return wx
|
indextts/BigVGAN/nnet/normalization.py
ADDED
|
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Library implementing normalization.
|
| 2 |
+
|
| 3 |
+
Authors
|
| 4 |
+
* Mirco Ravanelli 2020
|
| 5 |
+
* Guillermo Cámbara 2021
|
| 6 |
+
* Sarthak Yadav 2022
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BatchNorm1d(nn.Module):
|
| 14 |
+
"""Applies 1d batch normalization to the input tensor.
|
| 15 |
+
|
| 16 |
+
Arguments
|
| 17 |
+
---------
|
| 18 |
+
input_shape : tuple
|
| 19 |
+
The expected shape of the input. Alternatively, use ``input_size``.
|
| 20 |
+
input_size : int
|
| 21 |
+
The expected size of the input. Alternatively, use ``input_shape``.
|
| 22 |
+
eps : float
|
| 23 |
+
This value is added to std deviation estimation to improve the numerical
|
| 24 |
+
stability.
|
| 25 |
+
momentum : float
|
| 26 |
+
It is a value used for the running_mean and running_var computation.
|
| 27 |
+
affine : bool
|
| 28 |
+
When set to True, the affine parameters are learned.
|
| 29 |
+
track_running_stats : bool
|
| 30 |
+
When set to True, this module tracks the running mean and variance,
|
| 31 |
+
and when set to False, this module does not track such statistics.
|
| 32 |
+
combine_batch_time : bool
|
| 33 |
+
When true, it combines batch an time axis.
|
| 34 |
+
skip_transpose : bool
|
| 35 |
+
Whether to skip the transposition.
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
Example
|
| 39 |
+
-------
|
| 40 |
+
>>> input = torch.randn(100, 10)
|
| 41 |
+
>>> norm = BatchNorm1d(input_shape=input.shape)
|
| 42 |
+
>>> output = norm(input)
|
| 43 |
+
>>> output.shape
|
| 44 |
+
torch.Size([100, 10])
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
input_shape=None,
|
| 50 |
+
input_size=None,
|
| 51 |
+
eps=1e-05,
|
| 52 |
+
momentum=0.1,
|
| 53 |
+
affine=True,
|
| 54 |
+
track_running_stats=True,
|
| 55 |
+
combine_batch_time=False,
|
| 56 |
+
skip_transpose=False,
|
| 57 |
+
):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.combine_batch_time = combine_batch_time
|
| 60 |
+
self.skip_transpose = skip_transpose
|
| 61 |
+
|
| 62 |
+
if input_size is None and skip_transpose:
|
| 63 |
+
input_size = input_shape[1]
|
| 64 |
+
elif input_size is None:
|
| 65 |
+
input_size = input_shape[-1]
|
| 66 |
+
|
| 67 |
+
self.norm = nn.BatchNorm1d(
|
| 68 |
+
input_size,
|
| 69 |
+
eps=eps,
|
| 70 |
+
momentum=momentum,
|
| 71 |
+
affine=affine,
|
| 72 |
+
track_running_stats=track_running_stats,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
"""Returns the normalized input tensor.
|
| 77 |
+
|
| 78 |
+
Arguments
|
| 79 |
+
---------
|
| 80 |
+
x : torch.Tensor (batch, time, [channels])
|
| 81 |
+
input to normalize. 2d or 3d tensors are expected in input
|
| 82 |
+
4d tensors can be used when combine_dims=True.
|
| 83 |
+
|
| 84 |
+
Returns
|
| 85 |
+
-------
|
| 86 |
+
x_n : torch.Tensor
|
| 87 |
+
The normalized outputs.
|
| 88 |
+
"""
|
| 89 |
+
shape_or = x.shape
|
| 90 |
+
if self.combine_batch_time:
|
| 91 |
+
if x.ndim == 3:
|
| 92 |
+
x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
|
| 93 |
+
else:
|
| 94 |
+
x = x.reshape(
|
| 95 |
+
shape_or[0] * shape_or[1], shape_or[3], shape_or[2]
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
elif not self.skip_transpose:
|
| 99 |
+
x = x.transpose(-1, 1)
|
| 100 |
+
|
| 101 |
+
x_n = self.norm(x)
|
| 102 |
+
|
| 103 |
+
if self.combine_batch_time:
|
| 104 |
+
x_n = x_n.reshape(shape_or)
|
| 105 |
+
elif not self.skip_transpose:
|
| 106 |
+
x_n = x_n.transpose(1, -1)
|
| 107 |
+
|
| 108 |
+
return x_n
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class BatchNorm2d(nn.Module):
|
| 112 |
+
"""Applies 2d batch normalization to the input tensor.
|
| 113 |
+
|
| 114 |
+
Arguments
|
| 115 |
+
---------
|
| 116 |
+
input_shape : tuple
|
| 117 |
+
The expected shape of the input. Alternatively, use ``input_size``.
|
| 118 |
+
input_size : int
|
| 119 |
+
The expected size of the input. Alternatively, use ``input_shape``.
|
| 120 |
+
eps : float
|
| 121 |
+
This value is added to std deviation estimation to improve the numerical
|
| 122 |
+
stability.
|
| 123 |
+
momentum : float
|
| 124 |
+
It is a value used for the running_mean and running_var computation.
|
| 125 |
+
affine : bool
|
| 126 |
+
When set to True, the affine parameters are learned.
|
| 127 |
+
track_running_stats : bool
|
| 128 |
+
When set to True, this module tracks the running mean and variance,
|
| 129 |
+
and when set to False, this module does not track such statistics.
|
| 130 |
+
|
| 131 |
+
Example
|
| 132 |
+
-------
|
| 133 |
+
>>> input = torch.randn(100, 10, 5, 20)
|
| 134 |
+
>>> norm = BatchNorm2d(input_shape=input.shape)
|
| 135 |
+
>>> output = norm(input)
|
| 136 |
+
>>> output.shape
|
| 137 |
+
torch.Size([100, 10, 5, 20])
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __init__(
|
| 141 |
+
self,
|
| 142 |
+
input_shape=None,
|
| 143 |
+
input_size=None,
|
| 144 |
+
eps=1e-05,
|
| 145 |
+
momentum=0.1,
|
| 146 |
+
affine=True,
|
| 147 |
+
track_running_stats=True,
|
| 148 |
+
):
|
| 149 |
+
super().__init__()
|
| 150 |
+
|
| 151 |
+
if input_shape is None and input_size is None:
|
| 152 |
+
raise ValueError("Expected input_shape or input_size as input")
|
| 153 |
+
|
| 154 |
+
if input_size is None:
|
| 155 |
+
input_size = input_shape[-1]
|
| 156 |
+
|
| 157 |
+
self.norm = nn.BatchNorm2d(
|
| 158 |
+
input_size,
|
| 159 |
+
eps=eps,
|
| 160 |
+
momentum=momentum,
|
| 161 |
+
affine=affine,
|
| 162 |
+
track_running_stats=track_running_stats,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
def forward(self, x):
|
| 166 |
+
"""Returns the normalized input tensor.
|
| 167 |
+
|
| 168 |
+
Arguments
|
| 169 |
+
---------
|
| 170 |
+
x : torch.Tensor (batch, time, channel1, channel2)
|
| 171 |
+
input to normalize. 4d tensors are expected.
|
| 172 |
+
|
| 173 |
+
Returns
|
| 174 |
+
-------
|
| 175 |
+
x_n : torch.Tensor
|
| 176 |
+
The normalized outputs.
|
| 177 |
+
"""
|
| 178 |
+
x = x.transpose(-1, 1)
|
| 179 |
+
x_n = self.norm(x)
|
| 180 |
+
x_n = x_n.transpose(1, -1)
|
| 181 |
+
|
| 182 |
+
return x_n
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class LayerNorm(nn.Module):
|
| 186 |
+
"""Applies layer normalization to the input tensor.
|
| 187 |
+
|
| 188 |
+
Arguments
|
| 189 |
+
---------
|
| 190 |
+
input_size : int
|
| 191 |
+
The expected size of the dimension to be normalized.
|
| 192 |
+
input_shape : tuple
|
| 193 |
+
The expected shape of the input.
|
| 194 |
+
eps : float
|
| 195 |
+
This value is added to std deviation estimation to improve the numerical
|
| 196 |
+
stability.
|
| 197 |
+
elementwise_affine : bool
|
| 198 |
+
If True, this module has learnable per-element affine parameters
|
| 199 |
+
initialized to ones (for weights) and zeros (for biases).
|
| 200 |
+
|
| 201 |
+
Example
|
| 202 |
+
-------
|
| 203 |
+
>>> input = torch.randn(100, 101, 128)
|
| 204 |
+
>>> norm = LayerNorm(input_shape=input.shape)
|
| 205 |
+
>>> output = norm(input)
|
| 206 |
+
>>> output.shape
|
| 207 |
+
torch.Size([100, 101, 128])
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
input_size=None,
|
| 213 |
+
input_shape=None,
|
| 214 |
+
eps=1e-05,
|
| 215 |
+
elementwise_affine=True,
|
| 216 |
+
):
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.eps = eps
|
| 219 |
+
self.elementwise_affine = elementwise_affine
|
| 220 |
+
|
| 221 |
+
if input_shape is not None:
|
| 222 |
+
input_size = input_shape[2:]
|
| 223 |
+
|
| 224 |
+
self.norm = torch.nn.LayerNorm(
|
| 225 |
+
input_size,
|
| 226 |
+
eps=self.eps,
|
| 227 |
+
elementwise_affine=self.elementwise_affine,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
def forward(self, x):
|
| 231 |
+
"""Returns the normalized input tensor.
|
| 232 |
+
|
| 233 |
+
Arguments
|
| 234 |
+
---------
|
| 235 |
+
x : torch.Tensor (batch, time, channels)
|
| 236 |
+
input to normalize. 3d or 4d tensors are expected.
|
| 237 |
+
|
| 238 |
+
Returns
|
| 239 |
+
-------
|
| 240 |
+
The normalized outputs.
|
| 241 |
+
"""
|
| 242 |
+
return self.norm(x)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class InstanceNorm1d(nn.Module):
|
| 246 |
+
"""Applies 1d instance normalization to the input tensor.
|
| 247 |
+
|
| 248 |
+
Arguments
|
| 249 |
+
---------
|
| 250 |
+
input_shape : tuple
|
| 251 |
+
The expected shape of the input. Alternatively, use ``input_size``.
|
| 252 |
+
input_size : int
|
| 253 |
+
The expected size of the input. Alternatively, use ``input_shape``.
|
| 254 |
+
eps : float
|
| 255 |
+
This value is added to std deviation estimation to improve the numerical
|
| 256 |
+
stability.
|
| 257 |
+
momentum : float
|
| 258 |
+
It is a value used for the running_mean and running_var computation.
|
| 259 |
+
track_running_stats : bool
|
| 260 |
+
When set to True, this module tracks the running mean and variance,
|
| 261 |
+
and when set to False, this module does not track such statistics.
|
| 262 |
+
affine : bool
|
| 263 |
+
A boolean value that when set to True, this module has learnable
|
| 264 |
+
affine parameters, initialized the same way as done for
|
| 265 |
+
batch normalization. Default: False.
|
| 266 |
+
|
| 267 |
+
Example
|
| 268 |
+
-------
|
| 269 |
+
>>> input = torch.randn(100, 10, 20)
|
| 270 |
+
>>> norm = InstanceNorm1d(input_shape=input.shape)
|
| 271 |
+
>>> output = norm(input)
|
| 272 |
+
>>> output.shape
|
| 273 |
+
torch.Size([100, 10, 20])
|
| 274 |
+
"""
|
| 275 |
+
|
| 276 |
+
def __init__(
|
| 277 |
+
self,
|
| 278 |
+
input_shape=None,
|
| 279 |
+
input_size=None,
|
| 280 |
+
eps=1e-05,
|
| 281 |
+
momentum=0.1,
|
| 282 |
+
track_running_stats=True,
|
| 283 |
+
affine=False,
|
| 284 |
+
):
|
| 285 |
+
super().__init__()
|
| 286 |
+
|
| 287 |
+
if input_shape is None and input_size is None:
|
| 288 |
+
raise ValueError("Expected input_shape or input_size as input")
|
| 289 |
+
|
| 290 |
+
if input_size is None:
|
| 291 |
+
input_size = input_shape[-1]
|
| 292 |
+
|
| 293 |
+
self.norm = nn.InstanceNorm1d(
|
| 294 |
+
input_size,
|
| 295 |
+
eps=eps,
|
| 296 |
+
momentum=momentum,
|
| 297 |
+
track_running_stats=track_running_stats,
|
| 298 |
+
affine=affine,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
def forward(self, x):
|
| 302 |
+
"""Returns the normalized input tensor.
|
| 303 |
+
|
| 304 |
+
Arguments
|
| 305 |
+
---------
|
| 306 |
+
x : torch.Tensor (batch, time, channels)
|
| 307 |
+
input to normalize. 3d tensors are expected.
|
| 308 |
+
|
| 309 |
+
Returns
|
| 310 |
+
-------
|
| 311 |
+
x_n : torch.Tensor
|
| 312 |
+
The normalized outputs.
|
| 313 |
+
"""
|
| 314 |
+
x = x.transpose(-1, 1)
|
| 315 |
+
x_n = self.norm(x)
|
| 316 |
+
x_n = x_n.transpose(1, -1)
|
| 317 |
+
|
| 318 |
+
return x_n
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class InstanceNorm2d(nn.Module):
|
| 322 |
+
"""Applies 2d instance normalization to the input tensor.
|
| 323 |
+
|
| 324 |
+
Arguments
|
| 325 |
+
---------
|
| 326 |
+
input_shape : tuple
|
| 327 |
+
The expected shape of the input. Alternatively, use ``input_size``.
|
| 328 |
+
input_size : int
|
| 329 |
+
The expected size of the input. Alternatively, use ``input_shape``.
|
| 330 |
+
eps : float
|
| 331 |
+
This value is added to std deviation estimation to improve the numerical
|
| 332 |
+
stability.
|
| 333 |
+
momentum : float
|
| 334 |
+
It is a value used for the running_mean and running_var computation.
|
| 335 |
+
track_running_stats : bool
|
| 336 |
+
When set to True, this module tracks the running mean and variance,
|
| 337 |
+
and when set to False, this module does not track such statistics.
|
| 338 |
+
affine : bool
|
| 339 |
+
A boolean value that when set to True, this module has learnable
|
| 340 |
+
affine parameters, initialized the same way as done for
|
| 341 |
+
batch normalization. Default: False.
|
| 342 |
+
|
| 343 |
+
Example
|
| 344 |
+
-------
|
| 345 |
+
>>> input = torch.randn(100, 10, 20, 2)
|
| 346 |
+
>>> norm = InstanceNorm2d(input_shape=input.shape)
|
| 347 |
+
>>> output = norm(input)
|
| 348 |
+
>>> output.shape
|
| 349 |
+
torch.Size([100, 10, 20, 2])
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
def __init__(
|
| 353 |
+
self,
|
| 354 |
+
input_shape=None,
|
| 355 |
+
input_size=None,
|
| 356 |
+
eps=1e-05,
|
| 357 |
+
momentum=0.1,
|
| 358 |
+
track_running_stats=True,
|
| 359 |
+
affine=False,
|
| 360 |
+
):
|
| 361 |
+
super().__init__()
|
| 362 |
+
|
| 363 |
+
if input_shape is None and input_size is None:
|
| 364 |
+
raise ValueError("Expected input_shape or input_size as input")
|
| 365 |
+
|
| 366 |
+
if input_size is None:
|
| 367 |
+
input_size = input_shape[-1]
|
| 368 |
+
|
| 369 |
+
self.norm = nn.InstanceNorm2d(
|
| 370 |
+
input_size,
|
| 371 |
+
eps=eps,
|
| 372 |
+
momentum=momentum,
|
| 373 |
+
track_running_stats=track_running_stats,
|
| 374 |
+
affine=affine,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
def forward(self, x):
|
| 378 |
+
"""Returns the normalized input tensor.
|
| 379 |
+
|
| 380 |
+
Arguments
|
| 381 |
+
---------
|
| 382 |
+
x : torch.Tensor (batch, time, channel1, channel2)
|
| 383 |
+
input to normalize. 4d tensors are expected.
|
| 384 |
+
|
| 385 |
+
Returns
|
| 386 |
+
-------
|
| 387 |
+
x_n : torch.Tensor
|
| 388 |
+
The normalized outputs.
|
| 389 |
+
"""
|
| 390 |
+
x = x.transpose(-1, 1)
|
| 391 |
+
x_n = self.norm(x)
|
| 392 |
+
x_n = x_n.transpose(1, -1)
|
| 393 |
+
|
| 394 |
+
return x_n
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
class GroupNorm(nn.Module):
|
| 398 |
+
"""Applies group normalization to the input tensor.
|
| 399 |
+
|
| 400 |
+
Arguments
|
| 401 |
+
---------
|
| 402 |
+
input_shape : tuple
|
| 403 |
+
The expected shape of the input. Alternatively, use ``input_size``.
|
| 404 |
+
input_size : int
|
| 405 |
+
The expected size of the input. Alternatively, use ``input_shape``.
|
| 406 |
+
num_groups : int
|
| 407 |
+
Number of groups to separate the channels into.
|
| 408 |
+
eps : float
|
| 409 |
+
This value is added to std deviation estimation to improve the numerical
|
| 410 |
+
stability.
|
| 411 |
+
affine : bool
|
| 412 |
+
A boolean value that when set to True, this module has learnable per-channel
|
| 413 |
+
affine parameters initialized to ones (for weights) and zeros (for biases).
|
| 414 |
+
|
| 415 |
+
Example
|
| 416 |
+
-------
|
| 417 |
+
>>> input = torch.randn(100, 101, 128)
|
| 418 |
+
>>> norm = GroupNorm(input_size=128, num_groups=128)
|
| 419 |
+
>>> output = norm(input)
|
| 420 |
+
>>> output.shape
|
| 421 |
+
torch.Size([100, 101, 128])
|
| 422 |
+
"""
|
| 423 |
+
|
| 424 |
+
def __init__(
|
| 425 |
+
self,
|
| 426 |
+
input_shape=None,
|
| 427 |
+
input_size=None,
|
| 428 |
+
num_groups=None,
|
| 429 |
+
eps=1e-05,
|
| 430 |
+
affine=True,
|
| 431 |
+
):
|
| 432 |
+
super().__init__()
|
| 433 |
+
self.eps = eps
|
| 434 |
+
self.affine = affine
|
| 435 |
+
|
| 436 |
+
if input_shape is None and input_size is None:
|
| 437 |
+
raise ValueError("Expected input_shape or input_size as input")
|
| 438 |
+
|
| 439 |
+
if num_groups is None:
|
| 440 |
+
raise ValueError("Expected num_groups as input")
|
| 441 |
+
|
| 442 |
+
if input_shape is not None:
|
| 443 |
+
input_size = input_shape[-1]
|
| 444 |
+
|
| 445 |
+
self.norm = torch.nn.GroupNorm(
|
| 446 |
+
num_groups,
|
| 447 |
+
input_size,
|
| 448 |
+
eps=self.eps,
|
| 449 |
+
affine=self.affine,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
def forward(self, x):
|
| 453 |
+
"""Returns the normalized input tensor.
|
| 454 |
+
|
| 455 |
+
Arguments
|
| 456 |
+
---------
|
| 457 |
+
x : torch.Tensor (batch, time, channels)
|
| 458 |
+
input to normalize. 3d or 4d tensors are expected.
|
| 459 |
+
|
| 460 |
+
Returns
|
| 461 |
+
-------
|
| 462 |
+
x_n : torch.Tensor
|
| 463 |
+
The normalized outputs.
|
| 464 |
+
"""
|
| 465 |
+
x = x.transpose(-1, 1)
|
| 466 |
+
x_n = self.norm(x)
|
| 467 |
+
x_n = x_n.transpose(1, -1)
|
| 468 |
+
|
| 469 |
+
return x_n
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class ExponentialMovingAverage(nn.Module):
|
| 473 |
+
"""
|
| 474 |
+
Applies learnable exponential moving average, as required by learnable PCEN layer
|
| 475 |
+
|
| 476 |
+
Arguments
|
| 477 |
+
---------
|
| 478 |
+
input_size : int
|
| 479 |
+
The expected size of the input.
|
| 480 |
+
coeff_init: float
|
| 481 |
+
Initial smoothing coefficient value
|
| 482 |
+
per_channel: bool
|
| 483 |
+
Controls whether every smoothing coefficients are learned
|
| 484 |
+
independently for every input channel
|
| 485 |
+
trainable: bool
|
| 486 |
+
whether to learn the PCEN parameters or use fixed
|
| 487 |
+
skip_transpose : bool
|
| 488 |
+
If False, uses batch x time x channel convention of speechbrain.
|
| 489 |
+
If True, uses batch x channel x time convention.
|
| 490 |
+
|
| 491 |
+
Example
|
| 492 |
+
-------
|
| 493 |
+
>>> inp_tensor = torch.rand([10, 50, 40])
|
| 494 |
+
>>> pcen = ExponentialMovingAverage(40)
|
| 495 |
+
>>> out_tensor = pcen(inp_tensor)
|
| 496 |
+
>>> out_tensor.shape
|
| 497 |
+
torch.Size([10, 50, 40])
|
| 498 |
+
"""
|
| 499 |
+
|
| 500 |
+
def __init__(
|
| 501 |
+
self,
|
| 502 |
+
input_size: int,
|
| 503 |
+
coeff_init: float = 0.04,
|
| 504 |
+
per_channel: bool = False,
|
| 505 |
+
trainable: bool = True,
|
| 506 |
+
skip_transpose: bool = False,
|
| 507 |
+
):
|
| 508 |
+
super().__init__()
|
| 509 |
+
self._coeff_init = coeff_init
|
| 510 |
+
self._per_channel = per_channel
|
| 511 |
+
self.skip_transpose = skip_transpose
|
| 512 |
+
self.trainable = trainable
|
| 513 |
+
weights = (
|
| 514 |
+
torch.ones(
|
| 515 |
+
input_size,
|
| 516 |
+
)
|
| 517 |
+
if self._per_channel
|
| 518 |
+
else torch.ones(
|
| 519 |
+
1,
|
| 520 |
+
)
|
| 521 |
+
)
|
| 522 |
+
self._weights = nn.Parameter(
|
| 523 |
+
weights * self._coeff_init, requires_grad=trainable
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
def forward(self, x):
|
| 527 |
+
"""Returns the normalized input tensor.
|
| 528 |
+
|
| 529 |
+
Arguments
|
| 530 |
+
---------
|
| 531 |
+
x : torch.Tensor (batch, time, channels)
|
| 532 |
+
input to normalize.
|
| 533 |
+
"""
|
| 534 |
+
if not self.skip_transpose:
|
| 535 |
+
x = x.transpose(1, -1)
|
| 536 |
+
w = torch.clamp(self._weights, min=0.0, max=1.0)
|
| 537 |
+
initial_state = x[:, :, 0]
|
| 538 |
+
|
| 539 |
+
def scan(init_state, x, w):
|
| 540 |
+
"""Loops and accumulates."""
|
| 541 |
+
x = x.permute(2, 0, 1)
|
| 542 |
+
acc = init_state
|
| 543 |
+
results = []
|
| 544 |
+
for ix in range(x.shape[0]):
|
| 545 |
+
acc = (w * x[ix]) + ((1.0 - w) * acc)
|
| 546 |
+
results.append(acc.unsqueeze(0))
|
| 547 |
+
results = torch.cat(results, dim=0)
|
| 548 |
+
results = results.permute(1, 2, 0)
|
| 549 |
+
return results
|
| 550 |
+
|
| 551 |
+
output = scan(initial_state, x, w)
|
| 552 |
+
if not self.skip_transpose:
|
| 553 |
+
output = output.transpose(1, -1)
|
| 554 |
+
return output
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
class PCEN(nn.Module):
|
| 558 |
+
"""
|
| 559 |
+
This class implements a learnable Per-channel energy normalization (PCEN) layer, supporting both
|
| 560 |
+
original PCEN as specified in [1] as well as sPCEN as specified in [2]
|
| 561 |
+
|
| 562 |
+
[1] Yuxuan Wang, Pascal Getreuer, Thad Hughes, Richard F. Lyon, Rif A. Saurous, "Trainable Frontend For
|
| 563 |
+
Robust and Far-Field Keyword Spotting", in Proc of ICASSP 2017 (https://arxiv.org/abs/1607.05666)
|
| 564 |
+
|
| 565 |
+
[2] Neil Zeghidour, Olivier Teboul, F{\'e}lix de Chaumont Quitry & Marco Tagliasacchi, "LEAF: A LEARNABLE FRONTEND
|
| 566 |
+
FOR AUDIO CLASSIFICATION", in Proc of ICLR 2021 (https://arxiv.org/abs/2101.08596)
|
| 567 |
+
|
| 568 |
+
The default argument values correspond with those used by [2].
|
| 569 |
+
|
| 570 |
+
Arguments
|
| 571 |
+
---------
|
| 572 |
+
input_size : int
|
| 573 |
+
The expected size of the input.
|
| 574 |
+
alpha: float
|
| 575 |
+
specifies alpha coefficient for PCEN
|
| 576 |
+
smooth_coef: float
|
| 577 |
+
specified smooth coefficient for PCEN
|
| 578 |
+
delta: float
|
| 579 |
+
specifies delta coefficient for PCEN
|
| 580 |
+
root: float
|
| 581 |
+
specifies root coefficient for PCEN
|
| 582 |
+
floor: float
|
| 583 |
+
specifies floor coefficient for PCEN
|
| 584 |
+
trainable: bool
|
| 585 |
+
whether to learn the PCEN parameters or use fixed
|
| 586 |
+
per_channel_smooth_coef: bool
|
| 587 |
+
whether to learn independent smooth coefficients for every channel.
|
| 588 |
+
when True, essentially using sPCEN from [2]
|
| 589 |
+
skip_transpose : bool
|
| 590 |
+
If False, uses batch x time x channel convention of speechbrain.
|
| 591 |
+
If True, uses batch x channel x time convention.
|
| 592 |
+
|
| 593 |
+
Example
|
| 594 |
+
-------
|
| 595 |
+
>>> inp_tensor = torch.rand([10, 50, 40])
|
| 596 |
+
>>> pcen = PCEN(40, alpha=0.96) # sPCEN
|
| 597 |
+
>>> out_tensor = pcen(inp_tensor)
|
| 598 |
+
>>> out_tensor.shape
|
| 599 |
+
torch.Size([10, 50, 40])
|
| 600 |
+
"""
|
| 601 |
+
|
| 602 |
+
def __init__(
|
| 603 |
+
self,
|
| 604 |
+
input_size,
|
| 605 |
+
alpha: float = 0.96,
|
| 606 |
+
smooth_coef: float = 0.04,
|
| 607 |
+
delta: float = 2.0,
|
| 608 |
+
root: float = 2.0,
|
| 609 |
+
floor: float = 1e-12,
|
| 610 |
+
trainable: bool = True,
|
| 611 |
+
per_channel_smooth_coef: bool = True,
|
| 612 |
+
skip_transpose: bool = False,
|
| 613 |
+
):
|
| 614 |
+
super().__init__()
|
| 615 |
+
self._smooth_coef = smooth_coef
|
| 616 |
+
self._floor = floor
|
| 617 |
+
self._per_channel_smooth_coef = per_channel_smooth_coef
|
| 618 |
+
self.skip_transpose = skip_transpose
|
| 619 |
+
self.alpha = nn.Parameter(
|
| 620 |
+
torch.ones(input_size) * alpha, requires_grad=trainable
|
| 621 |
+
)
|
| 622 |
+
self.delta = nn.Parameter(
|
| 623 |
+
torch.ones(input_size) * delta, requires_grad=trainable
|
| 624 |
+
)
|
| 625 |
+
self.root = nn.Parameter(
|
| 626 |
+
torch.ones(input_size) * root, requires_grad=trainable
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
self.ema = ExponentialMovingAverage(
|
| 630 |
+
input_size,
|
| 631 |
+
coeff_init=self._smooth_coef,
|
| 632 |
+
per_channel=self._per_channel_smooth_coef,
|
| 633 |
+
skip_transpose=True,
|
| 634 |
+
trainable=trainable,
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
def forward(self, x):
|
| 638 |
+
"""Returns the normalized input tensor.
|
| 639 |
+
|
| 640 |
+
Arguments
|
| 641 |
+
---------
|
| 642 |
+
x : torch.Tensor (batch, time, channels)
|
| 643 |
+
input to normalize.
|
| 644 |
+
|
| 645 |
+
Returns
|
| 646 |
+
-------
|
| 647 |
+
output : torch.Tensor
|
| 648 |
+
The normalized outputs.
|
| 649 |
+
"""
|
| 650 |
+
if not self.skip_transpose:
|
| 651 |
+
x = x.transpose(1, -1)
|
| 652 |
+
alpha = torch.min(
|
| 653 |
+
self.alpha, torch.tensor(1.0, dtype=x.dtype, device=x.device)
|
| 654 |
+
)
|
| 655 |
+
root = torch.max(
|
| 656 |
+
self.root, torch.tensor(1.0, dtype=x.dtype, device=x.device)
|
| 657 |
+
)
|
| 658 |
+
ema_smoother = self.ema(x)
|
| 659 |
+
one_over_root = 1.0 / root
|
| 660 |
+
output = (
|
| 661 |
+
x / (self._floor + ema_smoother) ** alpha.view(1, -1, 1)
|
| 662 |
+
+ self.delta.view(1, -1, 1)
|
| 663 |
+
) ** one_over_root.view(1, -1, 1) - self.delta.view(
|
| 664 |
+
1, -1, 1
|
| 665 |
+
) ** one_over_root.view(
|
| 666 |
+
1, -1, 1
|
| 667 |
+
)
|
| 668 |
+
if not self.skip_transpose:
|
| 669 |
+
output = output.transpose(1, -1)
|
| 670 |
+
return output
|
indextts/BigVGAN/utils.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
| 2 |
+
# LICENSE is in incl_licenses directory.
|
| 3 |
+
|
| 4 |
+
import glob
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
import matplotlib
|
| 8 |
+
import matplotlib.pylab as plt
|
| 9 |
+
import torch
|
| 10 |
+
from scipy.io.wavfile import write
|
| 11 |
+
from torch.nn.utils import weight_norm
|
| 12 |
+
|
| 13 |
+
matplotlib.use("Agg")
|
| 14 |
+
|
| 15 |
+
MAX_WAV_VALUE = 32768.0
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def plot_spectrogram(spectrogram):
|
| 19 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
| 20 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
| 21 |
+
plt.colorbar(im, ax=ax)
|
| 22 |
+
|
| 23 |
+
fig.canvas.draw()
|
| 24 |
+
plt.close()
|
| 25 |
+
|
| 26 |
+
return fig
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
|
| 30 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
| 31 |
+
im = ax.imshow(
|
| 32 |
+
spectrogram,
|
| 33 |
+
aspect="auto",
|
| 34 |
+
origin="lower",
|
| 35 |
+
interpolation="none",
|
| 36 |
+
vmin=1e-6,
|
| 37 |
+
vmax=clip_max,
|
| 38 |
+
)
|
| 39 |
+
plt.colorbar(im, ax=ax)
|
| 40 |
+
|
| 41 |
+
fig.canvas.draw()
|
| 42 |
+
plt.close()
|
| 43 |
+
|
| 44 |
+
return fig
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 48 |
+
classname = m.__class__.__name__
|
| 49 |
+
if classname.find("Conv") != -1:
|
| 50 |
+
m.weight.data.normal_(mean, std)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def apply_weight_norm(m):
|
| 54 |
+
classname = m.__class__.__name__
|
| 55 |
+
if classname.find("Conv") != -1:
|
| 56 |
+
weight_norm(m)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_padding(kernel_size, dilation=1):
|
| 60 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_checkpoint(filepath, device):
|
| 64 |
+
assert os.path.isfile(filepath)
|
| 65 |
+
print(f"Loading '{filepath}'")
|
| 66 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
| 67 |
+
print("Complete.")
|
| 68 |
+
return checkpoint_dict
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def save_checkpoint(filepath, obj):
|
| 72 |
+
print(f"Saving checkpoint to {filepath}")
|
| 73 |
+
torch.save(obj, filepath)
|
| 74 |
+
print("Complete.")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def scan_checkpoint(cp_dir, prefix, renamed_file=None):
|
| 78 |
+
# Fallback to original scanning logic first
|
| 79 |
+
pattern = os.path.join(cp_dir, prefix + "????????")
|
| 80 |
+
cp_list = glob.glob(pattern)
|
| 81 |
+
|
| 82 |
+
if len(cp_list) > 0:
|
| 83 |
+
last_checkpoint_path = sorted(cp_list)[-1]
|
| 84 |
+
print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
|
| 85 |
+
return last_checkpoint_path
|
| 86 |
+
|
| 87 |
+
# If no pattern-based checkpoints are found, check for renamed file
|
| 88 |
+
if renamed_file:
|
| 89 |
+
renamed_path = os.path.join(cp_dir, renamed_file)
|
| 90 |
+
if os.path.isfile(renamed_path):
|
| 91 |
+
print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
|
| 92 |
+
return renamed_path
|
| 93 |
+
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def save_audio(audio, path, sr):
|
| 98 |
+
# wav: torch with 1d shape
|
| 99 |
+
audio = audio * MAX_WAV_VALUE
|
| 100 |
+
audio = audio.cpu().numpy().astype("int16")
|
| 101 |
+
write(path, sr, audio)
|
indextts/__init__.py
ADDED
|
File without changes
|
indextts/cli.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import warnings
|
| 4 |
+
# Suppress warnings from tensorflow and other libraries
|
| 5 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 6 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 7 |
+
def main():
|
| 8 |
+
import argparse
|
| 9 |
+
parser = argparse.ArgumentParser(description="IndexTTS Command Line")
|
| 10 |
+
parser.add_argument("text", type=str, help="Text to be synthesized")
|
| 11 |
+
parser.add_argument("-v", "--voice", type=str, required=True, help="Path to the audio prompt file (wav format)")
|
| 12 |
+
parser.add_argument("-o", "--output_path", type=str, default="gen.wav", help="Path to the output wav file")
|
| 13 |
+
parser.add_argument("-c", "--config", type=str, default="checkpoints/config.yaml", help="Path to the config file. Default is 'checkpoints/config.yaml'")
|
| 14 |
+
parser.add_argument("--model_dir", type=str, default="checkpoints", help="Path to the model directory. Default is 'checkpoints'")
|
| 15 |
+
parser.add_argument("--fp16", action="store_true", default=False, help="Use FP16 for inference if available")
|
| 16 |
+
parser.add_argument("-f", "--force", action="store_true", default=False, help="Force to overwrite the output file if it exists")
|
| 17 |
+
parser.add_argument("-d", "--device", type=str, default=None, help="Device to run the model on (cpu, cuda, mps, xpu)." )
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
if len(args.text.strip()) == 0:
|
| 20 |
+
print("ERROR: Text is empty.")
|
| 21 |
+
parser.print_help()
|
| 22 |
+
sys.exit(1)
|
| 23 |
+
if not os.path.exists(args.voice):
|
| 24 |
+
print(f"Audio prompt file {args.voice} does not exist.")
|
| 25 |
+
parser.print_help()
|
| 26 |
+
sys.exit(1)
|
| 27 |
+
if not os.path.exists(args.config):
|
| 28 |
+
print(f"Config file {args.config} does not exist.")
|
| 29 |
+
parser.print_help()
|
| 30 |
+
sys.exit(1)
|
| 31 |
+
|
| 32 |
+
output_path = args.output_path
|
| 33 |
+
if os.path.exists(output_path):
|
| 34 |
+
if not args.force:
|
| 35 |
+
print(f"ERROR: Output file {output_path} already exists. Use --force to overwrite.")
|
| 36 |
+
parser.print_help()
|
| 37 |
+
sys.exit(1)
|
| 38 |
+
else:
|
| 39 |
+
os.remove(output_path)
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
import torch
|
| 43 |
+
except ImportError:
|
| 44 |
+
print("ERROR: PyTorch is not installed. Please install it first.")
|
| 45 |
+
sys.exit(1)
|
| 46 |
+
|
| 47 |
+
if args.device is None:
|
| 48 |
+
if torch.cuda.is_available():
|
| 49 |
+
args.device = "cuda:0"
|
| 50 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 51 |
+
args.device = "xpu"
|
| 52 |
+
elif hasattr(torch, "mps") and torch.mps.is_available():
|
| 53 |
+
args.device = "mps"
|
| 54 |
+
else:
|
| 55 |
+
args.device = "cpu"
|
| 56 |
+
args.fp16 = False # Disable FP16 on CPU
|
| 57 |
+
print("WARNING: Running on CPU may be slow.")
|
| 58 |
+
|
| 59 |
+
# TODO: Add CLI support for IndexTTS2.
|
| 60 |
+
from indextts.infer import IndexTTS
|
| 61 |
+
tts = IndexTTS(cfg_path=args.config, model_dir=args.model_dir, use_fp16=args.fp16, device=args.device)
|
| 62 |
+
tts.infer(audio_prompt=args.voice, text=args.text.strip(), output_path=output_path)
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
main()
|
indextts/gpt/__init__.py
ADDED
|
File without changes
|
indextts/gpt/conformer/__init__.py
ADDED
|
File without changes
|