niobures commited on
Commit
2e62044
·
verified ·
1 Parent(s): ca4e3a2

RNNoise (models)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +7 -0
  2. models/WaveRNNModel/.gitattributes +1 -0
  3. models/WaveRNNModel/.gitignore +48 -0
  4. models/WaveRNNModel/LICENSE.txt +21 -0
  5. models/WaveRNNModel/README.md +102 -0
  6. models/WaveRNNModel/__init__.py +13 -0
  7. models/WaveRNNModel/assets/WaveRNN.png +0 -0
  8. models/WaveRNNModel/assets/tacotron_wavernn.png +3 -0
  9. models/WaveRNNModel/assets/training_viz.gif +3 -0
  10. models/WaveRNNModel/assets/wavernn_alt_model_hrz2.png +3 -0
  11. models/WaveRNNModel/data/dataset.pkl +3 -0
  12. models/WaveRNNModel/data/text_dict.pkl +3 -0
  13. models/WaveRNNModel/gen_tacotron.py +178 -0
  14. models/WaveRNNModel/gen_wavernn.py +142 -0
  15. models/WaveRNNModel/hparams.py +101 -0
  16. models/WaveRNNModel/loss_plot.py +70 -0
  17. models/WaveRNNModel/model_outputs/ljspeech_lsa_smooth_attention.tacotron.zip +3 -0
  18. models/WaveRNNModel/model_outputs/ljspeech_mol.wavernn.zip +3 -0
  19. models/WaveRNNModel/models/__init__.py +0 -0
  20. models/WaveRNNModel/models/deepmind_version.py +176 -0
  21. models/WaveRNNModel/models/fatchord_version.py +435 -0
  22. models/WaveRNNModel/models/tacotron.py +469 -0
  23. models/WaveRNNModel/notebooks/NB1 - Fit a Sine Wave.ipynb +0 -0
  24. models/WaveRNNModel/notebooks/NB2 - Fit a Short Sample.ipynb +0 -0
  25. models/WaveRNNModel/notebooks/NB3 - Fit a 30min Sample.ipynb +0 -0
  26. models/WaveRNNModel/notebooks/NB4a - Alternative Model (Preprocessing).ipynb +0 -0
  27. models/WaveRNNModel/notebooks/NB4b - Alternative Model (Training).ipynb +0 -0
  28. models/WaveRNNModel/notebooks/Pruning - Scratchpad.ipynb +0 -0
  29. models/WaveRNNModel/notebooks/__init__.py +0 -0
  30. models/WaveRNNModel/notebooks/models/wavernn.py +172 -0
  31. models/WaveRNNModel/notebooks/outputs/nb1/model_output.wav +0 -0
  32. models/WaveRNNModel/notebooks/outputs/nb2/3k_steps.wav +3 -0
  33. models/WaveRNNModel/notebooks/outputs/nb3/12k_steps.wav +3 -0
  34. models/WaveRNNModel/notebooks/utils/__init__.py +0 -0
  35. models/WaveRNNModel/notebooks/utils/display.py +40 -0
  36. models/WaveRNNModel/notebooks/utils/dsp.py +70 -0
  37. models/WaveRNNModel/preprocess.py +103 -0
  38. models/WaveRNNModel/quick_start.py +122 -0
  39. models/WaveRNNModel/quick_start/tts_weights/latest_weights.pyt +3 -0
  40. models/WaveRNNModel/quick_start/voc_weights/latest_weights.pyt +3 -0
  41. models/WaveRNNModel/requirements.txt +6 -0
  42. models/WaveRNNModel/sentences.txt +6 -0
  43. models/WaveRNNModel/source.txt +1 -0
  44. models/WaveRNNModel/train_tacotron.py +203 -0
  45. models/WaveRNNModel/train_wavernn.py +164 -0
  46. models/WaveRNNModel/utils/__init__.py +106 -0
  47. models/WaveRNNModel/utils/checkpoints.py +128 -0
  48. models/WaveRNNModel/utils/dataset.py +232 -0
  49. models/WaveRNNModel/utils/display.py +121 -0
  50. models/WaveRNNModel/utils/distribution.py +132 -0
.gitattributes CHANGED
@@ -49,3 +49,10 @@ models/ailia-models/code/babble_15dB.wav filter=lfs diff=lfs merge=lfs -text
49
  models/ailia-models/code/denoised.wav filter=lfs diff=lfs merge=lfs -text
50
  models/rnnoise-wrapper/weights_5h_b_500k.hdf5 filter=lfs diff=lfs merge=lfs -text
51
  models/rnnoise-wrapper/weights_5h_ru_500k.hdf5 filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
49
  models/ailia-models/code/denoised.wav filter=lfs diff=lfs merge=lfs -text
50
  models/rnnoise-wrapper/weights_5h_b_500k.hdf5 filter=lfs diff=lfs merge=lfs -text
51
  models/rnnoise-wrapper/weights_5h_ru_500k.hdf5 filter=lfs diff=lfs merge=lfs -text
52
+ models/WaveRNNModel/assets/tacotron_wavernn.png filter=lfs diff=lfs merge=lfs -text
53
+ models/WaveRNNModel/assets/training_viz.gif filter=lfs diff=lfs merge=lfs -text
54
+ models/WaveRNNModel/assets/wavernn_alt_model_hrz2.png filter=lfs diff=lfs merge=lfs -text
55
+ models/WaveRNNModel/notebooks/outputs/nb2/3k_steps.wav filter=lfs diff=lfs merge=lfs -text
56
+ models/WaveRNNModel/notebooks/outputs/nb3/12k_steps.wav filter=lfs diff=lfs merge=lfs -text
57
+ models/WaveRNNModel/quick_start/tts_weights/latest_weights.pyt filter=lfs diff=lfs merge=lfs -text
58
+ models/WaveRNNModel/quick_start/voc_weights/latest_weights.pyt filter=lfs diff=lfs merge=lfs -text
models/WaveRNNModel/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.ipynb linguist-language=Python
models/WaveRNNModel/.gitignore ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IDE files
2
+ .idea
3
+ .vscode
4
+
5
+ # Mac files
6
+ .DS_Store
7
+
8
+ # Environments
9
+ .env
10
+ .venv
11
+ env/
12
+ venv/
13
+ ENV/
14
+ env.bak/
15
+ venv.bak/
16
+
17
+ # Byte-compiled / optimized / DLL files
18
+ __pycache__/
19
+ *.py[cod]
20
+ *$py.class
21
+
22
+ # Distribution / packaging
23
+ .Python
24
+ build/
25
+ develop-eggs/
26
+ dist/
27
+ downloads/
28
+ eggs/
29
+ .eggs/
30
+ lib/
31
+ lib64/
32
+ parts/
33
+ sdist/
34
+ var/
35
+ wheels/
36
+ pip-wheel-metadata/
37
+ share/python-wheels/
38
+ *.egg-info/
39
+ .installed.cfg
40
+ *.egg
41
+ MANIFEST
42
+
43
+ # Installer logs
44
+ pip-log.txt
45
+ pip-delete-this-directory.txt
46
+
47
+ # Jupyter Notebook
48
+ .ipynb_checkpoints
models/WaveRNNModel/LICENSE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 fatchord (https://github.com/fatchord)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
models/WaveRNNModel/README.md ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WaveRNN
2
+
3
+ ##### (Update: Vanilla Tacotron One TTS system just implemented - more coming soon!)
4
+
5
+ ![Tacotron with WaveRNN diagrams](assets/tacotron_wavernn.png)
6
+
7
+ Pytorch implementation of Deepmind's WaveRNN model from [Efficient Neural Audio Synthesis](https://arxiv.org/abs/1802.08435v1)
8
+
9
+ # Installation
10
+
11
+ Ensure you have:
12
+
13
+ * Python >= 3.6
14
+ * [Pytorch 1 with CUDA](https://pytorch.org/)
15
+
16
+ Then install the rest with pip:
17
+
18
+ > pip install -r requirements.txt
19
+
20
+ # How to Use
21
+
22
+ ### Quick Start
23
+
24
+ If you want to use TTS functionality immediately you can simply use:
25
+
26
+ > python quick_start.py
27
+
28
+ This will generate everything in the default sentences.txt file and output to a new 'quick_start' folder where you can playback the wav files and take a look at the attention plots
29
+
30
+ You can also use that script to generate custom tts sentences and/or use '-u' to generate unbatched (better audio quality):
31
+
32
+ > python quick_start.py -u --input_text "What will happen if I run this command?"
33
+
34
+
35
+ ### Training your own Models
36
+ ![Attenion and Mel Training GIF](assets/training_viz.gif)
37
+
38
+ Download the [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) Dataset.
39
+
40
+ Edit **hparams.py**, point **wav_path** to your dataset and run:
41
+
42
+ > python preprocess.py
43
+
44
+ or use preprocess.py --path to point directly to the dataset
45
+ ___
46
+
47
+ Here's my recommendation on what order to run things:
48
+
49
+ 1 - Train Tacotron with:
50
+
51
+ > python train_tacotron.py
52
+
53
+ 2 - You can leave that finish training or at any point you can use:
54
+
55
+ > python train_tacotron.py --force_gta
56
+
57
+ this will force tactron to create a GTA dataset even if it hasn't finish training.
58
+
59
+ 3 - Train WaveRNN with:
60
+
61
+ > python train_wavernn.py --gta
62
+
63
+ NB: You can always just run train_wavernn.py without --gta if you're not interested in TTS.
64
+
65
+ 4 - Generate Sentences with both models using:
66
+
67
+ > python gen_tacotron.py wavernn
68
+
69
+ this will generate default sentences. If you want generate custom sentences you can use
70
+
71
+ > python gen_tacotron.py --input_text "this is whatever you want it to be" wavernn
72
+
73
+ And finally, you can always use --help on any of those scripts to see what options are available :)
74
+
75
+
76
+
77
+ # Samples
78
+
79
+ [Can be found here.](https://fatchord.github.io/model_outputs/)
80
+
81
+ # Pretrained Models
82
+
83
+ Currently there are two pretrained models available in the /pretrained/ folder':
84
+
85
+ Both are trained on LJSpeech
86
+
87
+ * WaveRNN (Mixture of Logistics output) trained to 800k steps
88
+ * Tacotron trained to 180k steps
89
+
90
+ ____
91
+
92
+ ### References
93
+
94
+ * [Efficient Neural Audio Synthesis](https://arxiv.org/abs/1802.08435v1)
95
+ * [Tacotron: Towards End-to-End Speech Synthesis](https://arxiv.org/abs/1703.10135)
96
+ * [Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions](https://arxiv.org/abs/1712.05884)
97
+
98
+ ### Acknowlegements
99
+
100
+ * [https://github.com/keithito/tacotron](https://github.com/keithito/tacotron)
101
+ * [https://github.com/r9y9/wavenet_vocoder](https://github.com/r9y9/wavenet_vocoder)
102
+ * Special thanks to github users [G-Wang](https://github.com/G-Wang), [geneing](https://github.com/geneing) & [erogol](https://github.com/erogol)
models/WaveRNNModel/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from pathlib import Path
4
+
5
+ # 获取当前包的绝对路径(即 WaveRNN_master 的目录)
6
+ package_dir = Path(__file__).resolve().parent
7
+
8
+ # 将该路径加入 sys.path,使其成为模块搜索的根目录
9
+ if str(package_dir) not in sys.path:
10
+ sys.path.insert(0, str(package_dir))
11
+
12
+ # 设置环境变量 PYTHONPATH(可选,增强兼容性)
13
+ os.environ["PYTHONPATH"] = str(package_dir) + os.pathsep + os.environ.get("PYTHONPATH", "")
models/WaveRNNModel/assets/WaveRNN.png ADDED
models/WaveRNNModel/assets/tacotron_wavernn.png ADDED

Git LFS Details

  • SHA256: 1357e893810079e088224833c4677bdf59eb7b3a3ea52e0836ef5ab3d2c73544
  • Pointer size: 131 Bytes
  • Size of remote file: 203 kB
models/WaveRNNModel/assets/training_viz.gif ADDED

Git LFS Details

  • SHA256: 73728a59e80755d3e924b060ec8c88f4d1e91aa0469f7e18240cf869deb88aca
  • Pointer size: 132 Bytes
  • Size of remote file: 8.58 MB
models/WaveRNNModel/assets/wavernn_alt_model_hrz2.png ADDED

Git LFS Details

  • SHA256: 31b5fb4a0b8a5bce580719f699707be9f2b36d7671269b840b975750a3bec827
  • Pointer size: 131 Bytes
  • Size of remote file: 200 kB
models/WaveRNNModel/data/dataset.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f323c9bfdfcd5709ad210f538851303c62f63e4285acaee1af791d7da671d88
3
+ size 234790
models/WaveRNNModel/data/text_dict.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3a9c7752b430df10a503697b4c563a7ce82cbf732ca211dfeed5d0bffae6e5a6
3
+ size 1531658
models/WaveRNNModel/gen_tacotron.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models.fatchord_version import WaveRNN
3
+ from utils import hparams as hp
4
+ from utils.text.symbols import symbols
5
+ from utils.paths import Paths
6
+ from models.tacotron import Tacotron
7
+ import argparse
8
+ from utils.text import text_to_sequence
9
+ from utils.display import save_attention, simple_table
10
+ from utils.dsp import reconstruct_waveform, save_wav
11
+ import numpy as np
12
+
13
+ def gen_tacotron_from_inputtext(args_list=None):
14
+ # Parse Arguments
15
+ parser = argparse.ArgumentParser(description='TTS Generator')
16
+ parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!')
17
+ parser.add_argument('--tts_weights', type=str, help='[string/path] Load in different Tacotron weights')
18
+ parser.add_argument('--save_attention', '-a', dest='save_attn', action='store_true', help='Save Attention Plots')
19
+ parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
20
+ parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
21
+
22
+ parser.set_defaults(input_text=None)
23
+ parser.set_defaults(weights_path=None)
24
+
25
+ # name of subcommand goes to args.vocoder
26
+ subparsers = parser.add_subparsers(required=True, dest='vocoder')
27
+
28
+ wr_parser = subparsers.add_parser('wavernn', aliases=['wr'])
29
+ wr_parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation')
30
+ wr_parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation')
31
+ wr_parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
32
+ wr_parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
33
+ wr_parser.add_argument('--voc_weights', type=str, help='[string/path] Load in different WaveRNN weights')
34
+ wr_parser.set_defaults(batched=None)
35
+
36
+ gl_parser = subparsers.add_parser('griffinlim', aliases=['gl'])
37
+ gl_parser.add_argument('--iters', type=int, default=32, help='[int] number of griffinlim iterations')
38
+
39
+ args = parser.parse_args(args=args_list)
40
+
41
+ if args.vocoder in ['griffinlim', 'gl']:
42
+ args.vocoder = 'griffinlim'
43
+ elif args.vocoder in ['wavernn', 'wr']:
44
+ args.vocoder = 'wavernn'
45
+ else:
46
+ raise argparse.ArgumentError('Must provide a valid vocoder type!')
47
+
48
+ if not hp.is_configured():
49
+ print("args.hp_file:",args.hp_file)
50
+ hp.configure(args.hp_file) # Load hparams from file
51
+ # set defaults for any arguments that depend on hparams
52
+ if args.vocoder == 'wavernn':
53
+ if args.target is None:
54
+ args.target = hp.voc_target
55
+ if args.overlap is None:
56
+ args.overlap = hp.voc_overlap
57
+ if args.batched is None:
58
+ args.batched = hp.voc_gen_batched
59
+
60
+ batched = args.batched
61
+ target = args.target
62
+ overlap = args.overlap
63
+
64
+ input_text = args.input_text
65
+ tts_weights = args.tts_weights
66
+ save_attn = args.save_attn
67
+
68
+ paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
69
+
70
+ if not args.force_cpu and torch.cuda.is_available():
71
+ device = torch.device('cuda')
72
+ else:
73
+ device = torch.device('cpu')
74
+ print('Using device:', device)
75
+
76
+ if args.vocoder == 'wavernn':
77
+ print('\nInitialising WaveRNN Model...\n')
78
+ # Instantiate WaveRNN Model
79
+ voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
80
+ fc_dims=hp.voc_fc_dims,
81
+ bits=hp.bits,
82
+ pad=hp.voc_pad,
83
+ upsample_factors=hp.voc_upsample_factors,
84
+ feat_dims=hp.num_mels,
85
+ compute_dims=hp.voc_compute_dims,
86
+ res_out_dims=hp.voc_res_out_dims,
87
+ res_blocks=hp.voc_res_blocks,
88
+ hop_length=hp.hop_length,
89
+ sample_rate=hp.sample_rate,
90
+ mode=hp.voc_mode).to(device)
91
+
92
+ voc_load_path = args.voc_weights if args.voc_weights else paths.voc_latest_weights
93
+ voc_model.load(voc_load_path)
94
+
95
+ print('\nInitialising Tacotron Model...\n')
96
+
97
+ # Instantiate Tacotron Model
98
+ tts_model = Tacotron(embed_dims=hp.tts_embed_dims,
99
+ num_chars=len(symbols),
100
+ encoder_dims=hp.tts_encoder_dims,
101
+ decoder_dims=hp.tts_decoder_dims,
102
+ n_mels=hp.num_mels,
103
+ fft_bins=hp.num_mels,
104
+ postnet_dims=hp.tts_postnet_dims,
105
+ encoder_K=hp.tts_encoder_K,
106
+ lstm_dims=hp.tts_lstm_dims,
107
+ postnet_K=hp.tts_postnet_K,
108
+ num_highways=hp.tts_num_highways,
109
+ dropout=hp.tts_dropout,
110
+ stop_threshold=hp.tts_stop_threshold).to(device)
111
+
112
+ tts_load_path = tts_weights if tts_weights else paths.tts_latest_weights
113
+ tts_model.load(tts_load_path)
114
+
115
+ if input_text:
116
+ inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
117
+ else:
118
+ with open('sentences.txt') as f:
119
+ inputs = [text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f]
120
+
121
+ if args.vocoder == 'wavernn':
122
+ voc_k = voc_model.get_step() // 1000
123
+ tts_k = tts_model.get_step() // 1000
124
+
125
+ simple_table([('Tacotron', str(tts_k) + 'k'),
126
+ ('r', tts_model.r),
127
+ ('Vocoder Type', 'WaveRNN'),
128
+ ('WaveRNN', str(voc_k) + 'k'),
129
+ ('Generation Mode', 'Batched' if batched else 'Unbatched'),
130
+ ('Target Samples', target if batched else 'N/A'),
131
+ ('Overlap Samples', overlap if batched else 'N/A')])
132
+
133
+ elif args.vocoder == 'griffinlim':
134
+ tts_k = tts_model.get_step() // 1000
135
+ simple_table([('Tacotron', str(tts_k) + 'k'),
136
+ ('r', tts_model.r),
137
+ ('Vocoder Type', 'Griffin-Lim'),
138
+ ('GL Iters', args.iters)])
139
+
140
+ for i, x in enumerate(inputs, 1):
141
+
142
+ print(f'\n| Generating {i}/{len(inputs)}')
143
+ _, m, attention = tts_model.generate(x)
144
+ # Fix mel spectrogram scaling to be from 0 to 1
145
+ m = (m + 4) / 8
146
+ np.clip(m, 0, 1, out=m)
147
+
148
+ if args.vocoder == 'griffinlim':
149
+ v_type = args.vocoder
150
+ elif args.vocoder == 'wavernn' and args.batched:
151
+ v_type = 'wavernn_batched'
152
+ else:
153
+ v_type = 'wavernn_unbatched'
154
+
155
+ if input_text:
156
+ print("path:",paths.tts_output)
157
+ save_path = paths.tts_output/f'__input_{input_text[:10]}_{v_type}_{tts_k}k.wav'
158
+ else:
159
+ print("path:",paths.tts_output)
160
+ save_path = paths.tts_output/f'{i}_{v_type}_{tts_k}k.wav'
161
+
162
+ if save_attn: save_attention(attention, save_path)
163
+
164
+ if args.vocoder == 'wavernn':
165
+ m = torch.tensor(m).unsqueeze(0)
166
+ voc_model.generate(m, save_path, batched, hp.voc_target, hp.voc_overlap, hp.mu_law)
167
+ elif args.vocoder == 'griffinlim':
168
+ wav = reconstruct_waveform(m, n_iter=args.iters)
169
+ save_wav(wav, save_path)
170
+
171
+ print('\n\nDone.\n')
172
+ return save_path
173
+
174
+
175
+
176
+ if __name__ == "__main__":
177
+
178
+ gen_tacotron_from_inputtext()
models/WaveRNNModel/gen_wavernn.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.dataset import get_vocoder_datasets
2
+ from utils.dsp import *
3
+ from models.fatchord_version import WaveRNN
4
+ from utils.paths import Paths
5
+ from utils.display import simple_table
6
+ import torch
7
+ import argparse
8
+ from pathlib import Path
9
+
10
+
11
+ def gen_testset(model: WaveRNN, test_set, samples, batched, target, overlap, save_path: Path):
12
+
13
+ k = model.get_step() // 1000
14
+
15
+ for i, (m, x) in enumerate(test_set, 1):
16
+
17
+ if i > samples: break
18
+
19
+ print('\n| Generating: %i/%i' % (i, samples))
20
+
21
+ x = x[0].numpy()
22
+
23
+ bits = 16 if hp.voc_mode == 'MOL' else hp.bits
24
+
25
+ if hp.mu_law and hp.voc_mode != 'MOL':
26
+ x = decode_mu_law(x, 2**bits, from_labels=True)
27
+ else:
28
+ x = label_2_float(x, bits)
29
+
30
+ save_wav(x, save_path/f'{k}k_steps_{i}_target.wav')
31
+
32
+ batch_str = f'gen_batched_target{target}_overlap{overlap}' if batched else 'gen_NOT_BATCHED'
33
+ save_str = str(save_path/f'{k}k_steps_{i}_{batch_str}.wav')
34
+
35
+ _ = model.generate(m, save_str, batched, target, overlap, hp.mu_law)
36
+
37
+
38
+ def gen_from_file(model: WaveRNN, load_path: Path, save_path: Path, batched, target, overlap):
39
+
40
+ k = model.get_step() // 1000
41
+ file_name = load_path.stem
42
+
43
+ suffix = load_path.suffix
44
+ if suffix == ".wav":
45
+ wav = load_wav(load_path)
46
+ save_wav(wav, save_path/f'__{file_name}__{k}k_steps_target.wav')
47
+ mel = melspectrogram(wav)
48
+ elif suffix == ".npy":
49
+ mel = np.load(load_path)
50
+ if mel.ndim != 2 or mel.shape[0] != hp.num_mels:
51
+ raise ValueError(f'Expected a numpy array shaped (n_mels, n_hops), but got {wav.shape}!')
52
+ _max = np.max(mel)
53
+ _min = np.min(mel)
54
+ if _max >= 1.01 or _min <= -0.01:
55
+ raise ValueError(f'Expected spectrogram range in [0,1] but was instead [{_min}, {_max}]')
56
+ else:
57
+ raise ValueError(f"Expected an extension of .wav or .npy, but got {suffix}!")
58
+
59
+
60
+ mel = torch.tensor(mel).unsqueeze(0)
61
+
62
+ batch_str = f'gen_batched_target{target}_overlap{overlap}' if batched else 'gen_NOT_BATCHED'
63
+ save_str = save_path/f'__{file_name}__{k}k_steps_{batch_str}.wav'
64
+
65
+ _ = model.generate(mel, save_str, batched, target, overlap, hp.mu_law)
66
+
67
+
68
+ if __name__ == "__main__":
69
+
70
+ parser = argparse.ArgumentParser(description='Generate WaveRNN Samples')
71
+ parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation')
72
+ parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slow Unbatched Generation')
73
+ parser.add_argument('--samples', '-s', type=int, help='[int] number of utterances to generate')
74
+ parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
75
+ parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
76
+ parser.add_argument('--file', '-f', type=str, help='[string/path] for testing a wav outside dataset')
77
+ parser.add_argument('--voc_weights', '-w', type=str, help='[string/path] Load in different WaveRNN weights')
78
+ parser.add_argument('--gta', '-g', dest='gta', action='store_true', help='Generate from GTA testset')
79
+ parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
80
+ parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
81
+
82
+ parser.set_defaults(batched=None)
83
+
84
+ args = parser.parse_args()
85
+
86
+ hp.configure(args.hp_file) # Load hparams from file
87
+ # set defaults for any arguments that depend on hparams
88
+ if args.target is None:
89
+ args.target = hp.voc_target
90
+ if args.overlap is None:
91
+ args.overlap = hp.voc_overlap
92
+ if args.batched is None:
93
+ args.batched = hp.voc_gen_batched
94
+ if args.samples is None:
95
+ args.samples = hp.voc_gen_at_checkpoint
96
+
97
+ batched = args.batched
98
+ samples = args.samples
99
+ target = args.target
100
+ overlap = args.overlap
101
+ file = args.file
102
+ gta = args.gta
103
+
104
+ if not args.force_cpu and torch.cuda.is_available():
105
+ device = torch.device('cuda')
106
+ else:
107
+ device = torch.device('cpu')
108
+ print('Using device:', device)
109
+
110
+ print('\nInitialising Model...\n')
111
+
112
+ model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
113
+ fc_dims=hp.voc_fc_dims,
114
+ bits=hp.bits,
115
+ pad=hp.voc_pad,
116
+ upsample_factors=hp.voc_upsample_factors,
117
+ feat_dims=hp.num_mels,
118
+ compute_dims=hp.voc_compute_dims,
119
+ res_out_dims=hp.voc_res_out_dims,
120
+ res_blocks=hp.voc_res_blocks,
121
+ hop_length=hp.hop_length,
122
+ sample_rate=hp.sample_rate,
123
+ mode=hp.voc_mode).to(device)
124
+
125
+ paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
126
+
127
+ voc_weights = args.voc_weights if args.voc_weights else paths.voc_latest_weights
128
+
129
+ model.load(voc_weights)
130
+
131
+ simple_table([('Generation Mode', 'Batched' if batched else 'Unbatched'),
132
+ ('Target Samples', target if batched else 'N/A'),
133
+ ('Overlap Samples', overlap if batched else 'N/A')])
134
+
135
+ if file:
136
+ file = Path(file).expanduser()
137
+ gen_from_file(model, file, paths.voc_output, batched, target, overlap)
138
+ else:
139
+ _, test_set = get_vocoder_datasets(paths.data, 1, gta)
140
+ gen_testset(model, test_set, samples, batched, target, overlap, paths.voc_output)
141
+
142
+ print('\n\nExiting...\n')
models/WaveRNNModel/hparams.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # CONFIG -----------------------------------------------------------------------------------------------------------#
3
+
4
+ # Here are the input and output data paths (Note: you can override wav_path in preprocess.py)
5
+ wav_path = 'E:\\智能语音处理系统\\Noise-suppression-and-speech-recognition-systems-master\\WaveRNNModel\\data\\LJSpeech-1.1\\wavs'
6
+ data_path = 'E:\\智能语音处理系统\\Noise-suppression-and-speech-recognition-systems-master\\WaveRNNModel\\data'
7
+
8
+ # model ids are separate - that way you can use a new tts with an old wavernn and vice versa
9
+ # NB: expect undefined behaviour if models were trained on different DSP settings
10
+ voc_model_id = 'ljspeech_mol'
11
+ tts_model_id = 'ljspeech_lsa_smooth_attention'
12
+
13
+ # set this to True if you are only interested in WaveRNN
14
+ ignore_tts = False
15
+
16
+
17
+ # DSP --------------------------------------------------------------------------------------------------------------#
18
+
19
+ # Settings for all models
20
+ sample_rate = 22050
21
+ n_fft = 2048
22
+ fft_bins = n_fft // 2 + 1
23
+ num_mels = 80
24
+ hop_length = 275 # 12.5ms - in line with Tacotron 2 paper
25
+ win_length = 1100 # 50ms - same reason as above
26
+ fmin = 40
27
+ min_level_db = -100
28
+ ref_level_db = 20
29
+ bits = 9 # bit depth of signal
30
+ mu_law = True # Recommended to suppress noise if using raw bits in hp.voc_mode below
31
+ peak_norm = False # Normalise to the peak of each wav file
32
+
33
+
34
+ # WAVERNN / VOCODER ------------------------------------------------------------------------------------------------#
35
+
36
+
37
+ # Model Hparams
38
+ voc_mode = 'MOL' # either 'RAW' (softmax on raw bits) or 'MOL' (sample from mixture of logistics)
39
+ voc_upsample_factors = (5, 5, 11) # NB - this needs to correctly factorise hop_length
40
+ voc_rnn_dims = 512
41
+ voc_fc_dims = 512
42
+ voc_compute_dims = 128
43
+ voc_res_out_dims = 128
44
+ voc_res_blocks = 10
45
+
46
+ # Training
47
+ voc_batch_size = 32
48
+ voc_lr = 1e-4
49
+ voc_checkpoint_every = 25_000
50
+ voc_gen_at_checkpoint = 5 # number of samples to generate at each checkpoint
51
+ voc_total_steps = 1_000_000 # Total number of training steps
52
+ voc_test_samples = 50 # How many unseen samples to put aside for testing
53
+ voc_pad = 2 # this will pad the input so that the resnet can 'see' wider than input length
54
+ voc_seq_len = hop_length * 5 # must be a multiple of hop_length
55
+ voc_clip_grad_norm = 4 # set to None if no gradient clipping needed
56
+
57
+ # Generating / Synthesizing
58
+ voc_gen_batched = True # very fast (realtime+) single utterance batched generation
59
+ voc_target = 11_000 # target number of samples to be generated in each batch entry
60
+ voc_overlap = 550 # number of samples for crossfading between batches
61
+
62
+
63
+ # TACOTRON/TTS -----------------------------------------------------------------------------------------------------#
64
+
65
+
66
+ # Model Hparams
67
+ tts_embed_dims = 256 # embedding dimension for the graphemes/phoneme inputs
68
+ tts_encoder_dims = 128
69
+ tts_decoder_dims = 256
70
+ tts_postnet_dims = 128
71
+ tts_encoder_K = 16
72
+ tts_lstm_dims = 512
73
+ tts_postnet_K = 8
74
+ tts_num_highways = 4
75
+ tts_dropout = 0.5
76
+ tts_cleaner_names = ['english_cleaners']
77
+ tts_stop_threshold = -3.4 # Value below which audio generation ends.
78
+ # For example, for a range of [-4, 4], this
79
+ # will terminate the sequence at the first
80
+ # frame that has all values < -3.4
81
+
82
+ # Training
83
+
84
+ #tts_schedule = [(7, 1e-3, 10_000, 32), # progressive training schedule
85
+ # (5, 1e-4, 100_000, 32), # (r, lr, step, batch_size)
86
+ # (2, 1e-4, 180_000, 16),
87
+ # (2, 1e-4, 350_000, 8)]
88
+ tts_schedule = [(7, 1e-3, 10_000, 32)] # progressive training schedule
89
+ #(5, 1e-4, 100_000, 64), # (r, lr, step, batch_size)
90
+ #(2, 1e-4, 180_000, 64),
91
+ #(2, 1e-4, 350_000, 64)]
92
+
93
+ tts_max_mel_len = 1250 # if you have a couple of extremely long spectrograms you might want to use this
94
+ tts_bin_lengths = True # bins the spectrogram lengths before sampling in data loader - speeds up training
95
+ tts_clip_grad_norm = 1.0 # clips the gradient norm to prevent explosion - set to None if not needed
96
+ tts_checkpoint_every = 2_000 # checkpoints the model every X steps
97
+ # TODO: tts_phoneme_prob = 0.0 # [0 <-> 1] probability for feeding model phonemes vrs graphemes
98
+
99
+
100
+ # ------------------------------------------------------------------------------------------------------------------#
101
+
models/WaveRNNModel/loss_plot.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import re
3
+ import numpy as np
4
+
5
+ # 设置中文字体
6
+ plt.rcParams['font.sans-serif'] = ['SimHei'] # 简体中文(根据系统调整)
7
+ plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
8
+
9
+ # 从txt文件读取日志数据
10
+ def parse_log_file(file_path):
11
+ epochs = []
12
+ losses = []
13
+
14
+ with open(file_path, 'r', encoding='utf-8') as f:
15
+ for line in f:
16
+ # 使用正则匹配有效行
17
+ match = re.search(
18
+ r'Epoch:\s+(\d+)/*.*Loss:\s+(\d+\.\d+)',
19
+ line.strip()
20
+ )
21
+ if match:
22
+ epoch = int(match.group(1))
23
+ loss = float(match.group(2))
24
+ epochs.append(epoch)
25
+ losses.append(loss)
26
+
27
+ return epochs, losses
28
+
29
+ # 文件路径
30
+ log_file = "E:\\智能语音处理系统\\Noise-suppression-and-speech-recognition-systems-master\\WaveRNNModel\\checkpoints\\ljspeech_lsa_smooth_attention.tacotron\\log_test.txt"
31
+
32
+ # 提取数据
33
+ try:
34
+ epochs_read, losses = parse_log_file(log_file)
35
+ print(epochs_read)
36
+ epochs=np.arange(len(epochs_read))
37
+ print(epochs)
38
+ except FileNotFoundError:
39
+ print(f"错误:文件 {log_file} 不存在,请检查路径!")
40
+ exit()
41
+ except Exception as e:
42
+ print(f"解析文件时出错: {str(e)}")
43
+ exit()
44
+
45
+ # 绘制曲线
46
+ plt.figure(figsize=(10, 6))
47
+ plt.plot(epochs, losses, 'b-', linewidth=2, label='训练损失')
48
+
49
+ # 图表美化
50
+ plt.title('训练损失随轮次变化曲线', fontsize=14)
51
+ plt.xlabel('训练轮次 (Epoch)', fontsize=12)
52
+ plt.ylabel('损失值 (Loss)', fontsize=12)
53
+ #plt.xticks(range(1, len(epochs))) # 强制显示所有epoch刻度
54
+ plt.grid(True, linestyle='--', alpha=0.7)
55
+ plt.legend()
56
+
57
+ # 标注最低损失
58
+ min_loss = min(losses)
59
+ min_idx = losses.index(min_loss)
60
+ plt.annotate(
61
+ f'最低损失: {min_loss:.3f}',
62
+ xy=(epochs[min_idx], min_loss),
63
+ xytext=(epochs[min_idx]-3, min_loss+0.1),
64
+ arrowprops=dict(arrowstyle='->', color='red'),
65
+ fontsize=10,
66
+ color='red'
67
+ )
68
+
69
+ plt.tight_layout()
70
+ plt.show()
models/WaveRNNModel/model_outputs/ljspeech_lsa_smooth_attention.tacotron.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed4c5cc52ae740080b0bcea155133d430f01dcb1a2d0097ff9aaef9ee698886a
3
+ size 45040845
models/WaveRNNModel/model_outputs/ljspeech_mol.wavernn.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78a9cff91b58f6163f4cc9e9e878961829f07c7bf40e71778f0bf5447a4900fc
3
+ size 15610590
models/WaveRNNModel/models/__init__.py ADDED
File without changes
models/WaveRNNModel/models/deepmind_version.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from utils.display import *
5
+ from utils.dsp import *
6
+ import numpy as np
7
+
8
+ class WaveRNN(nn.Module):
9
+ def __init__(self, hidden_size=896, quantisation=256):
10
+ super(WaveRNN, self).__init__()
11
+
12
+ self.hidden_size = hidden_size
13
+ self.split_size = hidden_size // 2
14
+
15
+ # The main matmul
16
+ self.R = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
17
+
18
+ # Output fc layers
19
+ self.O1 = nn.Linear(self.split_size, self.split_size)
20
+ self.O2 = nn.Linear(self.split_size, quantisation)
21
+ self.O3 = nn.Linear(self.split_size, self.split_size)
22
+ self.O4 = nn.Linear(self.split_size, quantisation)
23
+
24
+ # Input fc layers
25
+ self.I_coarse = nn.Linear(2, 3 * self.split_size, bias=False)
26
+ self.I_fine = nn.Linear(3, 3 * self.split_size, bias=False)
27
+
28
+ # biases for the gates
29
+ self.bias_u = nn.Parameter(torch.zeros(self.hidden_size))
30
+ self.bias_r = nn.Parameter(torch.zeros(self.hidden_size))
31
+ self.bias_e = nn.Parameter(torch.zeros(self.hidden_size))
32
+
33
+ # display num params
34
+ self.num_params()
35
+
36
+
37
+ def forward(self, prev_y, prev_hidden, current_coarse):
38
+
39
+ # Main matmul - the projection is split 3 ways
40
+ R_hidden = self.R(prev_hidden)
41
+ R_u, R_r, R_e, = torch.split(R_hidden, self.hidden_size, dim=1)
42
+
43
+ # Project the prev input
44
+ coarse_input_proj = self.I_coarse(prev_y)
45
+ I_coarse_u, I_coarse_r, I_coarse_e = \
46
+ torch.split(coarse_input_proj, self.split_size, dim=1)
47
+
48
+ # Project the prev input and current coarse sample
49
+ fine_input = torch.cat([prev_y, current_coarse], dim=1)
50
+ fine_input_proj = self.I_fine(fine_input)
51
+ I_fine_u, I_fine_r, I_fine_e = \
52
+ torch.split(fine_input_proj, self.split_size, dim=1)
53
+
54
+ # concatenate for the gates
55
+ I_u = torch.cat([I_coarse_u, I_fine_u], dim=1)
56
+ I_r = torch.cat([I_coarse_r, I_fine_r], dim=1)
57
+ I_e = torch.cat([I_coarse_e, I_fine_e], dim=1)
58
+
59
+ # Compute all gates for coarse and fine
60
+ u = F.sigmoid(R_u + I_u + self.bias_u)
61
+ r = F.sigmoid(R_r + I_r + self.bias_r)
62
+ e = F.tanh(r * R_e + I_e + self.bias_e)
63
+ hidden = u * prev_hidden + (1. - u) * e
64
+
65
+ # Split the hidden state
66
+ hidden_coarse, hidden_fine = torch.split(hidden, self.split_size, dim=1)
67
+
68
+ # Compute outputs
69
+ out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
70
+ out_fine = self.O4(F.relu(self.O3(hidden_fine)))
71
+
72
+ return out_coarse, out_fine, hidden
73
+
74
+
75
+ def generate(self, seq_len):
76
+ device = next(self.parameters()).device # use same device as parameters
77
+
78
+ with torch.no_grad():
79
+
80
+ # First split up the biases for the gates
81
+ b_coarse_u, b_fine_u = torch.split(self.bias_u, self.split_size)
82
+ b_coarse_r, b_fine_r = torch.split(self.bias_r, self.split_size)
83
+ b_coarse_e, b_fine_e = torch.split(self.bias_e, self.split_size)
84
+
85
+ # Lists for the two output seqs
86
+ c_outputs, f_outputs = [], []
87
+
88
+ # Some initial inputs
89
+ out_coarse = torch.tensor([0], dtype=torch.long, device=device)
90
+ out_fine = torch.tensor([0], dtype=torch.long, device=device)
91
+
92
+ # We'll meed a hidden state
93
+ hidden = self.get_initial_hidden()
94
+
95
+ # Need a clock for display
96
+ start = time.time()
97
+
98
+ # Loop for generation
99
+ for i in range(seq_len):
100
+
101
+ # Split into two hidden states
102
+ hidden_coarse, hidden_fine = \
103
+ torch.split(hidden, self.split_size, dim=1)
104
+
105
+ # Scale and concat previous predictions
106
+ out_coarse = out_coarse.unsqueeze(0).float() / 127.5 - 1.
107
+ out_fine = out_fine.unsqueeze(0).float() / 127.5 - 1.
108
+ prev_outputs = torch.cat([out_coarse, out_fine], dim=1)
109
+
110
+ # Project input
111
+ coarse_input_proj = self.I_coarse(prev_outputs)
112
+ I_coarse_u, I_coarse_r, I_coarse_e = \
113
+ torch.split(coarse_input_proj, self.split_size, dim=1)
114
+
115
+ # Project hidden state and split 6 ways
116
+ R_hidden = self.R(hidden)
117
+ R_coarse_u , R_fine_u, \
118
+ R_coarse_r, R_fine_r, \
119
+ R_coarse_e, R_fine_e = torch.split(R_hidden, self.split_size, dim=1)
120
+
121
+ # Compute the coarse gates
122
+ u = F.sigmoid(R_coarse_u + I_coarse_u + b_coarse_u)
123
+ r = F.sigmoid(R_coarse_r + I_coarse_r + b_coarse_r)
124
+ e = F.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e)
125
+ hidden_coarse = u * hidden_coarse + (1. - u) * e
126
+
127
+ # Compute the coarse output
128
+ out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
129
+ posterior = F.softmax(out_coarse, dim=1)
130
+ distrib = torch.distributions.Categorical(posterior)
131
+ out_coarse = distrib.sample()
132
+ c_outputs.append(out_coarse)
133
+
134
+ # Project the [prev outputs and predicted coarse sample]
135
+ coarse_pred = out_coarse.float() / 127.5 - 1.
136
+ fine_input = torch.cat([prev_outputs, coarse_pred.unsqueeze(0)], dim=1)
137
+ fine_input_proj = self.I_fine(fine_input)
138
+ I_fine_u, I_fine_r, I_fine_e = \
139
+ torch.split(fine_input_proj, self.split_size, dim=1)
140
+
141
+ # Compute the fine gates
142
+ u = F.sigmoid(R_fine_u + I_fine_u + b_fine_u)
143
+ r = F.sigmoid(R_fine_r + I_fine_r + b_fine_r)
144
+ e = F.tanh(r * R_fine_e + I_fine_e + b_fine_e)
145
+ hidden_fine = u * hidden_fine + (1. - u) * e
146
+
147
+ # Compute the fine output
148
+ out_fine = self.O4(F.relu(self.O3(hidden_fine)))
149
+ posterior = F.softmax(out_fine, dim=1)
150
+ distrib = torch.distributions.Categorical(posterior)
151
+ out_fine = distrib.sample()
152
+ f_outputs.append(out_fine)
153
+
154
+ # Put the hidden state back together
155
+ hidden = torch.cat([hidden_coarse, hidden_fine], dim=1)
156
+
157
+ # Display progress
158
+ speed = (i + 1) / (time.time() - start)
159
+ stream('Gen: %i/%i -- Speed: %i', (i + 1, seq_len, speed))
160
+
161
+ coarse = torch.stack(c_outputs).squeeze(1).cpu().data.numpy()
162
+ fine = torch.stack(f_outputs).squeeze(1).cpu().data.numpy()
163
+ output = combine_signal(coarse, fine)
164
+
165
+ return output, coarse, fine
166
+
167
+ def get_initial_hidden(self, batch_size=1):
168
+ device = next(self.parameters()).device # use same device as parameters
169
+ return torch.zeros(batch_size, self.hidden_size, device=device)
170
+
171
+ def num_params(self, print_out=True):
172
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
173
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
174
+ if print_out:
175
+ print('Trainable Parameters: %.3f million' % parameters)
176
+ return parameters
models/WaveRNNModel/models/fatchord_version.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from utils.distribution import sample_from_discretized_mix_logistic
5
+ from utils.display import *
6
+ from utils.dsp import *
7
+ import os
8
+ import numpy as np
9
+ from pathlib import Path
10
+ from typing import Union
11
+
12
+
13
+ class ResBlock(nn.Module):
14
+ def __init__(self, dims):
15
+ super().__init__()
16
+ self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
17
+ self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
18
+ self.batch_norm1 = nn.BatchNorm1d(dims)
19
+ self.batch_norm2 = nn.BatchNorm1d(dims)
20
+
21
+ def forward(self, x):
22
+ residual = x
23
+ x = self.conv1(x)
24
+ x = self.batch_norm1(x)
25
+ x = F.relu(x)
26
+ x = self.conv2(x)
27
+ x = self.batch_norm2(x)
28
+ return x + residual
29
+
30
+
31
+ class MelResNet(nn.Module):
32
+ def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad):
33
+ super().__init__()
34
+ k_size = pad * 2 + 1
35
+ self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
36
+ self.batch_norm = nn.BatchNorm1d(compute_dims)
37
+ self.layers = nn.ModuleList()
38
+ for i in range(res_blocks):
39
+ self.layers.append(ResBlock(compute_dims))
40
+ self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
41
+
42
+ def forward(self, x):
43
+ x = self.conv_in(x)
44
+ x = self.batch_norm(x)
45
+ x = F.relu(x)
46
+ for f in self.layers: x = f(x)
47
+ x = self.conv_out(x)
48
+ return x
49
+
50
+
51
+ class Stretch2d(nn.Module):
52
+ def __init__(self, x_scale, y_scale):
53
+ super().__init__()
54
+ self.x_scale = x_scale
55
+ self.y_scale = y_scale
56
+
57
+ def forward(self, x):
58
+ b, c, h, w = x.size()
59
+ x = x.unsqueeze(-1).unsqueeze(3)
60
+ x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
61
+ return x.view(b, c, h * self.y_scale, w * self.x_scale)
62
+
63
+
64
+ class UpsampleNetwork(nn.Module):
65
+ def __init__(self, feat_dims, upsample_scales, compute_dims,
66
+ res_blocks, res_out_dims, pad):
67
+ super().__init__()
68
+ total_scale = np.cumproduct(upsample_scales)[-1]
69
+ self.indent = pad * total_scale
70
+ self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad)
71
+ self.resnet_stretch = Stretch2d(total_scale, 1)
72
+ self.up_layers = nn.ModuleList()
73
+ for scale in upsample_scales:
74
+ k_size = (1, scale * 2 + 1)
75
+ padding = (0, scale)
76
+ stretch = Stretch2d(scale, 1)
77
+ conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
78
+ conv.weight.data.fill_(1. / k_size[1])
79
+ self.up_layers.append(stretch)
80
+ self.up_layers.append(conv)
81
+
82
+ def forward(self, m):
83
+ aux = self.resnet(m).unsqueeze(1)
84
+ aux = self.resnet_stretch(aux)
85
+ aux = aux.squeeze(1)
86
+ m = m.unsqueeze(1)
87
+ for f in self.up_layers: m = f(m)
88
+ m = m.squeeze(1)[:, :, self.indent:-self.indent]
89
+ return m.transpose(1, 2), aux.transpose(1, 2)
90
+
91
+
92
+ class WaveRNN(nn.Module):
93
+ def __init__(self, rnn_dims, fc_dims, bits, pad, upsample_factors,
94
+ feat_dims, compute_dims, res_out_dims, res_blocks,
95
+ hop_length, sample_rate, mode='RAW'):
96
+ super().__init__()
97
+ self.mode = mode
98
+ self.pad = pad
99
+ if self.mode == 'RAW':
100
+ self.n_classes = 2 ** bits
101
+ elif self.mode == 'MOL':
102
+ self.n_classes = 30
103
+ else:
104
+ RuntimeError("Unknown model mode value - ", self.mode)
105
+
106
+ # List of rnns to call `flatten_parameters()` on
107
+ self._to_flatten = []
108
+
109
+ self.rnn_dims = rnn_dims
110
+ self.aux_dims = res_out_dims // 4
111
+ self.hop_length = hop_length
112
+ self.sample_rate = sample_rate
113
+
114
+ self.upsample = UpsampleNetwork(feat_dims, upsample_factors, compute_dims, res_blocks, res_out_dims, pad)
115
+ self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
116
+
117
+ self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
118
+ self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True)
119
+ self._to_flatten += [self.rnn1, self.rnn2]
120
+
121
+ self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
122
+ self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
123
+ self.fc3 = nn.Linear(fc_dims, self.n_classes)
124
+
125
+ self.register_buffer('step', torch.zeros(1, dtype=torch.long))
126
+ self.num_params()
127
+
128
+ # Avoid fragmentation of RNN parameters and associated warning
129
+ self._flatten_parameters()
130
+
131
+ def forward(self, x, mels):
132
+ device = next(self.parameters()).device # use same device as parameters
133
+
134
+ # Although we `_flatten_parameters()` on init, when using DataParallel
135
+ # the model gets replicated, making it no longer guaranteed that the
136
+ # weights are contiguous in GPU memory. Hence, we must call it again
137
+ self._flatten_parameters()
138
+
139
+ self.step += 1
140
+ bsize = x.size(0)
141
+ h1 = torch.zeros(1, bsize, self.rnn_dims, device=device)
142
+ h2 = torch.zeros(1, bsize, self.rnn_dims, device=device)
143
+ mels, aux = self.upsample(mels)
144
+
145
+ aux_idx = [self.aux_dims * i for i in range(5)]
146
+ a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
147
+ a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
148
+ a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
149
+ a4 = aux[:, :, aux_idx[3]:aux_idx[4]]
150
+
151
+ x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
152
+ x = self.I(x)
153
+ res = x
154
+ x, _ = self.rnn1(x, h1)
155
+
156
+ x = x + res
157
+ res = x
158
+ x = torch.cat([x, a2], dim=2)
159
+ x, _ = self.rnn2(x, h2)
160
+
161
+ x = x + res
162
+ x = torch.cat([x, a3], dim=2)
163
+ x = F.relu(self.fc1(x))
164
+
165
+ x = torch.cat([x, a4], dim=2)
166
+ x = F.relu(self.fc2(x))
167
+ return self.fc3(x)
168
+
169
+ def generate(self, mels, save_path: Union[str, Path], batched, target, overlap, mu_law):
170
+ self.eval()
171
+
172
+ device = next(self.parameters()).device # use same device as parameters
173
+
174
+ mu_law = mu_law if self.mode == 'RAW' else False
175
+
176
+ output = []
177
+ start = time.time()
178
+ rnn1 = self.get_gru_cell(self.rnn1)
179
+ rnn2 = self.get_gru_cell(self.rnn2)
180
+
181
+ with torch.no_grad():
182
+
183
+ mels = torch.as_tensor(mels, device=device)
184
+ wave_len = (mels.size(-1) - 1) * self.hop_length
185
+ mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side='both')
186
+ mels, aux = self.upsample(mels.transpose(1, 2))
187
+
188
+ if batched:
189
+ mels = self.fold_with_overlap(mels, target, overlap)
190
+ aux = self.fold_with_overlap(aux, target, overlap)
191
+
192
+ b_size, seq_len, _ = mels.size()
193
+
194
+ h1 = torch.zeros(b_size, self.rnn_dims, device=device)
195
+ h2 = torch.zeros(b_size, self.rnn_dims, device=device)
196
+ x = torch.zeros(b_size, 1, device=device)
197
+
198
+ d = self.aux_dims
199
+ aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)]
200
+
201
+ for i in range(seq_len):
202
+
203
+ m_t = mels[:, i, :]
204
+
205
+ a1_t, a2_t, a3_t, a4_t = \
206
+ (a[:, i, :] for a in aux_split)
207
+
208
+ x = torch.cat([x, m_t, a1_t], dim=1)
209
+ x = self.I(x)
210
+ h1 = rnn1(x, h1)
211
+
212
+ x = x + h1
213
+ inp = torch.cat([x, a2_t], dim=1)
214
+ h2 = rnn2(inp, h2)
215
+
216
+ x = x + h2
217
+ x = torch.cat([x, a3_t], dim=1)
218
+ x = F.relu(self.fc1(x))
219
+
220
+ x = torch.cat([x, a4_t], dim=1)
221
+ x = F.relu(self.fc2(x))
222
+
223
+ logits = self.fc3(x)
224
+
225
+ if self.mode == 'MOL':
226
+ sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
227
+ output.append(sample.view(-1))
228
+ # x = torch.FloatTensor([[sample]]).cuda()
229
+ x = sample.transpose(0, 1)
230
+
231
+ elif self.mode == 'RAW':
232
+ posterior = F.softmax(logits, dim=1)
233
+ distrib = torch.distributions.Categorical(posterior)
234
+
235
+ sample = 2 * distrib.sample().float() / (self.n_classes - 1.) - 1.
236
+ output.append(sample)
237
+ x = sample.unsqueeze(-1)
238
+ else:
239
+ raise RuntimeError("Unknown model mode value - ", self.mode)
240
+
241
+ if i % 100 == 0: self.gen_display(i, seq_len, b_size, start)
242
+
243
+ output = torch.stack(output).transpose(0, 1)
244
+ output = output.cpu().numpy()
245
+ output = output.astype(np.float64)
246
+
247
+ if mu_law:
248
+ output = decode_mu_law(output, self.n_classes, False)
249
+
250
+ if batched:
251
+ output = self.xfade_and_unfold(output, target, overlap)
252
+ else:
253
+ output = output[0]
254
+
255
+ # Fade-out at the end to avoid signal cutting out suddenly
256
+ fade_out = np.linspace(1, 0, 20 * self.hop_length)
257
+ output = output[:wave_len]
258
+ output[-20 * self.hop_length:] *= fade_out
259
+
260
+ save_wav(output, save_path)
261
+
262
+ self.train()
263
+
264
+ return output
265
+
266
+
267
+ def gen_display(self, i, seq_len, b_size, start):
268
+ gen_rate = (i + 1) / (time.time() - start) * b_size / 1000
269
+ pbar = progbar(i, seq_len)
270
+ msg = f'| {pbar} {i*b_size}/{seq_len*b_size} | Batch Size: {b_size} | Gen Rate: {gen_rate:.1f}kHz | '
271
+ stream(msg)
272
+
273
+ def get_gru_cell(self, gru):
274
+ gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
275
+ gru_cell.weight_hh.data = gru.weight_hh_l0.data
276
+ gru_cell.weight_ih.data = gru.weight_ih_l0.data
277
+ gru_cell.bias_hh.data = gru.bias_hh_l0.data
278
+ gru_cell.bias_ih.data = gru.bias_ih_l0.data
279
+ return gru_cell
280
+
281
+ def pad_tensor(self, x, pad, side='both'):
282
+ # NB - this is just a quick method i need right now
283
+ # i.e., it won't generalise to other shapes/dims
284
+ b, t, c = x.size()
285
+ total = t + 2 * pad if side == 'both' else t + pad
286
+ padded = torch.zeros(b, total, c, device=x.device)
287
+ if side == 'before' or side == 'both':
288
+ padded[:, pad:pad + t, :] = x
289
+ elif side == 'after':
290
+ padded[:, :t, :] = x
291
+ return padded
292
+
293
+ def fold_with_overlap(self, x, target, overlap):
294
+
295
+ ''' Fold the tensor with overlap for quick batched inference.
296
+ Overlap will be used for crossfading in xfade_and_unfold()
297
+
298
+ Args:
299
+ x (tensor) : Upsampled conditioning features.
300
+ shape=(1, timesteps, features)
301
+ target (int) : Target timesteps for each index of batch
302
+ overlap (int) : Timesteps for both xfade and rnn warmup
303
+
304
+ Return:
305
+ (tensor) : shape=(num_folds, target + 2 * overlap, features)
306
+
307
+ Details:
308
+ x = [[h1, h2, ... hn]]
309
+
310
+ Where each h is a vector of conditioning features
311
+
312
+ Eg: target=2, overlap=1 with x.size(1)=10
313
+
314
+ folded = [[h1, h2, h3, h4],
315
+ [h4, h5, h6, h7],
316
+ [h7, h8, h9, h10]]
317
+ '''
318
+
319
+ _, total_len, features = x.size()
320
+
321
+ # Calculate variables needed
322
+ num_folds = (total_len - overlap) // (target + overlap)
323
+ extended_len = num_folds * (overlap + target) + overlap
324
+ remaining = total_len - extended_len
325
+
326
+ # Pad if some time steps poking out
327
+ if remaining != 0:
328
+ num_folds += 1
329
+ padding = target + 2 * overlap - remaining
330
+ x = self.pad_tensor(x, padding, side='after')
331
+
332
+ folded = torch.zeros(num_folds, target + 2 * overlap, features, device=x.device)
333
+
334
+ # Get the values for the folded tensor
335
+ for i in range(num_folds):
336
+ start = i * (target + overlap)
337
+ end = start + target + 2 * overlap
338
+ folded[i] = x[:, start:end, :]
339
+
340
+ return folded
341
+
342
+ def xfade_and_unfold(self, y, target, overlap):
343
+
344
+ ''' Applies a crossfade and unfolds into a 1d array.
345
+
346
+ Args:
347
+ y (ndarry) : Batched sequences of audio samples
348
+ shape=(num_folds, target + 2 * overlap)
349
+ dtype=np.float64
350
+ overlap (int) : Timesteps for both xfade and rnn warmup
351
+
352
+ Return:
353
+ (ndarry) : audio samples in a 1d array
354
+ shape=(total_len)
355
+ dtype=np.float64
356
+
357
+ Details:
358
+ y = [[seq1],
359
+ [seq2],
360
+ [seq3]]
361
+
362
+ Apply a gain envelope at both ends of the sequences
363
+
364
+ y = [[seq1_in, seq1_target, seq1_out],
365
+ [seq2_in, seq2_target, seq2_out],
366
+ [seq3_in, seq3_target, seq3_out]]
367
+
368
+ Stagger and add up the groups of samples:
369
+
370
+ [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...]
371
+
372
+ '''
373
+
374
+ num_folds, length = y.shape
375
+ target = length - 2 * overlap
376
+ total_len = num_folds * (target + overlap) + overlap
377
+
378
+ # Need some silence for the rnn warmup
379
+ silence_len = overlap // 2
380
+ fade_len = overlap - silence_len
381
+ silence = np.zeros((silence_len), dtype=np.float64)
382
+ linear = np.ones((silence_len), dtype=np.float64)
383
+
384
+ # Equal power crossfade
385
+ t = np.linspace(-1, 1, fade_len, dtype=np.float64)
386
+ fade_in = np.sqrt(0.5 * (1 + t))
387
+ fade_out = np.sqrt(0.5 * (1 - t))
388
+
389
+ # Concat the silence to the fades
390
+ fade_in = np.concatenate([silence, fade_in])
391
+ fade_out = np.concatenate([linear, fade_out])
392
+
393
+ # Apply the gain to the overlap samples
394
+ y[:, :overlap] *= fade_in
395
+ y[:, -overlap:] *= fade_out
396
+
397
+ unfolded = np.zeros((total_len), dtype=np.float64)
398
+
399
+ # Loop to add up all the samples
400
+ for i in range(num_folds):
401
+ start = i * (target + overlap)
402
+ end = start + target + 2 * overlap
403
+ unfolded[start:end] += y[i]
404
+
405
+ return unfolded
406
+
407
+ def get_step(self):
408
+ return self.step.data.item()
409
+
410
+ def log(self, path, msg):
411
+ with open(path, 'a') as f:
412
+ print(msg, file=f)
413
+
414
+ def load(self, path: Union[str, Path]):
415
+ # Use device of model params as location for loaded state
416
+ device = next(self.parameters()).device
417
+ self.load_state_dict(torch.load(path, map_location=device), strict=False)
418
+
419
+ def save(self, path: Union[str, Path]):
420
+ # No optimizer argument because saving a model should not include data
421
+ # only relevant in the training process - it should only be properties
422
+ # of the model itself. Let caller take care of saving optimzier state.
423
+ torch.save(self.state_dict(), path)
424
+
425
+ def num_params(self, print_out=True):
426
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
427
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
428
+ if print_out:
429
+ print('Trainable Parameters: %.3fM' % parameters)
430
+ return parameters
431
+
432
+ def _flatten_parameters(self):
433
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
434
+ to improve efficiency and avoid PyTorch yelling at us."""
435
+ [m.flatten_parameters() for m in self._to_flatten]
models/WaveRNNModel/models/tacotron.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from pathlib import Path
7
+ from typing import Union
8
+
9
+
10
+ class HighwayNetwork(nn.Module):
11
+ def __init__(self, size):
12
+ super().__init__()
13
+ self.W1 = nn.Linear(size, size)
14
+ self.W2 = nn.Linear(size, size)
15
+ self.W1.bias.data.fill_(0.)
16
+
17
+ def forward(self, x):
18
+ x1 = self.W1(x)
19
+ x2 = self.W2(x)
20
+ g = torch.sigmoid(x2)
21
+ y = g * F.relu(x1) + (1. - g) * x
22
+ return y
23
+
24
+
25
+ class Encoder(nn.Module):
26
+ def __init__(self, embed_dims, num_chars, cbhg_channels, K, num_highways, dropout):
27
+ super().__init__()
28
+ self.embedding = nn.Embedding(num_chars, embed_dims)
29
+ self.pre_net = PreNet(embed_dims)
30
+ self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
31
+ proj_channels=[cbhg_channels, cbhg_channels],
32
+ num_highways=num_highways)
33
+
34
+ def forward(self, x):
35
+ x = self.embedding(x)
36
+ x = self.pre_net(x)
37
+ x.transpose_(1, 2)
38
+ x = self.cbhg(x)
39
+ return x
40
+
41
+
42
+ class BatchNormConv(nn.Module):
43
+ def __init__(self, in_channels, out_channels, kernel, relu=True):
44
+ super().__init__()
45
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
46
+ self.bnorm = nn.BatchNorm1d(out_channels)
47
+ self.relu = relu
48
+
49
+ def forward(self, x):
50
+ x = self.conv(x)
51
+ x = F.relu(x) if self.relu is True else x
52
+ return self.bnorm(x)
53
+
54
+
55
+ class CBHG(nn.Module):
56
+ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
57
+ super().__init__()
58
+
59
+ # List of all rnns to call `flatten_parameters()` on
60
+ self._to_flatten = []
61
+
62
+ self.bank_kernels = [i for i in range(1, K + 1)]
63
+ self.conv1d_bank = nn.ModuleList()
64
+ for k in self.bank_kernels:
65
+ conv = BatchNormConv(in_channels, channels, k)
66
+ self.conv1d_bank.append(conv)
67
+
68
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
69
+
70
+ self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
71
+ self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
72
+
73
+ # Fix the highway input if necessary
74
+ if proj_channels[-1] != channels:
75
+ self.highway_mismatch = True
76
+ self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
77
+ else:
78
+ self.highway_mismatch = False
79
+
80
+ self.highways = nn.ModuleList()
81
+ for i in range(num_highways):
82
+ hn = HighwayNetwork(channels)
83
+ self.highways.append(hn)
84
+
85
+ self.rnn = nn.GRU(channels, channels, batch_first=True, bidirectional=True)
86
+ self._to_flatten.append(self.rnn)
87
+
88
+ # Avoid fragmentation of RNN parameters and associated warning
89
+ self._flatten_parameters()
90
+
91
+ def forward(self, x):
92
+ # Although we `_flatten_parameters()` on init, when using DataParallel
93
+ # the model gets replicated, making it no longer guaranteed that the
94
+ # weights are contiguous in GPU memory. Hence, we must call it again
95
+ self._flatten_parameters()
96
+
97
+ # Save these for later
98
+ residual = x
99
+ seq_len = x.size(-1)
100
+ conv_bank = []
101
+
102
+ # Convolution Bank
103
+ for conv in self.conv1d_bank:
104
+ c = conv(x) # Convolution
105
+ conv_bank.append(c[:, :, :seq_len])
106
+
107
+ # Stack along the channel axis
108
+ conv_bank = torch.cat(conv_bank, dim=1)
109
+
110
+ # dump the last padding to fit residual
111
+ x = self.maxpool(conv_bank)[:, :, :seq_len]
112
+
113
+ # Conv1d projections
114
+ x = self.conv_project1(x)
115
+ x = self.conv_project2(x)
116
+
117
+ # Residual Connect
118
+ x = x + residual
119
+
120
+ # Through the highways
121
+ x = x.transpose(1, 2)
122
+ if self.highway_mismatch is True:
123
+ x = self.pre_highway(x)
124
+ for h in self.highways: x = h(x)
125
+
126
+ # And then the RNN
127
+ x, _ = self.rnn(x)
128
+ return x
129
+
130
+ def _flatten_parameters(self):
131
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
132
+ to improve efficiency and avoid PyTorch yelling at us."""
133
+ [m.flatten_parameters() for m in self._to_flatten]
134
+
135
+ class PreNet(nn.Module):
136
+ def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
137
+ super().__init__()
138
+ self.fc1 = nn.Linear(in_dims, fc1_dims)
139
+ self.fc2 = nn.Linear(fc1_dims, fc2_dims)
140
+ self.p = dropout
141
+
142
+ def forward(self, x):
143
+ x = self.fc1(x)
144
+ x = F.relu(x)
145
+ x = F.dropout(x, self.p, training=self.training)
146
+ x = self.fc2(x)
147
+ x = F.relu(x)
148
+ x = F.dropout(x, self.p, training=self.training)
149
+ return x
150
+
151
+
152
+ class Attention(nn.Module):
153
+ def __init__(self, attn_dims):
154
+ super().__init__()
155
+ self.W = nn.Linear(attn_dims, attn_dims, bias=False)
156
+ self.v = nn.Linear(attn_dims, 1, bias=False)
157
+
158
+ def forward(self, encoder_seq_proj, query, t):
159
+
160
+ # print(encoder_seq_proj.shape)
161
+ # Transform the query vector
162
+ query_proj = self.W(query).unsqueeze(1)
163
+
164
+ # Compute the scores
165
+ u = self.v(torch.tanh(encoder_seq_proj + query_proj))
166
+ scores = F.softmax(u, dim=1)
167
+
168
+ return scores.transpose(1, 2)
169
+
170
+
171
+ class LSA(nn.Module):
172
+ def __init__(self, attn_dim, kernel_size=31, filters=32):
173
+ super().__init__()
174
+ self.conv = nn.Conv1d(2, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=False)
175
+ self.L = nn.Linear(filters, attn_dim, bias=True)
176
+ self.W = nn.Linear(attn_dim, attn_dim, bias=True)
177
+ self.v = nn.Linear(attn_dim, 1, bias=False)
178
+ self.cumulative = None
179
+ self.attention = None
180
+
181
+ def init_attention(self, encoder_seq_proj):
182
+ device = next(self.parameters()).device # use same device as parameters
183
+ b, t, c = encoder_seq_proj.size()
184
+ self.cumulative = torch.zeros(b, t, device=device)
185
+ self.attention = torch.zeros(b, t, device=device)
186
+
187
+ def forward(self, encoder_seq_proj, query, t):
188
+
189
+ if t == 0: self.init_attention(encoder_seq_proj)
190
+
191
+ processed_query = self.W(query).unsqueeze(1)
192
+
193
+ location = torch.cat([self.cumulative.unsqueeze(1), self.attention.unsqueeze(1)], dim=1)
194
+ processed_loc = self.L(self.conv(location).transpose(1, 2))
195
+
196
+ u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
197
+ u = u.squeeze(-1)
198
+
199
+ # Smooth Attention
200
+ scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
201
+ # scores = F.softmax(u, dim=1)
202
+ self.attention = scores
203
+ self.cumulative += self.attention
204
+
205
+ return scores.unsqueeze(-1).transpose(1, 2)
206
+
207
+
208
+ class Decoder(nn.Module):
209
+ # Class variable because its value doesn't change between classes
210
+ # yet ought to be scoped by class because its a property of a Decoder
211
+ max_r = 20
212
+ def __init__(self, n_mels, decoder_dims, lstm_dims):
213
+ super().__init__()
214
+ self.register_buffer('r', torch.tensor(1, dtype=torch.int))
215
+ self.n_mels = n_mels
216
+ self.prenet = PreNet(n_mels)
217
+ self.attn_net = LSA(decoder_dims)
218
+ self.attn_rnn = nn.GRUCell(decoder_dims + decoder_dims // 2, decoder_dims)
219
+ self.rnn_input = nn.Linear(2 * decoder_dims, lstm_dims)
220
+ self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
221
+ self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
222
+ self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
223
+
224
+ def zoneout(self, prev, current, p=0.1):
225
+ device = next(self.parameters()).device # Use same device as parameters
226
+ mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
227
+ return prev * mask + current * (1 - mask)
228
+
229
+ def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
230
+ hidden_states, cell_states, context_vec, t):
231
+
232
+ # Need this for reshaping mels
233
+ batch_size = encoder_seq.size(0)
234
+
235
+ # Unpack the hidden and cell states
236
+ attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
237
+ rnn1_cell, rnn2_cell = cell_states
238
+
239
+ # PreNet for the Attention RNN
240
+ prenet_out = self.prenet(prenet_in)
241
+
242
+ # Compute the Attention RNN hidden state
243
+ attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
244
+ attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
245
+
246
+ # Compute the attention scores
247
+ scores = self.attn_net(encoder_seq_proj, attn_hidden, t)
248
+
249
+ # Dot product to create the context vector
250
+ context_vec = scores @ encoder_seq
251
+ context_vec = context_vec.squeeze(1)
252
+
253
+ # Concat Attention RNN output w. Context Vector & project
254
+ x = torch.cat([context_vec, attn_hidden], dim=1)
255
+ x = self.rnn_input(x)
256
+
257
+ # Compute first Residual RNN
258
+ rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
259
+ if self.training:
260
+ rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
261
+ else:
262
+ rnn1_hidden = rnn1_hidden_next
263
+ x = x + rnn1_hidden
264
+
265
+ # Compute second Residual RNN
266
+ rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
267
+ if self.training:
268
+ rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
269
+ else:
270
+ rnn2_hidden = rnn2_hidden_next
271
+ x = x + rnn2_hidden
272
+
273
+ # Project Mels
274
+ mels = self.mel_proj(x)
275
+ mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
276
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
277
+ cell_states = (rnn1_cell, rnn2_cell)
278
+
279
+ return mels, scores, hidden_states, cell_states, context_vec
280
+
281
+
282
+ class Tacotron(nn.Module):
283
+ def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, fft_bins, postnet_dims,
284
+ encoder_K, lstm_dims, postnet_K, num_highways, dropout, stop_threshold):
285
+ super().__init__()
286
+ self.n_mels = n_mels
287
+ self.lstm_dims = lstm_dims
288
+ self.decoder_dims = decoder_dims
289
+ self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
290
+ encoder_K, num_highways, dropout)
291
+ self.encoder_proj = nn.Linear(decoder_dims, decoder_dims, bias=False)
292
+ self.decoder = Decoder(n_mels, decoder_dims, lstm_dims)
293
+ self.postnet = CBHG(postnet_K, n_mels, postnet_dims, [256, 80], num_highways)
294
+ self.post_proj = nn.Linear(postnet_dims * 2, fft_bins, bias=False)
295
+
296
+ self.init_model()
297
+ self.num_params()
298
+
299
+ self.register_buffer('step', torch.zeros(1, dtype=torch.long))
300
+ self.register_buffer('stop_threshold', torch.tensor(stop_threshold, dtype=torch.float32))
301
+
302
+ @property
303
+ def r(self):
304
+ return self.decoder.r.item()
305
+
306
+ @r.setter
307
+ def r(self, value):
308
+ self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
309
+
310
+ def forward(self, x, m, generate_gta=False):
311
+ device = next(self.parameters()).device # use same device as parameters
312
+
313
+ self.step += 1
314
+
315
+ if generate_gta:
316
+ self.eval()
317
+ else:
318
+ self.train()
319
+
320
+ batch_size, _, steps = m.size()
321
+
322
+ # Initialise all hidden states and pack into tuple
323
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
324
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
325
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
326
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
327
+
328
+ # Initialise all lstm cell states and pack into tuple
329
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
330
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
331
+ cell_states = (rnn1_cell, rnn2_cell)
332
+
333
+ # <GO> Frame for start of decoder loop
334
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
335
+
336
+ # Need an initial context vector
337
+ context_vec = torch.zeros(batch_size, self.decoder_dims, device=device)
338
+
339
+ # Project the encoder outputs to avoid
340
+ # unnecessary matmuls in the decoder loop
341
+ encoder_seq = self.encoder(x)
342
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
343
+
344
+ # Need a couple of lists for outputs
345
+ mel_outputs, attn_scores = [], []
346
+
347
+ # Run the decoder loop
348
+ for t in range(0, steps, self.r):
349
+ prenet_in = m[:, :, t - 1] if t > 0 else go_frame
350
+ mel_frames, scores, hidden_states, cell_states, context_vec = \
351
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
352
+ hidden_states, cell_states, context_vec, t)
353
+ mel_outputs.append(mel_frames)
354
+ attn_scores.append(scores)
355
+
356
+ # Concat the mel outputs into sequence
357
+ mel_outputs = torch.cat(mel_outputs, dim=2)
358
+
359
+ # Post-Process for Linear Spectrograms
360
+ postnet_out = self.postnet(mel_outputs)
361
+ linear = self.post_proj(postnet_out)
362
+ linear = linear.transpose(1, 2)
363
+
364
+ # For easy visualisation
365
+ attn_scores = torch.cat(attn_scores, 1)
366
+ # attn_scores = attn_scores.cpu().data.numpy()
367
+
368
+ return mel_outputs, linear, attn_scores
369
+
370
+ def generate(self, x, steps=2000):
371
+ self.eval()
372
+ device = next(self.parameters()).device # use same device as parameters
373
+
374
+ batch_size = 1
375
+ x = torch.as_tensor(x, dtype=torch.long, device=device).unsqueeze(0)
376
+
377
+ # Need to initialise all hidden states and pack into tuple for tidyness
378
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
379
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
380
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
381
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
382
+
383
+ # Need to initialise all lstm cell states and pack into tuple for tidyness
384
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
385
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
386
+ cell_states = (rnn1_cell, rnn2_cell)
387
+
388
+ # Need a <GO> Frame for start of decoder loop
389
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
390
+
391
+ # Need an initial context vector
392
+ context_vec = torch.zeros(batch_size, self.decoder_dims, device=device)
393
+
394
+ # Project the encoder outputs to avoid
395
+ # unnecessary matmuls in the decoder loop
396
+ encoder_seq = self.encoder(x)
397
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
398
+
399
+ # Need a couple of lists for outputs
400
+ mel_outputs, attn_scores = [], []
401
+
402
+ # Run the decoder loop
403
+ for t in range(0, steps, self.r):
404
+ prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
405
+ mel_frames, scores, hidden_states, cell_states, context_vec = \
406
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
407
+ hidden_states, cell_states, context_vec, t)
408
+ mel_outputs.append(mel_frames)
409
+ attn_scores.append(scores)
410
+ # Stop the loop if silent frames present
411
+ if (mel_frames < self.stop_threshold).all() and t > 10: break
412
+
413
+ # Concat the mel outputs into sequence
414
+ mel_outputs = torch.cat(mel_outputs, dim=2)
415
+
416
+ # Post-Process for Linear Spectrograms
417
+ postnet_out = self.postnet(mel_outputs)
418
+ linear = self.post_proj(postnet_out)
419
+
420
+
421
+ linear = linear.transpose(1, 2)[0].cpu().data.numpy()
422
+ mel_outputs = mel_outputs[0].cpu().data.numpy()
423
+
424
+ # For easy visualisation
425
+ attn_scores = torch.cat(attn_scores, 1)
426
+ attn_scores = attn_scores.cpu().data.numpy()[0]
427
+
428
+ self.train()
429
+
430
+ return mel_outputs, linear, attn_scores
431
+
432
+ def init_model(self):
433
+ for p in self.parameters():
434
+ if p.dim() > 1: nn.init.xavier_uniform_(p)
435
+
436
+ def get_step(self):
437
+ return self.step.data.item()
438
+
439
+ def reset_step(self):
440
+ # assignment to parameters or buffers is overloaded, updates internal dict entry
441
+ self.step = self.step.data.new_tensor(1)
442
+
443
+ def log(self, path, msg):
444
+ with open(path, 'a') as f:
445
+ print(msg, file=f)
446
+
447
+ def load(self, path: Union[str, Path]):
448
+ # Use device of model params as location for loaded state
449
+ device = next(self.parameters()).device
450
+ state_dict = torch.load(path, map_location=device)
451
+
452
+ # Backwards compatibility with old saved models
453
+ if 'r' in state_dict and not 'decoder.r' in state_dict:
454
+ self.r = state_dict['r']
455
+
456
+ self.load_state_dict(state_dict, strict=False)
457
+
458
+ def save(self, path: Union[str, Path]):
459
+ # No optimizer argument because saving a model should not include data
460
+ # only relevant in the training process - it should only be properties
461
+ # of the model itself. Let caller take care of saving optimzier state.
462
+ torch.save(self.state_dict(), path)
463
+
464
+ def num_params(self, print_out=True):
465
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
466
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
467
+ if print_out:
468
+ print('Trainable Parameters: %.3fM' % parameters)
469
+ return parameters
models/WaveRNNModel/notebooks/NB1 - Fit a Sine Wave.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/WaveRNNModel/notebooks/NB2 - Fit a Short Sample.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/WaveRNNModel/notebooks/NB3 - Fit a 30min Sample.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/WaveRNNModel/notebooks/NB4a - Alternative Model (Preprocessing).ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/WaveRNNModel/notebooks/NB4b - Alternative Model (Training).ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/WaveRNNModel/notebooks/Pruning - Scratchpad.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/WaveRNNModel/notebooks/__init__.py ADDED
File without changes
models/WaveRNNModel/notebooks/models/wavernn.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class WaveRNN(nn.Module) :
7
+ def __init__(self, hidden_size=896, quantisation=256) :
8
+ super(WaveRNN, self).__init__()
9
+
10
+ self.hidden_size = hidden_size
11
+ self.split_size = hidden_size // 2
12
+
13
+ # The main matmul
14
+ self.R = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
15
+
16
+ # Output fc layers
17
+ self.O1 = nn.Linear(self.split_size, self.split_size)
18
+ self.O2 = nn.Linear(self.split_size, quantisation)
19
+ self.O3 = nn.Linear(self.split_size, self.split_size)
20
+ self.O4 = nn.Linear(self.split_size, quantisation)
21
+
22
+ # Input fc layers
23
+ self.I_coarse = nn.Linear(2, 3 * self.split_size, bias=False)
24
+ self.I_fine = nn.Linear(3, 3 * self.split_size, bias=False)
25
+
26
+ # biases for the gates
27
+ self.bias_u = nn.Parameter(torch.zeros(self.hidden_size))
28
+ self.bias_r = nn.Parameter(torch.zeros(self.hidden_size))
29
+ self.bias_e = nn.Parameter(torch.zeros(self.hidden_size))
30
+
31
+ # display num params
32
+ self.num_params()
33
+
34
+
35
+ def forward(self, prev_y, prev_hidden, current_coarse) :
36
+
37
+ # Main matmul - the projection is split 3 ways
38
+ R_hidden = self.R(prev_hidden)
39
+ R_u, R_r, R_e, = torch.split(R_hidden, self.hidden_size, dim=1)
40
+
41
+ # Project the prev input
42
+ coarse_input_proj = self.I_coarse(prev_y)
43
+ I_coarse_u, I_coarse_r, I_coarse_e = \
44
+ torch.split(coarse_input_proj, self.split_size, dim=1)
45
+
46
+ # Project the prev input and current coarse sample
47
+ fine_input = torch.cat([prev_y, current_coarse], dim=1)
48
+ fine_input_proj = self.I_fine(fine_input)
49
+ I_fine_u, I_fine_r, I_fine_e = \
50
+ torch.split(fine_input_proj, self.split_size, dim=1)
51
+
52
+ # concatenate for the gates
53
+ I_u = torch.cat([I_coarse_u, I_fine_u], dim=1)
54
+ I_r = torch.cat([I_coarse_r, I_fine_r], dim=1)
55
+ I_e = torch.cat([I_coarse_e, I_fine_e], dim=1)
56
+
57
+ # Compute all gates for coarse and fine
58
+ u = F.sigmoid(R_u + I_u + self.bias_u)
59
+ r = F.sigmoid(R_r + I_r + self.bias_r)
60
+ e = F.tanh(r * R_e + I_e + self.bias_e)
61
+ hidden = u * prev_hidden + (1. - u) * e
62
+
63
+ # Split the hidden state
64
+ hidden_coarse, hidden_fine = torch.split(hidden, self.split_size, dim=1)
65
+
66
+ # Compute outputs
67
+ out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
68
+ out_fine = self.O4(F.relu(self.O3(hidden_fine)))
69
+
70
+ return out_coarse, out_fine, hidden
71
+
72
+
73
+ def generate(self, seq_len) :
74
+
75
+ with torch.no_grad() :
76
+
77
+ # First split up the biases for the gates
78
+ b_coarse_u, b_fine_u = torch.split(self.bias_u, self.split_size)
79
+ b_coarse_r, b_fine_r = torch.split(self.bias_r, self.split_size)
80
+ b_coarse_e, b_fine_e = torch.split(self.bias_e, self.split_size)
81
+
82
+ # Lists for the two output seqs
83
+ c_outputs, f_outputs = [], []
84
+
85
+ # Some initial inputs
86
+ out_coarse = torch.LongTensor([0]).cuda()
87
+ out_fine = torch.LongTensor([0]).cuda()
88
+
89
+ # We'll meed a hidden state
90
+ hidden = self.init_hidden()
91
+
92
+ # Need a clock for display
93
+ start = time.time()
94
+
95
+ # Loop for generation
96
+ for i in range(seq_len) :
97
+
98
+ # Split into two hidden states
99
+ hidden_coarse, hidden_fine = \
100
+ torch.split(hidden, self.split_size, dim=1)
101
+
102
+ # Scale and concat previous predictions
103
+ out_coarse = out_coarse.unsqueeze(0).float() / 127.5 - 1.
104
+ out_fine = out_fine.unsqueeze(0).float() / 127.5 - 1.
105
+ prev_outputs = torch.cat([out_coarse, out_fine], dim=1)
106
+
107
+ # Project input
108
+ coarse_input_proj = self.I_coarse(prev_outputs)
109
+ I_coarse_u, I_coarse_r, I_coarse_e = \
110
+ torch.split(coarse_input_proj, self.split_size, dim=1)
111
+
112
+ # Project hidden state and split 6 ways
113
+ R_hidden = self.R(hidden)
114
+ R_coarse_u , R_fine_u, \
115
+ R_coarse_r, R_fine_r, \
116
+ R_coarse_e, R_fine_e = torch.split(R_hidden, self.split_size, dim=1)
117
+
118
+ # Compute the coarse gates
119
+ u = F.sigmoid(R_coarse_u + I_coarse_u + b_coarse_u)
120
+ r = F.sigmoid(R_coarse_r + I_coarse_r + b_coarse_r)
121
+ e = F.tanh(r * R_coarse_e + I_coarse_e + b_coarse_e)
122
+ hidden_coarse = u * hidden_coarse + (1. - u) * e
123
+
124
+ # Compute the coarse output
125
+ out_coarse = self.O2(F.relu(self.O1(hidden_coarse)))
126
+ posterior = F.softmax(out_coarse, dim=1)
127
+ distrib = torch.distributions.Categorical(posterior)
128
+ out_coarse = distrib.sample()
129
+ c_outputs.append(out_coarse)
130
+
131
+ # Project the [prev outputs and predicted coarse sample]
132
+ coarse_pred = out_coarse.float() / 127.5 - 1.
133
+ fine_input = torch.cat([prev_outputs, coarse_pred.unsqueeze(0)], dim=1)
134
+ fine_input_proj = self.I_fine(fine_input)
135
+ I_fine_u, I_fine_r, I_fine_e = \
136
+ torch.split(fine_input_proj, self.split_size, dim=1)
137
+
138
+ # Compute the fine gates
139
+ u = F.sigmoid(R_fine_u + I_fine_u + b_fine_u)
140
+ r = F.sigmoid(R_fine_r + I_fine_r + b_fine_r)
141
+ e = F.tanh(r * R_fine_e + I_fine_e + b_fine_e)
142
+ hidden_fine = u * hidden_fine + (1. - u) * e
143
+
144
+ # Compute the fine output
145
+ out_fine = self.O4(F.relu(self.O3(hidden_fine)))
146
+ posterior = F.softmax(out_fine, dim=1)
147
+ distrib = torch.distributions.Categorical(posterior)
148
+ out_fine = distrib.sample()
149
+ f_outputs.append(out_fine)
150
+
151
+ # Put the hidden state back together
152
+ hidden = torch.cat([hidden_coarse, hidden_fine], dim=1)
153
+
154
+ # Display progress
155
+ speed = (i + 1) / (time.time() - start)
156
+ stream('Gen: %i/%i -- Speed: %i', (i + 1, seq_len, speed))
157
+
158
+ coarse = torch.stack(c_outputs).squeeze(1).cpu().data.numpy()
159
+ fine = torch.stack(f_outputs).squeeze(1).cpu().data.numpy()
160
+ output = combine_signal(coarse, fine)
161
+
162
+ return output, coarse, fine
163
+
164
+
165
+ def init_hidden(self, batch_size=1) :
166
+ return torch.zeros(batch_size, self.hidden_size).cuda()
167
+
168
+
169
+ def num_params(self) :
170
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
171
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
172
+ print('Trainable Parameters: %.3f million' % parameters)
models/WaveRNNModel/notebooks/outputs/nb1/model_output.wav ADDED
Binary file (80 kB). View file
 
models/WaveRNNModel/notebooks/outputs/nb2/3k_steps.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eeadf150c5e62325421d700609eec88c106ba86330a011d6b54542a8764a0728
3
+ size 220544
models/WaveRNNModel/notebooks/outputs/nb3/12k_steps.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b18a69000c2f2852d1b9f5104a7f9bc3421882c5ff2847f7185b081c9feab0b8
3
+ size 882044
models/WaveRNNModel/notebooks/utils/__init__.py ADDED
File without changes
models/WaveRNNModel/notebooks/utils/display.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import time, sys, math
3
+ import numpy as np
4
+
5
+ def stream(string, variables) :
6
+ sys.stdout.write(f'\r{string}' % variables)
7
+
8
+ def num_params(model) :
9
+ parameters = filter(lambda p: p.requires_grad, model.parameters())
10
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
11
+ print('Trainable Parameters: %.3f million' % parameters)
12
+
13
+ def time_since(started) :
14
+ elapsed = time.time() - started
15
+ m = int(elapsed // 60)
16
+ s = int(elapsed % 60)
17
+ if m >= 60 :
18
+ h = int(m // 60)
19
+ m = m % 60
20
+ return f'{h}h {m}m {s}s'
21
+ else :
22
+ return f'{m}m {s}s'
23
+
24
+ def plot(array) :
25
+ fig = plt.figure(figsize=(30, 5))
26
+ ax = fig.add_subplot(111)
27
+ ax.xaxis.label.set_color('grey')
28
+ ax.yaxis.label.set_color('grey')
29
+ ax.xaxis.label.set_fontsize(23)
30
+ ax.yaxis.label.set_fontsize(23)
31
+ ax.tick_params(axis='x', colors='grey', labelsize=23)
32
+ ax.tick_params(axis='y', colors='grey', labelsize=23)
33
+ plt.plot(array)
34
+
35
+ def plot_spec(M) :
36
+ M = np.flip(M, axis=0)
37
+ plt.figure(figsize=(18,4))
38
+ plt.imshow(M, interpolation='nearest', aspect='auto')
39
+ plt.show()
40
+
models/WaveRNNModel/notebooks/utils/dsp.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import librosa, math
3
+
4
+ sample_rate = 22050
5
+ n_fft = 2048
6
+ fft_bins = n_fft // 2 + 1
7
+ num_mels = 80
8
+ hop_length = int(sample_rate * 0.0125) # 12.5ms
9
+ win_length = int(sample_rate * 0.05) # 50ms
10
+ fmin = 40
11
+ min_level_db = -100
12
+ ref_level_db = 20
13
+
14
+ def load_wav(filename, encode=True) :
15
+ x = librosa.load(filename, sr=sample_rate)[0]
16
+ if encode == True : x = encode_16bits(x)
17
+ return x
18
+
19
+ def save_wav(y, filename) :
20
+ if y.dtype != 'int16' :
21
+ y = encode_16bits(y)
22
+ librosa.output.write_wav(filename, y.astype(np.int16), sample_rate)
23
+
24
+ def split_signal(x) :
25
+ unsigned = x + 2**15
26
+ coarse = unsigned // 256
27
+ fine = unsigned % 256
28
+ return coarse, fine
29
+
30
+ def combine_signal(coarse, fine) :
31
+ return coarse * 256 + fine - 2**15
32
+
33
+ def encode_16bits(x) :
34
+ return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)
35
+
36
+ mel_basis = None
37
+
38
+ def linear_to_mel(spectrogram):
39
+ global mel_basis
40
+ if mel_basis is None:
41
+ mel_basis = build_mel_basis()
42
+ return np.dot(mel_basis, spectrogram)
43
+
44
+ def build_mel_basis():
45
+ return librosa.filters.mel(sample_rate, n_fft, n_mels=num_mels, fmin=fmin)
46
+
47
+ def normalize(S):
48
+ return np.clip((S - min_level_db) / -min_level_db, 0, 1)
49
+
50
+ def denormalize(S):
51
+ return (np.clip(S, 0, 1) * -min_level_db) + min_level_db
52
+
53
+ def amp_to_db(x):
54
+ return 20 * np.log10(np.maximum(1e-5, x))
55
+
56
+ def db_to_amp(x):
57
+ return np.power(10.0, x * 0.05)
58
+
59
+ def spectrogram(y):
60
+ D = stft(y)
61
+ S = amp_to_db(np.abs(D)) - ref_level_db
62
+ return normalize(S)
63
+
64
+ def melspectrogram(y):
65
+ D = stft(y)
66
+ S = amp_to_db(linear_to_mel(np.abs(D)))
67
+ return normalize(S)
68
+
69
+ def stft(y):
70
+ return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
models/WaveRNNModel/preprocess.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ from utils.display import *
3
+ from utils.dsp import *
4
+ from utils import hparams as hp
5
+ from multiprocessing import Pool, cpu_count
6
+ from utils.paths import Paths
7
+ import pickle
8
+ import argparse
9
+ from utils.text.recipes import ljspeech
10
+ from utils.files import get_files
11
+ from pathlib import Path
12
+
13
+
14
+ # Helper functions for argument types
15
+ def valid_n_workers(num):
16
+ n = int(num)
17
+ if n < 1:
18
+ raise argparse.ArgumentTypeError('%r must be an integer greater than 0' % num)
19
+ return n
20
+
21
+ parser = argparse.ArgumentParser(description='Preprocessing for WaveRNN and Tacotron')
22
+ parser.add_argument('--path', '-p', help='directly point to dataset path (overrides hparams.wav_path')
23
+ parser.add_argument('--extension', '-e', metavar='EXT', default='.wav', help='file extension to search for in dataset folder')
24
+ parser.add_argument('--num_workers', '-w', metavar='N', type=valid_n_workers, default=cpu_count()-1, help='The number of worker threads to use for preprocessing')
25
+ parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
26
+
27
+ args = parser.parse_args()
28
+ hp.configure(args.hp_file) # Load hparams from file
29
+ if args.path is None:
30
+ args.path = hp.wav_path
31
+
32
+ extension = args.extension
33
+ path = args.path
34
+
35
+ wav_files = get_files(path, extension)
36
+ paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
37
+
38
+ print(f'\n{len(wav_files)} {extension[1:]} files found in "{path}"\n')
39
+
40
+
41
+
42
+ def convert_file(path: Path):
43
+ y = load_wav(path)
44
+ peak = np.abs(y).max()
45
+ if hp.peak_norm or peak > 1.0:
46
+ y /= peak
47
+ mel = melspectrogram(y)
48
+ if hp.voc_mode == 'RAW':
49
+ quant = encode_mu_law(y, mu=2**hp.bits) if hp.mu_law else float_2_label(y, bits=hp.bits)
50
+ elif hp.voc_mode == 'MOL':
51
+ quant = float_2_label(y, bits=16)
52
+
53
+ return mel.astype(np.float32), quant.astype(np.int64)
54
+
55
+
56
+ def process_wav(path: Path):
57
+ wav_id = path.stem
58
+ m, x = convert_file(path)
59
+ #print("paths.mel:::",paths.mel)
60
+ np.save(paths.mel/f'{wav_id}.npy', m, allow_pickle=False)
61
+ np.save(paths.quant/f'{wav_id}.npy', x, allow_pickle=False)
62
+ return wav_id, m.shape[-1]
63
+
64
+ if __name__ == '__main__':
65
+
66
+
67
+ if len(wav_files) == 0:
68
+
69
+ print('Please point wav_path in hparams.py to your dataset,')
70
+ print('or use the --path option.\n')
71
+
72
+ else:
73
+
74
+ if not hp.ignore_tts:
75
+
76
+ text_dict = ljspeech(path)
77
+
78
+ with open(paths.data/'text_dict.pkl', 'wb') as f:
79
+ pickle.dump(text_dict, f)
80
+
81
+ n_workers = max(1, args.num_workers)
82
+
83
+ simple_table([
84
+ ('Sample Rate', hp.sample_rate),
85
+ ('Bit Depth', hp.bits),
86
+ ('Mu Law', hp.mu_law),
87
+ ('Hop Length', hp.hop_length),
88
+ ('CPU Usage', f'{n_workers}/{cpu_count()}')
89
+ ])
90
+
91
+ pool = Pool(processes=n_workers)
92
+ dataset = []
93
+ print("test22222")
94
+ for i, (item_id, length) in enumerate(pool.imap_unordered(process_wav, wav_files), 1):
95
+ dataset += [(item_id, length)]
96
+ bar = progbar(i, len(wav_files))
97
+ message = f'{bar} {i}/{len(wav_files)} '
98
+ stream(message)
99
+
100
+ with open(paths.data/'dataset.pkl', 'wb') as f:
101
+ pickle.dump(dataset, f)
102
+
103
+ print('\n\nCompleted. Ready to run "python train_tacotron.py" or "python train_wavernn.py". \n')
models/WaveRNNModel/quick_start.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models.fatchord_version import WaveRNN
3
+ from utils import hparams as hp
4
+ from utils.text.symbols import symbols
5
+ from models.tacotron import Tacotron
6
+ import argparse
7
+ from utils.text import text_to_sequence
8
+ from utils.display import save_attention, simple_table
9
+ import zipfile, os
10
+
11
+
12
+ os.makedirs('quick_start/tts_weights/', exist_ok=True)
13
+ os.makedirs('quick_start/voc_weights/', exist_ok=True)
14
+
15
+ zip_ref = zipfile.ZipFile('pretrained/ljspeech.wavernn.mol.800k.zip', 'r')
16
+ zip_ref.extractall('quick_start/voc_weights/')
17
+ zip_ref.close()
18
+
19
+ zip_ref = zipfile.ZipFile('pretrained/ljspeech.tacotron.r2.180k.zip', 'r')
20
+ zip_ref.extractall('quick_start/tts_weights/')
21
+ zip_ref.close()
22
+
23
+
24
+ if __name__ == "__main__":
25
+
26
+ # Parse Arguments
27
+ parser = argparse.ArgumentParser(description='TTS Generator')
28
+ parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!')
29
+ parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation (lower quality)')
30
+ parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slower Unbatched Generation (better quality)')
31
+ parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
32
+ parser.add_argument('--hp_file', metavar='FILE', default='hparams.py',
33
+ help='The file to use for the hyperparameters')
34
+ args = parser.parse_args()
35
+
36
+ hp.configure(args.hp_file) # Load hparams from file
37
+
38
+ parser.set_defaults(batched=True)
39
+ parser.set_defaults(input_text=None)
40
+
41
+ batched = args.batched
42
+ input_text = args.input_text
43
+
44
+ if not args.force_cpu and torch.cuda.is_available():
45
+ device = torch.device('cuda')
46
+ else:
47
+ device = torch.device('cpu')
48
+ print('Using device:', device)
49
+
50
+ print('\nInitialising WaveRNN Model...\n')
51
+
52
+ # Instantiate WaveRNN Model
53
+ voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
54
+ fc_dims=hp.voc_fc_dims,
55
+ bits=hp.bits,
56
+ pad=hp.voc_pad,
57
+ upsample_factors=hp.voc_upsample_factors,
58
+ feat_dims=hp.num_mels,
59
+ compute_dims=hp.voc_compute_dims,
60
+ res_out_dims=hp.voc_res_out_dims,
61
+ res_blocks=hp.voc_res_blocks,
62
+ hop_length=hp.hop_length,
63
+ sample_rate=hp.sample_rate,
64
+ mode='MOL').to(device)
65
+
66
+ voc_model.load('quick_start/voc_weights/latest_weights.pyt')
67
+
68
+ print('\nInitialising Tacotron Model...\n')
69
+
70
+ # Instantiate Tacotron Model
71
+ tts_model = Tacotron(embed_dims=hp.tts_embed_dims,
72
+ num_chars=len(symbols),
73
+ encoder_dims=hp.tts_encoder_dims,
74
+ decoder_dims=hp.tts_decoder_dims,
75
+ n_mels=hp.num_mels,
76
+ fft_bins=hp.num_mels,
77
+ postnet_dims=hp.tts_postnet_dims,
78
+ encoder_K=hp.tts_encoder_K,
79
+ lstm_dims=hp.tts_lstm_dims,
80
+ postnet_K=hp.tts_postnet_K,
81
+ num_highways=hp.tts_num_highways,
82
+ dropout=hp.tts_dropout,
83
+ stop_threshold=hp.tts_stop_threshold).to(device)
84
+
85
+
86
+ tts_model.load('quick_start/tts_weights/latest_weights.pyt')
87
+
88
+ if input_text:
89
+ inputs = [text_to_sequence(input_text.strip(), hp.tts_cleaner_names)]
90
+ else:
91
+ with open('sentences.txt') as f:
92
+ inputs = [text_to_sequence(l.strip(), hp.tts_cleaner_names) for l in f]
93
+
94
+ voc_k = voc_model.get_step() // 1000
95
+ tts_k = tts_model.get_step() // 1000
96
+
97
+ r = tts_model.r
98
+
99
+ simple_table([('WaveRNN', str(voc_k) + 'k'),
100
+ (f'Tacotron(r={r})', str(tts_k) + 'k'),
101
+ ('Generation Mode', 'Batched' if batched else 'Unbatched'),
102
+ ('Target Samples', 11_000 if batched else 'N/A'),
103
+ ('Overlap Samples', 550 if batched else 'N/A')])
104
+
105
+ for i, x in enumerate(inputs, 1):
106
+
107
+ print(f'\n| Generating {i}/{len(inputs)}')
108
+ _, m, attention = tts_model.generate(x)
109
+
110
+ if input_text:
111
+ save_path = f'quick_start/__input_{input_text[:10]}_{tts_k}k.wav'
112
+ else:
113
+ save_path = f'quick_start/{i}_batched{str(batched)}_{tts_k}k.wav'
114
+
115
+ # save_attention(attention, save_path)
116
+
117
+ m = torch.tensor(m).unsqueeze(0)
118
+ m = (m + 4) / 8
119
+
120
+ voc_model.generate(m, save_path, batched, 11_000, 550, hp.mu_law)
121
+
122
+ print('\n\nDone.\n')
models/WaveRNNModel/quick_start/tts_weights/latest_weights.pyt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c4c491d6ad43f1b9b9e1f393c7d8437592da4b26412838a7bc20f446b76d2f0
3
+ size 44433225
models/WaveRNNModel/quick_start/voc_weights/latest_weights.pyt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99e2a93453a8d531ba851b2e5c488105816f7c60ec50153c56436ff9bda8a26a
3
+ size 16985706
models/WaveRNNModel/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy==1.22.0
2
+ librosa==0.6.3
3
+ matplotlib
4
+ unidecode
5
+ inflect
6
+ nltk
models/WaveRNNModel/sentences.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ However, no attempt has been made yet to formulat a relativistic generalisation of action-angle variables for geodesic motion in Kerr spacetime and to calculate the dynamical frequencies of arbitrary bound non-plunging orbits.
2
+
3
+ The investigation of bound geodesic orbits in Kerr spacetime presented in this article clearly illustrates that the properties of these orbits in the regime of strong gravity are profoundly different from Keplerian orbits in the Newtonian regime.
4
+
5
+ The observation of as few as ten EMRIs can provide a measurement of the slope of the black-hole mass function to better precision than is currently known.
6
+
models/WaveRNNModel/source.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ https://github.com/Adwardgyhjs/RNNoiseAndWaveRNN
models/WaveRNNModel/train_tacotron.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import optim
3
+ import torch.nn.functional as F
4
+ from utils import hparams as hp
5
+ from utils.display import *
6
+ from utils.dataset import get_tts_datasets
7
+ from utils.text.symbols import symbols
8
+ from utils.paths import Paths
9
+ from models.tacotron import Tacotron
10
+ import argparse
11
+ from utils import data_parallel_workaround
12
+ import os
13
+ from pathlib import Path
14
+ import time
15
+ import numpy as np
16
+ import sys
17
+ from utils.checkpoints import save_checkpoint, restore_checkpoint
18
+
19
+
20
+ def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
21
+
22
+
23
+ def main():
24
+ # Parse Arguments
25
+ parser = argparse.ArgumentParser(description='Train Tacotron TTS')
26
+ parser.add_argument('--force_train', '-f', action='store_true', help='Forces the model to train past total steps')
27
+ parser.add_argument('--force_gta', '-g', action='store_true', help='Force the model to create GTA features')
28
+ parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
29
+ parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
30
+ args = parser.parse_args()
31
+
32
+ hp.configure(args.hp_file) # Load hparams from file
33
+ paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
34
+
35
+ force_train = args.force_train
36
+ force_gta = args.force_gta
37
+
38
+ if not args.force_cpu and torch.cuda.is_available():
39
+ device = torch.device('cuda')
40
+ for session in hp.tts_schedule:
41
+ _, _, _, batch_size = session
42
+ if batch_size % torch.cuda.device_count() != 0:
43
+ raise ValueError('`batch_size` must be evenly divisible by n_gpus!')
44
+ else:
45
+ device = torch.device('cpu')
46
+ print('Using device:', device)
47
+
48
+ # Instantiate Tacotron Model
49
+ print('\nInitialising Tacotron Model...\n')
50
+ model = Tacotron(embed_dims=hp.tts_embed_dims,
51
+ num_chars=len(symbols),
52
+ encoder_dims=hp.tts_encoder_dims,
53
+ decoder_dims=hp.tts_decoder_dims,
54
+ n_mels=hp.num_mels,
55
+ fft_bins=hp.num_mels,
56
+ postnet_dims=hp.tts_postnet_dims,
57
+ encoder_K=hp.tts_encoder_K,
58
+ lstm_dims=hp.tts_lstm_dims,
59
+ postnet_K=hp.tts_postnet_K,
60
+ num_highways=hp.tts_num_highways,
61
+ dropout=hp.tts_dropout,
62
+ stop_threshold=hp.tts_stop_threshold).to(device)
63
+
64
+ optimizer = optim.Adam(model.parameters())
65
+ restore_checkpoint('tts', paths, model, optimizer, create_if_missing=True)
66
+
67
+ if not force_gta:
68
+ for i, session in enumerate(hp.tts_schedule):
69
+ current_step = model.get_step()
70
+
71
+ r, lr, max_step, batch_size = session
72
+
73
+ training_steps = max_step - current_step
74
+
75
+ # Do we need to change to the next session?
76
+ if current_step >= max_step:
77
+ # Are there no further sessions than the current one?
78
+ if i == len(hp.tts_schedule)-1:
79
+ # There are no more sessions. Check if we force training.
80
+ if force_train:
81
+ # Don't finish the loop - train forever
82
+ training_steps = 999_999_999
83
+ else:
84
+ # We have completed training. Breaking is same as continue
85
+ break
86
+ else:
87
+ # There is a following session, go to it
88
+ continue
89
+
90
+ model.r = r
91
+
92
+ simple_table([(f'Steps with r={r}', str(training_steps//1000) + 'k Steps'),
93
+ ('Batch Size', batch_size),
94
+ ('Learning Rate', lr),
95
+ ('Outputs/Step (r)', model.r)])
96
+
97
+ train_set, attn_example = get_tts_datasets(paths.data, batch_size, r)
98
+ tts_train_loop(paths, model, optimizer, train_set, lr, training_steps, attn_example)
99
+
100
+ print('Training Complete.')
101
+ print('To continue training increase tts_total_steps in hparams.py or use --force_train\n')
102
+
103
+
104
+ print('Creating Ground Truth Aligned Dataset...\n')
105
+
106
+ train_set, attn_example = get_tts_datasets(paths.data, 8, model.r)
107
+ create_gta_features(model, train_set, paths.gta)
108
+
109
+ print('\n\nYou can now train WaveRNN on GTA features - use python train_wavernn.py --gta\n')
110
+
111
+
112
+ def tts_train_loop(paths: Paths, model: Tacotron, optimizer, train_set, lr, train_steps, attn_example):
113
+ device = next(model.parameters()).device # use same device as model parameters
114
+
115
+ for g in optimizer.param_groups: g['lr'] = lr
116
+
117
+ total_iters = len(train_set)
118
+ #print("train set",total_iters)
119
+ epochs = train_steps // total_iters + 1
120
+
121
+ for e in range(1, epochs+1):
122
+
123
+ start = time.time()
124
+ running_loss = 0
125
+
126
+ # Perform 1 epoch
127
+ for i, (x, m, ids, _) in enumerate(train_set, 1):
128
+
129
+ x, m = x.to(device), m.to(device)
130
+ #print("test33333333")
131
+ # Parallelize model onto GPUS using workaround due to python bug
132
+ if device.type == 'cuda' and torch.cuda.device_count() > 1:
133
+ m1_hat, m2_hat, attention = data_parallel_workaround(model, x, m)
134
+ else:
135
+ m1_hat, m2_hat, attention = model(x, m)
136
+
137
+ m1_loss = F.l1_loss(m1_hat, m)
138
+ m2_loss = F.l1_loss(m2_hat, m)
139
+
140
+ loss = m1_loss + m2_loss
141
+
142
+ optimizer.zero_grad()
143
+ loss.backward()
144
+ if hp.tts_clip_grad_norm is not None:
145
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.tts_clip_grad_norm)
146
+ if torch.isnan(grad_norm):
147
+ print('grad_norm was NaN!')
148
+
149
+ optimizer.step()
150
+ #print("test4444444")
151
+ running_loss += loss.item()
152
+ avg_loss = running_loss / i
153
+
154
+ speed = i / (time.time() - start)
155
+
156
+ step = model.get_step()
157
+ k = step // 1000
158
+
159
+ if step % hp.tts_checkpoint_every == 0:
160
+ ckpt_name = f'taco_step{k}K'
161
+ save_checkpoint('tts', paths, model, optimizer,
162
+ name=ckpt_name, is_silent=True)
163
+
164
+ if attn_example in ids:
165
+ idx = ids.index(attn_example)
166
+ save_attention(np_now(attention[idx][:, :160]), paths.tts_attention/f'{step}')
167
+ save_spectrogram(np_now(m2_hat[idx]), paths.tts_mel_plot/f'{step}', 600)
168
+
169
+ msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:#.4} | {speed:#.2} steps/s | Step: {k}k | '
170
+ stream(msg)
171
+
172
+ # Must save latest optimizer state to ensure that resuming training
173
+ # doesn't produce artifacts
174
+ save_checkpoint('tts', paths, model, optimizer, is_silent=True)
175
+ model.log(paths.tts_log, msg)
176
+ print(' ')
177
+
178
+
179
+ def create_gta_features(model: Tacotron, train_set, save_path: Path):
180
+ device = next(model.parameters()).device # use same device as model parameters
181
+
182
+ iters = len(train_set)
183
+
184
+ for i, (x, mels, ids, mel_lens) in enumerate(train_set, 1):
185
+
186
+ x, mels = x.to(device), mels.to(device)
187
+
188
+ with torch.no_grad(): _, gta, _ = model(x, mels)
189
+
190
+ gta = gta.cpu().numpy()
191
+
192
+ for j, item_id in enumerate(ids):
193
+ mel = gta[j][:, :mel_lens[j]]
194
+ mel = (mel + 4) / 8
195
+ np.save(save_path/f'{item_id}.npy', mel, allow_pickle=False)
196
+
197
+ bar = progbar(i, iters)
198
+ msg = f'{bar} {i}/{iters} Batches '
199
+ stream(msg)
200
+
201
+
202
+ if __name__ == "__main__":
203
+ main()
models/WaveRNNModel/train_wavernn.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ import torch
4
+ from torch import optim
5
+ import torch.nn.functional as F
6
+ from utils.display import stream, simple_table
7
+ from utils.dataset import get_vocoder_datasets
8
+ from utils.distribution import discretized_mix_logistic_loss
9
+ from utils import hparams as hp
10
+ from models.fatchord_version import WaveRNN
11
+ from gen_wavernn import gen_testset
12
+ from utils.paths import Paths
13
+ import argparse
14
+ from utils import data_parallel_workaround
15
+ from utils.checkpoints import save_checkpoint, restore_checkpoint
16
+
17
+
18
+ def main():
19
+
20
+ # Parse Arguments
21
+ parser = argparse.ArgumentParser(description='Train WaveRNN Vocoder')
22
+ parser.add_argument('--lr', '-l', type=float, help='[float] override hparams.py learning rate')
23
+ parser.add_argument('--batch_size', '-b', type=int, help='[int] override hparams.py batch size')
24
+ parser.add_argument('--force_train', '-f', action='store_true', help='Forces the model to train past total steps')
25
+ parser.add_argument('--gta', '-g', action='store_true', help='train wavernn on GTA features')
26
+ parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
27
+ parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters')
28
+ args = parser.parse_args()
29
+
30
+ hp.configure(args.hp_file) # load hparams from file
31
+ if args.lr is None:
32
+ args.lr = hp.voc_lr
33
+ if args.batch_size is None:
34
+ args.batch_size = hp.voc_batch_size
35
+
36
+ paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id)
37
+
38
+ batch_size = args.batch_size
39
+ force_train = args.force_train
40
+ train_gta = args.gta
41
+ lr = args.lr
42
+
43
+ if not args.force_cpu and torch.cuda.is_available():
44
+ device = torch.device('cuda')
45
+ if batch_size % torch.cuda.device_count() != 0:
46
+ raise ValueError('`batch_size` must be evenly divisible by n_gpus!')
47
+ else:
48
+ device = torch.device('cpu')
49
+ print('Using device:', device)
50
+
51
+ print('\nInitialising Model...\n')
52
+
53
+ # Instantiate WaveRNN Model
54
+ voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
55
+ fc_dims=hp.voc_fc_dims,
56
+ bits=hp.bits,
57
+ pad=hp.voc_pad,
58
+ upsample_factors=hp.voc_upsample_factors,
59
+ feat_dims=hp.num_mels,
60
+ compute_dims=hp.voc_compute_dims,
61
+ res_out_dims=hp.voc_res_out_dims,
62
+ res_blocks=hp.voc_res_blocks,
63
+ hop_length=hp.hop_length,
64
+ sample_rate=hp.sample_rate,
65
+ mode=hp.voc_mode).to(device)
66
+
67
+ # Check to make sure the hop length is correctly factorised
68
+ assert np.cumprod(hp.voc_upsample_factors)[-1] == hp.hop_length
69
+
70
+ optimizer = optim.Adam(voc_model.parameters())
71
+ restore_checkpoint('voc', paths, voc_model, optimizer, create_if_missing=True)
72
+
73
+ train_set, test_set = get_vocoder_datasets(paths.data, batch_size, train_gta)
74
+
75
+ total_steps = 10_000_000 if force_train else hp.voc_total_steps
76
+
77
+ simple_table([('Remaining', str((total_steps - voc_model.get_step())//1000) + 'k Steps'),
78
+ ('Batch Size', batch_size),
79
+ ('LR', lr),
80
+ ('Sequence Len', hp.voc_seq_len),
81
+ ('GTA Train', train_gta)])
82
+
83
+ loss_func = F.cross_entropy if voc_model.mode == 'RAW' else discretized_mix_logistic_loss
84
+ #print("test5555555555")
85
+ voc_train_loop(paths, voc_model, loss_func, optimizer, train_set, test_set, lr, total_steps)
86
+
87
+ print('Training Complete.')
88
+ print('To continue training increase voc_total_steps in hparams.py or use --force_train')
89
+
90
+
91
+ def voc_train_loop(paths: Paths, model: WaveRNN, loss_func, optimizer, train_set, test_set, lr, total_steps):
92
+ # Use same device as model parameters
93
+ device = next(model.parameters()).device
94
+
95
+ for g in optimizer.param_groups: g['lr'] = lr
96
+
97
+ total_iters = len(train_set)
98
+ print("total iters test:",len(train_set))
99
+ epochs = (total_steps - model.get_step()) // total_iters + 1
100
+
101
+ for e in range(1, epochs + 1):
102
+
103
+ start = time.time()
104
+ running_loss = 0.
105
+ #print("test666666666")
106
+ for i, (x, y, m) in enumerate(train_set, 1):
107
+ #print("test44444444444")
108
+ x, m, y = x.to(device), m.to(device), y.to(device)
109
+
110
+ # Parallelize model onto GPUS using workaround due to python bug
111
+ if device.type == 'cuda' and torch.cuda.device_count() > 1:
112
+ y_hat = data_parallel_workaround(model, x, m)
113
+ else:
114
+ y_hat = model(x, m)
115
+
116
+ if model.mode == 'RAW':
117
+ y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
118
+
119
+ elif model.mode == 'MOL':
120
+ y = y.float()
121
+
122
+ y = y.unsqueeze(-1)
123
+
124
+
125
+ loss = loss_func(y_hat, y)
126
+
127
+ optimizer.zero_grad()
128
+ loss.backward()
129
+ #print("test111111111111111111")
130
+ if hp.voc_clip_grad_norm is not None:
131
+ #print("test333333333333")
132
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.voc_clip_grad_norm)
133
+ if torch.isnan(grad_norm):
134
+ print('grad_norm was NaN!')
135
+ optimizer.step()
136
+
137
+ running_loss += loss.item()
138
+ avg_loss = running_loss / i
139
+
140
+ speed = i / (time.time() - start)
141
+
142
+ step = model.get_step()
143
+ k = step // 1000
144
+
145
+ if step % hp.voc_checkpoint_every == 0:
146
+ #print("test22222222222222222")
147
+ gen_testset(model, test_set, hp.voc_gen_at_checkpoint, hp.voc_gen_batched,
148
+ hp.voc_target, hp.voc_overlap, paths.voc_output)
149
+ ckpt_name = f'wave_step{k}K'
150
+ save_checkpoint('voc', paths, model, optimizer,
151
+ name=ckpt_name, is_silent=True)
152
+
153
+ msg = f'| Epoch: {e}/{epochs} ({i}/{total_iters}) | Loss: {avg_loss:.4f} | {speed:.1f} steps/s | Step: {k}k | '
154
+ stream(msg)
155
+
156
+ # Must save latest optimizer state to ensure that resuming training
157
+ # doesn't produce artifacts
158
+ save_checkpoint('voc', paths, model, optimizer, is_silent=True)
159
+ model.log(paths.voc_log, msg)
160
+ print(' ')
161
+
162
+
163
+ if __name__ == "__main__":
164
+ main()
models/WaveRNNModel/utils/__init__.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Make it explicit that we do it the Python 3 way
2
+ from __future__ import absolute_import, division, print_function, unicode_literals
3
+ from builtins import *
4
+ import sys
5
+ import torch
6
+ import re
7
+
8
+ from importlib.util import spec_from_file_location, module_from_spec
9
+ from pathlib import Path
10
+ from typing import Union
11
+
12
+ # Credit: Ryuichi Yamamoto (https://github.com/r9y9/wavenet_vocoder/blob/1717f145c8f8c0f3f85ccdf346b5209fa2e1c920/train.py#L599)
13
+ # Modified by: Ryan Butler (https://github.com/TheButlah)
14
+ # workaround for https://github.com/pytorch/pytorch/issues/15716
15
+ # the idea is to return outputs and replicas explicitly, so that making pytorch
16
+ # not to release the nodes (this is a pytorch bug though)
17
+
18
+ _output_ref = None
19
+ _replicas_ref = None
20
+
21
+ def data_parallel_workaround(model, *input):
22
+ global _output_ref
23
+ global _replicas_ref
24
+ device_ids = list(range(torch.cuda.device_count()))
25
+ output_device = device_ids[0]
26
+ replicas = torch.nn.parallel.replicate(model, device_ids)
27
+ # input.shape = (num_args, batch, ...)
28
+ inputs = torch.nn.parallel.scatter(input, device_ids)
29
+ # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
30
+ replicas = replicas[:len(inputs)]
31
+ outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
32
+ y_hat = torch.nn.parallel.gather(outputs, output_device)
33
+ _output_ref = outputs
34
+ _replicas_ref = replicas
35
+ return y_hat
36
+
37
+
38
+ ###### Deal with hparams import that has to be configured at runtime ######
39
+ class __HParams:
40
+ """Manages the hyperparams pseudo-module"""
41
+ def __init__(self, path: Union[str, Path]=None):
42
+ """Constructs the hyperparameters from a path to a python module. If
43
+ `path` is None, will raise an AttributeError whenever its attributes
44
+ are accessed. Otherwise, configures self based on `path`."""
45
+ if path is None:
46
+ print("path is none")
47
+ self._configured = False
48
+ else:
49
+ self.configure(path)
50
+
51
+ def __getattr__(self, item):
52
+ print("self config2222:",self.is_configured())
53
+ if not self.is_configured():
54
+ raise AttributeError("HParams not configured yet. Call self.configure()")
55
+ else:
56
+ return super().__getattr__(item)
57
+
58
+ def configure(self, path: Union[str, Path]):
59
+ """Configures hparams by copying over atrributes from a module with the
60
+ given path. Raises an exception if already configured."""
61
+ if self.is_configured():
62
+ raise RuntimeError("Cannot reconfigure hparams!")
63
+ print("path=",path)
64
+ ###### Check for proper path ######
65
+ if not isinstance(path, Path):
66
+ path = Path(path).expanduser()
67
+ if not path.exists():
68
+ raise FileNotFoundError(f"Could not find hparams file {path}")
69
+ elif path.suffix != ".py":
70
+ raise ValueError("`path` must be a python file")
71
+
72
+ ###### Load in attributes from module ######
73
+ m = _import_from_file("hparams", path)
74
+
75
+ reg = re.compile(r"^__.+__$") # Matches magic methods
76
+ for name, value in m.__dict__.items():
77
+ if reg.match(name):
78
+ # Skip builtins
79
+ continue
80
+ if name in self.__dict__:
81
+ # Cannot overwrite already existing attributes
82
+ raise AttributeError(
83
+ f"module at `path` cannot contain attribute {name} as it "
84
+ "overwrites an attribute of the same name in utils.hparams")
85
+ # Fair game to copy over the attribute
86
+ self.__setattr__(name, value)
87
+
88
+ self._configured = True
89
+ print("self config1111:",self._configured)
90
+
91
+ def is_configured(self):
92
+ return self._configured
93
+
94
+ hparams = __HParams()
95
+
96
+
97
+ def _import_from_file(name, path: Path):
98
+ """Programmatically returns a module object from a filepath"""
99
+ if not Path(path).exists():
100
+ raise FileNotFoundError('"%s" doesn\'t exist!' % path)
101
+ spec = spec_from_file_location(name, path)
102
+ if spec is None:
103
+ raise ValueError('could not load module from "%s"' % path)
104
+ m = module_from_spec(spec)
105
+ spec.loader.exec_module(m)
106
+ return m
models/WaveRNNModel/utils/checkpoints.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils.paths import Paths
3
+ from models.tacotron import Tacotron
4
+
5
+
6
+ def get_checkpoint_paths(checkpoint_type: str, paths: Paths):
7
+ """
8
+ Returns the correct checkpointing paths
9
+ depending on whether model is Vocoder or TTS
10
+
11
+ Args:
12
+ checkpoint_type: Either 'voc' or 'tts'
13
+ paths: Paths object
14
+ """
15
+ if checkpoint_type is 'tts':
16
+ weights_path = paths.tts_latest_weights
17
+ optim_path = paths.tts_latest_optim
18
+ checkpoint_path = paths.tts_checkpoints
19
+ elif checkpoint_type is 'voc':
20
+ weights_path = paths.voc_latest_weights
21
+ optim_path = paths.voc_latest_optim
22
+ checkpoint_path = paths.voc_checkpoints
23
+ else:
24
+ raise NotImplementedError
25
+
26
+ return weights_path, optim_path, checkpoint_path
27
+
28
+
29
+ def save_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,
30
+ name=None, is_silent=False):
31
+ """Saves the training session to disk.
32
+
33
+ Args:
34
+ paths: Provides information about the different paths to use.
35
+ model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from.
36
+ optimizer: An optmizer to save the state of (momentum, etc).
37
+ name: If provided, will name to a checkpoint with the given name. Note
38
+ that regardless of whether this is provided or not, this function
39
+ will always update the files specified in `paths` that give the
40
+ location of the latest weights and optimizer state. Saving
41
+ a named checkpoint happens in addition to this update.
42
+ """
43
+ def helper(path_dict, is_named):
44
+ s = 'named' if is_named else 'latest'
45
+ num_exist = sum(p.exists() for p in path_dict.values())
46
+
47
+ if num_exist not in (0,2):
48
+ # Checkpoint broken
49
+ raise FileNotFoundError(
50
+ f'We expected either both or no files in the {s} checkpoint to '
51
+ 'exist, but instead we got exactly one!')
52
+
53
+ if num_exist == 0:
54
+ if not is_silent: print(f'Creating {s} checkpoint...')
55
+ for p in path_dict.values():
56
+ p.parent.mkdir(parents=True, exist_ok=True)
57
+ else:
58
+ if not is_silent: print(f'Saving to existing {s} checkpoint...')
59
+
60
+ if not is_silent: print(f'Saving {s} weights: {path_dict["w"]}')
61
+ model.save(path_dict['w'])
62
+ if not is_silent: print(f'Saving {s} optimizer state: {path_dict["o"]}')
63
+ torch.save(optimizer.state_dict(), path_dict['o'])
64
+
65
+ weights_path, optim_path, checkpoint_path = \
66
+ get_checkpoint_paths(checkpoint_type, paths)
67
+
68
+ latest_paths = {'w': weights_path, 'o': optim_path}
69
+ helper(latest_paths, False)
70
+
71
+ if name:
72
+ named_paths = {
73
+ 'w': checkpoint_path/f'{name}_weights.pyt',
74
+ 'o': checkpoint_path/f'{name}_optim.pyt',
75
+ }
76
+ helper(named_paths, True)
77
+
78
+
79
+ def restore_checkpoint(checkpoint_type: str, paths: Paths, model, optimizer, *,
80
+ name=None, create_if_missing=False):
81
+ """Restores from a training session saved to disk.
82
+
83
+ NOTE: The optimizer's state is placed on the same device as it's model
84
+ parameters. Therefore, be sure you have done `model.to(device)` before
85
+ calling this method.
86
+
87
+ Args:
88
+ paths: Provides information about the different paths to use.
89
+ model: A `Tacotron` or `WaveRNN` model to save the parameters and buffers from.
90
+ optimizer: An optmizer to save the state of (momentum, etc).
91
+ name: If provided, will restore from a checkpoint with the given name.
92
+ Otherwise, will restore from the latest weights and optimizer state
93
+ as specified in `paths`.
94
+ create_if_missing: If `True`, will create the checkpoint if it doesn't
95
+ yet exist, as well as update the files specified in `paths` that
96
+ give the location of the current latest weights and optimizer state.
97
+ If `False` and the checkpoint doesn't exist, will raise a
98
+ `FileNotFoundError`.
99
+ """
100
+
101
+ weights_path, optim_path, checkpoint_path = \
102
+ get_checkpoint_paths(checkpoint_type, paths)
103
+
104
+ if name:
105
+ path_dict = {
106
+ 'w': checkpoint_path/f'{name}_weights.pyt',
107
+ 'o': checkpoint_path/f'{name}_optim.pyt',
108
+ }
109
+ s = 'named'
110
+ else:
111
+ path_dict = {
112
+ 'w': weights_path,
113
+ 'o': optim_path
114
+ }
115
+ s = 'latest'
116
+
117
+ num_exist = sum(p.exists() for p in path_dict.values())
118
+ if num_exist == 2:
119
+ # Checkpoint exists
120
+ print(f'Restoring from {s} checkpoint...')
121
+ print(f'Loading {s} weights: {path_dict["w"]}')
122
+ model.load(path_dict['w'])
123
+ print(f'Loading {s} optimizer state: {path_dict["o"]}')
124
+ optimizer.load_state_dict(torch.load(path_dict['o']))
125
+ elif create_if_missing:
126
+ save_checkpoint(checkpoint_type, paths, model, optimizer, name=name, is_silent=False)
127
+ else:
128
+ raise FileNotFoundError(f'The {s} checkpoint could not be found!')
models/WaveRNNModel/utils/dataset.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import random
3
+ import torch
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torch.utils.data.sampler import Sampler
6
+ from utils.dsp import *
7
+ from utils import hparams as hp
8
+ from utils.text import text_to_sequence
9
+ from utils.paths import Paths
10
+ from pathlib import Path
11
+
12
+ from functools import partial
13
+
14
+ ###################################################################################
15
+ # WaveRNN/Vocoder Dataset #########################################################
16
+ ###################################################################################
17
+
18
+
19
+ class VocoderDataset(Dataset):
20
+ def __init__(self, path: Path, dataset_ids, train_gta=False):
21
+ self.metadata = dataset_ids
22
+ self.mel_path = path/'gta' if train_gta else path/'mel'
23
+ self.quant_path = path/'quant'
24
+
25
+
26
+ def __getitem__(self, index):
27
+ item_id = self.metadata[index]
28
+ m = np.load(self.mel_path/f'{item_id}.npy')
29
+ x = np.load(self.quant_path/f'{item_id}.npy')
30
+ return m, x
31
+
32
+ def __len__(self):
33
+ return len(self.metadata)
34
+
35
+
36
+ def get_vocoder_datasets(path: Path, batch_size, train_gta):
37
+
38
+ with open(path/'dataset.pkl', 'rb') as f:
39
+ dataset = pickle.load(f)
40
+
41
+ dataset_ids = [x[0] for x in dataset]
42
+
43
+ random.seed(1234)
44
+ random.shuffle(dataset_ids)
45
+
46
+ test_ids = dataset_ids[-hp.voc_test_samples:]
47
+ train_ids = dataset_ids[:-hp.voc_test_samples]
48
+
49
+ train_dataset = VocoderDataset(path, train_ids, train_gta)
50
+ test_dataset = VocoderDataset(path, test_ids, train_gta)
51
+
52
+ train_set = DataLoader(train_dataset,
53
+ collate_fn=collate_vocoder,
54
+ batch_size=batch_size,
55
+ num_workers=2,
56
+ shuffle=True,
57
+ pin_memory=True)
58
+
59
+ test_set = DataLoader(test_dataset,
60
+ batch_size=1,
61
+ num_workers=1,
62
+ shuffle=False,
63
+ pin_memory=True)
64
+
65
+ return train_set, test_set
66
+
67
+
68
+ def collate_vocoder(batch):
69
+ if not hp.is_configured():
70
+ print("未配置参数")
71
+ hp.configure("E:\\智能语音处理系统\\Noise-suppression-and-speech-recognition-systems-master\\WaveRNNModel\\hparams.py")
72
+ mel_win = hp.voc_seq_len // hp.hop_length + 2 * hp.voc_pad
73
+ max_offsets = [x[0].shape[-1] -2 - (mel_win + 2 * hp.voc_pad) for x in batch]
74
+ mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
75
+ sig_offsets = [(offset + hp.voc_pad) * hp.hop_length for offset in mel_offsets]
76
+
77
+ mels = [x[0][:, mel_offsets[i]:mel_offsets[i] + mel_win] for i, x in enumerate(batch)]
78
+
79
+ labels = [x[1][sig_offsets[i]:sig_offsets[i] + hp.voc_seq_len + 1] for i, x in enumerate(batch)]
80
+
81
+ mels = np.stack(mels).astype(np.float32)
82
+ labels = np.stack(labels).astype(np.int64)
83
+
84
+ mels = torch.tensor(mels)
85
+ labels = torch.tensor(labels).long()
86
+
87
+ x = labels[:, :hp.voc_seq_len]
88
+ y = labels[:, 1:]
89
+
90
+ bits = 16 if hp.voc_mode == 'MOL' else hp.bits
91
+
92
+ x = label_2_float(x.float(), bits)
93
+
94
+ if hp.voc_mode == 'MOL':
95
+ y = label_2_float(y.float(), bits)
96
+
97
+ return x, y, mels
98
+
99
+
100
+ ###################################################################################
101
+ # Tacotron/TTS Dataset ############################################################
102
+ ###################################################################################
103
+
104
+
105
+ def get_tts_datasets(path: Path, batch_size, r):
106
+ print("path",path)
107
+ with open(path/'dataset.pkl', 'rb') as f:
108
+ dataset = pickle.load(f)
109
+
110
+ dataset_ids = []
111
+ mel_lengths = []
112
+ print("hp.tts_max_mel_len",hp.tts_max_mel_len)
113
+ for (item_id, len) in dataset:
114
+ if len <= hp.tts_max_mel_len:
115
+ dataset_ids += [item_id]
116
+ mel_lengths += [len]
117
+
118
+ with open(path/'text_dict.pkl', 'rb') as f:
119
+ text_dict = pickle.load(f)
120
+
121
+
122
+ train_dataset = TTSDataset(path, dataset_ids, text_dict)
123
+
124
+ sampler = None
125
+
126
+ if hp.tts_bin_lengths:
127
+ sampler = BinnedLengthSampler(mel_lengths, batch_size, batch_size * 3)
128
+
129
+ train_set = DataLoader(train_dataset,
130
+ collate_fn=partial(collate_tts, r=r),
131
+ batch_size=batch_size,
132
+ sampler=sampler,
133
+ num_workers=1,
134
+ pin_memory=True)
135
+
136
+ longest = mel_lengths.index(max(mel_lengths))
137
+
138
+ # Used to evaluate attention during training process
139
+ attn_example = dataset_ids[longest]
140
+
141
+ # print(attn_example)
142
+
143
+ return train_set, attn_example
144
+
145
+
146
+ class TTSDataset(Dataset):
147
+ def __init__(self, path: Path, dataset_ids, text_dict):
148
+ self.path = path
149
+ self.metadata = dataset_ids
150
+ self.text_dict = text_dict
151
+
152
+ def __getitem__(self, index):
153
+ item_id = self.metadata[index]
154
+ #print("path555555",self.path)
155
+ if not hp.is_configured():
156
+ print("未配置参数")
157
+ hp.configure("E:\\智能语音处理系统\\Noise-suppression-and-speech-recognition-systems-master\\WaveRNNModel\\hparams.py")
158
+ #print("test666666",hp.tts_cleaner_names)
159
+ x = text_to_sequence(self.text_dict[item_id], hp.tts_cleaner_names)
160
+ mel = np.load(self.path/'mel'/f'{item_id}.npy')
161
+ mel_len = mel.shape[-1]
162
+ return x, mel, item_id, mel_len
163
+
164
+ def __len__(self):
165
+ return len(self.metadata)
166
+
167
+
168
+ def pad1d(x, max_len):
169
+ return np.pad(x, (0, max_len - len(x)), mode='constant')
170
+
171
+
172
+ def pad2d(x, max_len):
173
+ return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode='constant')
174
+
175
+
176
+ def collate_tts(batch, r):
177
+
178
+ x_lens = [len(x[0]) for x in batch]
179
+ max_x_len = max(x_lens)
180
+
181
+ chars = [pad1d(x[0], max_x_len) for x in batch]
182
+ chars = np.stack(chars)
183
+
184
+ spec_lens = [x[1].shape[-1] for x in batch]
185
+ max_spec_len = max(spec_lens) + 1
186
+ if max_spec_len % r != 0:
187
+ max_spec_len += r - max_spec_len % r
188
+
189
+ mel = [pad2d(x[1], max_spec_len) for x in batch]
190
+ mel = np.stack(mel)
191
+
192
+ ids = [x[2] for x in batch]
193
+ mel_lens = [x[3] for x in batch]
194
+
195
+ chars = torch.tensor(chars).long()
196
+ mel = torch.tensor(mel)
197
+
198
+ # scale spectrograms to -4 <--> 4
199
+ mel = (mel * 8.) - 4.
200
+ return chars, mel, ids, mel_lens
201
+
202
+
203
+ class BinnedLengthSampler(Sampler):
204
+ def __init__(self, lengths, batch_size, bin_size):
205
+ _, self.idx = torch.sort(torch.tensor(lengths).long())
206
+ self.batch_size = batch_size
207
+ self.bin_size = bin_size
208
+ assert self.bin_size % self.batch_size == 0
209
+
210
+ def __iter__(self):
211
+ # Need to change to numpy since there's a bug in random.shuffle(tensor)
212
+ # TODO: Post an issue on pytorch repo
213
+ idx = self.idx.numpy()
214
+ bins = []
215
+
216
+ for i in range(len(idx) // self.bin_size):
217
+ this_bin = idx[i * self.bin_size:(i + 1) * self.bin_size]
218
+ random.shuffle(this_bin)
219
+ bins += [this_bin]
220
+
221
+ random.shuffle(bins)
222
+ binned_idx = np.stack(bins).reshape(-1)
223
+
224
+ if len(binned_idx) < len(idx):
225
+ last_bin = idx[len(binned_idx):]
226
+ random.shuffle(last_bin)
227
+ binned_idx = np.concatenate([binned_idx, last_bin])
228
+
229
+ return iter(torch.tensor(binned_idx).long())
230
+
231
+ def __len__(self):
232
+ return len(self.idx)
models/WaveRNNModel/utils/display.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib as mpl
2
+ mpl.use('agg') # Use non-interactive backend by default
3
+ import matplotlib.pyplot as plt
4
+ import time
5
+ import numpy as np
6
+ import sys
7
+
8
+
9
+ def progbar(i, n, size=16):
10
+ done = (i * size) // n
11
+ bar = ''
12
+ for i in range(size):
13
+ bar += '█' if i <= done else '░'
14
+ return bar
15
+
16
+
17
+ def stream(message):
18
+ sys.stdout.write(f"\r{message}")
19
+
20
+
21
+ def simple_table(item_tuples):
22
+
23
+ border_pattern = '+---------------------------------------'
24
+ whitespace = ' '
25
+
26
+ headings, cells, = [], []
27
+
28
+ for item in item_tuples:
29
+
30
+ heading, cell = str(item[0]), str(item[1])
31
+
32
+ pad_head = True if len(heading) < len(cell) else False
33
+
34
+ pad = abs(len(heading) - len(cell))
35
+ pad = whitespace[:pad]
36
+
37
+ pad_left = pad[:len(pad)//2]
38
+ pad_right = pad[len(pad)//2:]
39
+
40
+ if pad_head:
41
+ heading = pad_left + heading + pad_right
42
+ else:
43
+ cell = pad_left + cell + pad_right
44
+
45
+ headings += [heading]
46
+ cells += [cell]
47
+
48
+ border, head, body = '', '', ''
49
+
50
+ for i in range(len(item_tuples)):
51
+
52
+ temp_head = f'| {headings[i]} '
53
+ temp_body = f'| {cells[i]} '
54
+
55
+ border += border_pattern[:len(temp_head)]
56
+ head += temp_head
57
+ body += temp_body
58
+
59
+ if i == len(item_tuples) - 1:
60
+ head += '|'
61
+ body += '|'
62
+ border += '+'
63
+
64
+ print(border)
65
+ print(head)
66
+ print(border)
67
+ print(body)
68
+ print(border)
69
+ print(' ')
70
+
71
+
72
+ def time_since(started):
73
+ elapsed = time.time() - started
74
+ m = int(elapsed // 60)
75
+ s = int(elapsed % 60)
76
+ if m >= 60:
77
+ h = int(m // 60)
78
+ m = m % 60
79
+ return f'{h}h {m}m {s}s'
80
+ else:
81
+ return f'{m}m {s}s'
82
+
83
+
84
+ def save_attention(attn, path):
85
+ fig = plt.figure(figsize=(12, 6))
86
+ plt.imshow(attn.T, interpolation='nearest', aspect='auto')
87
+ fig.savefig(path.parent/f'{path.stem}.png', bbox_inches='tight')
88
+ plt.close(fig)
89
+
90
+
91
+ def save_spectrogram(M, path, length=None):
92
+ M = np.flip(M, axis=0)
93
+ if length: M = M[:, :length]
94
+ fig = plt.figure(figsize=(12, 6))
95
+ plt.imshow(M, interpolation='nearest', aspect='auto')
96
+ fig.savefig(f'{path}.png', bbox_inches='tight')
97
+ plt.close(fig)
98
+
99
+
100
+ def plot(array):
101
+ mpl.interactive(True)
102
+ fig = plt.figure(figsize=(30, 5))
103
+ ax = fig.add_subplot(111)
104
+ ax.xaxis.label.set_color('grey')
105
+ ax.yaxis.label.set_color('grey')
106
+ ax.xaxis.label.set_fontsize(23)
107
+ ax.yaxis.label.set_fontsize(23)
108
+ ax.tick_params(axis='x', colors='grey', labelsize=23)
109
+ ax.tick_params(axis='y', colors='grey', labelsize=23)
110
+ plt.plot(array)
111
+ mpl.interactive(False)
112
+
113
+
114
+ def plot_spec(M):
115
+ mpl.interactive(True)
116
+ M = np.flip(M, axis=0)
117
+ plt.figure(figsize=(18,4))
118
+ plt.imshow(M, interpolation='nearest', aspect='auto')
119
+ plt.show()
120
+ mpl.interactive(False)
121
+
models/WaveRNNModel/utils/distribution.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def log_sum_exp(x):
7
+ """ numerically stable log_sum_exp implementation that prevents overflow """
8
+ # TF ordering
9
+ axis = len(x.size()) - 1
10
+ m, _ = torch.max(x, dim=axis)
11
+ m2, _ = torch.max(x, dim=axis, keepdim=True)
12
+ return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
13
+
14
+
15
+ # It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
16
+ def discretized_mix_logistic_loss(y_hat, y, num_classes=65536,
17
+ log_scale_min=None, reduce=True):
18
+ if log_scale_min is None:
19
+ log_scale_min = float(np.log(1e-14))
20
+ y_hat = y_hat.permute(0,2,1)
21
+ assert y_hat.dim() == 3
22
+ assert y_hat.size(1) % 3 == 0
23
+ nr_mix = y_hat.size(1) // 3
24
+
25
+ # (B x T x C)
26
+ y_hat = y_hat.transpose(1, 2)
27
+
28
+ # unpack parameters. (B, T, num_mixtures) x 3
29
+ logit_probs = y_hat[:, :, :nr_mix]
30
+ means = y_hat[:, :, nr_mix:2 * nr_mix]
31
+ log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)
32
+
33
+ # B x T x 1 -> B x T x num_mixtures
34
+ y = y.expand_as(means)
35
+
36
+ centered_y = y - means
37
+ inv_stdv = torch.exp(-log_scales)
38
+ plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
39
+ cdf_plus = torch.sigmoid(plus_in)
40
+ min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
41
+ cdf_min = torch.sigmoid(min_in)
42
+
43
+ # log probability for edge case of 0 (before scaling)
44
+ # equivalent: torch.log(F.sigmoid(plus_in))
45
+ log_cdf_plus = plus_in - F.softplus(plus_in)
46
+
47
+ # log probability for edge case of 255 (before scaling)
48
+ # equivalent: (1 - F.sigmoid(min_in)).log()
49
+ log_one_minus_cdf_min = -F.softplus(min_in)
50
+
51
+ # probability for all other cases
52
+ cdf_delta = cdf_plus - cdf_min
53
+
54
+ mid_in = inv_stdv * centered_y
55
+ # log probability in the center of the bin, to be used in extreme cases
56
+ # (not actually used in our code)
57
+ log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
58
+
59
+ # tf equivalent
60
+ """
61
+ log_probs = tf.where(x < -0.999, log_cdf_plus,
62
+ tf.where(x > 0.999, log_one_minus_cdf_min,
63
+ tf.where(cdf_delta > 1e-5,
64
+ tf.log(tf.maximum(cdf_delta, 1e-12)),
65
+ log_pdf_mid - np.log(127.5))))
66
+ """
67
+ # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
68
+ # for num_classes=65536 case? 1e-7? not sure..
69
+ inner_inner_cond = (cdf_delta > 1e-5).float()
70
+
71
+ inner_inner_out = inner_inner_cond * \
72
+ torch.log(torch.clamp(cdf_delta, min=1e-12)) + \
73
+ (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
74
+ inner_cond = (y > 0.999).float()
75
+ inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
76
+ cond = (y < -0.999).float()
77
+ log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
78
+
79
+ log_probs = log_probs + F.log_softmax(logit_probs, -1)
80
+
81
+ if reduce:
82
+ return -torch.mean(log_sum_exp(log_probs))
83
+ else:
84
+ return -log_sum_exp(log_probs).unsqueeze(-1)
85
+
86
+
87
+ def sample_from_discretized_mix_logistic(y, log_scale_min=None):
88
+ """
89
+ Sample from discretized mixture of logistic distributions
90
+ Args:
91
+ y (Tensor): B x C x T
92
+ log_scale_min (float): Log scale minimum value
93
+ Returns:
94
+ Tensor: sample in range of [-1, 1].
95
+ """
96
+ if log_scale_min is None:
97
+ log_scale_min = float(np.log(1e-14))
98
+ assert y.size(1) % 3 == 0
99
+ nr_mix = y.size(1) // 3
100
+
101
+ # B x T x C
102
+ y = y.transpose(1, 2)
103
+ logit_probs = y[:, :, :nr_mix]
104
+
105
+ # sample mixture indicator from softmax
106
+ temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
107
+ temp = logit_probs.data - torch.log(- torch.log(temp))
108
+ _, argmax = temp.max(dim=-1)
109
+
110
+ # (B, T) -> (B, T, nr_mix)
111
+ one_hot = F.one_hot(argmax, nr_mix).float()
112
+ # select logistic parameters
113
+ means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
114
+ log_scales = torch.clamp(torch.sum(
115
+ y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
116
+ # sample from logistic & clip to interval
117
+ # we don't actually round to the nearest 8bit value when sampling
118
+ u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
119
+ x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
120
+
121
+ x = torch.clamp(torch.clamp(x, min=-1.), max=1.)
122
+
123
+ return x
124
+
125
+ '''
126
+ def to_one_hot(tensor, n, fill_with=1.):
127
+ # we perform one hot encore with respect to the last axis
128
+ one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
129
+ if tensor.is_cuda:
130
+ one_hot = one_hot.cuda()
131
+ one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
132
+ return one_hot'''