RNNoise (models)
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- models/WaveRNNModel/.gitattributes +1 -0
- models/WaveRNNModel/.gitignore +48 -0
- models/WaveRNNModel/LICENSE.txt +21 -0
- models/WaveRNNModel/README.md +102 -0
- models/WaveRNNModel/__init__.py +13 -0
- models/WaveRNNModel/assets/WaveRNN.png +0 -0
- models/WaveRNNModel/assets/tacotron_wavernn.png +3 -0
- models/WaveRNNModel/assets/training_viz.gif +3 -0
- models/WaveRNNModel/assets/wavernn_alt_model_hrz2.png +3 -0
- models/WaveRNNModel/data/dataset.pkl +3 -0
- models/WaveRNNModel/data/text_dict.pkl +3 -0
- models/WaveRNNModel/gen_tacotron.py +178 -0
- models/WaveRNNModel/gen_wavernn.py +142 -0
- models/WaveRNNModel/hparams.py +101 -0
- models/WaveRNNModel/loss_plot.py +70 -0
- models/WaveRNNModel/model_outputs/ljspeech_lsa_smooth_attention.tacotron.zip +3 -0
- models/WaveRNNModel/model_outputs/ljspeech_mol.wavernn.zip +3 -0
- models/WaveRNNModel/models/__init__.py +0 -0
- models/WaveRNNModel/models/deepmind_version.py +176 -0
- models/WaveRNNModel/models/fatchord_version.py +435 -0
- models/WaveRNNModel/models/tacotron.py +469 -0
- models/WaveRNNModel/notebooks/NB1 - Fit a Sine Wave.ipynb +0 -0
- models/WaveRNNModel/notebooks/NB2 - Fit a Short Sample.ipynb +0 -0
- models/WaveRNNModel/notebooks/NB3 - Fit a 30min Sample.ipynb +0 -0
- models/WaveRNNModel/notebooks/NB4a - Alternative Model (Preprocessing).ipynb +0 -0
- models/WaveRNNModel/notebooks/NB4b - Alternative Model (Training).ipynb +0 -0
- models/WaveRNNModel/notebooks/Pruning - Scratchpad.ipynb +0 -0
- models/WaveRNNModel/notebooks/__init__.py +0 -0
- models/WaveRNNModel/notebooks/models/wavernn.py +172 -0
- models/WaveRNNModel/notebooks/outputs/nb1/model_output.wav +0 -0
- models/WaveRNNModel/notebooks/outputs/nb2/3k_steps.wav +3 -0
- models/WaveRNNModel/notebooks/outputs/nb3/12k_steps.wav +3 -0
- models/WaveRNNModel/notebooks/utils/__init__.py +0 -0
- models/WaveRNNModel/notebooks/utils/display.py +40 -0
- models/WaveRNNModel/notebooks/utils/dsp.py +70 -0
- models/WaveRNNModel/preprocess.py +103 -0
- models/WaveRNNModel/quick_start.py +122 -0
- models/WaveRNNModel/quick_start/tts_weights/latest_weights.pyt +3 -0
- models/WaveRNNModel/quick_start/voc_weights/latest_weights.pyt +3 -0
- models/WaveRNNModel/requirements.txt +6 -0
- models/WaveRNNModel/sentences.txt +6 -0
- models/WaveRNNModel/source.txt +1 -0
- models/WaveRNNModel/train_tacotron.py +203 -0
- models/WaveRNNModel/train_wavernn.py +164 -0
- models/WaveRNNModel/utils/__init__.py +106 -0
- models/WaveRNNModel/utils/checkpoints.py +128 -0
- models/WaveRNNModel/utils/dataset.py +232 -0
- models/WaveRNNModel/utils/display.py +121 -0
- models/WaveRNNModel/utils/distribution.py +132 -0
.gitattributes
CHANGED
|
@@ -49,3 +49,10 @@ models/ailia-models/code/babble_15dB.wav filter=lfs diff=lfs merge=lfs -text
|
|
| 49 |
models/ailia-models/code/denoised.wav filter=lfs diff=lfs merge=lfs -text
|
| 50 |
models/rnnoise-wrapper/weights_5h_b_500k.hdf5 filter=lfs diff=lfs merge=lfs -text
|
| 51 |
models/rnnoise-wrapper/weights_5h_ru_500k.hdf5 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
models/ailia-models/code/denoised.wav filter=lfs diff=lfs merge=lfs -text
|
| 50 |
models/rnnoise-wrapper/weights_5h_b_500k.hdf5 filter=lfs diff=lfs merge=lfs -text
|
| 51 |
models/rnnoise-wrapper/weights_5h_ru_500k.hdf5 filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
models/WaveRNNModel/assets/tacotron_wavernn.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
models/WaveRNNModel/assets/training_viz.gif filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
models/WaveRNNModel/assets/wavernn_alt_model_hrz2.png filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
models/WaveRNNModel/notebooks/outputs/nb2/3k_steps.wav filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
models/WaveRNNModel/notebooks/outputs/nb3/12k_steps.wav filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
models/WaveRNNModel/quick_start/tts_weights/latest_weights.pyt filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
models/WaveRNNModel/quick_start/voc_weights/latest_weights.pyt filter=lfs diff=lfs merge=lfs -text
|
models/WaveRNNModel/.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.ipynb linguist-language=Python
|
models/WaveRNNModel/.gitignore
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# IDE files
|
| 2 |
+
.idea
|
| 3 |
+
.vscode
|
| 4 |
+
|
| 5 |
+
# Mac files
|
| 6 |
+
.DS_Store
|
| 7 |
+
|
| 8 |
+
# Environments
|
| 9 |
+
.env
|
| 10 |
+
.venv
|
| 11 |
+
env/
|
| 12 |
+
venv/
|
| 13 |
+
ENV/
|
| 14 |
+
env.bak/
|
| 15 |
+
venv.bak/
|
| 16 |
+
|
| 17 |
+
# Byte-compiled / optimized / DLL files
|
| 18 |
+
__pycache__/
|
| 19 |
+
*.py[cod]
|
| 20 |
+
*$py.class
|
| 21 |
+
|
| 22 |
+
# Distribution / packaging
|
| 23 |
+
.Python
|
| 24 |
+
build/
|
| 25 |
+
develop-eggs/
|
| 26 |
+
dist/
|
| 27 |
+
downloads/
|
| 28 |
+
eggs/
|
| 29 |
+
.eggs/
|
| 30 |
+
lib/
|
| 31 |
+
lib64/
|
| 32 |
+
parts/
|
| 33 |
+
sdist/
|
| 34 |
+
var/
|
| 35 |
+
wheels/
|
| 36 |
+
pip-wheel-metadata/
|
| 37 |
+
share/python-wheels/
|
| 38 |
+
*.egg-info/
|
| 39 |
+
.installed.cfg
|
| 40 |
+
*.egg
|
| 41 |
+
MANIFEST
|
| 42 |
+
|
| 43 |
+
# Installer logs
|
| 44 |
+
pip-log.txt
|
| 45 |
+
pip-delete-this-directory.txt
|
| 46 |
+
|
| 47 |
+
# Jupyter Notebook
|
| 48 |
+
.ipynb_checkpoints
|
models/WaveRNNModel/LICENSE.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2019 fatchord (https://github.com/fatchord)
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
models/WaveRNNModel/README.md
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# WaveRNN
|
| 2 |
+
|
| 3 |
+
##### (Update: Vanilla Tacotron One TTS system just implemented - more coming soon!)
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
Pytorch implementation of Deepmind's WaveRNN model from [Efficient Neural Audio Synthesis](https://arxiv.org/abs/1802.08435v1)
|
| 8 |
+
|
| 9 |
+
# Installation
|
| 10 |
+
|
| 11 |
+
Ensure you have:
|
| 12 |
+
|
| 13 |
+
* Python >= 3.6
|
| 14 |
+
* [Pytorch 1 with CUDA](https://pytorch.org/)
|
| 15 |
+
|
| 16 |
+
Then install the rest with pip:
|
| 17 |
+
|
| 18 |
+
> pip install -r requirements.txt
|
| 19 |
+
|
| 20 |
+
# How to Use
|
| 21 |
+
|
| 22 |
+
### Quick Start
|
| 23 |
+
|
| 24 |
+
If you want to use TTS functionality immediately you can simply use:
|
| 25 |
+
|
| 26 |
+
> python quick_start.py
|
| 27 |
+
|
| 28 |
+
This will generate everything in the default sentences.txt file and output to a new 'quick_start' folder where you can playback the wav files and take a look at the attention plots
|
| 29 |
+
|
| 30 |
+
You can also use that script to generate custom tts sentences and/or use '-u' to generate unbatched (better audio quality):
|
| 31 |
+
|
| 32 |
+
> python quick_start.py -u --input_text "What will happen if I run this command?"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
### Training your own Models
|
| 36 |
+

|
| 37 |
+
|
| 38 |
+
Download the [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) Dataset.
|
| 39 |
+
|
| 40 |
+
Edit **hparams.py**, point **wav_path** to your dataset and run:
|
| 41 |
+
|
| 42 |
+
> python preprocess.py
|
| 43 |
+
|
| 44 |
+
or use preprocess.py --path to point directly to the dataset
|
| 45 |
+
___
|
| 46 |
+
|
| 47 |
+
Here's my recommendation on what order to run things:
|
| 48 |
+
|
| 49 |
+
1 - Train Tacotron with:
|
| 50 |
+
|
| 51 |
+
> python train_tacotron.py
|
| 52 |
+
|
| 53 |
+
2 - You can leave that finish training or at any point you can use:
|
| 54 |
+
|
| 55 |
+
> python train_tacotron.py --force_gta
|
| 56 |
+
|
| 57 |
+
this will force tactron to create a GTA dataset even if it hasn't finish training.
|
| 58 |
+
|
| 59 |
+
3 - Train WaveRNN with:
|
| 60 |
+
|
| 61 |
+
> python train_wavernn.py --gta
|
| 62 |
+
|
| 63 |
+
NB: You can always just run train_wavernn.py without --gta if you're not interested in TTS.
|
| 64 |
+
|
| 65 |
+
4 - Generate Sentences with both models using:
|
| 66 |
+
|
| 67 |
+
> python gen_tacotron.py wavernn
|
| 68 |
+
|
| 69 |
+
this will generate default sentences. If you want generate custom sentences you can use
|
| 70 |
+
|
| 71 |
+
> python gen_tacotron.py --input_text "this is whatever you want it to be" wavernn
|
| 72 |
+
|
| 73 |
+
And finally, you can always use --help on any of those scripts to see what options are available :)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# Samples
|
| 78 |
+
|
| 79 |
+
[Can be found here.](https://fatchord.github.io/model_outputs/)
|
| 80 |
+
|
| 81 |
+
# Pretrained Models
|
| 82 |
+
|
| 83 |
+
Currently there are two pretrained models available in the /pretrained/ folder':
|
| 84 |
+
|
| 85 |
+
Both are trained on LJSpeech
|
| 86 |
+
|
| 87 |
+
* WaveRNN (Mixture of Logistics output) trained to 800k steps
|
| 88 |
+
* Tacotron trained to 180k steps
|
| 89 |
+
|
| 90 |
+
____
|
| 91 |
+
|
| 92 |
+
### References
|
| 93 |
+
|
| 94 |
+
* [Efficient Neural Audio Synthesis](https://arxiv.org/abs/1802.08435v1)
|
| 95 |
+
* [Tacotron: Towards End-to-End Speech Synthesis](https://arxiv.org/abs/1703.10135)
|
| 96 |
+
* [Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions](https://arxiv.org/abs/1712.05884)
|
| 97 |
+
|
| 98 |
+
### Acknowlegements
|
| 99 |
+
|
| 100 |
+
* [https://github.com/keithito/tacotron](https://github.com/keithito/tacotron)
|
| 101 |
+
* [https://github.com/r9y9/wavenet_vocoder](https://github.com/r9y9/wavenet_vocoder)
|
| 102 |
+
* Special thanks to github users [G-Wang](https://github.com/G-Wang), [geneing](https://github.com/geneing) & [erogol](https://github.com/erogol)
|
models/WaveRNNModel/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
# 获取当前包的绝对路径(即 WaveRNN_master 的目录)
|
| 6 |
+
package_dir = Path(__file__).resolve().parent
|
| 7 |
+
|
| 8 |
+
# 将该路径加入 sys.path,使其成为模块搜索的根目录
|
| 9 |
+
if str(package_dir) not in sys.path:
|
| 10 |
+
sys.path.insert(0, str(package_dir))
|
| 11 |
+
|
| 12 |
+
# 设置环境变量 PYTHONPATH(可选,增强兼容性)
|
| 13 |
+
os.environ["PYTHONPATH"] = str(package_dir) + os.pathsep + os.environ.get("PYTHONPATH", "")
|
models/WaveRNNModel/assets/WaveRNN.png
ADDED
|
models/WaveRNNModel/assets/tacotron_wavernn.png
ADDED
|
Git LFS Details
|
models/WaveRNNModel/assets/training_viz.gif
ADDED
|
Git LFS Details
|
models/WaveRNNModel/assets/wavernn_alt_model_hrz2.png
ADDED
|
Git LFS Details
|
models/WaveRNNModel/data/dataset.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5f323c9bfdfcd5709ad210f538851303c62f63e4285acaee1af791d7da671d88
|
| 3 |
+
size 234790
|
models/WaveRNNModel/data/text_dict.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3a9c7752b430df10a503697b4c563a7ce82cbf732ca211dfeed5d0bffae6e5a6
|
| 3 |
+
size 1531658
|
models/WaveRNNModel/gen_tacotron.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from models.fatchord_version import WaveRNN
|
| 3 |
+
from utils import hparams as hp
|
| 4 |
+
from utils.text.symbols import symbols
|
| 5 |
+
from utils.paths import Paths
|
| 6 |
+
from models.tacotron import Tacotron
|
| 7 |
+
import argparse
|
| 8 |
+
from utils.text import text_to_sequence
|
| 9 |
+
from utils.display import save_attention, simple_table
|
| 10 |
+
from utils.dsp import reconstruct_waveform, save_wav
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
def gen_tacotron_from_inputtext(args_list=None):
|
| 14 |
+
# Parse Arguments
|
| 15 |
+
parser = argparse.ArgumentParser(description='TTS Generator')
|
| 16 |
+
parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!')
|
| 17 |
+
parser.add_argument('--tts_weights', type=str, help='[string/path] Load in different Tacotron weights')
|
| 18 |
+
parser.add_argument('--save_attention', '-a', dest='save_attn', action='store_true', help='Save Attention Plots')
|
| 19 |
+
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
|
| 20 |
+
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
|
| 21 |
+
|
| 22 |
+
parser.set_defaults(input_text=None)
|
| 23 |
+
parser.set_defaults(weights_path=None)
|
| 24 |
+
|
| 25 |
+
# name of subcommand goes to args.vocoder
|
| 26 |
+
subparsers = parser.add_subparsers(required=True, dest='vocoder')
|
| 27 |
+
|
| 28 |
+
wr_parser = subparsers.add_parser('wavernn', aliases=['wr'])
|
| 29 |
+
wr_parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation')
|
| 30 |
+
wr_parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation')
|
| 31 |
+
wr_parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
|
| 32 |
+
wr_parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
|
| 33 |
+
wr_parser.add_argument('--voc_weights', type=str, help='[string/path] Load in different WaveRNN weights')
|
| 34 |
+
wr_parser.set_defaults(batched=None)
|
| 35 |
+
|
| 36 |
+
gl_parser = subparsers.add_parser('griffinlim', aliases=['gl'])
|
| 37 |
+
gl_parser.add_argument('--iters', type=int, default=32, help='[int] number of griffinlim iterations')
|
| 38 |
+
|
| 39 |
+
args = parser.parse_args(args=args_list)
|
| 40 |
+
|
| 41 |
+
if args.vocoder in ['griffinlim', 'gl']:
|
| 42 |
+
args.vocoder = 'griffinlim'
|
| 43 |
+
elif args.vocoder in ['wavernn', 'wr']:
|
| 44 |
+
args.vocoder = 'wavernn'
|
| 45 |
+
else:
|
| 46 |
+
raise argparse.ArgumentError('Must provide a valid vocoder type!')
|
| 47 |
+
|
| 48 |
+
if not hp.is_configured():
|
| 49 |
+
print("args.hp_file:",args.hp_file)
|
| 50 |
+
hp.configure(args.hp_file) # Load hparams from file
|
| 51 |
+
# set defaults for any arguments that depend on hparams
|
| 52 |
+
if args.vocoder == 'wavernn':
|
| 53 |
+
if args.target is None:
|
| 54 |
+
args.target = hp.voc_target
|
| 55 |
+
if args.overlap is None:
|
| 56 |
+
args.overlap = hp.voc_overlap
|
| 57 |
+
if args.batched is None:
|
| 58 |
+
args.batched = hp.voc_gen_batched
|
| 59 |
+
|
| 60 |
+
batched = args.batched
|
| 61 |
+
target = args.target
|
| 62 |
+
overlap = args.overlap
|
| 63 |
+
|
| 64 |
+
input_text = args.input_text
|
| 65 |
+
tts_weights = args.tts_weights
|
| 66 |
+
save_attn = args.save_attn
|
| 67 |
+
|
| 68 |
+
paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
|
| 69 |
+
|
| 70 |
+
if not args.force_cpu and torch.cuda.is_available():
|
| 71 |
+
device = torch.device('cuda')
|
| 72 |
+
else:
|
| 73 |
+
device = torch.device('cpu')
|
| 74 |
+
print('Using device:', device)
|
| 75 |
+
|
| 76 |
+
if args.vocoder == 'wavernn':
|
| 77 |
+
print('\nInitialising WaveRNN Model...\n')
|
| 78 |
+
# Instantiate WaveRNN Model
|
| 79 |
+
voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
|
| 80 |
+
fc_dims=hp.voc_fc_dims,
|
| 81 |
+
bits=hp.bits,
|
| 82 |
+
pad=hp.voc_pad,
|
| 83 |
+
upsample_factors=hp.voc_upsample_factors,
|
| 84 |
+
feat_dims=hp.num_mels,
|
| 85 |
+
compute_dims=hp.voc_compute_dims,
|
| 86 |
+
res_out_dims=hp.voc_res_out_dims,
|
| 87 |
+
res_blocks=hp.voc_res_blocks,
|
| 88 |
+
hop_length=hp.hop_length,
|
| 89 |
+
sample_rate=hp.sample_rate,
|
| 90 |
+
mode=hp.voc_mode).to(device)
|
| 91 |
+
|
| 92 |
+
voc_load_path = args.voc_weights if args.voc_weights else paths.voc_latest_weights
|
| 93 |
+
voc_model.load(voc_load_path)
|
| 94 |
+
|
| 95 |
+
print('\nInitialising Tacotron Model...\n')
|
| 96 |
+
|
| 97 |
+
# Instantiate Tacotron Model
|
| 98 |
+
tts_model = Tacotron(embed_dims=hp.tts_embed_dims,
|
| 99 |
+
num_chars=len(symbols),
|
| 100 |
+
encoder_dims=hp.tts_encoder_dims,
|
| 101 |
+
decoder_dims=hp.tts_decoder_dims,
|
| 102 |
+
n_mels=hp.num_mels,
|
| 103 |
+
fft_bins=hp.num_mels,
|
| 104 |
+
postnet_dims=hp.tts_postnet_dims,
|
| 105 |
+
encoder_K=hp.tts_encoder_K,
|
| 106 |
+
lstm_dims=hp.tts_lstm_dims,
|
| 107 |
+
postnet_K=hp.tts_postnet_K,
|
| 108 |
+
num_highways=hp.tts_num_highways,
|
| 109 |
+
dropout=hp.tts_dropout,
|
| 110 |
+
stop_threshold=hp.tts_stop_threshold).to(device)
|
| 111 |
+
|
| 112 |
+
tts_load_path = tts_weights if tts_weights else paths.tts_latest_weights
|
| 113 |
+
tts_model.load(tts_load_path)
|
| 114 |
+
|
| 115 |
+
if input_text:
|
| 116 |
+
inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
|
| 117 |
+
else:
|
| 118 |
+
with open('sentences.txt') as f:
|
| 119 |
+
inputs = [text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f]
|
| 120 |
+
|
| 121 |
+
if args.vocoder == 'wavernn':
|
| 122 |
+
voc_k = voc_model.get_step() // 1000
|
| 123 |
+
tts_k = tts_model.get_step() // 1000
|
| 124 |
+
|
| 125 |
+
simple_table([('Tacotron', str(tts_k) + 'k'),
|
| 126 |
+
('r', tts_model.r),
|
| 127 |
+
('Vocoder Type', 'WaveRNN'),
|
| 128 |
+
('WaveRNN', str(voc_k) + 'k'),
|
| 129 |
+
('Generation Mode', 'Batched' if batched else 'Unbatched'),
|
| 130 |
+
('Target Samples', target if batched else 'N/A'),
|
| 131 |
+
('Overlap Samples', overlap if batched else 'N/A')])
|
| 132 |
+
|
| 133 |
+
elif args.vocoder == 'griffinlim':
|
| 134 |
+
tts_k = tts_model.get_step() // 1000
|
| 135 |
+
simple_table([('Tacotron', str(tts_k) + 'k'),
|
| 136 |
+
('r', tts_model.r),
|
| 137 |
+
('Vocoder Type', 'Griffin-Lim'),
|
| 138 |
+
('GL Iters', args.iters)])
|
| 139 |
+
|
| 140 |
+
for i, x in enumerate(inputs, 1):
|
| 141 |
+
|
| 142 |
+
print(f'\n| Generating {i}/{len(inputs)}')
|
| 143 |
+
_, m, attention = tts_model.generate(x)
|
| 144 |
+
# Fix mel spectrogram scaling to be from 0 to 1
|
| 145 |
+
m = (m + 4) / 8
|
| 146 |
+
np.clip(m, 0, 1, out=m)
|
| 147 |
+
|
| 148 |
+
if args.vocoder == 'griffinlim':
|
| 149 |
+
v_type = args.vocoder
|
| 150 |
+
elif args.vocoder == 'wavernn' and args.batched:
|
| 151 |
+
v_type = 'wavernn_batched'
|
| 152 |
+
else:
|
| 153 |
+
v_type = 'wavernn_unbatched'
|
| 154 |
+
|
| 155 |
+
if input_text:
|
| 156 |
+
print("path:",paths.tts_output)
|
| 157 |
+
save_path = paths.tts_output/f'__input_{input_text[:10]}_{v_type}_{tts_k}k.wav'
|
| 158 |
+
else:
|
| 159 |
+
print("path:",paths.tts_output)
|
| 160 |
+
save_path = paths.tts_output/f'{i}_{v_type}_{tts_k}k.wav'
|
| 161 |
+
|
| 162 |
+
if save_attn: save_attention(attention, save_path)
|
| 163 |
+
|
| 164 |
+
if args.vocoder == 'wavernn':
|
| 165 |
+
m = torch.tensor(m).unsqueeze(0)
|
| 166 |
+
voc_model.generate(m, save_path, batched, hp.voc_target, hp.voc_overlap, hp.mu_law)
|
| 167 |
+
elif args.vocoder == 'griffinlim':
|
| 168 |
+
wav = reconstruct_waveform(m, n_iter=args.iters)
|
| 169 |
+
save_wav(wav, save_path)
|
| 170 |
+
|
| 171 |
+
print('\n\nDone.\n')
|
| 172 |
+
return save_path
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
if __name__ == "__main__":
|
| 177 |
+
|
| 178 |
+
gen_tacotron_from_inputtext()
|
models/WaveRNNModel/gen_wavernn.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils.dataset import get_vocoder_datasets
|
| 2 |
+
from utils.dsp import *
|
| 3 |
+
from models.fatchord_version import WaveRNN
|
| 4 |
+
from utils.paths import Paths
|
| 5 |
+
from utils.display import simple_table
|
| 6 |
+
import torch
|
| 7 |
+
import argparse
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def gen_testset(model: WaveRNN, test_set, samples, batched, target, overlap, save_path: Path):
|
| 12 |
+
|
| 13 |
+
k = model.get_step() // 1000
|
| 14 |
+
|
| 15 |
+
for i, (m, x) in enumerate(test_set, 1):
|
| 16 |
+
|
| 17 |
+
if i > samples: break
|
| 18 |
+
|
| 19 |
+
print('\n| Generating: %i/%i' % (i, samples))
|
| 20 |
+
|
| 21 |
+
x = x[0].numpy()
|
| 22 |
+
|
| 23 |
+
bits = 16 if hp.voc_mode == 'MOL' else hp.bits
|
| 24 |
+
|
| 25 |
+
if hp.mu_law and hp.voc_mode != 'MOL':
|
| 26 |
+
x = decode_mu_law(x, 2**bits, from_labels=True)
|
| 27 |
+
else:
|
| 28 |
+
x = label_2_float(x, bits)
|
| 29 |
+
|
| 30 |
+
save_wav(x, save_path/f'{k}k_steps_{i}_target.wav')
|
| 31 |
+
|
| 32 |
+
batch_str = f'gen_batched_target{target}_overlap{overlap}' if batched else 'gen_NOT_BATCHED'
|
| 33 |
+
save_str = str(save_path/f'{k}k_steps_{i}_{batch_str}.wav')
|
| 34 |
+
|
| 35 |
+
_ = model.generate(m, save_str, batched, target, overlap, hp.mu_law)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def gen_from_file(model: WaveRNN, load_path: Path, save_path: Path, batched, target, overlap):
|
| 39 |
+
|
| 40 |
+
k = model.get_step() // 1000
|
| 41 |
+
file_name = load_path.stem
|
| 42 |
+
|
| 43 |
+
suffix = load_path.suffix
|
| 44 |
+
if suffix == ".wav":
|
| 45 |
+
wav = load_wav(load_path)
|
| 46 |
+
save_wav(wav, save_path/f'__{file_name}__{k}k_steps_target.wav')
|
| 47 |
+
mel = melspectrogram(wav)
|
| 48 |
+
elif suffix == ".npy":
|
| 49 |
+
mel = np.load(load_path)
|
| 50 |
+
if mel.ndim != 2 or mel.shape[0] != hp.num_mels:
|
| 51 |
+
raise ValueError(f'Expected a numpy array shaped (n_mels, n_hops), but got {wav.shape}!')
|
| 52 |
+
_max = np.max(mel)
|
| 53 |
+
_min = np.min(mel)
|
| 54 |
+
if _max >= 1.01 or _min <= -0.01:
|
| 55 |
+
raise ValueError(f'Expected spectrogram range in [0,1] but was instead [{_min}, {_max}]')
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(f"Expected an extension of .wav or .npy, but got {suffix}!")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
mel = torch.tensor(mel).unsqueeze(0)
|
| 61 |
+
|
| 62 |
+
batch_str = f'gen_batched_target{target}_overlap{overlap}' if batched else 'gen_NOT_BATCHED'
|
| 63 |
+
save_str = save_path/f'__{file_name}__{k}k_steps_{batch_str}.wav'
|
| 64 |
+
|
| 65 |
+
_ = model.generate(mel, save_str, batched, target, overlap, hp.mu_law)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
|
| 70 |
+
parser = argparse.ArgumentParser(description='Generate WaveRNN Samples')
|
| 71 |
+
parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation')
|
| 72 |
+
parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation')
|
| 73 |
+
parser.add_argument('--samples', '-s', type=int, help='[int] number of utterances to generate')
|
| 74 |
+
parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
|
| 75 |
+
parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
|
| 76 |
+
parser.add_argument('--file', '-f', type=str, help='[string/path] for testing a wav outside dataset')
|
| 77 |
+
parser.add_argument('--voc_weights', '-w', type=str, help='[string/path] Load in different WaveRNN weights')
|
| 78 |
+
parser.add_argument('--gta', '-g', dest='gta', action='store_true', help='Generate from GTA testset')
|
| 79 |
+
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
|
| 80 |
+
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
|
| 81 |
+
|
| 82 |
+
parser.set_defaults(batched=None)
|
| 83 |
+
|
| 84 |
+
args = parser.parse_args()
|
| 85 |
+
|
| 86 |
+
hp.configure(args.hp_file) # Load hparams from file
|
| 87 |
+
# set defaults for any arguments that depend on hparams
|
| 88 |
+
if args.target is None:
|
| 89 |
+
args.target = hp.voc_target
|
| 90 |
+
if args.overlap is None:
|
| 91 |
+
args.overlap = hp.voc_overlap
|
| 92 |
+
if args.batched is None:
|
| 93 |
+
args.batched = hp.voc_gen_batched
|
| 94 |
+
if args.samples is None:
|
| 95 |
+
args.samples = hp.voc_gen_at_checkpoint
|
| 96 |
+
|
| 97 |
+
batched = args.batched
|
| 98 |
+
samples = args.samples
|
| 99 |
+
target = args.target
|
| 100 |
+
overlap = args.overlap
|
| 101 |
+
file = args.file
|
| 102 |
+
gta = args.gta
|
| 103 |
+
|
| 104 |
+
if not args.force_cpu and torch.cuda.is_available():
|
| 105 |
+
device = torch.device('cuda')
|
| 106 |
+
else:
|
| 107 |
+
device = torch.device('cpu')
|
| 108 |
+
print('Using device:', device)
|
| 109 |
+
|
| 110 |
+
print('\nInitialising Model...\n')
|
| 111 |
+
|
| 112 |
+
model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
|
| 113 |
+
fc_dims=hp.voc_fc_dims,
|
| 114 |
+
bits=hp.bits,
|
| 115 |
+
pad=hp.voc_pad,
|
| 116 |
+
upsample_factors=hp.voc_upsample_factors,
|
| 117 |
+
feat_dims=hp.num_mels,
|
| 118 |
+
compute_dims=hp.voc_compute_dims,
|
| 119 |
+
res_out_dims=hp.voc_res_out_dims,
|
| 120 |
+
res_blocks=hp.voc_res_blocks,
|
| 121 |
+
hop_length=hp.hop_length,
|
| 122 |
+
sample_rate=hp.sample_rate,
|
| 123 |
+
mode=hp.voc_mode).to(device)
|
| 124 |
+
|
| 125 |
+
paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
|
| 126 |
+
|
| 127 |
+
voc_weights = args.voc_weights if args.voc_weights else paths.voc_latest_weights
|
| 128 |
+
|
| 129 |
+
model.load(voc_weights)
|
| 130 |
+
|
| 131 |
+
simple_table([('Generation Mode', 'Batched' if batched else 'Unbatched'),
|
| 132 |
+
('Target Samples', target if batched else 'N/A'),
|
| 133 |
+
('Overlap Samples', overlap if batched else 'N/A')])
|
| 134 |
+
|
| 135 |
+
if file:
|
| 136 |
+
file = Path(file).expanduser()
|
| 137 |
+
gen_from_file(model, file, paths.voc_output, batched, target, overlap)
|
| 138 |
+
else:
|
| 139 |
+
_, test_set = get_vocoder_datasets(paths.data, 1, gta)
|
| 140 |
+
gen_testset(model, test_set, samples, batched, target, overlap, paths.voc_output)
|
| 141 |
+
|
| 142 |
+
print('\n\nExiting...\n')
|
models/WaveRNNModel/hparams.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# CONFIG -----------------------------------------------------------------------------------------------------------#
|
| 3 |
+
|
| 4 |
+
# Here are the input and output data paths (Note: you can override wav_path in preprocess.py)
|
| 5 |
+
wav_path = 'E:\\智能语音处理系统\\Noise-suppression-and-speech-recognition-systems-master\\WaveRNNModel\\data\\LJSpeech-1.1\\wavs'
|
| 6 |
+
data_path = 'E:\\智能语音处理系统\\Noise-suppression-and-speech-recognition-systems-master\\WaveRNNModel\\data'
|
| 7 |
+
|
| 8 |
+
# model ids are separate - that way you can use a new tts with an old wavernn and vice versa
|
| 9 |
+
# NB: expect undefined behaviour if models were trained on different DSP settings
|
| 10 |
+
voc_model_id = 'ljspeech_mol'
|
| 11 |
+
tts_model_id = 'ljspeech_lsa_smooth_attention'
|
| 12 |
+
|
| 13 |
+
# set this to True if you are only interested in WaveRNN
|
| 14 |
+
ignore_tts = False
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# DSP --------------------------------------------------------------------------------------------------------------#
|
| 18 |
+
|
| 19 |
+
# Settings for all models
|
| 20 |
+
sample_rate = 22050
|
| 21 |
+
n_fft = 2048
|
| 22 |
+
fft_bins = n_fft // 2 + 1
|
| 23 |
+
num_mels = 80
|
| 24 |
+
hop_length = 275 # 12.5ms - in line with Tacotron 2 paper
|
| 25 |
+
win_length = 1100 # 50ms - same reason as above
|
| 26 |
+
fmin = 40
|
| 27 |
+
min_level_db = -100
|
| 28 |
+
ref_level_db = 20
|
| 29 |
+
bits = 9 # bit depth of signal
|
| 30 |
+
mu_law = True # Recommended to suppress noise if using raw bits in hp.voc_mode below
|
| 31 |
+
peak_norm = False # Normalise to the peak of each wav file
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# WAVERNN / VOCODER ------------------------------------------------------------------------------------------------#
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Model Hparams
|
| 38 |
+
voc_mode = 'MOL' # either 'RAW' (softmax on raw bits) or 'MOL' (sample from mixture of logistics)
|
| 39 |
+
voc_upsample_factors = (5, 5, 11) # NB - this needs to correctly factorise hop_length
|
| 40 |
+
voc_rnn_dims = 512
|
| 41 |
+
voc_fc_dims = 512
|
| 42 |
+
voc_compute_dims = 128
|
| 43 |
+
voc_res_out_dims = 128
|
| 44 |
+
voc_res_blocks = 10
|
| 45 |
+
|
| 46 |
+
# Training
|
| 47 |
+
voc_batch_size = 32
|
| 48 |
+
voc_lr = 1e-4
|
| 49 |
+
voc_checkpoint_every = 25_000
|
| 50 |
+
voc_gen_at_checkpoint = 5 # number of samples to generate at each checkpoint
|
| 51 |
+
voc_total_steps = 1_000_000 # Total number of training steps
|
| 52 |
+
voc_test_samples = 50 # How many unseen samples to put aside for testing
|
| 53 |
+
voc_pad = 2 # this will pad the input so that the resnet can 'see' wider than input length
|
| 54 |
+
voc_seq_len = hop_length * 5 # must be a multiple of hop_length
|
| 55 |
+
voc_clip_grad_norm = 4 # set to None if no gradient clipping needed
|
| 56 |
+
|
| 57 |
+
# Generating / Synthesizing
|
| 58 |
+
voc_gen_batched = True # very fast (realtime+) single utterance batched generation
|
| 59 |
+
voc_target = 11_000 # target number of samples to be generated in each batch entry
|
| 60 |
+
voc_overlap = 550 # number of samples for crossfading between batches
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# TACOTRON/TTS -----------------------------------------------------------------------------------------------------#
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# Model Hparams
|
| 67 |
+
tts_embed_dims = 256 # embedding dimension for the graphemes/phoneme inputs
|
| 68 |
+
tts_encoder_dims = 128
|
| 69 |
+
tts_decoder_dims = 256
|
| 70 |
+
tts_postnet_dims = 128
|
| 71 |
+
tts_encoder_K = 16
|
| 72 |
+
tts_lstm_dims = 512
|
| 73 |
+
tts_postnet_K = 8
|
| 74 |
+
tts_num_highways = 4
|
| 75 |
+
tts_dropout = 0.5
|
| 76 |
+
tts_cleaner_names = ['english_cleaners']
|
| 77 |
+
tts_stop_threshold = -3.4 # Value below which audio generation ends.
|
| 78 |
+
# For example, for a range of [-4, 4], this
|
| 79 |
+
# will terminate the sequence at the first
|
| 80 |
+
# frame that has all values < -3.4
|
| 81 |
+
|
| 82 |
+
# Training
|
| 83 |
+
|
| 84 |
+
#tts_schedule = [(7, 1e-3, 10_000, 32), # progressive training schedule
|
| 85 |
+
# (5, 1e-4, 100_000, 32), # (r, lr, step, batch_size)
|
| 86 |
+
# (2, 1e-4, 180_000, 16),
|
| 87 |
+
# (2, 1e-4, 350_000, 8)]
|
| 88 |
+
tts_schedule = [(7, 1e-3, 10_000, 32)] # progressive training schedule
|
| 89 |
+
#(5, 1e-4, 100_000, 64), # (r, lr, step, batch_size)
|
| 90 |
+
#(2, 1e-4, 180_000, 64),
|
| 91 |
+
#(2, 1e-4, 350_000, 64)]
|
| 92 |
+
|
| 93 |
+
tts_max_mel_len = 1250 # if you have a couple of extremely long spectrograms you might want to use this
|
| 94 |
+
tts_bin_lengths = True # bins the spectrogram lengths before sampling in data loader - speeds up training
|
| 95 |
+
tts_clip_grad_norm = 1.0 # clips the gradient norm to prevent explosion - set to None if not needed
|
| 96 |
+
tts_checkpoint_every = 2_000 # checkpoints the model every X steps
|
| 97 |
+
# TODO: tts_phoneme_prob = 0.0 # [0 <-> 1] probability for feeding model phonemes vrs graphemes
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ------------------------------------------------------------------------------------------------------------------#
|
| 101 |
+
|
models/WaveRNNModel/loss_plot.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import re
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
# 设置中文字体
|
| 6 |
+
plt.rcParams['font.sans-serif'] = ['SimHei'] # 简体中文(根据系统调整)
|
| 7 |
+
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
| 8 |
+
|
| 9 |
+
# 从txt文件读取日志数据
|
| 10 |
+
def parse_log_file(file_path):
|
| 11 |
+
epochs = []
|
| 12 |
+
losses = []
|
| 13 |
+
|
| 14 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 15 |
+
for line in f:
|
| 16 |
+
# 使用正则匹配有效行
|
| 17 |
+
match = re.search(
|
| 18 |
+
r'Epoch:\s+(\d+)/*.*Loss:\s+(\d+\.\d+)',
|
| 19 |
+
line.strip()
|
| 20 |
+
)
|
| 21 |
+
if match:
|
| 22 |
+
epoch = int(match.group(1))
|
| 23 |
+
loss = float(match.group(2))
|
| 24 |
+
epochs.append(epoch)
|
| 25 |
+
losses.append(loss)
|
| 26 |
+
|
| 27 |
+
return epochs, losses
|
| 28 |
+
|
| 29 |
+
# 文件路径
|
| 30 |
+
log_file = "E:\\智能语音处理系统\\Noise-suppression-and-speech-recognition-systems-master\\WaveRNNModel\\checkpoints\\ljspeech_lsa_smooth_attention.tacotron\\log_test.txt"
|
| 31 |
+
|
| 32 |
+
# 提取数据
|
| 33 |
+
try:
|
| 34 |
+
epochs_read, losses = parse_log_file(log_file)
|
| 35 |
+
print(epochs_read)
|
| 36 |
+
epochs=np.arange(len(epochs_read))
|
| 37 |
+
print(epochs)
|
| 38 |
+
except FileNotFoundError:
|
| 39 |
+
print(f"错误:文件 {log_file} 不存在,请检查路径!")
|
| 40 |
+
exit()
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"解析文件时出错: {str(e)}")
|
| 43 |
+
exit()
|
| 44 |
+
|
| 45 |
+
# 绘制曲线
|
| 46 |
+
plt.figure(figsize=(10, 6))
|
| 47 |
+
plt.plot(epochs, losses, 'b-', linewidth=2, label='训练损失')
|
| 48 |
+
|
| 49 |
+
# 图表美化
|
| 50 |
+
plt.title('训练损失随轮次变化曲线', fontsize=14)
|
| 51 |
+
plt.xlabel('训练轮次 (Epoch)', fontsize=12)
|
| 52 |
+
plt.ylabel('损失值 (Loss)', fontsize=12)
|
| 53 |
+
#plt.xticks(range(1, len(epochs))) # 强制显示所有epoch刻度
|
| 54 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
| 55 |
+
plt.legend()
|
| 56 |
+
|
| 57 |
+
# 标注最低损失
|
| 58 |
+
min_loss = min(losses)
|
| 59 |
+
min_idx = losses.index(min_loss)
|
| 60 |
+
plt.annotate(
|
| 61 |
+
f'最低损失: {min_loss:.3f}',
|
| 62 |
+
xy=(epochs[min_idx], min_loss),
|
| 63 |
+
xytext=(epochs[min_idx]-3, min_loss+0.1),
|
| 64 |
+
arrowprops=dict(arrowstyle='->', color='red'),
|
| 65 |
+
fontsize=10,
|
| 66 |
+
color='red'
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
plt.tight_layout()
|
| 70 |
+
plt.show()
|
models/WaveRNNModel/model_outputs/ljspeech_lsa_smooth_attention.tacotron.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ed4c5cc52ae740080b0bcea155133d430f01dcb1a2d0097ff9aaef9ee698886a
|
| 3 |
+
size 45040845
|
models/WaveRNNModel/model_outputs/ljspeech_mol.wavernn.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78a9cff91b58f6163f4cc9e9e878961829f07c7bf40e71778f0bf5447a4900fc
|
| 3 |
+
size 15610590
|
models/WaveRNNModel/models/__init__.py
ADDED
|
File without changes
|
models/WaveRNNModel/models/deepmind_version.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from utils.display import *
|
| 5 |
+
from utils.dsp import *
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
class WaveRNN(nn.Module):
|
| 9 |
+
def __init__(self, hidden_size=896, quantisation=256):
|
| 10 |
+
super(WaveRNN, self).__init__()
|
| 11 |
+
|
| 12 |
+
self.hidden_size = hidden_size
|
| 13 |
+
self.split_size = hidden_size // 2
|
| 14 |
+
|
| 15 |
+
# The main matmul
|
| 16 |
+
self.R = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
|
| 17 |
+
|
| 18 |
+
# Output fc layers
|
| 19 |
+
self.O1 = nn.Linear(self.split_size, self.split_size)
|
| 20 |
+
self.O2 = nn.Linear(self.split_size, quantisation)
|
| 21 |
+
self.O3 = nn.Linear(self.split_size, self.split_size)
|
| 22 |
+
self.O4 = nn.Linear(self.split_size, quantisation)
|
| 23 |
+
|
| 24 |
+
# Input fc layers
|
| 25 |
+
self.I_coarse = nn.Linear(2, 3 * self.split_size, bias=False)
|
| 26 |
+
self.I_fine = nn.Linear(3, 3 * self.split_size, bias=False)
|
| 27 |
+
|
| 28 |
+
# biases for the gates
|
| 29 |
+
self.bias_u = nn.Parameter(torch.zeros(self.hidden_size))
|
| 30 |
+
self.bias_r = nn.Parameter(torch.zeros(self.hidden_size))
|
| 31 |
+
self.bias_e = nn.Parameter(torch.zeros(self.hidden_size))
|
| 32 |
+
|
| 33 |
+
# display num params
|
| 34 |
+
self.num_params()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def forward(self, prev_y, prev_hidden, current_coarse):
|
| 38 |
+
|
| 39 |
+
# Main matmul - the projection is split 3 ways
|
| 40 |
+
R_hidden = self.R(prev_hidden)
|
| 41 |
+
R_u, R_r, R_e, = torch.split(R_hidden, self.hidden_size, dim=1)
|
| 42 |
+
|
| 43 |
+
# Project the prev input
|
| 44 |
+
coarse_input_proj = self.I_coarse(prev_y)
|
| 45 |
+
I_coarse_u, I_coarse_r, I_coarse_e = \
|
| 46 |
+
torch.split(coarse_input_proj, self.split_size, dim=1)
|
| 47 |
+
|
| 48 |
+
# Project the prev input and current coarse sample
|
| 49 |
+
fine_input = torch.cat([prev_y, current_coarse], dim=1)
|
| 50 |
+
fine_input_proj = self.I_fine(fine_input)
|
| 51 |
+
I_fine_u, I_fine_r, I_fine_e = \
|
| 52 |
+
torch.split(fine_input_proj, self.split_size, dim=1)
|
| 53 |
+
|
| 54 |
+
# concatenate for the gates
|
| 55 |
+
I_u = torch.cat([I_coarse_u, I_fine_u], dim=1)
|
| 56 |
+
I_r = torch.cat([I_coarse_r, I_fine_r], dim=1)
|
| 57 |
+
I_e = torch.cat([I_coarse_e, I_fine_e], dim=1)
|
| 58 |
+
|
| 59 |
+
# Compute all gates for coarse and fine
|
| 60 |
+
u = F.sigmoid(R_u + I_u + self.bias_u)
|
| 61 |
+
r = F.sigmoid(R_r + I_r + self.bias_r)
|
| 62 |
+
e = F.tanh(r * R_e + I_e + self.bias_e)
|
| 63 |
+
hidden = u * prev_hidden + (1. - u) * e
|
| 64 |
+
|
| 65 |
+
# Split the hidden state
|
| 66 |
+
hidden_coarse, hidden_fine = torch.split(hidden, self.split_size, dim=1)
|
| 67 |
+
|
| 68 |
+
# Compute outputs
|
| 69 |
+
out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
|
| 70 |
+
out_fine = self.O4(F.relu(self.O3(hidden_fine)))
|
| 71 |
+
|
| 72 |
+
return out_coarse, out_fine, hidden
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def generate(self, seq_len):
|
| 76 |
+
device = next(self.parameters()).device # use same device as parameters
|
| 77 |
+
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
|
| 80 |
+
# First split up the biases for the gates
|
| 81 |
+
b_coarse_u, b_fine_u = torch.split(self.bias_u, self.split_size)
|
| 82 |
+
b_coarse_r, b_fine_r = torch.split(self.bias_r, self.split_size)
|
| 83 |
+
b_coarse_e, b_fine_e = torch.split(self.bias_e, self.split_size)
|
| 84 |
+
|
| 85 |
+
# Lists for the two output seqs
|
| 86 |
+
c_outputs, f_outputs = [], []
|
| 87 |
+
|
| 88 |
+
# Some initial inputs
|
| 89 |
+
out_coarse = torch.tensor([0], dtype=torch.long, device=device)
|
| 90 |
+
out_fine = torch.tensor([0], dtype=torch.long, device=device)
|
| 91 |
+
|
| 92 |
+
# We'll meed a hidden state
|
| 93 |
+
hidden = self.get_initial_hidden()
|
| 94 |
+
|
| 95 |
+
# Need a clock for display
|
| 96 |
+
start = time.time()
|
| 97 |
+
|
| 98 |
+
# Loop for generation
|
| 99 |
+
for i in range(seq_len):
|
| 100 |
+
|
| 101 |
+
# Split into two hidden states
|
| 102 |
+
hidden_coarse, hidden_fine = \
|
| 103 |
+
torch.split(hidden, self.split_size, dim=1)
|
| 104 |
+
|
| 105 |
+
# Scale and concat previous predictions
|
| 106 |
+
out_coarse = out_coarse.unsqueeze(0).float() / 127.5 - 1.
|
| 107 |
+
out_fine = out_fine.unsqueeze(0).float() / 127.5 - 1.
|
| 108 |
+
prev_outputs = torch.cat([out_coarse, out_fine], dim=1)
|
| 109 |
+
|
| 110 |
+
# Project input
|
| 111 |
+
coarse_input_proj = self.I_coarse(prev_outputs)
|
| 112 |
+
I_coarse_u, I_coarse_r, I_coarse_e = \
|
| 113 |
+
torch.split(coarse_input_proj, self.split_size, dim=1)
|
| 114 |
+
|
| 115 |
+
# Project hidden state and split 6 ways
|
| 116 |
+
R_hidden = self.R(hidden)
|
| 117 |
+
R_coarse_u , R_fine_u, \
|
| 118 |
+
R_coarse_r, R_fine_r, \
|
| 119 |
+
R_coarse_e, R_fine_e = torch.split(R_hidden, self.split_size, dim=1)
|
| 120 |
+
|
| 121 |
+
# Compute the coarse gates
|
| 122 |
+
u = F.sigmoid(R_coarse_u + I_coarse_u + b_coarse_u)
|
| 123 |
+
r = F.sigmoid(R_coarse_r + I_coarse_r + b_coarse_r)
|
| 124 |
+
e = F.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e)
|
| 125 |
+
hidden_coarse = u * hidden_coarse + (1. - u) * e
|
| 126 |
+
|
| 127 |
+
# Compute the coarse output
|
| 128 |
+
out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
|
| 129 |
+
posterior = F.softmax(out_coarse, dim=1)
|
| 130 |
+
distrib = torch.distributions.Categorical(posterior)
|
| 131 |
+
out_coarse = distrib.sample()
|
| 132 |
+
c_outputs.append(out_coarse)
|
| 133 |
+
|
| 134 |
+
# Project the [prev outputs and predicted coarse sample]
|
| 135 |
+
coarse_pred = out_coarse.float() / 127.5 - 1.
|
| 136 |
+
fine_input = torch.cat([prev_outputs, coarse_pred.unsqueeze(0)], dim=1)
|
| 137 |
+
fine_input_proj = self.I_fine(fine_input)
|
| 138 |
+
I_fine_u, I_fine_r, I_fine_e = \
|
| 139 |
+
torch.split(fine_input_proj, self.split_size, dim=1)
|
| 140 |
+
|
| 141 |
+
# Compute the fine gates
|
| 142 |
+
u = F.sigmoid(R_fine_u + I_fine_u + b_fine_u)
|
| 143 |
+
r = F.sigmoid(R_fine_r + I_fine_r + b_fine_r)
|
| 144 |
+
e = F.tanh(r * R_fine_e + I_fine_e + b_fine_e)
|
| 145 |
+
hidden_fine = u * hidden_fine + (1. - u) * e
|
| 146 |
+
|
| 147 |
+
# Compute the fine output
|
| 148 |
+
out_fine = self.O4(F.relu(self.O3(hidden_fine)))
|
| 149 |
+
posterior = F.softmax(out_fine, dim=1)
|
| 150 |
+
distrib = torch.distributions.Categorical(posterior)
|
| 151 |
+
out_fine = distrib.sample()
|
| 152 |
+
f_outputs.append(out_fine)
|
| 153 |
+
|
| 154 |
+
# Put the hidden state back together
|
| 155 |
+
hidden = torch.cat([hidden_coarse, hidden_fine], dim=1)
|
| 156 |
+
|
| 157 |
+
# Display progress
|
| 158 |
+
speed = (i + 1) / (time.time() - start)
|
| 159 |
+
stream('Gen: %i/%i -- Speed: %i', (i + 1, seq_len, speed))
|
| 160 |
+
|
| 161 |
+
coarse = torch.stack(c_outputs).squeeze(1).cpu().data.numpy()
|
| 162 |
+
fine = torch.stack(f_outputs).squeeze(1).cpu().data.numpy()
|
| 163 |
+
output = combine_signal(coarse, fine)
|
| 164 |
+
|
| 165 |
+
return output, coarse, fine
|
| 166 |
+
|
| 167 |
+
def get_initial_hidden(self, batch_size=1):
|
| 168 |
+
device = next(self.parameters()).device # use same device as parameters
|
| 169 |
+
return torch.zeros(batch_size, self.hidden_size, device=device)
|
| 170 |
+
|
| 171 |
+
def num_params(self, print_out=True):
|
| 172 |
+
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
| 173 |
+
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
| 174 |
+
if print_out:
|
| 175 |
+
print('Trainable Parameters: %.3f million' % parameters)
|
| 176 |
+
return parameters
|
models/WaveRNNModel/models/fatchord_version.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from utils.distribution import sample_from_discretized_mix_logistic
|
| 5 |
+
from utils.display import *
|
| 6 |
+
from utils.dsp import *
|
| 7 |
+
import os
|
| 8 |
+
import numpy as np
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Union
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ResBlock(nn.Module):
|
| 14 |
+
def __init__(self, dims):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
|
| 17 |
+
self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
|
| 18 |
+
self.batch_norm1 = nn.BatchNorm1d(dims)
|
| 19 |
+
self.batch_norm2 = nn.BatchNorm1d(dims)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
residual = x
|
| 23 |
+
x = self.conv1(x)
|
| 24 |
+
x = self.batch_norm1(x)
|
| 25 |
+
x = F.relu(x)
|
| 26 |
+
x = self.conv2(x)
|
| 27 |
+
x = self.batch_norm2(x)
|
| 28 |
+
return x + residual
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MelResNet(nn.Module):
|
| 32 |
+
def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad):
|
| 33 |
+
super().__init__()
|
| 34 |
+
k_size = pad * 2 + 1
|
| 35 |
+
self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
|
| 36 |
+
self.batch_norm = nn.BatchNorm1d(compute_dims)
|
| 37 |
+
self.layers = nn.ModuleList()
|
| 38 |
+
for i in range(res_blocks):
|
| 39 |
+
self.layers.append(ResBlock(compute_dims))
|
| 40 |
+
self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
x = self.conv_in(x)
|
| 44 |
+
x = self.batch_norm(x)
|
| 45 |
+
x = F.relu(x)
|
| 46 |
+
for f in self.layers: x = f(x)
|
| 47 |
+
x = self.conv_out(x)
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Stretch2d(nn.Module):
|
| 52 |
+
def __init__(self, x_scale, y_scale):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.x_scale = x_scale
|
| 55 |
+
self.y_scale = y_scale
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
b, c, h, w = x.size()
|
| 59 |
+
x = x.unsqueeze(-1).unsqueeze(3)
|
| 60 |
+
x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
|
| 61 |
+
return x.view(b, c, h * self.y_scale, w * self.x_scale)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class UpsampleNetwork(nn.Module):
|
| 65 |
+
def __init__(self, feat_dims, upsample_scales, compute_dims,
|
| 66 |
+
res_blocks, res_out_dims, pad):
|
| 67 |
+
super().__init__()
|
| 68 |
+
total_scale = np.cumproduct(upsample_scales)[-1]
|
| 69 |
+
self.indent = pad * total_scale
|
| 70 |
+
self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad)
|
| 71 |
+
self.resnet_stretch = Stretch2d(total_scale, 1)
|
| 72 |
+
self.up_layers = nn.ModuleList()
|
| 73 |
+
for scale in upsample_scales:
|
| 74 |
+
k_size = (1, scale * 2 + 1)
|
| 75 |
+
padding = (0, scale)
|
| 76 |
+
stretch = Stretch2d(scale, 1)
|
| 77 |
+
conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
|
| 78 |
+
conv.weight.data.fill_(1. / k_size[1])
|
| 79 |
+
self.up_layers.append(stretch)
|
| 80 |
+
self.up_layers.append(conv)
|
| 81 |
+
|
| 82 |
+
def forward(self, m):
|
| 83 |
+
aux = self.resnet(m).unsqueeze(1)
|
| 84 |
+
aux = self.resnet_stretch(aux)
|
| 85 |
+
aux = aux.squeeze(1)
|
| 86 |
+
m = m.unsqueeze(1)
|
| 87 |
+
for f in self.up_layers: m = f(m)
|
| 88 |
+
m = m.squeeze(1)[:, :, self.indent:-self.indent]
|
| 89 |
+
return m.transpose(1, 2), aux.transpose(1, 2)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class WaveRNN(nn.Module):
|
| 93 |
+
def __init__(self, rnn_dims, fc_dims, bits, pad, upsample_factors,
|
| 94 |
+
feat_dims, compute_dims, res_out_dims, res_blocks,
|
| 95 |
+
hop_length, sample_rate, mode='RAW'):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.mode = mode
|
| 98 |
+
self.pad = pad
|
| 99 |
+
if self.mode == 'RAW':
|
| 100 |
+
self.n_classes = 2 ** bits
|
| 101 |
+
elif self.mode == 'MOL':
|
| 102 |
+
self.n_classes = 30
|
| 103 |
+
else:
|
| 104 |
+
RuntimeError("Unknown model mode value - ", self.mode)
|
| 105 |
+
|
| 106 |
+
# List of rnns to call `flatten_parameters()` on
|
| 107 |
+
self._to_flatten = []
|
| 108 |
+
|
| 109 |
+
self.rnn_dims = rnn_dims
|
| 110 |
+
self.aux_dims = res_out_dims // 4
|
| 111 |
+
self.hop_length = hop_length
|
| 112 |
+
self.sample_rate = sample_rate
|
| 113 |
+
|
| 114 |
+
self.upsample = UpsampleNetwork(feat_dims, upsample_factors, compute_dims, res_blocks, res_out_dims, pad)
|
| 115 |
+
self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
|
| 116 |
+
|
| 117 |
+
self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
|
| 118 |
+
self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True)
|
| 119 |
+
self._to_flatten += [self.rnn1, self.rnn2]
|
| 120 |
+
|
| 121 |
+
self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
|
| 122 |
+
self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
|
| 123 |
+
self.fc3 = nn.Linear(fc_dims, self.n_classes)
|
| 124 |
+
|
| 125 |
+
self.register_buffer('step', torch.zeros(1, dtype=torch.long))
|
| 126 |
+
self.num_params()
|
| 127 |
+
|
| 128 |
+
# Avoid fragmentation of RNN parameters and associated warning
|
| 129 |
+
self._flatten_parameters()
|
| 130 |
+
|
| 131 |
+
def forward(self, x, mels):
|
| 132 |
+
device = next(self.parameters()).device # use same device as parameters
|
| 133 |
+
|
| 134 |
+
# Although we `_flatten_parameters()` on init, when using DataParallel
|
| 135 |
+
# the model gets replicated, making it no longer guaranteed that the
|
| 136 |
+
# weights are contiguous in GPU memory. Hence, we must call it again
|
| 137 |
+
self._flatten_parameters()
|
| 138 |
+
|
| 139 |
+
self.step += 1
|
| 140 |
+
bsize = x.size(0)
|
| 141 |
+
h1 = torch.zeros(1, bsize, self.rnn_dims, device=device)
|
| 142 |
+
h2 = torch.zeros(1, bsize, self.rnn_dims, device=device)
|
| 143 |
+
mels, aux = self.upsample(mels)
|
| 144 |
+
|
| 145 |
+
aux_idx = [self.aux_dims * i for i in range(5)]
|
| 146 |
+
a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
|
| 147 |
+
a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
|
| 148 |
+
a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
|
| 149 |
+
a4 = aux[:, :, aux_idx[3]:aux_idx[4]]
|
| 150 |
+
|
| 151 |
+
x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
|
| 152 |
+
x = self.I(x)
|
| 153 |
+
res = x
|
| 154 |
+
x, _ = self.rnn1(x, h1)
|
| 155 |
+
|
| 156 |
+
x = x + res
|
| 157 |
+
res = x
|
| 158 |
+
x = torch.cat([x, a2], dim=2)
|
| 159 |
+
x, _ = self.rnn2(x, h2)
|
| 160 |
+
|
| 161 |
+
x = x + res
|
| 162 |
+
x = torch.cat([x, a3], dim=2)
|
| 163 |
+
x = F.relu(self.fc1(x))
|
| 164 |
+
|
| 165 |
+
x = torch.cat([x, a4], dim=2)
|
| 166 |
+
x = F.relu(self.fc2(x))
|
| 167 |
+
return self.fc3(x)
|
| 168 |
+
|
| 169 |
+
def generate(self, mels, save_path: Union[str, Path], batched, target, overlap, mu_law):
|
| 170 |
+
self.eval()
|
| 171 |
+
|
| 172 |
+
device = next(self.parameters()).device # use same device as parameters
|
| 173 |
+
|
| 174 |
+
mu_law = mu_law if self.mode == 'RAW' else False
|
| 175 |
+
|
| 176 |
+
output = []
|
| 177 |
+
start = time.time()
|
| 178 |
+
rnn1 = self.get_gru_cell(self.rnn1)
|
| 179 |
+
rnn2 = self.get_gru_cell(self.rnn2)
|
| 180 |
+
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
|
| 183 |
+
mels = torch.as_tensor(mels, device=device)
|
| 184 |
+
wave_len = (mels.size(-1) - 1) * self.hop_length
|
| 185 |
+
mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side='both')
|
| 186 |
+
mels, aux = self.upsample(mels.transpose(1, 2))
|
| 187 |
+
|
| 188 |
+
if batched:
|
| 189 |
+
mels = self.fold_with_overlap(mels, target, overlap)
|
| 190 |
+
aux = self.fold_with_overlap(aux, target, overlap)
|
| 191 |
+
|
| 192 |
+
b_size, seq_len, _ = mels.size()
|
| 193 |
+
|
| 194 |
+
h1 = torch.zeros(b_size, self.rnn_dims, device=device)
|
| 195 |
+
h2 = torch.zeros(b_size, self.rnn_dims, device=device)
|
| 196 |
+
x = torch.zeros(b_size, 1, device=device)
|
| 197 |
+
|
| 198 |
+
d = self.aux_dims
|
| 199 |
+
aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)]
|
| 200 |
+
|
| 201 |
+
for i in range(seq_len):
|
| 202 |
+
|
| 203 |
+
m_t = mels[:, i, :]
|
| 204 |
+
|
| 205 |
+
a1_t, a2_t, a3_t, a4_t = \
|
| 206 |
+
(a[:, i, :] for a in aux_split)
|
| 207 |
+
|
| 208 |
+
x = torch.cat([x, m_t, a1_t], dim=1)
|
| 209 |
+
x = self.I(x)
|
| 210 |
+
h1 = rnn1(x, h1)
|
| 211 |
+
|
| 212 |
+
x = x + h1
|
| 213 |
+
inp = torch.cat([x, a2_t], dim=1)
|
| 214 |
+
h2 = rnn2(inp, h2)
|
| 215 |
+
|
| 216 |
+
x = x + h2
|
| 217 |
+
x = torch.cat([x, a3_t], dim=1)
|
| 218 |
+
x = F.relu(self.fc1(x))
|
| 219 |
+
|
| 220 |
+
x = torch.cat([x, a4_t], dim=1)
|
| 221 |
+
x = F.relu(self.fc2(x))
|
| 222 |
+
|
| 223 |
+
logits = self.fc3(x)
|
| 224 |
+
|
| 225 |
+
if self.mode == 'MOL':
|
| 226 |
+
sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
|
| 227 |
+
output.append(sample.view(-1))
|
| 228 |
+
# x = torch.FloatTensor([[sample]]).cuda()
|
| 229 |
+
x = sample.transpose(0, 1)
|
| 230 |
+
|
| 231 |
+
elif self.mode == 'RAW':
|
| 232 |
+
posterior = F.softmax(logits, dim=1)
|
| 233 |
+
distrib = torch.distributions.Categorical(posterior)
|
| 234 |
+
|
| 235 |
+
sample = 2 * distrib.sample().float() / (self.n_classes - 1.) - 1.
|
| 236 |
+
output.append(sample)
|
| 237 |
+
x = sample.unsqueeze(-1)
|
| 238 |
+
else:
|
| 239 |
+
raise RuntimeError("Unknown model mode value - ", self.mode)
|
| 240 |
+
|
| 241 |
+
if i % 100 == 0: self.gen_display(i, seq_len, b_size, start)
|
| 242 |
+
|
| 243 |
+
output = torch.stack(output).transpose(0, 1)
|
| 244 |
+
output = output.cpu().numpy()
|
| 245 |
+
output = output.astype(np.float64)
|
| 246 |
+
|
| 247 |
+
if mu_law:
|
| 248 |
+
output = decode_mu_law(output, self.n_classes, False)
|
| 249 |
+
|
| 250 |
+
if batched:
|
| 251 |
+
output = self.xfade_and_unfold(output, target, overlap)
|
| 252 |
+
else:
|
| 253 |
+
output = output[0]
|
| 254 |
+
|
| 255 |
+
# Fade-out at the end to avoid signal cutting out suddenly
|
| 256 |
+
fade_out = np.linspace(1, 0, 20 * self.hop_length)
|
| 257 |
+
output = output[:wave_len]
|
| 258 |
+
output[-20 * self.hop_length:] *= fade_out
|
| 259 |
+
|
| 260 |
+
save_wav(output, save_path)
|
| 261 |
+
|
| 262 |
+
self.train()
|
| 263 |
+
|
| 264 |
+
return output
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def gen_display(self, i, seq_len, b_size, start):
|
| 268 |
+
gen_rate = (i + 1) / (time.time() - start) * b_size / 1000
|
| 269 |
+
pbar = progbar(i, seq_len)
|
| 270 |
+
msg = f'| {pbar} {i*b_size}/{seq_len*b_size} | Batch Size: {b_size} | Gen Rate: {gen_rate:.1f}kHz | '
|
| 271 |
+
stream(msg)
|
| 272 |
+
|
| 273 |
+
def get_gru_cell(self, gru):
|
| 274 |
+
gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
|
| 275 |
+
gru_cell.weight_hh.data = gru.weight_hh_l0.data
|
| 276 |
+
gru_cell.weight_ih.data = gru.weight_ih_l0.data
|
| 277 |
+
gru_cell.bias_hh.data = gru.bias_hh_l0.data
|
| 278 |
+
gru_cell.bias_ih.data = gru.bias_ih_l0.data
|
| 279 |
+
return gru_cell
|
| 280 |
+
|
| 281 |
+
def pad_tensor(self, x, pad, side='both'):
|
| 282 |
+
# NB - this is just a quick method i need right now
|
| 283 |
+
# i.e., it won't generalise to other shapes/dims
|
| 284 |
+
b, t, c = x.size()
|
| 285 |
+
total = t + 2 * pad if side == 'both' else t + pad
|
| 286 |
+
padded = torch.zeros(b, total, c, device=x.device)
|
| 287 |
+
if side == 'before' or side == 'both':
|
| 288 |
+
padded[:, pad:pad + t, :] = x
|
| 289 |
+
elif side == 'after':
|
| 290 |
+
padded[:, :t, :] = x
|
| 291 |
+
return padded
|
| 292 |
+
|
| 293 |
+
def fold_with_overlap(self, x, target, overlap):
|
| 294 |
+
|
| 295 |
+
''' Fold the tensor with overlap for quick batched inference.
|
| 296 |
+
Overlap will be used for crossfading in xfade_and_unfold()
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
x (tensor) : Upsampled conditioning features.
|
| 300 |
+
shape=(1, timesteps, features)
|
| 301 |
+
target (int) : Target timesteps for each index of batch
|
| 302 |
+
overlap (int) : Timesteps for both xfade and rnn warmup
|
| 303 |
+
|
| 304 |
+
Return:
|
| 305 |
+
(tensor) : shape=(num_folds, target + 2 * overlap, features)
|
| 306 |
+
|
| 307 |
+
Details:
|
| 308 |
+
x = [[h1, h2, ... hn]]
|
| 309 |
+
|
| 310 |
+
Where each h is a vector of conditioning features
|
| 311 |
+
|
| 312 |
+
Eg: target=2, overlap=1 with x.size(1)=10
|
| 313 |
+
|
| 314 |
+
folded = [[h1, h2, h3, h4],
|
| 315 |
+
[h4, h5, h6, h7],
|
| 316 |
+
[h7, h8, h9, h10]]
|
| 317 |
+
'''
|
| 318 |
+
|
| 319 |
+
_, total_len, features = x.size()
|
| 320 |
+
|
| 321 |
+
# Calculate variables needed
|
| 322 |
+
num_folds = (total_len - overlap) // (target + overlap)
|
| 323 |
+
extended_len = num_folds * (overlap + target) + overlap
|
| 324 |
+
remaining = total_len - extended_len
|
| 325 |
+
|
| 326 |
+
# Pad if some time steps poking out
|
| 327 |
+
if remaining != 0:
|
| 328 |
+
num_folds += 1
|
| 329 |
+
padding = target + 2 * overlap - remaining
|
| 330 |
+
x = self.pad_tensor(x, padding, side='after')
|
| 331 |
+
|
| 332 |
+
folded = torch.zeros(num_folds, target + 2 * overlap, features, device=x.device)
|
| 333 |
+
|
| 334 |
+
# Get the values for the folded tensor
|
| 335 |
+
for i in range(num_folds):
|
| 336 |
+
start = i * (target + overlap)
|
| 337 |
+
end = start + target + 2 * overlap
|
| 338 |
+
folded[i] = x[:, start:end, :]
|
| 339 |
+
|
| 340 |
+
return folded
|
| 341 |
+
|
| 342 |
+
def xfade_and_unfold(self, y, target, overlap):
|
| 343 |
+
|
| 344 |
+
''' Applies a crossfade and unfolds into a 1d array.
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
y (ndarry) : Batched sequences of audio samples
|
| 348 |
+
shape=(num_folds, target + 2 * overlap)
|
| 349 |
+
dtype=np.float64
|
| 350 |
+
overlap (int) : Timesteps for both xfade and rnn warmup
|
| 351 |
+
|
| 352 |
+
Return:
|
| 353 |
+
(ndarry) : audio samples in a 1d array
|
| 354 |
+
shape=(total_len)
|
| 355 |
+
dtype=np.float64
|
| 356 |
+
|
| 357 |
+
Details:
|
| 358 |
+
y = [[seq1],
|
| 359 |
+
[seq2],
|
| 360 |
+
[seq3]]
|
| 361 |
+
|
| 362 |
+
Apply a gain envelope at both ends of the sequences
|
| 363 |
+
|
| 364 |
+
y = [[seq1_in, seq1_target, seq1_out],
|
| 365 |
+
[seq2_in, seq2_target, seq2_out],
|
| 366 |
+
[seq3_in, seq3_target, seq3_out]]
|
| 367 |
+
|
| 368 |
+
Stagger and add up the groups of samples:
|
| 369 |
+
|
| 370 |
+
[seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...]
|
| 371 |
+
|
| 372 |
+
'''
|
| 373 |
+
|
| 374 |
+
num_folds, length = y.shape
|
| 375 |
+
target = length - 2 * overlap
|
| 376 |
+
total_len = num_folds * (target + overlap) + overlap
|
| 377 |
+
|
| 378 |
+
# Need some silence for the rnn warmup
|
| 379 |
+
silence_len = overlap // 2
|
| 380 |
+
fade_len = overlap - silence_len
|
| 381 |
+
silence = np.zeros((silence_len), dtype=np.float64)
|
| 382 |
+
linear = np.ones((silence_len), dtype=np.float64)
|
| 383 |
+
|
| 384 |
+
# Equal power crossfade
|
| 385 |
+
t = np.linspace(-1, 1, fade_len, dtype=np.float64)
|
| 386 |
+
fade_in = np.sqrt(0.5 * (1 + t))
|
| 387 |
+
fade_out = np.sqrt(0.5 * (1 - t))
|
| 388 |
+
|
| 389 |
+
# Concat the silence to the fades
|
| 390 |
+
fade_in = np.concatenate([silence, fade_in])
|
| 391 |
+
fade_out = np.concatenate([linear, fade_out])
|
| 392 |
+
|
| 393 |
+
# Apply the gain to the overlap samples
|
| 394 |
+
y[:, :overlap] *= fade_in
|
| 395 |
+
y[:, -overlap:] *= fade_out
|
| 396 |
+
|
| 397 |
+
unfolded = np.zeros((total_len), dtype=np.float64)
|
| 398 |
+
|
| 399 |
+
# Loop to add up all the samples
|
| 400 |
+
for i in range(num_folds):
|
| 401 |
+
start = i * (target + overlap)
|
| 402 |
+
end = start + target + 2 * overlap
|
| 403 |
+
unfolded[start:end] += y[i]
|
| 404 |
+
|
| 405 |
+
return unfolded
|
| 406 |
+
|
| 407 |
+
def get_step(self):
|
| 408 |
+
return self.step.data.item()
|
| 409 |
+
|
| 410 |
+
def log(self, path, msg):
|
| 411 |
+
with open(path, 'a') as f:
|
| 412 |
+
print(msg, file=f)
|
| 413 |
+
|
| 414 |
+
def load(self, path: Union[str, Path]):
|
| 415 |
+
# Use device of model params as location for loaded state
|
| 416 |
+
device = next(self.parameters()).device
|
| 417 |
+
self.load_state_dict(torch.load(path, map_location=device), strict=False)
|
| 418 |
+
|
| 419 |
+
def save(self, path: Union[str, Path]):
|
| 420 |
+
# No optimizer argument because saving a model should not include data
|
| 421 |
+
# only relevant in the training process - it should only be properties
|
| 422 |
+
# of the model itself. Let caller take care of saving optimzier state.
|
| 423 |
+
torch.save(self.state_dict(), path)
|
| 424 |
+
|
| 425 |
+
def num_params(self, print_out=True):
|
| 426 |
+
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
| 427 |
+
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
| 428 |
+
if print_out:
|
| 429 |
+
print('Trainable Parameters: %.3fM' % parameters)
|
| 430 |
+
return parameters
|
| 431 |
+
|
| 432 |
+
def _flatten_parameters(self):
|
| 433 |
+
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
|
| 434 |
+
to improve efficiency and avoid PyTorch yelling at us."""
|
| 435 |
+
[m.flatten_parameters() for m in self._to_flatten]
|
models/WaveRNNModel/models/tacotron.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HighwayNetwork(nn.Module):
|
| 11 |
+
def __init__(self, size):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.W1 = nn.Linear(size, size)
|
| 14 |
+
self.W2 = nn.Linear(size, size)
|
| 15 |
+
self.W1.bias.data.fill_(0.)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
x1 = self.W1(x)
|
| 19 |
+
x2 = self.W2(x)
|
| 20 |
+
g = torch.sigmoid(x2)
|
| 21 |
+
y = g * F.relu(x1) + (1. - g) * x
|
| 22 |
+
return y
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Encoder(nn.Module):
|
| 26 |
+
def __init__(self, embed_dims, num_chars, cbhg_channels, K, num_highways, dropout):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.embedding = nn.Embedding(num_chars, embed_dims)
|
| 29 |
+
self.pre_net = PreNet(embed_dims)
|
| 30 |
+
self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
|
| 31 |
+
proj_channels=[cbhg_channels, cbhg_channels],
|
| 32 |
+
num_highways=num_highways)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
x = self.embedding(x)
|
| 36 |
+
x = self.pre_net(x)
|
| 37 |
+
x.transpose_(1, 2)
|
| 38 |
+
x = self.cbhg(x)
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class BatchNormConv(nn.Module):
|
| 43 |
+
def __init__(self, in_channels, out_channels, kernel, relu=True):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
|
| 46 |
+
self.bnorm = nn.BatchNorm1d(out_channels)
|
| 47 |
+
self.relu = relu
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
x = self.conv(x)
|
| 51 |
+
x = F.relu(x) if self.relu is True else x
|
| 52 |
+
return self.bnorm(x)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class CBHG(nn.Module):
|
| 56 |
+
def __init__(self, K, in_channels, channels, proj_channels, num_highways):
|
| 57 |
+
super().__init__()
|
| 58 |
+
|
| 59 |
+
# List of all rnns to call `flatten_parameters()` on
|
| 60 |
+
self._to_flatten = []
|
| 61 |
+
|
| 62 |
+
self.bank_kernels = [i for i in range(1, K + 1)]
|
| 63 |
+
self.conv1d_bank = nn.ModuleList()
|
| 64 |
+
for k in self.bank_kernels:
|
| 65 |
+
conv = BatchNormConv(in_channels, channels, k)
|
| 66 |
+
self.conv1d_bank.append(conv)
|
| 67 |
+
|
| 68 |
+
self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
|
| 69 |
+
|
| 70 |
+
self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
|
| 71 |
+
self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
|
| 72 |
+
|
| 73 |
+
# Fix the highway input if necessary
|
| 74 |
+
if proj_channels[-1] != channels:
|
| 75 |
+
self.highway_mismatch = True
|
| 76 |
+
self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
|
| 77 |
+
else:
|
| 78 |
+
self.highway_mismatch = False
|
| 79 |
+
|
| 80 |
+
self.highways = nn.ModuleList()
|
| 81 |
+
for i in range(num_highways):
|
| 82 |
+
hn = HighwayNetwork(channels)
|
| 83 |
+
self.highways.append(hn)
|
| 84 |
+
|
| 85 |
+
self.rnn = nn.GRU(channels, channels, batch_first=True, bidirectional=True)
|
| 86 |
+
self._to_flatten.append(self.rnn)
|
| 87 |
+
|
| 88 |
+
# Avoid fragmentation of RNN parameters and associated warning
|
| 89 |
+
self._flatten_parameters()
|
| 90 |
+
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
# Although we `_flatten_parameters()` on init, when using DataParallel
|
| 93 |
+
# the model gets replicated, making it no longer guaranteed that the
|
| 94 |
+
# weights are contiguous in GPU memory. Hence, we must call it again
|
| 95 |
+
self._flatten_parameters()
|
| 96 |
+
|
| 97 |
+
# Save these for later
|
| 98 |
+
residual = x
|
| 99 |
+
seq_len = x.size(-1)
|
| 100 |
+
conv_bank = []
|
| 101 |
+
|
| 102 |
+
# Convolution Bank
|
| 103 |
+
for conv in self.conv1d_bank:
|
| 104 |
+
c = conv(x) # Convolution
|
| 105 |
+
conv_bank.append(c[:, :, :seq_len])
|
| 106 |
+
|
| 107 |
+
# Stack along the channel axis
|
| 108 |
+
conv_bank = torch.cat(conv_bank, dim=1)
|
| 109 |
+
|
| 110 |
+
# dump the last padding to fit residual
|
| 111 |
+
x = self.maxpool(conv_bank)[:, :, :seq_len]
|
| 112 |
+
|
| 113 |
+
# Conv1d projections
|
| 114 |
+
x = self.conv_project1(x)
|
| 115 |
+
x = self.conv_project2(x)
|
| 116 |
+
|
| 117 |
+
# Residual Connect
|
| 118 |
+
x = x + residual
|
| 119 |
+
|
| 120 |
+
# Through the highways
|
| 121 |
+
x = x.transpose(1, 2)
|
| 122 |
+
if self.highway_mismatch is True:
|
| 123 |
+
x = self.pre_highway(x)
|
| 124 |
+
for h in self.highways: x = h(x)
|
| 125 |
+
|
| 126 |
+
# And then the RNN
|
| 127 |
+
x, _ = self.rnn(x)
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
def _flatten_parameters(self):
|
| 131 |
+
"""Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
|
| 132 |
+
to improve efficiency and avoid PyTorch yelling at us."""
|
| 133 |
+
[m.flatten_parameters() for m in self._to_flatten]
|
| 134 |
+
|
| 135 |
+
class PreNet(nn.Module):
|
| 136 |
+
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
|
| 137 |
+
super().__init__()
|
| 138 |
+
self.fc1 = nn.Linear(in_dims, fc1_dims)
|
| 139 |
+
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
|
| 140 |
+
self.p = dropout
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
x = self.fc1(x)
|
| 144 |
+
x = F.relu(x)
|
| 145 |
+
x = F.dropout(x, self.p, training=self.training)
|
| 146 |
+
x = self.fc2(x)
|
| 147 |
+
x = F.relu(x)
|
| 148 |
+
x = F.dropout(x, self.p, training=self.training)
|
| 149 |
+
return x
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class Attention(nn.Module):
|
| 153 |
+
def __init__(self, attn_dims):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.W = nn.Linear(attn_dims, attn_dims, bias=False)
|
| 156 |
+
self.v = nn.Linear(attn_dims, 1, bias=False)
|
| 157 |
+
|
| 158 |
+
def forward(self, encoder_seq_proj, query, t):
|
| 159 |
+
|
| 160 |
+
# print(encoder_seq_proj.shape)
|
| 161 |
+
# Transform the query vector
|
| 162 |
+
query_proj = self.W(query).unsqueeze(1)
|
| 163 |
+
|
| 164 |
+
# Compute the scores
|
| 165 |
+
u = self.v(torch.tanh(encoder_seq_proj + query_proj))
|
| 166 |
+
scores = F.softmax(u, dim=1)
|
| 167 |
+
|
| 168 |
+
return scores.transpose(1, 2)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class LSA(nn.Module):
|
| 172 |
+
def __init__(self, attn_dim, kernel_size=31, filters=32):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.conv = nn.Conv1d(2, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=False)
|
| 175 |
+
self.L = nn.Linear(filters, attn_dim, bias=True)
|
| 176 |
+
self.W = nn.Linear(attn_dim, attn_dim, bias=True)
|
| 177 |
+
self.v = nn.Linear(attn_dim, 1, bias=False)
|
| 178 |
+
self.cumulative = None
|
| 179 |
+
self.attention = None
|
| 180 |
+
|
| 181 |
+
def init_attention(self, encoder_seq_proj):
|
| 182 |
+
device = next(self.parameters()).device # use same device as parameters
|
| 183 |
+
b, t, c = encoder_seq_proj.size()
|
| 184 |
+
self.cumulative = torch.zeros(b, t, device=device)
|
| 185 |
+
self.attention = torch.zeros(b, t, device=device)
|
| 186 |
+
|
| 187 |
+
def forward(self, encoder_seq_proj, query, t):
|
| 188 |
+
|
| 189 |
+
if t == 0: self.init_attention(encoder_seq_proj)
|
| 190 |
+
|
| 191 |
+
processed_query = self.W(query).unsqueeze(1)
|
| 192 |
+
|
| 193 |
+
location = torch.cat([self.cumulative.unsqueeze(1), self.attention.unsqueeze(1)], dim=1)
|
| 194 |
+
processed_loc = self.L(self.conv(location).transpose(1, 2))
|
| 195 |
+
|
| 196 |
+
u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
|
| 197 |
+
u = u.squeeze(-1)
|
| 198 |
+
|
| 199 |
+
# Smooth Attention
|
| 200 |
+
scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
|
| 201 |
+
# scores = F.softmax(u, dim=1)
|
| 202 |
+
self.attention = scores
|
| 203 |
+
self.cumulative += self.attention
|
| 204 |
+
|
| 205 |
+
return scores.unsqueeze(-1).transpose(1, 2)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class Decoder(nn.Module):
|
| 209 |
+
# Class variable because its value doesn't change between classes
|
| 210 |
+
# yet ought to be scoped by class because its a property of a Decoder
|
| 211 |
+
max_r = 20
|
| 212 |
+
def __init__(self, n_mels, decoder_dims, lstm_dims):
|
| 213 |
+
super().__init__()
|
| 214 |
+
self.register_buffer('r', torch.tensor(1, dtype=torch.int))
|
| 215 |
+
self.n_mels = n_mels
|
| 216 |
+
self.prenet = PreNet(n_mels)
|
| 217 |
+
self.attn_net = LSA(decoder_dims)
|
| 218 |
+
self.attn_rnn = nn.GRUCell(decoder_dims + decoder_dims // 2, decoder_dims)
|
| 219 |
+
self.rnn_input = nn.Linear(2 * decoder_dims, lstm_dims)
|
| 220 |
+
self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
|
| 221 |
+
self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
|
| 222 |
+
self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
|
| 223 |
+
|
| 224 |
+
def zoneout(self, prev, current, p=0.1):
|
| 225 |
+
device = next(self.parameters()).device # Use same device as parameters
|
| 226 |
+
mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
|
| 227 |
+
return prev * mask + current * (1 - mask)
|
| 228 |
+
|
| 229 |
+
def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
|
| 230 |
+
hidden_states, cell_states, context_vec, t):
|
| 231 |
+
|
| 232 |
+
# Need this for reshaping mels
|
| 233 |
+
batch_size = encoder_seq.size(0)
|
| 234 |
+
|
| 235 |
+
# Unpack the hidden and cell states
|
| 236 |
+
attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
|
| 237 |
+
rnn1_cell, rnn2_cell = cell_states
|
| 238 |
+
|
| 239 |
+
# PreNet for the Attention RNN
|
| 240 |
+
prenet_out = self.prenet(prenet_in)
|
| 241 |
+
|
| 242 |
+
# Compute the Attention RNN hidden state
|
| 243 |
+
attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
|
| 244 |
+
attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
|
| 245 |
+
|
| 246 |
+
# Compute the attention scores
|
| 247 |
+
scores = self.attn_net(encoder_seq_proj, attn_hidden, t)
|
| 248 |
+
|
| 249 |
+
# Dot product to create the context vector
|
| 250 |
+
context_vec = scores @ encoder_seq
|
| 251 |
+
context_vec = context_vec.squeeze(1)
|
| 252 |
+
|
| 253 |
+
# Concat Attention RNN output w. Context Vector & project
|
| 254 |
+
x = torch.cat([context_vec, attn_hidden], dim=1)
|
| 255 |
+
x = self.rnn_input(x)
|
| 256 |
+
|
| 257 |
+
# Compute first Residual RNN
|
| 258 |
+
rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
|
| 259 |
+
if self.training:
|
| 260 |
+
rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
|
| 261 |
+
else:
|
| 262 |
+
rnn1_hidden = rnn1_hidden_next
|
| 263 |
+
x = x + rnn1_hidden
|
| 264 |
+
|
| 265 |
+
# Compute second Residual RNN
|
| 266 |
+
rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
|
| 267 |
+
if self.training:
|
| 268 |
+
rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
|
| 269 |
+
else:
|
| 270 |
+
rnn2_hidden = rnn2_hidden_next
|
| 271 |
+
x = x + rnn2_hidden
|
| 272 |
+
|
| 273 |
+
# Project Mels
|
| 274 |
+
mels = self.mel_proj(x)
|
| 275 |
+
mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
|
| 276 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
| 277 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
| 278 |
+
|
| 279 |
+
return mels, scores, hidden_states, cell_states, context_vec
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class Tacotron(nn.Module):
|
| 283 |
+
def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, fft_bins, postnet_dims,
|
| 284 |
+
encoder_K, lstm_dims, postnet_K, num_highways, dropout, stop_threshold):
|
| 285 |
+
super().__init__()
|
| 286 |
+
self.n_mels = n_mels
|
| 287 |
+
self.lstm_dims = lstm_dims
|
| 288 |
+
self.decoder_dims = decoder_dims
|
| 289 |
+
self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
|
| 290 |
+
encoder_K, num_highways, dropout)
|
| 291 |
+
self.encoder_proj = nn.Linear(decoder_dims, decoder_dims, bias=False)
|
| 292 |
+
self.decoder = Decoder(n_mels, decoder_dims, lstm_dims)
|
| 293 |
+
self.postnet = CBHG(postnet_K, n_mels, postnet_dims, [256, 80], num_highways)
|
| 294 |
+
self.post_proj = nn.Linear(postnet_dims * 2, fft_bins, bias=False)
|
| 295 |
+
|
| 296 |
+
self.init_model()
|
| 297 |
+
self.num_params()
|
| 298 |
+
|
| 299 |
+
self.register_buffer('step', torch.zeros(1, dtype=torch.long))
|
| 300 |
+
self.register_buffer('stop_threshold', torch.tensor(stop_threshold, dtype=torch.float32))
|
| 301 |
+
|
| 302 |
+
@property
|
| 303 |
+
def r(self):
|
| 304 |
+
return self.decoder.r.item()
|
| 305 |
+
|
| 306 |
+
@r.setter
|
| 307 |
+
def r(self, value):
|
| 308 |
+
self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
|
| 309 |
+
|
| 310 |
+
def forward(self, x, m, generate_gta=False):
|
| 311 |
+
device = next(self.parameters()).device # use same device as parameters
|
| 312 |
+
|
| 313 |
+
self.step += 1
|
| 314 |
+
|
| 315 |
+
if generate_gta:
|
| 316 |
+
self.eval()
|
| 317 |
+
else:
|
| 318 |
+
self.train()
|
| 319 |
+
|
| 320 |
+
batch_size, _, steps = m.size()
|
| 321 |
+
|
| 322 |
+
# Initialise all hidden states and pack into tuple
|
| 323 |
+
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
| 324 |
+
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
| 325 |
+
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
| 326 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
| 327 |
+
|
| 328 |
+
# Initialise all lstm cell states and pack into tuple
|
| 329 |
+
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
| 330 |
+
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
| 331 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
| 332 |
+
|
| 333 |
+
# <GO> Frame for start of decoder loop
|
| 334 |
+
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
| 335 |
+
|
| 336 |
+
# Need an initial context vector
|
| 337 |
+
context_vec = torch.zeros(batch_size, self.decoder_dims, device=device)
|
| 338 |
+
|
| 339 |
+
# Project the encoder outputs to avoid
|
| 340 |
+
# unnecessary matmuls in the decoder loop
|
| 341 |
+
encoder_seq = self.encoder(x)
|
| 342 |
+
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
| 343 |
+
|
| 344 |
+
# Need a couple of lists for outputs
|
| 345 |
+
mel_outputs, attn_scores = [], []
|
| 346 |
+
|
| 347 |
+
# Run the decoder loop
|
| 348 |
+
for t in range(0, steps, self.r):
|
| 349 |
+
prenet_in = m[:, :, t - 1] if t > 0 else go_frame
|
| 350 |
+
mel_frames, scores, hidden_states, cell_states, context_vec = \
|
| 351 |
+
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
| 352 |
+
hidden_states, cell_states, context_vec, t)
|
| 353 |
+
mel_outputs.append(mel_frames)
|
| 354 |
+
attn_scores.append(scores)
|
| 355 |
+
|
| 356 |
+
# Concat the mel outputs into sequence
|
| 357 |
+
mel_outputs = torch.cat(mel_outputs, dim=2)
|
| 358 |
+
|
| 359 |
+
# Post-Process for Linear Spectrograms
|
| 360 |
+
postnet_out = self.postnet(mel_outputs)
|
| 361 |
+
linear = self.post_proj(postnet_out)
|
| 362 |
+
linear = linear.transpose(1, 2)
|
| 363 |
+
|
| 364 |
+
# For easy visualisation
|
| 365 |
+
attn_scores = torch.cat(attn_scores, 1)
|
| 366 |
+
# attn_scores = attn_scores.cpu().data.numpy()
|
| 367 |
+
|
| 368 |
+
return mel_outputs, linear, attn_scores
|
| 369 |
+
|
| 370 |
+
def generate(self, x, steps=2000):
|
| 371 |
+
self.eval()
|
| 372 |
+
device = next(self.parameters()).device # use same device as parameters
|
| 373 |
+
|
| 374 |
+
batch_size = 1
|
| 375 |
+
x = torch.as_tensor(x, dtype=torch.long, device=device).unsqueeze(0)
|
| 376 |
+
|
| 377 |
+
# Need to initialise all hidden states and pack into tuple for tidyness
|
| 378 |
+
attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
|
| 379 |
+
rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
| 380 |
+
rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
|
| 381 |
+
hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
|
| 382 |
+
|
| 383 |
+
# Need to initialise all lstm cell states and pack into tuple for tidyness
|
| 384 |
+
rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
| 385 |
+
rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
|
| 386 |
+
cell_states = (rnn1_cell, rnn2_cell)
|
| 387 |
+
|
| 388 |
+
# Need a <GO> Frame for start of decoder loop
|
| 389 |
+
go_frame = torch.zeros(batch_size, self.n_mels, device=device)
|
| 390 |
+
|
| 391 |
+
# Need an initial context vector
|
| 392 |
+
context_vec = torch.zeros(batch_size, self.decoder_dims, device=device)
|
| 393 |
+
|
| 394 |
+
# Project the encoder outputs to avoid
|
| 395 |
+
# unnecessary matmuls in the decoder loop
|
| 396 |
+
encoder_seq = self.encoder(x)
|
| 397 |
+
encoder_seq_proj = self.encoder_proj(encoder_seq)
|
| 398 |
+
|
| 399 |
+
# Need a couple of lists for outputs
|
| 400 |
+
mel_outputs, attn_scores = [], []
|
| 401 |
+
|
| 402 |
+
# Run the decoder loop
|
| 403 |
+
for t in range(0, steps, self.r):
|
| 404 |
+
prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
|
| 405 |
+
mel_frames, scores, hidden_states, cell_states, context_vec = \
|
| 406 |
+
self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
|
| 407 |
+
hidden_states, cell_states, context_vec, t)
|
| 408 |
+
mel_outputs.append(mel_frames)
|
| 409 |
+
attn_scores.append(scores)
|
| 410 |
+
# Stop the loop if silent frames present
|
| 411 |
+
if (mel_frames < self.stop_threshold).all() and t > 10: break
|
| 412 |
+
|
| 413 |
+
# Concat the mel outputs into sequence
|
| 414 |
+
mel_outputs = torch.cat(mel_outputs, dim=2)
|
| 415 |
+
|
| 416 |
+
# Post-Process for Linear Spectrograms
|
| 417 |
+
postnet_out = self.postnet(mel_outputs)
|
| 418 |
+
linear = self.post_proj(postnet_out)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
linear = linear.transpose(1, 2)[0].cpu().data.numpy()
|
| 422 |
+
mel_outputs = mel_outputs[0].cpu().data.numpy()
|
| 423 |
+
|
| 424 |
+
# For easy visualisation
|
| 425 |
+
attn_scores = torch.cat(attn_scores, 1)
|
| 426 |
+
attn_scores = attn_scores.cpu().data.numpy()[0]
|
| 427 |
+
|
| 428 |
+
self.train()
|
| 429 |
+
|
| 430 |
+
return mel_outputs, linear, attn_scores
|
| 431 |
+
|
| 432 |
+
def init_model(self):
|
| 433 |
+
for p in self.parameters():
|
| 434 |
+
if p.dim() > 1: nn.init.xavier_uniform_(p)
|
| 435 |
+
|
| 436 |
+
def get_step(self):
|
| 437 |
+
return self.step.data.item()
|
| 438 |
+
|
| 439 |
+
def reset_step(self):
|
| 440 |
+
# assignment to parameters or buffers is overloaded, updates internal dict entry
|
| 441 |
+
self.step = self.step.data.new_tensor(1)
|
| 442 |
+
|
| 443 |
+
def log(self, path, msg):
|
| 444 |
+
with open(path, 'a') as f:
|
| 445 |
+
print(msg, file=f)
|
| 446 |
+
|
| 447 |
+
def load(self, path: Union[str, Path]):
|
| 448 |
+
# Use device of model params as location for loaded state
|
| 449 |
+
device = next(self.parameters()).device
|
| 450 |
+
state_dict = torch.load(path, map_location=device)
|
| 451 |
+
|
| 452 |
+
# Backwards compatibility with old saved models
|
| 453 |
+
if 'r' in state_dict and not 'decoder.r' in state_dict:
|
| 454 |
+
self.r = state_dict['r']
|
| 455 |
+
|
| 456 |
+
self.load_state_dict(state_dict, strict=False)
|
| 457 |
+
|
| 458 |
+
def save(self, path: Union[str, Path]):
|
| 459 |
+
# No optimizer argument because saving a model should not include data
|
| 460 |
+
# only relevant in the training process - it should only be properties
|
| 461 |
+
# of the model itself. Let caller take care of saving optimzier state.
|
| 462 |
+
torch.save(self.state_dict(), path)
|
| 463 |
+
|
| 464 |
+
def num_params(self, print_out=True):
|
| 465 |
+
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
| 466 |
+
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
| 467 |
+
if print_out:
|
| 468 |
+
print('Trainable Parameters: %.3fM' % parameters)
|
| 469 |
+
return parameters
|
models/WaveRNNModel/notebooks/NB1 - Fit a Sine Wave.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/WaveRNNModel/notebooks/NB2 - Fit a Short Sample.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/WaveRNNModel/notebooks/NB3 - Fit a 30min Sample.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/WaveRNNModel/notebooks/NB4a - Alternative Model (Preprocessing).ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/WaveRNNModel/notebooks/NB4b - Alternative Model (Training).ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/WaveRNNModel/notebooks/Pruning - Scratchpad.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/WaveRNNModel/notebooks/__init__.py
ADDED
|
File without changes
|
models/WaveRNNModel/notebooks/models/wavernn.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class WaveRNN(nn.Module) :
|
| 7 |
+
def __init__(self, hidden_size=896, quantisation=256) :
|
| 8 |
+
super(WaveRNN, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.hidden_size = hidden_size
|
| 11 |
+
self.split_size = hidden_size // 2
|
| 12 |
+
|
| 13 |
+
# The main matmul
|
| 14 |
+
self.R = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
|
| 15 |
+
|
| 16 |
+
# Output fc layers
|
| 17 |
+
self.O1 = nn.Linear(self.split_size, self.split_size)
|
| 18 |
+
self.O2 = nn.Linear(self.split_size, quantisation)
|
| 19 |
+
self.O3 = nn.Linear(self.split_size, self.split_size)
|
| 20 |
+
self.O4 = nn.Linear(self.split_size, quantisation)
|
| 21 |
+
|
| 22 |
+
# Input fc layers
|
| 23 |
+
self.I_coarse = nn.Linear(2, 3 * self.split_size, bias=False)
|
| 24 |
+
self.I_fine = nn.Linear(3, 3 * self.split_size, bias=False)
|
| 25 |
+
|
| 26 |
+
# biases for the gates
|
| 27 |
+
self.bias_u = nn.Parameter(torch.zeros(self.hidden_size))
|
| 28 |
+
self.bias_r = nn.Parameter(torch.zeros(self.hidden_size))
|
| 29 |
+
self.bias_e = nn.Parameter(torch.zeros(self.hidden_size))
|
| 30 |
+
|
| 31 |
+
# display num params
|
| 32 |
+
self.num_params()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def forward(self, prev_y, prev_hidden, current_coarse) :
|
| 36 |
+
|
| 37 |
+
# Main matmul - the projection is split 3 ways
|
| 38 |
+
R_hidden = self.R(prev_hidden)
|
| 39 |
+
R_u, R_r, R_e, = torch.split(R_hidden, self.hidden_size, dim=1)
|
| 40 |
+
|
| 41 |
+
# Project the prev input
|
| 42 |
+
coarse_input_proj = self.I_coarse(prev_y)
|
| 43 |
+
I_coarse_u, I_coarse_r, I_coarse_e = \
|
| 44 |
+
torch.split(coarse_input_proj, self.split_size, dim=1)
|
| 45 |
+
|
| 46 |
+
# Project the prev input and current coarse sample
|
| 47 |
+
fine_input = torch.cat([prev_y, current_coarse], dim=1)
|
| 48 |
+
fine_input_proj = self.I_fine(fine_input)
|
| 49 |
+
I_fine_u, I_fine_r, I_fine_e = \
|
| 50 |
+
torch.split(fine_input_proj, self.split_size, dim=1)
|
| 51 |
+
|
| 52 |
+
# concatenate for the gates
|
| 53 |
+
I_u = torch.cat([I_coarse_u, I_fine_u], dim=1)
|
| 54 |
+
I_r = torch.cat([I_coarse_r, I_fine_r], dim=1)
|
| 55 |
+
I_e = torch.cat([I_coarse_e, I_fine_e], dim=1)
|
| 56 |
+
|
| 57 |
+
# Compute all gates for coarse and fine
|
| 58 |
+
u = F.sigmoid(R_u + I_u + self.bias_u)
|
| 59 |
+
r = F.sigmoid(R_r + I_r + self.bias_r)
|
| 60 |
+
e = F.tanh(r * R_e + I_e + self.bias_e)
|
| 61 |
+
hidden = u * prev_hidden + (1. - u) * e
|
| 62 |
+
|
| 63 |
+
# Split the hidden state
|
| 64 |
+
hidden_coarse, hidden_fine = torch.split(hidden, self.split_size, dim=1)
|
| 65 |
+
|
| 66 |
+
# Compute outputs
|
| 67 |
+
out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
|
| 68 |
+
out_fine = self.O4(F.relu(self.O3(hidden_fine)))
|
| 69 |
+
|
| 70 |
+
return out_coarse, out_fine, hidden
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def generate(self, seq_len) :
|
| 74 |
+
|
| 75 |
+
with torch.no_grad() :
|
| 76 |
+
|
| 77 |
+
# First split up the biases for the gates
|
| 78 |
+
b_coarse_u, b_fine_u = torch.split(self.bias_u, self.split_size)
|
| 79 |
+
b_coarse_r, b_fine_r = torch.split(self.bias_r, self.split_size)
|
| 80 |
+
b_coarse_e, b_fine_e = torch.split(self.bias_e, self.split_size)
|
| 81 |
+
|
| 82 |
+
# Lists for the two output seqs
|
| 83 |
+
c_outputs, f_outputs = [], []
|
| 84 |
+
|
| 85 |
+
# Some initial inputs
|
| 86 |
+
out_coarse = torch.LongTensor([0]).cuda()
|
| 87 |
+
out_fine = torch.LongTensor([0]).cuda()
|
| 88 |
+
|
| 89 |
+
# We'll meed a hidden state
|
| 90 |
+
hidden = self.init_hidden()
|
| 91 |
+
|
| 92 |
+
# Need a clock for display
|
| 93 |
+
start = time.time()
|
| 94 |
+
|
| 95 |
+
# Loop for generation
|
| 96 |
+
for i in range(seq_len) :
|
| 97 |
+
|
| 98 |
+
# Split into two hidden states
|
| 99 |
+
hidden_coarse, hidden_fine = \
|
| 100 |
+
torch.split(hidden, self.split_size, dim=1)
|
| 101 |
+
|
| 102 |
+
# Scale and concat previous predictions
|
| 103 |
+
out_coarse = out_coarse.unsqueeze(0).float() / 127.5 - 1.
|
| 104 |
+
out_fine = out_fine.unsqueeze(0).float() / 127.5 - 1.
|
| 105 |
+
prev_outputs = torch.cat([out_coarse, out_fine], dim=1)
|
| 106 |
+
|
| 107 |
+
# Project input
|
| 108 |
+
coarse_input_proj = self.I_coarse(prev_outputs)
|
| 109 |
+
I_coarse_u, I_coarse_r, I_coarse_e = \
|
| 110 |
+
torch.split(coarse_input_proj, self.split_size, dim=1)
|
| 111 |
+
|
| 112 |
+
# Project hidden state and split 6 ways
|
| 113 |
+
R_hidden = self.R(hidden)
|
| 114 |
+
R_coarse_u , R_fine_u, \
|
| 115 |
+
R_coarse_r, R_fine_r, \
|
| 116 |
+
R_coarse_e, R_fine_e = torch.split(R_hidden, self.split_size, dim=1)
|
| 117 |
+
|
| 118 |
+
# Compute the coarse gates
|
| 119 |
+
u = F.sigmoid(R_coarse_u + I_coarse_u + b_coarse_u)
|
| 120 |
+
r = F.sigmoid(R_coarse_r + I_coarse_r + b_coarse_r)
|
| 121 |
+
e = F.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e)
|
| 122 |
+
hidden_coarse = u * hidden_coarse + (1. - u) * e
|
| 123 |
+
|
| 124 |
+
# Compute the coarse output
|
| 125 |
+
out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
|
| 126 |
+
posterior = F.softmax(out_coarse, dim=1)
|
| 127 |
+
distrib = torch.distributions.Categorical(posterior)
|
| 128 |
+
out_coarse = distrib.sample()
|
| 129 |
+
c_outputs.append(out_coarse)
|
| 130 |
+
|
| 131 |
+
# Project the [prev outputs and predicted coarse sample]
|
| 132 |
+
coarse_pred = out_coarse.float() / 127.5 - 1.
|
| 133 |
+
fine_input = torch.cat([prev_outputs, coarse_pred.unsqueeze(0)], dim=1)
|
| 134 |
+
fine_input_proj = self.I_fine(fine_input)
|
| 135 |
+
I_fine_u, I_fine_r, I_fine_e = \
|
| 136 |
+
torch.split(fine_input_proj, self.split_size, dim=1)
|
| 137 |
+
|
| 138 |
+
# Compute the fine gates
|
| 139 |
+
u = F.sigmoid(R_fine_u + I_fine_u + b_fine_u)
|
| 140 |
+
r = F.sigmoid(R_fine_r + I_fine_r + b_fine_r)
|
| 141 |
+
e = F.tanh(r * R_fine_e + I_fine_e + b_fine_e)
|
| 142 |
+
hidden_fine = u * hidden_fine + (1. - u) * e
|
| 143 |
+
|
| 144 |
+
# Compute the fine output
|
| 145 |
+
out_fine = self.O4(F.relu(self.O3(hidden_fine)))
|
| 146 |
+
posterior = F.softmax(out_fine, dim=1)
|
| 147 |
+
distrib = torch.distributions.Categorical(posterior)
|
| 148 |
+
out_fine = distrib.sample()
|
| 149 |
+
f_outputs.append(out_fine)
|
| 150 |
+
|
| 151 |
+
# Put the hidden state back together
|
| 152 |
+
hidden = torch.cat([hidden_coarse, hidden_fine], dim=1)
|
| 153 |
+
|
| 154 |
+
# Display progress
|
| 155 |
+
speed = (i + 1) / (time.time() - start)
|
| 156 |
+
stream('Gen: %i/%i -- Speed: %i', (i + 1, seq_len, speed))
|
| 157 |
+
|
| 158 |
+
coarse = torch.stack(c_outputs).squeeze(1).cpu().data.numpy()
|
| 159 |
+
fine = torch.stack(f_outputs).squeeze(1).cpu().data.numpy()
|
| 160 |
+
output = combine_signal(coarse, fine)
|
| 161 |
+
|
| 162 |
+
return output, coarse, fine
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def init_hidden(self, batch_size=1) :
|
| 166 |
+
return torch.zeros(batch_size, self.hidden_size).cuda()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def num_params(self) :
|
| 170 |
+
parameters = filter(lambda p: p.requires_grad, self.parameters())
|
| 171 |
+
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
| 172 |
+
print('Trainable Parameters: %.3f million' % parameters)
|
models/WaveRNNModel/notebooks/outputs/nb1/model_output.wav
ADDED
|
Binary file (80 kB). View file
|
|
|
models/WaveRNNModel/notebooks/outputs/nb2/3k_steps.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eeadf150c5e62325421d700609eec88c106ba86330a011d6b54542a8764a0728
|
| 3 |
+
size 220544
|
models/WaveRNNModel/notebooks/outputs/nb3/12k_steps.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b18a69000c2f2852d1b9f5104a7f9bc3421882c5ff2847f7185b081c9feab0b8
|
| 3 |
+
size 882044
|
models/WaveRNNModel/notebooks/utils/__init__.py
ADDED
|
File without changes
|
models/WaveRNNModel/notebooks/utils/display.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import time, sys, math
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
def stream(string, variables) :
|
| 6 |
+
sys.stdout.write(f'\r{string}' % variables)
|
| 7 |
+
|
| 8 |
+
def num_params(model) :
|
| 9 |
+
parameters = filter(lambda p: p.requires_grad, model.parameters())
|
| 10 |
+
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
| 11 |
+
print('Trainable Parameters: %.3f million' % parameters)
|
| 12 |
+
|
| 13 |
+
def time_since(started) :
|
| 14 |
+
elapsed = time.time() - started
|
| 15 |
+
m = int(elapsed // 60)
|
| 16 |
+
s = int(elapsed % 60)
|
| 17 |
+
if m >= 60 :
|
| 18 |
+
h = int(m // 60)
|
| 19 |
+
m = m % 60
|
| 20 |
+
return f'{h}h {m}m {s}s'
|
| 21 |
+
else :
|
| 22 |
+
return f'{m}m {s}s'
|
| 23 |
+
|
| 24 |
+
def plot(array) :
|
| 25 |
+
fig = plt.figure(figsize=(30, 5))
|
| 26 |
+
ax = fig.add_subplot(111)
|
| 27 |
+
ax.xaxis.label.set_color('grey')
|
| 28 |
+
ax.yaxis.label.set_color('grey')
|
| 29 |
+
ax.xaxis.label.set_fontsize(23)
|
| 30 |
+
ax.yaxis.label.set_fontsize(23)
|
| 31 |
+
ax.tick_params(axis='x', colors='grey', labelsize=23)
|
| 32 |
+
ax.tick_params(axis='y', colors='grey', labelsize=23)
|
| 33 |
+
plt.plot(array)
|
| 34 |
+
|
| 35 |
+
def plot_spec(M) :
|
| 36 |
+
M = np.flip(M, axis=0)
|
| 37 |
+
plt.figure(figsize=(18,4))
|
| 38 |
+
plt.imshow(M, interpolation='nearest', aspect='auto')
|
| 39 |
+
plt.show()
|
| 40 |
+
|
models/WaveRNNModel/notebooks/utils/dsp.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import librosa, math
|
| 3 |
+
|
| 4 |
+
sample_rate = 22050
|
| 5 |
+
n_fft = 2048
|
| 6 |
+
fft_bins = n_fft // 2 + 1
|
| 7 |
+
num_mels = 80
|
| 8 |
+
hop_length = int(sample_rate * 0.0125) # 12.5ms
|
| 9 |
+
win_length = int(sample_rate * 0.05) # 50ms
|
| 10 |
+
fmin = 40
|
| 11 |
+
min_level_db = -100
|
| 12 |
+
ref_level_db = 20
|
| 13 |
+
|
| 14 |
+
def load_wav(filename, encode=True) :
|
| 15 |
+
x = librosa.load(filename, sr=sample_rate)[0]
|
| 16 |
+
if encode == True : x = encode_16bits(x)
|
| 17 |
+
return x
|
| 18 |
+
|
| 19 |
+
def save_wav(y, filename) :
|
| 20 |
+
if y.dtype != 'int16' :
|
| 21 |
+
y = encode_16bits(y)
|
| 22 |
+
librosa.output.write_wav(filename, y.astype(np.int16), sample_rate)
|
| 23 |
+
|
| 24 |
+
def split_signal(x) :
|
| 25 |
+
unsigned = x + 2**15
|
| 26 |
+
coarse = unsigned // 256
|
| 27 |
+
fine = unsigned % 256
|
| 28 |
+
return coarse, fine
|
| 29 |
+
|
| 30 |
+
def combine_signal(coarse, fine) :
|
| 31 |
+
return coarse * 256 + fine - 2**15
|
| 32 |
+
|
| 33 |
+
def encode_16bits(x) :
|
| 34 |
+
return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)
|
| 35 |
+
|
| 36 |
+
mel_basis = None
|
| 37 |
+
|
| 38 |
+
def linear_to_mel(spectrogram):
|
| 39 |
+
global mel_basis
|
| 40 |
+
if mel_basis is None:
|
| 41 |
+
mel_basis = build_mel_basis()
|
| 42 |
+
return np.dot(mel_basis, spectrogram)
|
| 43 |
+
|
| 44 |
+
def build_mel_basis():
|
| 45 |
+
return librosa.filters.mel(sample_rate, n_fft, n_mels=num_mels, fmin=fmin)
|
| 46 |
+
|
| 47 |
+
def normalize(S):
|
| 48 |
+
return np.clip((S - min_level_db) / -min_level_db, 0, 1)
|
| 49 |
+
|
| 50 |
+
def denormalize(S):
|
| 51 |
+
return (np.clip(S, 0, 1) * -min_level_db) + min_level_db
|
| 52 |
+
|
| 53 |
+
def amp_to_db(x):
|
| 54 |
+
return 20 * np.log10(np.maximum(1e-5, x))
|
| 55 |
+
|
| 56 |
+
def db_to_amp(x):
|
| 57 |
+
return np.power(10.0, x * 0.05)
|
| 58 |
+
|
| 59 |
+
def spectrogram(y):
|
| 60 |
+
D = stft(y)
|
| 61 |
+
S = amp_to_db(np.abs(D)) - ref_level_db
|
| 62 |
+
return normalize(S)
|
| 63 |
+
|
| 64 |
+
def melspectrogram(y):
|
| 65 |
+
D = stft(y)
|
| 66 |
+
S = amp_to_db(linear_to_mel(np.abs(D)))
|
| 67 |
+
return normalize(S)
|
| 68 |
+
|
| 69 |
+
def stft(y):
|
| 70 |
+
return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
models/WaveRNNModel/preprocess.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
from utils.display import *
|
| 3 |
+
from utils.dsp import *
|
| 4 |
+
from utils import hparams as hp
|
| 5 |
+
from multiprocessing import Pool, cpu_count
|
| 6 |
+
from utils.paths import Paths
|
| 7 |
+
import pickle
|
| 8 |
+
import argparse
|
| 9 |
+
from utils.text.recipes import ljspeech
|
| 10 |
+
from utils.files import get_files
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Helper functions for argument types
|
| 15 |
+
def valid_n_workers(num):
|
| 16 |
+
n = int(num)
|
| 17 |
+
if n < 1:
|
| 18 |
+
raise argparse.ArgumentTypeError('%r must be an integer greater than 0' % num)
|
| 19 |
+
return n
|
| 20 |
+
|
| 21 |
+
parser = argparse.ArgumentParser(description='Preprocessing for WaveRNN and Tacotron')
|
| 22 |
+
parser.add_argument('--path', '-p', help='directly point to dataset path (overrides hparams.wav_path')
|
| 23 |
+
parser.add_argument('--extension', '-e', metavar='EXT', default='.wav', help='file extension to search for in dataset folder')
|
| 24 |
+
parser.add_argument('--num_workers', '-w', metavar='N', type=valid_n_workers, default=cpu_count()-1, help='The number of worker threads to use for preprocessing')
|
| 25 |
+
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
|
| 26 |
+
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
hp.configure(args.hp_file) # Load hparams from file
|
| 29 |
+
if args.path is None:
|
| 30 |
+
args.path = hp.wav_path
|
| 31 |
+
|
| 32 |
+
extension = args.extension
|
| 33 |
+
path = args.path
|
| 34 |
+
|
| 35 |
+
wav_files = get_files(path, extension)
|
| 36 |
+
paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
|
| 37 |
+
|
| 38 |
+
print(f'\n{len(wav_files)} {extension[1:]} files found in "{path}"\n')
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def convert_file(path: Path):
|
| 43 |
+
y = load_wav(path)
|
| 44 |
+
peak = np.abs(y).max()
|
| 45 |
+
if hp.peak_norm or peak > 1.0:
|
| 46 |
+
y /= peak
|
| 47 |
+
mel = melspectrogram(y)
|
| 48 |
+
if hp.voc_mode == 'RAW':
|
| 49 |
+
quant = encode_mu_law(y, mu=2**hp.bits) if hp.mu_law else float_2_label(y, bits=hp.bits)
|
| 50 |
+
elif hp.voc_mode == 'MOL':
|
| 51 |
+
quant = float_2_label(y, bits=16)
|
| 52 |
+
|
| 53 |
+
return mel.astype(np.float32), quant.astype(np.int64)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def process_wav(path: Path):
|
| 57 |
+
wav_id = path.stem
|
| 58 |
+
m, x = convert_file(path)
|
| 59 |
+
#print("paths.mel:::",paths.mel)
|
| 60 |
+
np.save(paths.mel/f'{wav_id}.npy', m, allow_pickle=False)
|
| 61 |
+
np.save(paths.quant/f'{wav_id}.npy', x, allow_pickle=False)
|
| 62 |
+
return wav_id, m.shape[-1]
|
| 63 |
+
|
| 64 |
+
if __name__ == '__main__':
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if len(wav_files) == 0:
|
| 68 |
+
|
| 69 |
+
print('Please point wav_path in hparams.py to your dataset,')
|
| 70 |
+
print('or use the --path option.\n')
|
| 71 |
+
|
| 72 |
+
else:
|
| 73 |
+
|
| 74 |
+
if not hp.ignore_tts:
|
| 75 |
+
|
| 76 |
+
text_dict = ljspeech(path)
|
| 77 |
+
|
| 78 |
+
with open(paths.data/'text_dict.pkl', 'wb') as f:
|
| 79 |
+
pickle.dump(text_dict, f)
|
| 80 |
+
|
| 81 |
+
n_workers = max(1, args.num_workers)
|
| 82 |
+
|
| 83 |
+
simple_table([
|
| 84 |
+
('Sample Rate', hp.sample_rate),
|
| 85 |
+
('Bit Depth', hp.bits),
|
| 86 |
+
('Mu Law', hp.mu_law),
|
| 87 |
+
('Hop Length', hp.hop_length),
|
| 88 |
+
('CPU Usage', f'{n_workers}/{cpu_count()}')
|
| 89 |
+
])
|
| 90 |
+
|
| 91 |
+
pool = Pool(processes=n_workers)
|
| 92 |
+
dataset = []
|
| 93 |
+
print("test22222")
|
| 94 |
+
for i, (item_id, length) in enumerate(pool.imap_unordered(process_wav, wav_files), 1):
|
| 95 |
+
dataset += [(item_id, length)]
|
| 96 |
+
bar = progbar(i, len(wav_files))
|
| 97 |
+
message = f'{bar} {i}/{len(wav_files)} '
|
| 98 |
+
stream(message)
|
| 99 |
+
|
| 100 |
+
with open(paths.data/'dataset.pkl', 'wb') as f:
|
| 101 |
+
pickle.dump(dataset, f)
|
| 102 |
+
|
| 103 |
+
print('\n\nCompleted. Ready to run "python train_tacotron.py" or "python train_wavernn.py". \n')
|
models/WaveRNNModel/quick_start.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from models.fatchord_version import WaveRNN
|
| 3 |
+
from utils import hparams as hp
|
| 4 |
+
from utils.text.symbols import symbols
|
| 5 |
+
from models.tacotron import Tacotron
|
| 6 |
+
import argparse
|
| 7 |
+
from utils.text import text_to_sequence
|
| 8 |
+
from utils.display import save_attention, simple_table
|
| 9 |
+
import zipfile, os
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
os.makedirs('quick_start/tts_weights/', exist_ok=True)
|
| 13 |
+
os.makedirs('quick_start/voc_weights/', exist_ok=True)
|
| 14 |
+
|
| 15 |
+
zip_ref = zipfile.ZipFile('pretrained/ljspeech.wavernn.mol.800k.zip', 'r')
|
| 16 |
+
zip_ref.extractall('quick_start/voc_weights/')
|
| 17 |
+
zip_ref.close()
|
| 18 |
+
|
| 19 |
+
zip_ref = zipfile.ZipFile('pretrained/ljspeech.tacotron.r2.180k.zip', 'r')
|
| 20 |
+
zip_ref.extractall('quick_start/tts_weights/')
|
| 21 |
+
zip_ref.close()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
if __name__ == "__main__":
|
| 25 |
+
|
| 26 |
+
# Parse Arguments
|
| 27 |
+
parser = argparse.ArgumentParser(description='TTS Generator')
|
| 28 |
+
parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!')
|
| 29 |
+
parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation (lower quality)')
|
| 30 |
+
parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slower Unbatched Generation (better quality)')
|
| 31 |
+
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
|
| 32 |
+
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py',
|
| 33 |
+
help='The file to use for the hyperparameters')
|
| 34 |
+
args = parser.parse_args()
|
| 35 |
+
|
| 36 |
+
hp.configure(args.hp_file) # Load hparams from file
|
| 37 |
+
|
| 38 |
+
parser.set_defaults(batched=True)
|
| 39 |
+
parser.set_defaults(input_text=None)
|
| 40 |
+
|
| 41 |
+
batched = args.batched
|
| 42 |
+
input_text = args.input_text
|
| 43 |
+
|
| 44 |
+
if not args.force_cpu and torch.cuda.is_available():
|
| 45 |
+
device = torch.device('cuda')
|
| 46 |
+
else:
|
| 47 |
+
device = torch.device('cpu')
|
| 48 |
+
print('Using device:', device)
|
| 49 |
+
|
| 50 |
+
print('\nInitialising WaveRNN Model...\n')
|
| 51 |
+
|
| 52 |
+
# Instantiate WaveRNN Model
|
| 53 |
+
voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
|
| 54 |
+
fc_dims=hp.voc_fc_dims,
|
| 55 |
+
bits=hp.bits,
|
| 56 |
+
pad=hp.voc_pad,
|
| 57 |
+
upsample_factors=hp.voc_upsample_factors,
|
| 58 |
+
feat_dims=hp.num_mels,
|
| 59 |
+
compute_dims=hp.voc_compute_dims,
|
| 60 |
+
res_out_dims=hp.voc_res_out_dims,
|
| 61 |
+
res_blocks=hp.voc_res_blocks,
|
| 62 |
+
hop_length=hp.hop_length,
|
| 63 |
+
sample_rate=hp.sample_rate,
|
| 64 |
+
mode='MOL').to(device)
|
| 65 |
+
|
| 66 |
+
voc_model.load('quick_start/voc_weights/latest_weights.pyt')
|
| 67 |
+
|
| 68 |
+
print('\nInitialising Tacotron Model...\n')
|
| 69 |
+
|
| 70 |
+
# Instantiate Tacotron Model
|
| 71 |
+
tts_model = Tacotron(embed_dims=hp.tts_embed_dims,
|
| 72 |
+
num_chars=len(symbols),
|
| 73 |
+
encoder_dims=hp.tts_encoder_dims,
|
| 74 |
+
decoder_dims=hp.tts_decoder_dims,
|
| 75 |
+
n_mels=hp.num_mels,
|
| 76 |
+
fft_bins=hp.num_mels,
|
| 77 |
+
postnet_dims=hp.tts_postnet_dims,
|
| 78 |
+
encoder_K=hp.tts_encoder_K,
|
| 79 |
+
lstm_dims=hp.tts_lstm_dims,
|
| 80 |
+
postnet_K=hp.tts_postnet_K,
|
| 81 |
+
num_highways=hp.tts_num_highways,
|
| 82 |
+
dropout=hp.tts_dropout,
|
| 83 |
+
stop_threshold=hp.tts_stop_threshold).to(device)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
tts_model.load('quick_start/tts_weights/latest_weights.pyt')
|
| 87 |
+
|
| 88 |
+
if input_text:
|
| 89 |
+
inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
|
| 90 |
+
else:
|
| 91 |
+
with open('sentences.txt') as f:
|
| 92 |
+
inputs = [text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f]
|
| 93 |
+
|
| 94 |
+
voc_k = voc_model.get_step() // 1000
|
| 95 |
+
tts_k = tts_model.get_step() // 1000
|
| 96 |
+
|
| 97 |
+
r = tts_model.r
|
| 98 |
+
|
| 99 |
+
simple_table([('WaveRNN', str(voc_k) + 'k'),
|
| 100 |
+
(f'Tacotron(r={r})', str(tts_k) + 'k'),
|
| 101 |
+
('Generation Mode', 'Batched' if batched else 'Unbatched'),
|
| 102 |
+
('Target Samples', 11_000 if batched else 'N/A'),
|
| 103 |
+
('Overlap Samples', 550 if batched else 'N/A')])
|
| 104 |
+
|
| 105 |
+
for i, x in enumerate(inputs, 1):
|
| 106 |
+
|
| 107 |
+
print(f'\n| Generating {i}/{len(inputs)}')
|
| 108 |
+
_, m, attention = tts_model.generate(x)
|
| 109 |
+
|
| 110 |
+
if input_text:
|
| 111 |
+
save_path = f'quick_start/__input_{input_text[:10]}_{tts_k}k.wav'
|
| 112 |
+
else:
|
| 113 |
+
save_path = f'quick_start/{i}_batched{str(batched)}_{tts_k}k.wav'
|
| 114 |
+
|
| 115 |
+
# save_attention(attention, save_path)
|
| 116 |
+
|
| 117 |
+
m = torch.tensor(m).unsqueeze(0)
|
| 118 |
+
m = (m + 4) / 8
|
| 119 |
+
|
| 120 |
+
voc_model.generate(m, save_path, batched, 11_000, 550, hp.mu_law)
|
| 121 |
+
|
| 122 |
+
print('\n\nDone.\n')
|
models/WaveRNNModel/quick_start/tts_weights/latest_weights.pyt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9c4c491d6ad43f1b9b9e1f393c7d8437592da4b26412838a7bc20f446b76d2f0
|
| 3 |
+
size 44433225
|
models/WaveRNNModel/quick_start/voc_weights/latest_weights.pyt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:99e2a93453a8d531ba851b2e5c488105816f7c60ec50153c56436ff9bda8a26a
|
| 3 |
+
size 16985706
|
models/WaveRNNModel/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy==1.22.0
|
| 2 |
+
librosa==0.6.3
|
| 3 |
+
matplotlib
|
| 4 |
+
unidecode
|
| 5 |
+
inflect
|
| 6 |
+
nltk
|
models/WaveRNNModel/sentences.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
However, no attempt has been made yet to formulat a relativistic generalisation of action-angle variables for geodesic motion in Kerr spacetime and to calculate the dynamical frequencies of arbitrary bound non-plunging orbits.
|
| 2 |
+
|
| 3 |
+
The investigation of bound geodesic orbits in Kerr spacetime presented in this article clearly illustrates that the properties of these orbits in the regime of strong gravity are profoundly different from Keplerian orbits in the Newtonian regime.
|
| 4 |
+
|
| 5 |
+
The observation of as few as ten EMRIs can provide a measurement of the slope of the black-hole mass function to better precision than is currently known.
|
| 6 |
+
|
models/WaveRNNModel/source.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
https://github.com/Adwardgyhjs/RNNoiseAndWaveRNN
|
models/WaveRNNModel/train_tacotron.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import optim
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from utils import hparams as hp
|
| 5 |
+
from utils.display import *
|
| 6 |
+
from utils.dataset import get_tts_datasets
|
| 7 |
+
from utils.text.symbols import symbols
|
| 8 |
+
from utils.paths import Paths
|
| 9 |
+
from models.tacotron import Tacotron
|
| 10 |
+
import argparse
|
| 11 |
+
from utils import data_parallel_workaround
|
| 12 |
+
import os
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import time
|
| 15 |
+
import numpy as np
|
| 16 |
+
import sys
|
| 17 |
+
from utils.checkpoints import save_checkpoint, restore_checkpoint
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def main():
|
| 24 |
+
# Parse Arguments
|
| 25 |
+
parser = argparse.ArgumentParser(description='Train Tacotron TTS')
|
| 26 |
+
parser.add_argument('--force_train', '-f', action='store_true', help='Forces the model to train past total steps')
|
| 27 |
+
parser.add_argument('--force_gta', '-g', action='store_true', help='Force the model to create GTA features')
|
| 28 |
+
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
|
| 29 |
+
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
|
| 30 |
+
args = parser.parse_args()
|
| 31 |
+
|
| 32 |
+
hp.configure(args.hp_file) # Load hparams from file
|
| 33 |
+
paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
|
| 34 |
+
|
| 35 |
+
force_train = args.force_train
|
| 36 |
+
force_gta = args.force_gta
|
| 37 |
+
|
| 38 |
+
if not args.force_cpu and torch.cuda.is_available():
|
| 39 |
+
device = torch.device('cuda')
|
| 40 |
+
for session in hp.tts_schedule:
|
| 41 |
+
_, _, _, batch_size = session
|
| 42 |
+
if batch_size % torch.cuda.device_count() != 0:
|
| 43 |
+
raise ValueError('`batch_size` must be evenly divisible by n_gpus!')
|
| 44 |
+
else:
|
| 45 |
+
device = torch.device('cpu')
|
| 46 |
+
print('Using device:', device)
|
| 47 |
+
|
| 48 |
+
# Instantiate Tacotron Model
|
| 49 |
+
print('\nInitialising Tacotron Model...\n')
|
| 50 |
+
model = Tacotron(embed_dims=hp.tts_embed_dims,
|
| 51 |
+
num_chars=len(symbols),
|
| 52 |
+
encoder_dims=hp.tts_encoder_dims,
|
| 53 |
+
decoder_dims=hp.tts_decoder_dims,
|
| 54 |
+
n_mels=hp.num_mels,
|
| 55 |
+
fft_bins=hp.num_mels,
|
| 56 |
+
postnet_dims=hp.tts_postnet_dims,
|
| 57 |
+
encoder_K=hp.tts_encoder_K,
|
| 58 |
+
lstm_dims=hp.tts_lstm_dims,
|
| 59 |
+
postnet_K=hp.tts_postnet_K,
|
| 60 |
+
num_highways=hp.tts_num_highways,
|
| 61 |
+
dropout=hp.tts_dropout,
|
| 62 |
+
stop_threshold=hp.tts_stop_threshold).to(device)
|
| 63 |
+
|
| 64 |
+
optimizer = optim.Adam(model.parameters())
|
| 65 |
+
restore_checkpoint('tts', paths, model, optimizer, create_if_missing=True)
|
| 66 |
+
|
| 67 |
+
if not force_gta:
|
| 68 |
+
for i, session in enumerate(hp.tts_schedule):
|
| 69 |
+
current_step = model.get_step()
|
| 70 |
+
|
| 71 |
+
r, lr, max_step, batch_size = session
|
| 72 |
+
|
| 73 |
+
training_steps = max_step - current_step
|
| 74 |
+
|
| 75 |
+
# Do we need to change to the next session?
|
| 76 |
+
if current_step >= max_step:
|
| 77 |
+
# Are there no further sessions than the current one?
|
| 78 |
+
if i == len(hp.tts_schedule)-1:
|
| 79 |
+
# There are no more sessions. Check if we force training.
|
| 80 |
+
if force_train:
|
| 81 |
+
# Don't finish the loop - train forever
|
| 82 |
+
training_steps = 999_999_999
|
| 83 |
+
else:
|
| 84 |
+
# We have completed training. Breaking is same as continue
|
| 85 |
+
break
|
| 86 |
+
else:
|
| 87 |
+
# There is a following session, go to it
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
model.r = r
|
| 91 |
+
|
| 92 |
+
simple_table([(f'Steps with r={r}', str(training_steps//1000) + 'k Steps'),
|
| 93 |
+
('Batch Size', batch_size),
|
| 94 |
+
('Learning Rate', lr),
|
| 95 |
+
('Outputs/Step (r)', model.r)])
|
| 96 |
+
|
| 97 |
+
train_set, attn_example = get_tts_datasets(paths.data, batch_size, r)
|
| 98 |
+
tts_train_loop(paths, model, optimizer, train_set, lr, training_steps, attn_example)
|
| 99 |
+
|
| 100 |
+
print('Training Complete.')
|
| 101 |
+
print('To continue training increase tts_total_steps in hparams.py or use --force_train\n')
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
print('Creating Ground Truth Aligned Dataset...\n')
|
| 105 |
+
|
| 106 |
+
train_set, attn_example = get_tts_datasets(paths.data, 8, model.r)
|
| 107 |
+
create_gta_features(model, train_set, paths.gta)
|
| 108 |
+
|
| 109 |
+
print('\n\nYou can now train WaveRNN on GTA features - use python train_wavernn.py --gta\n')
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def tts_train_loop(paths: Paths, model: Tacotron, optimizer, train_set, lr, train_steps, attn_example):
|
| 113 |
+
device = next(model.parameters()).device # use same device as model parameters
|
| 114 |
+
|
| 115 |
+
for g in optimizer.param_groups: g['lr'] = lr
|
| 116 |
+
|
| 117 |
+
total_iters = len(train_set)
|
| 118 |
+
#print("train set",total_iters)
|
| 119 |
+
epochs = train_steps // total_iters + 1
|
| 120 |
+
|
| 121 |
+
for e in range(1, epochs+1):
|
| 122 |
+
|
| 123 |
+
start = time.time()
|
| 124 |
+
running_loss = 0
|
| 125 |
+
|
| 126 |
+
# Perform 1 epoch
|
| 127 |
+
for i, (x, m, ids, _) in enumerate(train_set, 1):
|
| 128 |
+
|
| 129 |
+
x, m = x.to(device), m.to(device)
|
| 130 |
+
#print("test33333333")
|
| 131 |
+
# Parallelize model onto GPUS using workaround due to python bug
|
| 132 |
+
if device.type == 'cuda' and torch.cuda.device_count() > 1:
|
| 133 |
+
m1_hat, m2_hat, attention = data_parallel_workaround(model, x, m)
|
| 134 |
+
else:
|
| 135 |
+
m1_hat, m2_hat, attention = model(x, m)
|
| 136 |
+
|
| 137 |
+
m1_loss = F.l1_loss(m1_hat, m)
|
| 138 |
+
m2_loss = F.l1_loss(m2_hat, m)
|
| 139 |
+
|
| 140 |
+
loss = m1_loss + m2_loss
|
| 141 |
+
|
| 142 |
+
optimizer.zero_grad()
|
| 143 |
+
loss.backward()
|
| 144 |
+
if hp.tts_clip_grad_norm is not None:
|
| 145 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm)
|
| 146 |
+
if torch.isnan(grad_norm):
|
| 147 |
+
print('grad_norm was NaN!')
|
| 148 |
+
|
| 149 |
+
optimizer.step()
|
| 150 |
+
#print("test4444444")
|
| 151 |
+
running_loss += loss.item()
|
| 152 |
+
avg_loss = running_loss / i
|
| 153 |
+
|
| 154 |
+
speed = i / (time.time() - start)
|
| 155 |
+
|
| 156 |
+
step = model.get_step()
|
| 157 |
+
k = step // 1000
|
| 158 |
+
|
| 159 |
+
if step % hp.tts_checkpoint_every == 0:
|
| 160 |
+
ckpt_name = f'taco_step{k}K'
|
| 161 |
+
save_checkpoint('tts', paths, model, optimizer,
|
| 162 |
+
name=ckpt_name, is_silent=True)
|
| 163 |
+
|
| 164 |
+
if attn_example in ids:
|
| 165 |
+
idx = ids.index(attn_example)
|
| 166 |
+
save_attention(np_now(attention[idx][:, :160]), paths.tts_attention/f'{step}')
|
| 167 |
+
save_spectrogram(np_now(m2_hat[idx]), paths.tts_mel_plot/f'{step}', 600)
|
| 168 |
+
|
| 169 |
+
msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:#.4} | {speed:#.2} steps/s | Step: {k}k | '
|
| 170 |
+
stream(msg)
|
| 171 |
+
|
| 172 |
+
# Must save latest optimizer state to ensure that resuming training
|
| 173 |
+
# doesn't produce artifacts
|
| 174 |
+
save_checkpoint('tts', paths, model, optimizer, is_silent=True)
|
| 175 |
+
model.log(paths.tts_log, msg)
|
| 176 |
+
print(' ')
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def create_gta_features(model: Tacotron, train_set, save_path: Path):
|
| 180 |
+
device = next(model.parameters()).device # use same device as model parameters
|
| 181 |
+
|
| 182 |
+
iters = len(train_set)
|
| 183 |
+
|
| 184 |
+
for i, (x, mels, ids, mel_lens) in enumerate(train_set, 1):
|
| 185 |
+
|
| 186 |
+
x, mels = x.to(device), mels.to(device)
|
| 187 |
+
|
| 188 |
+
with torch.no_grad(): _, gta, _ = model(x, mels)
|
| 189 |
+
|
| 190 |
+
gta = gta.cpu().numpy()
|
| 191 |
+
|
| 192 |
+
for j, item_id in enumerate(ids):
|
| 193 |
+
mel = gta[j][:, :mel_lens[j]]
|
| 194 |
+
mel = (mel + 4) / 8
|
| 195 |
+
np.save(save_path/f'{item_id}.npy', mel, allow_pickle=False)
|
| 196 |
+
|
| 197 |
+
bar = progbar(i, iters)
|
| 198 |
+
msg = f'{bar} {i}/{iters} Batches '
|
| 199 |
+
stream(msg)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
main()
|
models/WaveRNNModel/train_wavernn.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torch import optim
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from utils.display import stream, simple_table
|
| 7 |
+
from utils.dataset import get_vocoder_datasets
|
| 8 |
+
from utils.distribution import discretized_mix_logistic_loss
|
| 9 |
+
from utils import hparams as hp
|
| 10 |
+
from models.fatchord_version import WaveRNN
|
| 11 |
+
from gen_wavernn import gen_testset
|
| 12 |
+
from utils.paths import Paths
|
| 13 |
+
import argparse
|
| 14 |
+
from utils import data_parallel_workaround
|
| 15 |
+
from utils.checkpoints import save_checkpoint, restore_checkpoint
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main():
|
| 19 |
+
|
| 20 |
+
# Parse Arguments
|
| 21 |
+
parser = argparse.ArgumentParser(description='Train WaveRNN Vocoder')
|
| 22 |
+
parser.add_argument('--lr', '-l', type=float, help='[float] override hparams.py learning rate')
|
| 23 |
+
parser.add_argument('--batch_size', '-b', type=int, help='[int] override hparams.py batch size')
|
| 24 |
+
parser.add_argument('--force_train', '-f', action='store_true', help='Forces the model to train past total steps')
|
| 25 |
+
parser.add_argument('--gta', '-g', action='store_true', help='train wavernn on GTA features')
|
| 26 |
+
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
|
| 27 |
+
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
|
| 28 |
+
args = parser.parse_args()
|
| 29 |
+
|
| 30 |
+
hp.configure(args.hp_file) # load hparams from file
|
| 31 |
+
if args.lr is None:
|
| 32 |
+
args.lr = hp.voc_lr
|
| 33 |
+
if args.batch_size is None:
|
| 34 |
+
args.batch_size = hp.voc_batch_size
|
| 35 |
+
|
| 36 |
+
paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
|
| 37 |
+
|
| 38 |
+
batch_size = args.batch_size
|
| 39 |
+
force_train = args.force_train
|
| 40 |
+
train_gta = args.gta
|
| 41 |
+
lr = args.lr
|
| 42 |
+
|
| 43 |
+
if not args.force_cpu and torch.cuda.is_available():
|
| 44 |
+
device = torch.device('cuda')
|
| 45 |
+
if batch_size % torch.cuda.device_count() != 0:
|
| 46 |
+
raise ValueError('`batch_size` must be evenly divisible by n_gpus!')
|
| 47 |
+
else:
|
| 48 |
+
device = torch.device('cpu')
|
| 49 |
+
print('Using device:', device)
|
| 50 |
+
|
| 51 |
+
print('\nInitialising Model...\n')
|
| 52 |
+
|
| 53 |
+
# Instantiate WaveRNN Model
|
| 54 |
+
voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
|
| 55 |
+
fc_dims=hp.voc_fc_dims,
|
| 56 |
+
bits=hp.bits,
|
| 57 |
+
pad=hp.voc_pad,
|
| 58 |
+
upsample_factors=hp.voc_upsample_factors,
|
| 59 |
+
feat_dims=hp.num_mels,
|
| 60 |
+
compute_dims=hp.voc_compute_dims,
|
| 61 |
+
res_out_dims=hp.voc_res_out_dims,
|
| 62 |
+
res_blocks=hp.voc_res_blocks,
|
| 63 |
+
hop_length=hp.hop_length,
|
| 64 |
+
sample_rate=hp.sample_rate,
|
| 65 |
+
mode=hp.voc_mode).to(device)
|
| 66 |
+
|
| 67 |
+
# Check to make sure the hop length is correctly factorised
|
| 68 |
+
assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length
|
| 69 |
+
|
| 70 |
+
optimizer = optim.Adam(voc_model.parameters())
|
| 71 |
+
restore_checkpoint('voc', paths, voc_model, optimizer, create_if_missing=True)
|
| 72 |
+
|
| 73 |
+
train_set, test_set = get_vocoder_datasets(paths.data, batch_size, train_gta)
|
| 74 |
+
|
| 75 |
+
total_steps = 10_000_000 if force_train else hp.voc_total_steps
|
| 76 |
+
|
| 77 |
+
simple_table([('Remaining', str((total_steps - voc_model.get_step())//1000) + 'k Steps'),
|
| 78 |
+
('Batch Size', batch_size),
|
| 79 |
+
('LR', lr),
|
| 80 |
+
('Sequence Len', hp.voc_seq_len),
|
| 81 |
+
('GTA Train', train_gta)])
|
| 82 |
+
|
| 83 |
+
loss_func = F.cross_entropy if voc_model.mode == 'RAW' else discretized_mix_logistic_loss
|
| 84 |
+
#print("test5555555555")
|
| 85 |
+
voc_train_loop(paths, voc_model, loss_func, optimizer, train_set, test_set, lr, total_steps)
|
| 86 |
+
|
| 87 |
+
print('Training Complete.')
|
| 88 |
+
print('To continue training increase voc_total_steps in hparams.py or use --force_train')
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer, train_set, test_set, lr, total_steps):
|
| 92 |
+
# Use same device as model parameters
|
| 93 |
+
device = next(model.parameters()).device
|
| 94 |
+
|
| 95 |
+
for g in optimizer.param_groups: g['lr'] = lr
|
| 96 |
+
|
| 97 |
+
total_iters = len(train_set)
|
| 98 |
+
print("total iters test:",len(train_set))
|
| 99 |
+
epochs = (total_steps - model.get_step()) // total_iters + 1
|
| 100 |
+
|
| 101 |
+
for e in range(1, epochs + 1):
|
| 102 |
+
|
| 103 |
+
start = time.time()
|
| 104 |
+
running_loss = 0.
|
| 105 |
+
#print("test666666666")
|
| 106 |
+
for i, (x, y, m) in enumerate(train_set, 1):
|
| 107 |
+
#print("test44444444444")
|
| 108 |
+
x, m, y = x.to(device), m.to(device), y.to(device)
|
| 109 |
+
|
| 110 |
+
# Parallelize model onto GPUS using workaround due to python bug
|
| 111 |
+
if device.type == 'cuda' and torch.cuda.device_count() > 1:
|
| 112 |
+
y_hat = data_parallel_workaround(model, x, m)
|
| 113 |
+
else:
|
| 114 |
+
y_hat = model(x, m)
|
| 115 |
+
|
| 116 |
+
if model.mode == 'RAW':
|
| 117 |
+
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
| 118 |
+
|
| 119 |
+
elif model.mode == 'MOL':
|
| 120 |
+
y = y.float()
|
| 121 |
+
|
| 122 |
+
y = y.unsqueeze(-1)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
loss = loss_func(y_hat, y)
|
| 126 |
+
|
| 127 |
+
optimizer.zero_grad()
|
| 128 |
+
loss.backward()
|
| 129 |
+
#print("test111111111111111111")
|
| 130 |
+
if hp.voc_clip_grad_norm is not None:
|
| 131 |
+
#print("test333333333333")
|
| 132 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.voc_clip_grad_norm)
|
| 133 |
+
if torch.isnan(grad_norm):
|
| 134 |
+
print('grad_norm was NaN!')
|
| 135 |
+
optimizer.step()
|
| 136 |
+
|
| 137 |
+
running_loss += loss.item()
|
| 138 |
+
avg_loss = running_loss / i
|
| 139 |
+
|
| 140 |
+
speed = i / (time.time() - start)
|
| 141 |
+
|
| 142 |
+
step = model.get_step()
|
| 143 |
+
k = step // 1000
|
| 144 |
+
|
| 145 |
+
if step % hp.voc_checkpoint_every == 0:
|
| 146 |
+
#print("test22222222222222222")
|
| 147 |
+
gen_testset(model, test_set, hp.voc_gen_at_checkpoint, hp.voc_gen_batched,
|
| 148 |
+
hp.voc_target, hp.voc_overlap, paths.voc_output)
|
| 149 |
+
ckpt_name = f'wave_step{k}K'
|
| 150 |
+
save_checkpoint('voc', paths, model, optimizer,
|
| 151 |
+
name=ckpt_name, is_silent=True)
|
| 152 |
+
|
| 153 |
+
msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | '
|
| 154 |
+
stream(msg)
|
| 155 |
+
|
| 156 |
+
# Must save latest optimizer state to ensure that resuming training
|
| 157 |
+
# doesn't produce artifacts
|
| 158 |
+
save_checkpoint('voc', paths, model, optimizer, is_silent=True)
|
| 159 |
+
model.log(paths.voc_log, msg)
|
| 160 |
+
print(' ')
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
main()
|
models/WaveRNNModel/utils/__init__.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Make it explicit that we do it the Python 3 way
|
| 2 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
| 3 |
+
from builtins import *
|
| 4 |
+
import sys
|
| 5 |
+
import torch
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
from importlib.util import spec_from_file_location, module_from_spec
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Union
|
| 11 |
+
|
| 12 |
+
# Credit: Ryuichi Yamamoto (https://github.com/r9y9/wavenet_vocoder/blob/1717f145c8f8c0f3f85ccdf346b5209fa2e1c920/train.py#L599)
|
| 13 |
+
# Modified by: Ryan Butler (https://github.com/TheButlah)
|
| 14 |
+
# workaround for https://github.com/pytorch/pytorch/issues/15716
|
| 15 |
+
# the idea is to return outputs and replicas explicitly, so that making pytorch
|
| 16 |
+
# not to release the nodes (this is a pytorch bug though)
|
| 17 |
+
|
| 18 |
+
_output_ref = None
|
| 19 |
+
_replicas_ref = None
|
| 20 |
+
|
| 21 |
+
def data_parallel_workaround(model, *input):
|
| 22 |
+
global _output_ref
|
| 23 |
+
global _replicas_ref
|
| 24 |
+
device_ids = list(range(torch.cuda.device_count()))
|
| 25 |
+
output_device = device_ids[0]
|
| 26 |
+
replicas = torch.nn.parallel.replicate(model, device_ids)
|
| 27 |
+
# input.shape = (num_args, batch, ...)
|
| 28 |
+
inputs = torch.nn.parallel.scatter(input, device_ids)
|
| 29 |
+
# inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
|
| 30 |
+
replicas = replicas[:len(inputs)]
|
| 31 |
+
outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
|
| 32 |
+
y_hat = torch.nn.parallel.gather(outputs, output_device)
|
| 33 |
+
_output_ref = outputs
|
| 34 |
+
_replicas_ref = replicas
|
| 35 |
+
return y_hat
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
###### Deal with hparams import that has to be configured at runtime ######
|
| 39 |
+
class __HParams:
|
| 40 |
+
"""Manages the hyperparams pseudo-module"""
|
| 41 |
+
def __init__(self, path: Union[str, Path]=None):
|
| 42 |
+
"""Constructs the hyperparameters from a path to a python module. If
|
| 43 |
+
`path` is None, will raise an AttributeError whenever its attributes
|
| 44 |
+
are accessed. Otherwise, configures self based on `path`."""
|
| 45 |
+
if path is None:
|
| 46 |
+
print("path is none")
|
| 47 |
+
self._configured = False
|
| 48 |
+
else:
|
| 49 |
+
self.configure(path)
|
| 50 |
+
|
| 51 |
+
def __getattr__(self, item):
|
| 52 |
+
print("self config2222:",self.is_configured())
|
| 53 |
+
if not self.is_configured():
|
| 54 |
+
raise AttributeError("HParams not configured yet. Call self.configure()")
|
| 55 |
+
else:
|
| 56 |
+
return super().__getattr__(item)
|
| 57 |
+
|
| 58 |
+
def configure(self, path: Union[str, Path]):
|
| 59 |
+
"""Configures hparams by copying over atrributes from a module with the
|
| 60 |
+
given path. Raises an exception if already configured."""
|
| 61 |
+
if self.is_configured():
|
| 62 |
+
raise RuntimeError("Cannot reconfigure hparams!")
|
| 63 |
+
print("path=",path)
|
| 64 |
+
###### Check for proper path ######
|
| 65 |
+
if not isinstance(path, Path):
|
| 66 |
+
path = Path(path).expanduser()
|
| 67 |
+
if not path.exists():
|
| 68 |
+
raise FileNotFoundError(f"Could not find hparams file {path}")
|
| 69 |
+
elif path.suffix != ".py":
|
| 70 |
+
raise ValueError("`path` must be a python file")
|
| 71 |
+
|
| 72 |
+
###### Load in attributes from module ######
|
| 73 |
+
m = _import_from_file("hparams", path)
|
| 74 |
+
|
| 75 |
+
reg = re.compile(r"^__.+__$") # Matches magic methods
|
| 76 |
+
for name, value in m.__dict__.items():
|
| 77 |
+
if reg.match(name):
|
| 78 |
+
# Skip builtins
|
| 79 |
+
continue
|
| 80 |
+
if name in self.__dict__:
|
| 81 |
+
# Cannot overwrite already existing attributes
|
| 82 |
+
raise AttributeError(
|
| 83 |
+
f"module at `path` cannot contain attribute {name} as it "
|
| 84 |
+
"overwrites an attribute of the same name in utils.hparams")
|
| 85 |
+
# Fair game to copy over the attribute
|
| 86 |
+
self.__setattr__(name, value)
|
| 87 |
+
|
| 88 |
+
self._configured = True
|
| 89 |
+
print("self config1111:",self._configured)
|
| 90 |
+
|
| 91 |
+
def is_configured(self):
|
| 92 |
+
return self._configured
|
| 93 |
+
|
| 94 |
+
hparams = __HParams()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _import_from_file(name, path: Path):
|
| 98 |
+
"""Programmatically returns a module object from a filepath"""
|
| 99 |
+
if not Path(path).exists():
|
| 100 |
+
raise FileNotFoundError('"%s" doesn\'t exist!' % path)
|
| 101 |
+
spec = spec_from_file_location(name, path)
|
| 102 |
+
if spec is None:
|
| 103 |
+
raise ValueError('could not load module from "%s"' % path)
|
| 104 |
+
m = module_from_spec(spec)
|
| 105 |
+
spec.loader.exec_module(m)
|
| 106 |
+
return m
|
models/WaveRNNModel/utils/checkpoints.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from utils.paths import Paths
|
| 3 |
+
from models.tacotron import Tacotron
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_checkpoint_paths(checkpoint_type: str, paths: Paths):
|
| 7 |
+
"""
|
| 8 |
+
Returns the correct checkpointing paths
|
| 9 |
+
depending on whether model is Vocoder or TTS
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
checkpoint_type: Either 'voc' or 'tts'
|
| 13 |
+
paths: Paths object
|
| 14 |
+
"""
|
| 15 |
+
if checkpoint_type is 'tts':
|
| 16 |
+
weights_path = paths.tts_latest_weights
|
| 17 |
+
optim_path = paths.tts_latest_optim
|
| 18 |
+
checkpoint_path = paths.tts_checkpoints
|
| 19 |
+
elif checkpoint_type is 'voc':
|
| 20 |
+
weights_path = paths.voc_latest_weights
|
| 21 |
+
optim_path = paths.voc_latest_optim
|
| 22 |
+
checkpoint_path = paths.voc_checkpoints
|
| 23 |
+
else:
|
| 24 |
+
raise NotImplementedError
|
| 25 |
+
|
| 26 |
+
return weights_path, optim_path, checkpoint_path
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def save_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,
|
| 30 |
+
name=None, is_silent=False):
|
| 31 |
+
"""Saves the training session to disk.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
paths: Provides information about the different paths to use.
|
| 35 |
+
model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from.
|
| 36 |
+
optimizer: An optmizer to save the state of (momentum, etc).
|
| 37 |
+
name: If provided, will name to a checkpoint with the given name. Note
|
| 38 |
+
that regardless of whether this is provided or not, this function
|
| 39 |
+
will always update the files specified in `paths` that give the
|
| 40 |
+
location of the latest weights and optimizer state. Saving
|
| 41 |
+
a named checkpoint happens in addition to this update.
|
| 42 |
+
"""
|
| 43 |
+
def helper(path_dict, is_named):
|
| 44 |
+
s = 'named' if is_named else 'latest'
|
| 45 |
+
num_exist = sum(p.exists() for p in path_dict.values())
|
| 46 |
+
|
| 47 |
+
if num_exist not in (0,2):
|
| 48 |
+
# Checkpoint broken
|
| 49 |
+
raise FileNotFoundError(
|
| 50 |
+
f'We expected either both or no files in the {s} checkpoint to '
|
| 51 |
+
'exist, but instead we got exactly one!')
|
| 52 |
+
|
| 53 |
+
if num_exist == 0:
|
| 54 |
+
if not is_silent: print(f'Creating {s} checkpoint...')
|
| 55 |
+
for p in path_dict.values():
|
| 56 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 57 |
+
else:
|
| 58 |
+
if not is_silent: print(f'Saving to existing {s} checkpoint...')
|
| 59 |
+
|
| 60 |
+
if not is_silent: print(f'Saving {s} weights: {path_dict["w"]}')
|
| 61 |
+
model.save(path_dict['w'])
|
| 62 |
+
if not is_silent: print(f'Saving {s} optimizer state: {path_dict["o"]}')
|
| 63 |
+
torch.save(optimizer.state_dict(), path_dict['o'])
|
| 64 |
+
|
| 65 |
+
weights_path, optim_path, checkpoint_path = \
|
| 66 |
+
get_checkpoint_paths(checkpoint_type, paths)
|
| 67 |
+
|
| 68 |
+
latest_paths = {'w': weights_path, 'o': optim_path}
|
| 69 |
+
helper(latest_paths, False)
|
| 70 |
+
|
| 71 |
+
if name:
|
| 72 |
+
named_paths = {
|
| 73 |
+
'w': checkpoint_path/f'{name}_weights.pyt',
|
| 74 |
+
'o': checkpoint_path/f'{name}_optim.pyt',
|
| 75 |
+
}
|
| 76 |
+
helper(named_paths, True)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def restore_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,
|
| 80 |
+
name=None, create_if_missing=False):
|
| 81 |
+
"""Restores from a training session saved to disk.
|
| 82 |
+
|
| 83 |
+
NOTE: The optimizer's state is placed on the same device as it's model
|
| 84 |
+
parameters. Therefore, be sure you have done `model.to(device)` before
|
| 85 |
+
calling this method.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
paths: Provides information about the different paths to use.
|
| 89 |
+
model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from.
|
| 90 |
+
optimizer: An optmizer to save the state of (momentum, etc).
|
| 91 |
+
name: If provided, will restore from a checkpoint with the given name.
|
| 92 |
+
Otherwise, will restore from the latest weights and optimizer state
|
| 93 |
+
as specified in `paths`.
|
| 94 |
+
create_if_missing: If `True`, will create the checkpoint if it doesn't
|
| 95 |
+
yet exist, as well as update the files specified in `paths` that
|
| 96 |
+
give the location of the current latest weights and optimizer state.
|
| 97 |
+
If `False` and the checkpoint doesn't exist, will raise a
|
| 98 |
+
`FileNotFoundError`.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
weights_path, optim_path, checkpoint_path = \
|
| 102 |
+
get_checkpoint_paths(checkpoint_type, paths)
|
| 103 |
+
|
| 104 |
+
if name:
|
| 105 |
+
path_dict = {
|
| 106 |
+
'w': checkpoint_path/f'{name}_weights.pyt',
|
| 107 |
+
'o': checkpoint_path/f'{name}_optim.pyt',
|
| 108 |
+
}
|
| 109 |
+
s = 'named'
|
| 110 |
+
else:
|
| 111 |
+
path_dict = {
|
| 112 |
+
'w': weights_path,
|
| 113 |
+
'o': optim_path
|
| 114 |
+
}
|
| 115 |
+
s = 'latest'
|
| 116 |
+
|
| 117 |
+
num_exist = sum(p.exists() for p in path_dict.values())
|
| 118 |
+
if num_exist == 2:
|
| 119 |
+
# Checkpoint exists
|
| 120 |
+
print(f'Restoring from {s} checkpoint...')
|
| 121 |
+
print(f'Loading {s} weights: {path_dict["w"]}')
|
| 122 |
+
model.load(path_dict['w'])
|
| 123 |
+
print(f'Loading {s} optimizer state: {path_dict["o"]}')
|
| 124 |
+
optimizer.load_state_dict(torch.load(path_dict['o']))
|
| 125 |
+
elif create_if_missing:
|
| 126 |
+
save_checkpoint(checkpoint_type, paths, model, optimizer, name=name, is_silent=False)
|
| 127 |
+
else:
|
| 128 |
+
raise FileNotFoundError(f'The {s} checkpoint could not be found!')
|
models/WaveRNNModel/utils/dataset.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
from torch.utils.data.sampler import Sampler
|
| 6 |
+
from utils.dsp import *
|
| 7 |
+
from utils import hparams as hp
|
| 8 |
+
from utils.text import text_to_sequence
|
| 9 |
+
from utils.paths import Paths
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from functools import partial
|
| 13 |
+
|
| 14 |
+
###################################################################################
|
| 15 |
+
# WaveRNN/Vocoder Dataset #########################################################
|
| 16 |
+
###################################################################################
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class VocoderDataset(Dataset):
|
| 20 |
+
def __init__(self, path: Path, dataset_ids, train_gta=False):
|
| 21 |
+
self.metadata = dataset_ids
|
| 22 |
+
self.mel_path = path/'gta' if train_gta else path/'mel'
|
| 23 |
+
self.quant_path = path/'quant'
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def __getitem__(self, index):
|
| 27 |
+
item_id = self.metadata[index]
|
| 28 |
+
m = np.load(self.mel_path/f'{item_id}.npy')
|
| 29 |
+
x = np.load(self.quant_path/f'{item_id}.npy')
|
| 30 |
+
return m, x
|
| 31 |
+
|
| 32 |
+
def __len__(self):
|
| 33 |
+
return len(self.metadata)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_vocoder_datasets(path: Path, batch_size, train_gta):
|
| 37 |
+
|
| 38 |
+
with open(path/'dataset.pkl', 'rb') as f:
|
| 39 |
+
dataset = pickle.load(f)
|
| 40 |
+
|
| 41 |
+
dataset_ids = [x[0] for x in dataset]
|
| 42 |
+
|
| 43 |
+
random.seed(1234)
|
| 44 |
+
random.shuffle(dataset_ids)
|
| 45 |
+
|
| 46 |
+
test_ids = dataset_ids[-hp.voc_test_samples:]
|
| 47 |
+
train_ids = dataset_ids[:-hp.voc_test_samples]
|
| 48 |
+
|
| 49 |
+
train_dataset = VocoderDataset(path, train_ids, train_gta)
|
| 50 |
+
test_dataset = VocoderDataset(path, test_ids, train_gta)
|
| 51 |
+
|
| 52 |
+
train_set = DataLoader(train_dataset,
|
| 53 |
+
collate_fn=collate_vocoder,
|
| 54 |
+
batch_size=batch_size,
|
| 55 |
+
num_workers=2,
|
| 56 |
+
shuffle=True,
|
| 57 |
+
pin_memory=True)
|
| 58 |
+
|
| 59 |
+
test_set = DataLoader(test_dataset,
|
| 60 |
+
batch_size=1,
|
| 61 |
+
num_workers=1,
|
| 62 |
+
shuffle=False,
|
| 63 |
+
pin_memory=True)
|
| 64 |
+
|
| 65 |
+
return train_set, test_set
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def collate_vocoder(batch):
|
| 69 |
+
if not hp.is_configured():
|
| 70 |
+
print("未配置参数")
|
| 71 |
+
hp.configure("E:\\智能语音处理系统\\Noise-suppression-and-speech-recognition-systems-master\\WaveRNNModel\\hparams.py")
|
| 72 |
+
mel_win = hp.voc_seq_len // hp.hop_length + 2 * hp.voc_pad
|
| 73 |
+
max_offsets = [x[0].shape[-1] -2 - (mel_win + 2 * hp.voc_pad) for x in batch]
|
| 74 |
+
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
|
| 75 |
+
sig_offsets = [(offset + hp.voc_pad) * hp.hop_length for offset in mel_offsets]
|
| 76 |
+
|
| 77 |
+
mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] for i, x in enumerate(batch)]
|
| 78 |
+
|
| 79 |
+
labels = [x[1][sig_offsets[i]:sig_offsets[i] + hp.voc_seq_len + 1] for i, x in enumerate(batch)]
|
| 80 |
+
|
| 81 |
+
mels = np.stack(mels).astype(np.float32)
|
| 82 |
+
labels = np.stack(labels).astype(np.int64)
|
| 83 |
+
|
| 84 |
+
mels = torch.tensor(mels)
|
| 85 |
+
labels = torch.tensor(labels).long()
|
| 86 |
+
|
| 87 |
+
x = labels[:, :hp.voc_seq_len]
|
| 88 |
+
y = labels[:, 1:]
|
| 89 |
+
|
| 90 |
+
bits = 16 if hp.voc_mode == 'MOL' else hp.bits
|
| 91 |
+
|
| 92 |
+
x = label_2_float(x.float(), bits)
|
| 93 |
+
|
| 94 |
+
if hp.voc_mode == 'MOL':
|
| 95 |
+
y = label_2_float(y.float(), bits)
|
| 96 |
+
|
| 97 |
+
return x, y, mels
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
###################################################################################
|
| 101 |
+
# Tacotron/TTS Dataset ############################################################
|
| 102 |
+
###################################################################################
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_tts_datasets(path: Path, batch_size, r):
|
| 106 |
+
print("path",path)
|
| 107 |
+
with open(path/'dataset.pkl', 'rb') as f:
|
| 108 |
+
dataset = pickle.load(f)
|
| 109 |
+
|
| 110 |
+
dataset_ids = []
|
| 111 |
+
mel_lengths = []
|
| 112 |
+
print("hp.tts_max_mel_len",hp.tts_max_mel_len)
|
| 113 |
+
for (item_id, len) in dataset:
|
| 114 |
+
if len <= hp.tts_max_mel_len:
|
| 115 |
+
dataset_ids += [item_id]
|
| 116 |
+
mel_lengths += [len]
|
| 117 |
+
|
| 118 |
+
with open(path/'text_dict.pkl', 'rb') as f:
|
| 119 |
+
text_dict = pickle.load(f)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
train_dataset = TTSDataset(path, dataset_ids, text_dict)
|
| 123 |
+
|
| 124 |
+
sampler = None
|
| 125 |
+
|
| 126 |
+
if hp.tts_bin_lengths:
|
| 127 |
+
sampler = BinnedLengthSampler(mel_lengths, batch_size, batch_size * 3)
|
| 128 |
+
|
| 129 |
+
train_set = DataLoader(train_dataset,
|
| 130 |
+
collate_fn=partial(collate_tts, r=r),
|
| 131 |
+
batch_size=batch_size,
|
| 132 |
+
sampler=sampler,
|
| 133 |
+
num_workers=1,
|
| 134 |
+
pin_memory=True)
|
| 135 |
+
|
| 136 |
+
longest = mel_lengths.index(max(mel_lengths))
|
| 137 |
+
|
| 138 |
+
# Used to evaluate attention during training process
|
| 139 |
+
attn_example = dataset_ids[longest]
|
| 140 |
+
|
| 141 |
+
# print(attn_example)
|
| 142 |
+
|
| 143 |
+
return train_set, attn_example
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class TTSDataset(Dataset):
|
| 147 |
+
def __init__(self, path: Path, dataset_ids, text_dict):
|
| 148 |
+
self.path = path
|
| 149 |
+
self.metadata = dataset_ids
|
| 150 |
+
self.text_dict = text_dict
|
| 151 |
+
|
| 152 |
+
def __getitem__(self, index):
|
| 153 |
+
item_id = self.metadata[index]
|
| 154 |
+
#print("path555555",self.path)
|
| 155 |
+
if not hp.is_configured():
|
| 156 |
+
print("未配置参数")
|
| 157 |
+
hp.configure("E:\\智能语音处理系统\\Noise-suppression-and-speech-recognition-systems-master\\WaveRNNModel\\hparams.py")
|
| 158 |
+
#print("test666666",hp.tts_cleaner_names)
|
| 159 |
+
x = text_to_sequence(self.text_dict[item_id], hp.tts_cleaner_names)
|
| 160 |
+
mel = np.load(self.path/'mel'/f'{item_id}.npy')
|
| 161 |
+
mel_len = mel.shape[-1]
|
| 162 |
+
return x, mel, item_id, mel_len
|
| 163 |
+
|
| 164 |
+
def __len__(self):
|
| 165 |
+
return len(self.metadata)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def pad1d(x, max_len):
|
| 169 |
+
return np.pad(x, (0, max_len - len(x)), mode='constant')
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def pad2d(x, max_len):
|
| 173 |
+
return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode='constant')
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def collate_tts(batch, r):
|
| 177 |
+
|
| 178 |
+
x_lens = [len(x[0]) for x in batch]
|
| 179 |
+
max_x_len = max(x_lens)
|
| 180 |
+
|
| 181 |
+
chars = [pad1d(x[0], max_x_len) for x in batch]
|
| 182 |
+
chars = np.stack(chars)
|
| 183 |
+
|
| 184 |
+
spec_lens = [x[1].shape[-1] for x in batch]
|
| 185 |
+
max_spec_len = max(spec_lens) + 1
|
| 186 |
+
if max_spec_len % r != 0:
|
| 187 |
+
max_spec_len += r - max_spec_len % r
|
| 188 |
+
|
| 189 |
+
mel = [pad2d(x[1], max_spec_len) for x in batch]
|
| 190 |
+
mel = np.stack(mel)
|
| 191 |
+
|
| 192 |
+
ids = [x[2] for x in batch]
|
| 193 |
+
mel_lens = [x[3] for x in batch]
|
| 194 |
+
|
| 195 |
+
chars = torch.tensor(chars).long()
|
| 196 |
+
mel = torch.tensor(mel)
|
| 197 |
+
|
| 198 |
+
# scale spectrograms to -4 <--> 4
|
| 199 |
+
mel = (mel * 8.) - 4.
|
| 200 |
+
return chars, mel, ids, mel_lens
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class BinnedLengthSampler(Sampler):
|
| 204 |
+
def __init__(self, lengths, batch_size, bin_size):
|
| 205 |
+
_, self.idx = torch.sort(torch.tensor(lengths).long())
|
| 206 |
+
self.batch_size = batch_size
|
| 207 |
+
self.bin_size = bin_size
|
| 208 |
+
assert self.bin_size % self.batch_size == 0
|
| 209 |
+
|
| 210 |
+
def __iter__(self):
|
| 211 |
+
# Need to change to numpy since there's a bug in random.shuffle(tensor)
|
| 212 |
+
# TODO: Post an issue on pytorch repo
|
| 213 |
+
idx = self.idx.numpy()
|
| 214 |
+
bins = []
|
| 215 |
+
|
| 216 |
+
for i in range(len(idx) // self.bin_size):
|
| 217 |
+
this_bin = idx[i * self.bin_size:(i + 1) * self.bin_size]
|
| 218 |
+
random.shuffle(this_bin)
|
| 219 |
+
bins += [this_bin]
|
| 220 |
+
|
| 221 |
+
random.shuffle(bins)
|
| 222 |
+
binned_idx = np.stack(bins).reshape(-1)
|
| 223 |
+
|
| 224 |
+
if len(binned_idx) < len(idx):
|
| 225 |
+
last_bin = idx[len(binned_idx):]
|
| 226 |
+
random.shuffle(last_bin)
|
| 227 |
+
binned_idx = np.concatenate([binned_idx, last_bin])
|
| 228 |
+
|
| 229 |
+
return iter(torch.tensor(binned_idx).long())
|
| 230 |
+
|
| 231 |
+
def __len__(self):
|
| 232 |
+
return len(self.idx)
|
models/WaveRNNModel/utils/display.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib as mpl
|
| 2 |
+
mpl.use('agg') # Use non-interactive backend by default
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import time
|
| 5 |
+
import numpy as np
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def progbar(i, n, size=16):
|
| 10 |
+
done = (i * size) // n
|
| 11 |
+
bar = ''
|
| 12 |
+
for i in range(size):
|
| 13 |
+
bar += '█' if i <= done else '░'
|
| 14 |
+
return bar
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def stream(message):
|
| 18 |
+
sys.stdout.write(f"\r{message}")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def simple_table(item_tuples):
|
| 22 |
+
|
| 23 |
+
border_pattern = '+---------------------------------------'
|
| 24 |
+
whitespace = ' '
|
| 25 |
+
|
| 26 |
+
headings, cells, = [], []
|
| 27 |
+
|
| 28 |
+
for item in item_tuples:
|
| 29 |
+
|
| 30 |
+
heading, cell = str(item[0]), str(item[1])
|
| 31 |
+
|
| 32 |
+
pad_head = True if len(heading) < len(cell) else False
|
| 33 |
+
|
| 34 |
+
pad = abs(len(heading) - len(cell))
|
| 35 |
+
pad = whitespace[:pad]
|
| 36 |
+
|
| 37 |
+
pad_left = pad[:len(pad)//2]
|
| 38 |
+
pad_right = pad[len(pad)//2:]
|
| 39 |
+
|
| 40 |
+
if pad_head:
|
| 41 |
+
heading = pad_left + heading + pad_right
|
| 42 |
+
else:
|
| 43 |
+
cell = pad_left + cell + pad_right
|
| 44 |
+
|
| 45 |
+
headings += [heading]
|
| 46 |
+
cells += [cell]
|
| 47 |
+
|
| 48 |
+
border, head, body = '', '', ''
|
| 49 |
+
|
| 50 |
+
for i in range(len(item_tuples)):
|
| 51 |
+
|
| 52 |
+
temp_head = f'| {headings[i]} '
|
| 53 |
+
temp_body = f'| {cells[i]} '
|
| 54 |
+
|
| 55 |
+
border += border_pattern[:len(temp_head)]
|
| 56 |
+
head += temp_head
|
| 57 |
+
body += temp_body
|
| 58 |
+
|
| 59 |
+
if i == len(item_tuples) - 1:
|
| 60 |
+
head += '|'
|
| 61 |
+
body += '|'
|
| 62 |
+
border += '+'
|
| 63 |
+
|
| 64 |
+
print(border)
|
| 65 |
+
print(head)
|
| 66 |
+
print(border)
|
| 67 |
+
print(body)
|
| 68 |
+
print(border)
|
| 69 |
+
print(' ')
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def time_since(started):
|
| 73 |
+
elapsed = time.time() - started
|
| 74 |
+
m = int(elapsed // 60)
|
| 75 |
+
s = int(elapsed % 60)
|
| 76 |
+
if m >= 60:
|
| 77 |
+
h = int(m // 60)
|
| 78 |
+
m = m % 60
|
| 79 |
+
return f'{h}h {m}m {s}s'
|
| 80 |
+
else:
|
| 81 |
+
return f'{m}m {s}s'
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def save_attention(attn, path):
|
| 85 |
+
fig = plt.figure(figsize=(12, 6))
|
| 86 |
+
plt.imshow(attn.T, interpolation='nearest', aspect='auto')
|
| 87 |
+
fig.savefig(path.parent/f'{path.stem}.png', bbox_inches='tight')
|
| 88 |
+
plt.close(fig)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def save_spectrogram(M, path, length=None):
|
| 92 |
+
M = np.flip(M, axis=0)
|
| 93 |
+
if length: M = M[:, :length]
|
| 94 |
+
fig = plt.figure(figsize=(12, 6))
|
| 95 |
+
plt.imshow(M, interpolation='nearest', aspect='auto')
|
| 96 |
+
fig.savefig(f'{path}.png', bbox_inches='tight')
|
| 97 |
+
plt.close(fig)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def plot(array):
|
| 101 |
+
mpl.interactive(True)
|
| 102 |
+
fig = plt.figure(figsize=(30, 5))
|
| 103 |
+
ax = fig.add_subplot(111)
|
| 104 |
+
ax.xaxis.label.set_color('grey')
|
| 105 |
+
ax.yaxis.label.set_color('grey')
|
| 106 |
+
ax.xaxis.label.set_fontsize(23)
|
| 107 |
+
ax.yaxis.label.set_fontsize(23)
|
| 108 |
+
ax.tick_params(axis='x', colors='grey', labelsize=23)
|
| 109 |
+
ax.tick_params(axis='y', colors='grey', labelsize=23)
|
| 110 |
+
plt.plot(array)
|
| 111 |
+
mpl.interactive(False)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def plot_spec(M):
|
| 115 |
+
mpl.interactive(True)
|
| 116 |
+
M = np.flip(M, axis=0)
|
| 117 |
+
plt.figure(figsize=(18,4))
|
| 118 |
+
plt.imshow(M, interpolation='nearest', aspect='auto')
|
| 119 |
+
plt.show()
|
| 120 |
+
mpl.interactive(False)
|
| 121 |
+
|
models/WaveRNNModel/utils/distribution.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def log_sum_exp(x):
|
| 7 |
+
""" numerically stable log_sum_exp implementation that prevents overflow """
|
| 8 |
+
# TF ordering
|
| 9 |
+
axis = len(x.size()) - 1
|
| 10 |
+
m, _ = torch.max(x, dim=axis)
|
| 11 |
+
m2, _ = torch.max(x, dim=axis, keepdim=True)
|
| 12 |
+
return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
|
| 16 |
+
def discretized_mix_logistic_loss(y_hat, y, num_classes=65536,
|
| 17 |
+
log_scale_min=None, reduce=True):
|
| 18 |
+
if log_scale_min is None:
|
| 19 |
+
log_scale_min = float(np.log(1e-14))
|
| 20 |
+
y_hat = y_hat.permute(0,2,1)
|
| 21 |
+
assert y_hat.dim() == 3
|
| 22 |
+
assert y_hat.size(1) % 3 == 0
|
| 23 |
+
nr_mix = y_hat.size(1) // 3
|
| 24 |
+
|
| 25 |
+
# (B x T x C)
|
| 26 |
+
y_hat = y_hat.transpose(1, 2)
|
| 27 |
+
|
| 28 |
+
# unpack parameters. (B, T, num_mixtures) x 3
|
| 29 |
+
logit_probs = y_hat[:, :, :nr_mix]
|
| 30 |
+
means = y_hat[:, :, nr_mix:2 * nr_mix]
|
| 31 |
+
log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)
|
| 32 |
+
|
| 33 |
+
# B x T x 1 -> B x T x num_mixtures
|
| 34 |
+
y = y.expand_as(means)
|
| 35 |
+
|
| 36 |
+
centered_y = y - means
|
| 37 |
+
inv_stdv = torch.exp(-log_scales)
|
| 38 |
+
plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
|
| 39 |
+
cdf_plus = torch.sigmoid(plus_in)
|
| 40 |
+
min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
|
| 41 |
+
cdf_min = torch.sigmoid(min_in)
|
| 42 |
+
|
| 43 |
+
# log probability for edge case of 0 (before scaling)
|
| 44 |
+
# equivalent: torch.log(F.sigmoid(plus_in))
|
| 45 |
+
log_cdf_plus = plus_in - F.softplus(plus_in)
|
| 46 |
+
|
| 47 |
+
# log probability for edge case of 255 (before scaling)
|
| 48 |
+
# equivalent: (1 - F.sigmoid(min_in)).log()
|
| 49 |
+
log_one_minus_cdf_min = -F.softplus(min_in)
|
| 50 |
+
|
| 51 |
+
# probability for all other cases
|
| 52 |
+
cdf_delta = cdf_plus - cdf_min
|
| 53 |
+
|
| 54 |
+
mid_in = inv_stdv * centered_y
|
| 55 |
+
# log probability in the center of the bin, to be used in extreme cases
|
| 56 |
+
# (not actually used in our code)
|
| 57 |
+
log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
|
| 58 |
+
|
| 59 |
+
# tf equivalent
|
| 60 |
+
"""
|
| 61 |
+
log_probs = tf.where(x < -0.999, log_cdf_plus,
|
| 62 |
+
tf.where(x > 0.999, log_one_minus_cdf_min,
|
| 63 |
+
tf.where(cdf_delta > 1e-5,
|
| 64 |
+
tf.log(tf.maximum(cdf_delta, 1e-12)),
|
| 65 |
+
log_pdf_mid - np.log(127.5))))
|
| 66 |
+
"""
|
| 67 |
+
# TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
|
| 68 |
+
# for num_classes=65536 case? 1e-7? not sure..
|
| 69 |
+
inner_inner_cond = (cdf_delta > 1e-5).float()
|
| 70 |
+
|
| 71 |
+
inner_inner_out = inner_inner_cond * \
|
| 72 |
+
torch.log(torch.clamp(cdf_delta, min=1e-12)) + \
|
| 73 |
+
(1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
|
| 74 |
+
inner_cond = (y > 0.999).float()
|
| 75 |
+
inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
|
| 76 |
+
cond = (y < -0.999).float()
|
| 77 |
+
log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
|
| 78 |
+
|
| 79 |
+
log_probs = log_probs + F.log_softmax(logit_probs, -1)
|
| 80 |
+
|
| 81 |
+
if reduce:
|
| 82 |
+
return -torch.mean(log_sum_exp(log_probs))
|
| 83 |
+
else:
|
| 84 |
+
return -log_sum_exp(log_probs).unsqueeze(-1)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def sample_from_discretized_mix_logistic(y, log_scale_min=None):
|
| 88 |
+
"""
|
| 89 |
+
Sample from discretized mixture of logistic distributions
|
| 90 |
+
Args:
|
| 91 |
+
y (Tensor): B x C x T
|
| 92 |
+
log_scale_min (float): Log scale minimum value
|
| 93 |
+
Returns:
|
| 94 |
+
Tensor: sample in range of [-1, 1].
|
| 95 |
+
"""
|
| 96 |
+
if log_scale_min is None:
|
| 97 |
+
log_scale_min = float(np.log(1e-14))
|
| 98 |
+
assert y.size(1) % 3 == 0
|
| 99 |
+
nr_mix = y.size(1) // 3
|
| 100 |
+
|
| 101 |
+
# B x T x C
|
| 102 |
+
y = y.transpose(1, 2)
|
| 103 |
+
logit_probs = y[:, :, :nr_mix]
|
| 104 |
+
|
| 105 |
+
# sample mixture indicator from softmax
|
| 106 |
+
temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
|
| 107 |
+
temp = logit_probs.data - torch.log(- torch.log(temp))
|
| 108 |
+
_, argmax = temp.max(dim=-1)
|
| 109 |
+
|
| 110 |
+
# (B, T) -> (B, T, nr_mix)
|
| 111 |
+
one_hot = F.one_hot(argmax, nr_mix).float()
|
| 112 |
+
# select logistic parameters
|
| 113 |
+
means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
|
| 114 |
+
log_scales = torch.clamp(torch.sum(
|
| 115 |
+
y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
|
| 116 |
+
# sample from logistic & clip to interval
|
| 117 |
+
# we don't actually round to the nearest 8bit value when sampling
|
| 118 |
+
u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
|
| 119 |
+
x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
|
| 120 |
+
|
| 121 |
+
x = torch.clamp(torch.clamp(x, min=-1.), max=1.)
|
| 122 |
+
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
'''
|
| 126 |
+
def to_one_hot(tensor, n, fill_with=1.):
|
| 127 |
+
# we perform one hot encore with respect to the last axis
|
| 128 |
+
one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
|
| 129 |
+
if tensor.is_cuda:
|
| 130 |
+
one_hot = one_hot.cuda()
|
| 131 |
+
one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
|
| 132 |
+
return one_hot'''
|