nukopy commited on
Commit
5becb6b
·
1 Parent(s): 725afe9

Add audio files with Git LFS support

Browse files

- Add .wav, .ogg, .mp3, .flac to .gitattributes for LFS tracking
- Migrate existing audio files to Git LFS
- This resolves the binary file rejection from Hugging Face Spaces

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. .gitignore +219 -0
  3. Makefile +5 -0
  4. README.md +2 -0
  5. app.py +18 -0
  6. apps/audio_cloning/main.py +15 -0
  7. apps/audio_cloning/vallex/data/__init__.py +3 -0
  8. apps/audio_cloning/vallex/data/collation.py +107 -0
  9. apps/audio_cloning/vallex/data/datamodule.py +419 -0
  10. apps/audio_cloning/vallex/data/dataset.py +242 -0
  11. apps/audio_cloning/vallex/data/fbank.py +212 -0
  12. apps/audio_cloning/vallex/data/input_strategies.py +159 -0
  13. apps/audio_cloning/vallex/data/symbol_table.py +289 -0
  14. apps/audio_cloning/vallex/data/tokenizer.py +121 -0
  15. apps/audio_cloning/vallex/descriptions.py +34 -0
  16. apps/audio_cloning/vallex/examples.py +108 -0
  17. apps/audio_cloning/vallex/g2p/__init__.py +84 -0
  18. apps/audio_cloning/vallex/g2p/bpe_1024.json +2049 -0
  19. apps/audio_cloning/vallex/g2p/bpe_69.json +141 -0
  20. apps/audio_cloning/vallex/g2p/cleaners.py +76 -0
  21. apps/audio_cloning/vallex/g2p/english.py +197 -0
  22. apps/audio_cloning/vallex/g2p/japanese.py +173 -0
  23. apps/audio_cloning/vallex/g2p/mandarin.py +337 -0
  24. apps/audio_cloning/vallex/g2p/symbols.py +76 -0
  25. apps/audio_cloning/vallex/macros.py +34 -0
  26. apps/audio_cloning/vallex/main.py +461 -0
  27. apps/audio_cloning/vallex/models/__init__.py +127 -0
  28. apps/audio_cloning/vallex/models/macros.py +11 -0
  29. apps/audio_cloning/vallex/models/transformer.py +386 -0
  30. apps/audio_cloning/vallex/models/vallex.py +823 -0
  31. apps/audio_cloning/vallex/models/visualizer.py +102 -0
  32. apps/audio_cloning/vallex/modules/__init__.py +0 -0
  33. apps/audio_cloning/vallex/modules/activation.py +612 -0
  34. apps/audio_cloning/vallex/modules/embedding.py +97 -0
  35. apps/audio_cloning/vallex/modules/optim.py +1105 -0
  36. apps/audio_cloning/vallex/modules/scaling.py +1369 -0
  37. apps/audio_cloning/vallex/modules/scheduler.py +78 -0
  38. apps/audio_cloning/vallex/modules/transformer.py +683 -0
  39. apps/audio_cloning/vallex/presets/acou_1.npz +3 -0
  40. apps/audio_cloning/vallex/presets/acou_2.npz +3 -0
  41. apps/audio_cloning/vallex/presets/acou_3.npz +3 -0
  42. apps/audio_cloning/vallex/presets/acou_4.npz +3 -0
  43. apps/audio_cloning/vallex/presets/alan.npz +3 -0
  44. apps/audio_cloning/vallex/presets/amused.npz +3 -0
  45. apps/audio_cloning/vallex/presets/anger.npz +3 -0
  46. apps/audio_cloning/vallex/presets/babara.npz +3 -0
  47. apps/audio_cloning/vallex/presets/bronya.npz +3 -0
  48. apps/audio_cloning/vallex/presets/cafe.npz +3 -0
  49. apps/audio_cloning/vallex/presets/dingzhen.npz +3 -0
  50. apps/audio_cloning/vallex/presets/disgust.npz +3 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ # Audio files
37
+ *.wav filter=lfs diff=lfs merge=lfs -text
38
+ *.ogg filter=lfs diff=lfs merge=lfs -text
39
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
40
+ *.flac filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pretrained models
2
+ models/**/*.pt
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[codz]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py.cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ # Pipfile.lock
99
+
100
+ # UV
101
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # uv.lock
105
+
106
+ # poetry
107
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
108
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
109
+ # commonly ignored for libraries.
110
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
111
+ # poetry.lock
112
+ # poetry.toml
113
+
114
+ # pdm
115
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
116
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
117
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
118
+ # pdm.lock
119
+ # pdm.toml
120
+ .pdm-python
121
+ .pdm-build/
122
+
123
+ # pixi
124
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
125
+ # pixi.lock
126
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
127
+ # in the .venv directory. It is recommended not to include this directory in version control.
128
+ .pixi
129
+
130
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
131
+ __pypackages__/
132
+
133
+ # Celery stuff
134
+ celerybeat-schedule
135
+ celerybeat.pid
136
+
137
+ # Redis
138
+ *.rdb
139
+ *.aof
140
+ *.pid
141
+
142
+ # RabbitMQ
143
+ mnesia/
144
+ rabbitmq/
145
+ rabbitmq-data/
146
+
147
+ # ActiveMQ
148
+ activemq-data/
149
+
150
+ # SageMath parsed files
151
+ *.sage.py
152
+
153
+ # Environments
154
+ .env
155
+ .envrc
156
+ .venv
157
+ env/
158
+ venv/
159
+ ENV/
160
+ env.bak/
161
+ venv.bak/
162
+
163
+ # Spyder project settings
164
+ .spyderproject
165
+ .spyproject
166
+
167
+ # Rope project settings
168
+ .ropeproject
169
+
170
+ # mkdocs documentation
171
+ /site
172
+
173
+ # mypy
174
+ .mypy_cache/
175
+ .dmypy.json
176
+ dmypy.json
177
+
178
+ # Pyre type checker
179
+ .pyre/
180
+
181
+ # pytype static type analyzer
182
+ .pytype/
183
+
184
+ # Cython debug symbols
185
+ cython_debug/
186
+
187
+ # PyCharm
188
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
189
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
190
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
191
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
192
+ # .idea/
193
+
194
+ # Abstra
195
+ # Abstra is an AI-powered process automation framework.
196
+ # Ignore directories containing user credentials, local state, and settings.
197
+ # Learn more at https://abstra.io/docs
198
+ .abstra/
199
+
200
+ # Visual Studio Code
201
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
202
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
203
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
204
+ # you could uncomment the following to ignore the entire vscode folder
205
+ # .vscode/
206
+
207
+ # Ruff stuff:
208
+ .ruff_cache/
209
+
210
+ # PyPI configuration file
211
+ .pypirc
212
+
213
+ # Marimo
214
+ marimo/_static/
215
+ marimo/_lsp/
216
+ __marimo__/
217
+
218
+ # Streamlit
219
+ .streamlit/secrets.toml
Makefile ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ run:
2
+ PYTHONPATH=. uv run gradio app.py
3
+
4
+ export-requirements:
5
+ uv export --format requirements-txt > requirements.txt
README.md CHANGED
@@ -4,7 +4,9 @@ emoji: 🐨
4
  colorFrom: purple
5
  colorTo: pink
6
  sdk: gradio
 
7
  sdk_version: 5.49.1
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: purple
5
  colorTo: pink
6
  sdk: gradio
7
+ python_version: 3.13
8
  sdk_version: 5.49.1
9
+ suggested_hardware: g4
10
  app_file: app.py
11
  pinned: false
12
  ---
app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from apps.audio_cloning.main import main as audio_cloning
4
+ from apps.dev.main import main as dev
5
+
6
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
7
+ audio_cloning()
8
+
9
+
10
+ with demo.route(name="Dev", path="/dev"):
11
+ dev()
12
+
13
+
14
+ if __name__ == "__main__":
15
+ # demo.queue(max_size=2, concurrency_limit=2, concurrency_id="gpu_queue")
16
+ # auth = ("charaxim", "chrmx-demo-wordpass")
17
+ # demo.launch(share=False, auth=auth)
18
+ demo.launch(share=False)
apps/audio_cloning/main.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from .vallex.main import main as vallex
4
+
5
+
6
+ def main():
7
+ gr.Markdown("# Charamix Audio Cloning Prototype")
8
+
9
+ # zero-shot audio cloning
10
+ with gr.Tab("Zero-shot Audio Cloning with VALL-E-X"):
11
+ vallex()
12
+
13
+ # fine-tuning audio cloning
14
+ # with gr.Tab("Fine-tuning Audio Cloning"):
15
+ # gr.Markdown("TODO")
apps/audio_cloning/vallex/data/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # from .datamodule import *
2
+ # from .tokenizer import *
3
+ from .collation import *
apps/audio_cloning/vallex/data/collation.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ class TextTokenCollater:
8
+ """Collate list of text tokens
9
+
10
+ Map sentences to integers. Sentences are padded to equal length.
11
+ Beginning and end-of-sequence symbols can be added.
12
+
13
+ Example:
14
+ >>> token_collater = TextTokenCollater(text_tokens)
15
+ >>> tokens_batch, tokens_lens = token_collater(text)
16
+
17
+ Returns:
18
+ tokens_batch: IntTensor of shape (B, L)
19
+ B: batch dimension, number of input sentences
20
+ L: length of the longest sentence
21
+ tokens_lens: IntTensor of shape (B,)
22
+ Length of each sentence after adding <eos> and <bos>
23
+ but before padding.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ text_tokens: List[str],
29
+ add_eos: bool = True,
30
+ add_bos: bool = True,
31
+ pad_symbol: str = "<pad>",
32
+ bos_symbol: str = "<bos>",
33
+ eos_symbol: str = "<eos>",
34
+ ):
35
+ self.pad_symbol = pad_symbol
36
+
37
+ self.add_eos = add_eos
38
+ self.add_bos = add_bos
39
+
40
+ self.bos_symbol = bos_symbol
41
+ self.eos_symbol = eos_symbol
42
+
43
+ unique_tokens = (
44
+ [pad_symbol]
45
+ + ([bos_symbol] if add_bos else [])
46
+ + ([eos_symbol] if add_eos else [])
47
+ + sorted(text_tokens)
48
+ )
49
+
50
+ self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
51
+ self.idx2token = [token for token in unique_tokens]
52
+
53
+ def index(self, tokens_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
54
+ seqs, seq_lens = [], []
55
+ for tokens in tokens_list:
56
+ assert all([True if s in self.token2idx else False for s in tokens]) is True
57
+ seq = (
58
+ ([self.bos_symbol] if self.add_bos else [])
59
+ + list(tokens)
60
+ + ([self.eos_symbol] if self.add_eos else [])
61
+ )
62
+ seqs.append(seq)
63
+ seq_lens.append(len(seq))
64
+
65
+ max_len = max(seq_lens)
66
+ for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
67
+ seq.extend([self.pad_symbol] * (max_len - seq_len))
68
+
69
+ tokens = torch.from_numpy(
70
+ np.array(
71
+ [[self.token2idx[token] for token in seq] for seq in seqs],
72
+ dtype=np.int64,
73
+ )
74
+ )
75
+ tokens_lens = torch.IntTensor(seq_lens)
76
+
77
+ return tokens, tokens_lens
78
+
79
+ def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
80
+ tokens_seqs = [[p for p in text] for text in texts]
81
+ max_len = len(max(tokens_seqs, key=len))
82
+
83
+ seqs = [
84
+ ([self.bos_symbol] if self.add_bos else [])
85
+ + list(seq)
86
+ + ([self.eos_symbol] if self.add_eos else [])
87
+ + [self.pad_symbol] * (max_len - len(seq))
88
+ for seq in tokens_seqs
89
+ ]
90
+
91
+ tokens_batch = torch.from_numpy(
92
+ np.array(
93
+ [seq for seq in seqs],
94
+ dtype=np.int64,
95
+ )
96
+ )
97
+
98
+ tokens_lens = torch.IntTensor(
99
+ [len(seq) + int(self.add_eos) + int(self.add_bos) for seq in tokens_seqs]
100
+ )
101
+
102
+ return tokens_batch, tokens_lens
103
+
104
+
105
+ def get_text_token_collater() -> TextTokenCollater:
106
+ collater = TextTokenCollater(["0"], add_bos=False, add_eos=False)
107
+ return collater
apps/audio_cloning/vallex/data/datamodule.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 (authors: Feiteng Li)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import argparse
19
+ import inspect
20
+ import logging
21
+ from functools import lru_cache
22
+ from pathlib import Path
23
+ from typing import Any, Dict, Optional
24
+
25
+ import torch
26
+ # from icefall.utils import str2bool
27
+ # from lhotse import CutSet, load_manifest_lazy
28
+ # from lhotse.dataset import (
29
+ # CutConcatenate,
30
+ # DynamicBucketingSampler,
31
+ # PrecomputedFeatures,
32
+ # SingleCutSampler,
33
+ # SpecAugment,
34
+ # )
35
+ # from lhotse.dataset.input_strategies import OnTheFlyFeatures
36
+ # from lhotse.utils import fix_random_seed
37
+ from torch.utils.data import DataLoader
38
+
39
+ from data.collation import get_text_token_collater
40
+ # from data.dataset import SpeechSynthesisDataset
41
+ from data.fbank import get_fbank_extractor
42
+ from data.input_strategies import PromptedPrecomputedFeatures
43
+
44
+ # PrecomputedFeatures = PrecomputedFeatures
45
+
46
+
47
+ class _SeedWorkers:
48
+ def __init__(self, seed: int):
49
+ self.seed = seed
50
+
51
+ def __call__(self, worker_id: int):
52
+ fix_random_seed(self.seed + worker_id)
53
+
54
+
55
+ def _get_input_strategy(input_strategy, dataset, cuts):
56
+ if input_strategy == "PromptedPrecomputedFeatures":
57
+ return PromptedPrecomputedFeatures(dataset, cuts)
58
+
59
+ return eval(input_strategy)()
60
+
61
+
62
+ class TtsDataModule:
63
+ """
64
+ DataModule for VALL-E TTS experiments.
65
+ It assumes there is always one train and valid dataloader.
66
+
67
+ It contains all the common data pipeline modules used in TTS
68
+ experiments, e.g.:
69
+ - dynamic batch size,
70
+ - bucketing samplers,
71
+ - cut concatenation[not used & tested yet],
72
+ - augmentation[not used & tested yet],
73
+ - on-the-fly feature extraction[not used & tested yet]
74
+
75
+ This class should be derived for specific corpora used in TTS tasks.
76
+ """
77
+
78
+ def __init__(self, args: argparse.Namespace):
79
+ self.args = args
80
+
81
+ @classmethod
82
+ def add_arguments(cls, parser: argparse.ArgumentParser):
83
+ group = parser.add_argument_group(
84
+ title="TTS data related options",
85
+ description="These options are used for the preparation of "
86
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
87
+ "effective batch sizes, sampling strategies, applied data "
88
+ "augmentations, etc.",
89
+ )
90
+ group.add_argument(
91
+ "--manifest-dir",
92
+ type=Path,
93
+ default=Path("data/tokenized"),
94
+ help="Path to directory with train/valid/test cuts.",
95
+ )
96
+ group.add_argument(
97
+ "--max-duration",
98
+ type=int,
99
+ default=40.0,
100
+ help="Maximum pooled recordings duration (seconds) in a "
101
+ "single batch. You can reduce it if it causes CUDA OOM.",
102
+ )
103
+ group.add_argument(
104
+ "--bucketing-sampler",
105
+ type=str2bool,
106
+ default=True,
107
+ help="When enabled, the batches will come from buckets of "
108
+ "similar duration (saves padding frames).",
109
+ )
110
+ group.add_argument(
111
+ "--num-buckets",
112
+ type=int,
113
+ default=10,
114
+ help="The number of buckets for the DynamicBucketingSampler"
115
+ "(you might want to increase it for larger datasets).",
116
+ )
117
+ group.add_argument(
118
+ "--concatenate-cuts",
119
+ type=str2bool,
120
+ default=False,
121
+ help="When enabled, utterances (cuts) will be concatenated "
122
+ "to minimize the amount of padding.",
123
+ )
124
+ group.add_argument(
125
+ "--duration-factor",
126
+ type=float,
127
+ default=1.0,
128
+ help="Determines the maximum duration of a concatenated cut "
129
+ "relative to the duration of the longest cut in a batch.",
130
+ )
131
+ group.add_argument(
132
+ "--gap",
133
+ type=float,
134
+ default=0.1,
135
+ help="The amount of padding (in seconds) inserted between "
136
+ "concatenated cuts. This padding is filled with noise when "
137
+ "noise augmentation is used.",
138
+ )
139
+ group.add_argument(
140
+ "--on-the-fly-feats",
141
+ type=str2bool,
142
+ default=False,
143
+ help="When enabled, use on-the-fly cut mixing and feature "
144
+ "extraction. Will drop existing precomputed feature manifests "
145
+ "if available.",
146
+ )
147
+ group.add_argument(
148
+ "--shuffle",
149
+ type=str2bool,
150
+ default=True,
151
+ help="When enabled (=default), the examples will be "
152
+ "shuffled for each epoch.",
153
+ )
154
+ group.add_argument(
155
+ "--drop-last",
156
+ type=str2bool,
157
+ default=False,
158
+ help="Whether to drop last batch. Used by sampler.",
159
+ )
160
+ group.add_argument(
161
+ "--return-cuts",
162
+ type=str2bool,
163
+ default=True,
164
+ help="When enabled, each batch will have the "
165
+ "field: batch['supervisions']['cut'] with the cuts that "
166
+ "were used to construct it.",
167
+ )
168
+
169
+ group.add_argument(
170
+ "--num-workers",
171
+ type=int,
172
+ default=8,
173
+ help="The number of training dataloader workers that "
174
+ "collect the batches.",
175
+ )
176
+
177
+ group.add_argument(
178
+ "--enable-spec-aug",
179
+ type=str2bool,
180
+ default=False,
181
+ help="When enabled, use SpecAugment for training dataset.",
182
+ )
183
+
184
+ group.add_argument(
185
+ "--spec-aug-time-warp-factor",
186
+ type=int,
187
+ default=80,
188
+ help="Used only when --enable-spec-aug is True. "
189
+ "It specifies the factor for time warping in SpecAugment. "
190
+ "Larger values mean more warping. "
191
+ "A value less than 1 means to disable time warp.",
192
+ )
193
+
194
+ group.add_argument(
195
+ "--input-strategy",
196
+ type=str,
197
+ default="PrecomputedFeatures",
198
+ help="AudioSamples or PrecomputedFeatures or PromptedPrecomputedFeatures",
199
+ )
200
+
201
+ group.add_argument(
202
+ "--dataset",
203
+ type=str,
204
+ default="ljspeech",
205
+ help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.",
206
+ )
207
+
208
+ parser.add_argument(
209
+ "--text-tokens",
210
+ type=str,
211
+ default="data/tokenized/unique_text_tokens.k2symbols",
212
+ help="Path to the unique text tokens file",
213
+ )
214
+
215
+ parser.add_argument(
216
+ "--sampling-rate",
217
+ type=int,
218
+ default=24000,
219
+ help="""Audio sampling rate.""",
220
+ )
221
+
222
+ def train_dataloaders(
223
+ self,
224
+ cuts_train: CutSet,
225
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
226
+ ) -> DataLoader:
227
+ """
228
+ Args:
229
+ cuts_train:
230
+ CutSet for training.
231
+ sampler_state_dict:
232
+ The state dict for the training sampler.
233
+ """
234
+ transforms = []
235
+
236
+ if self.args.concatenate_cuts:
237
+ logging.info(
238
+ f"Using cut concatenation with duration factor "
239
+ f"{self.args.duration_factor} and gap {self.args.gap}."
240
+ )
241
+ # Cut concatenation should be the first transform in the list,
242
+ # so that if we e.g. mix noise in, it will fill the gaps between
243
+ # different utterances.
244
+ transforms = [
245
+ CutConcatenate(
246
+ duration_factor=self.args.duration_factor, gap=self.args.gap
247
+ )
248
+ ] + transforms
249
+
250
+ input_transforms = []
251
+ if self.args.enable_spec_aug:
252
+ logging.info("Enable SpecAugment")
253
+ logging.info(
254
+ f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
255
+ )
256
+ # Set the value of num_frame_masks according to Lhotse's version.
257
+ # In different Lhotse's versions, the default of num_frame_masks is
258
+ # different.
259
+ num_frame_masks = 10
260
+ num_frame_masks_parameter = inspect.signature(
261
+ SpecAugment.__init__
262
+ ).parameters["num_frame_masks"]
263
+ if num_frame_masks_parameter.default == 1:
264
+ num_frame_masks = 2
265
+ logging.info(f"Num frame mask: {num_frame_masks}")
266
+ input_transforms.append(
267
+ SpecAugment(
268
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
269
+ num_frame_masks=num_frame_masks,
270
+ features_mask_size=27,
271
+ num_feature_masks=2,
272
+ frames_mask_size=100,
273
+ )
274
+ )
275
+ else:
276
+ logging.info("Disable SpecAugment")
277
+
278
+ logging.info("About to create train dataset")
279
+ if self.args.on_the_fly_feats:
280
+ # NOTE: the PerturbSpeed transform should be added only if we
281
+ # remove it from data prep stage.
282
+ # Add on-the-fly speed perturbation; since originally it would
283
+ # have increased epoch size by 3, we will apply prob 2/3 and use
284
+ # 3x more epochs.
285
+ # Speed perturbation probably should come first before
286
+ # concatenation, but in principle the transforms order doesn't have
287
+ # to be strict (e.g. could be randomized)
288
+ # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
289
+ # Drop feats to be on the safe side.
290
+ train = SpeechSynthesisDataset(
291
+ get_text_token_collater(self.args.text_tokens),
292
+ cut_transforms=transforms,
293
+ feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
294
+ feature_transforms=input_transforms,
295
+ )
296
+ else:
297
+ train = SpeechSynthesisDataset(
298
+ get_text_token_collater(self.args.text_tokens),
299
+ feature_input_strategy=_get_input_strategy(
300
+ self.args.input_strategy, self.args.dataset, cuts_train
301
+ ),
302
+ cut_transforms=transforms,
303
+ feature_transforms=input_transforms,
304
+ )
305
+
306
+ if self.args.bucketing_sampler:
307
+ logging.info("Using DynamicBucketingSampler")
308
+ train_sampler = DynamicBucketingSampler(
309
+ cuts_train,
310
+ max_duration=self.args.max_duration,
311
+ shuffle=self.args.shuffle,
312
+ num_buckets=self.args.num_buckets,
313
+ drop_last=self.args.drop_last,
314
+ )
315
+ else:
316
+ logging.info(
317
+ "Using SingleCutSampler and sort by duraton(ascending=True)."
318
+ )
319
+ cuts_train = cuts_train.to_eager().sort_by_duration(ascending=True)
320
+ train_sampler = SingleCutSampler(
321
+ cuts_train,
322
+ max_duration=self.args.max_duration,
323
+ shuffle=self.args.shuffle,
324
+ )
325
+ logging.info("About to create train dataloader")
326
+
327
+ if sampler_state_dict is not None:
328
+ logging.info("Loading sampler state dict")
329
+ train_sampler.load_state_dict(sampler_state_dict)
330
+
331
+ # 'seed' is derived from the current random state, which will have
332
+ # previously been set in the main process.
333
+ seed = torch.randint(0, 100000, ()).item()
334
+ worker_init_fn = _SeedWorkers(seed)
335
+
336
+ train_dl = DataLoader(
337
+ train,
338
+ sampler=train_sampler,
339
+ batch_size=None,
340
+ num_workers=self.args.num_workers,
341
+ persistent_workers=False,
342
+ worker_init_fn=worker_init_fn,
343
+ )
344
+
345
+ return train_dl
346
+
347
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
348
+ logging.info("About to create dev dataset")
349
+ if self.args.on_the_fly_feats:
350
+ validate = SpeechSynthesisDataset(
351
+ get_text_token_collater(self.args.text_tokens),
352
+ feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
353
+ cut_transforms=[],
354
+ )
355
+ else:
356
+ validate = SpeechSynthesisDataset(
357
+ get_text_token_collater(self.args.text_tokens),
358
+ feature_input_strategy=_get_input_strategy(
359
+ self.args.input_strategy, self.args.dataset, cuts_valid
360
+ ),
361
+ cut_transforms=[],
362
+ )
363
+ valid_sampler = DynamicBucketingSampler(
364
+ cuts_valid,
365
+ max_duration=self.args.max_duration,
366
+ shuffle=False,
367
+ )
368
+ logging.info("About to create dev dataloader")
369
+ valid_dl = DataLoader(
370
+ validate,
371
+ sampler=valid_sampler,
372
+ batch_size=None,
373
+ num_workers=4,
374
+ persistent_workers=False,
375
+ )
376
+
377
+ return valid_dl
378
+
379
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
380
+ logging.debug("About to create test dataset")
381
+ test = SpeechSynthesisDataset(
382
+ get_text_token_collater(self.args.text_tokens),
383
+ feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor())
384
+ if self.args.on_the_fly_feats
385
+ else _get_input_strategy(
386
+ self.args.input_strategy, self.args.dataset, cuts
387
+ ),
388
+ cut_transforms=[],
389
+ )
390
+ sampler = DynamicBucketingSampler(
391
+ cuts,
392
+ max_duration=self.args.max_duration,
393
+ shuffle=False,
394
+ )
395
+ logging.debug("About to create test dataloader")
396
+ test_dl = DataLoader(
397
+ test,
398
+ batch_size=None,
399
+ sampler=sampler,
400
+ num_workers=self.args.num_workers,
401
+ )
402
+ return test_dl
403
+
404
+ @lru_cache()
405
+ def train_cuts(self) -> CutSet:
406
+ logging.info("About to get train cuts")
407
+ return load_manifest_lazy(
408
+ self.args.manifest_dir / "cuts_train.jsonl.gz"
409
+ )
410
+
411
+ @lru_cache()
412
+ def dev_cuts(self) -> CutSet:
413
+ logging.info("About to get dev cuts")
414
+ return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
415
+
416
+ @lru_cache()
417
+ def test_cuts(self) -> CutSet:
418
+ logging.info("About to get test cuts")
419
+ return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")
apps/audio_cloning/vallex/data/dataset.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 (authors: Feiteng Li)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ modified from lhoste.dataset.speech_synthesis.py
19
+ """
20
+
21
+ import torch
22
+ import math
23
+ import h5py
24
+ from tokenizers import Tokenizer
25
+ from typing import Union, List
26
+ import numpy as np
27
+ from tqdm import tqdm
28
+
29
+ _pad = '_'
30
+ _punctuation = ',.!?-~…'
31
+ _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
32
+ symbols = [_pad] + list(_punctuation) + list(_letters)
33
+
34
+ language_dict = {
35
+ 'en': 0,
36
+ 'zh': 1,
37
+ 'ja': 2,
38
+ }
39
+ def seq2phone(tokens: Union[List, np.ndarray]):
40
+ """
41
+ Convert tokenized phoneme ID sequence back to phoneme string
42
+ :param tokens: phoneme tokens
43
+ :return: recovered phoneme sequence
44
+ """
45
+ phones = "".join([symbols[i] for i in tokens])
46
+ return phones
47
+
48
+ class DynamicBatchSampler(torch.utils.data.Sampler):
49
+ def __init__(self, sampler, num_tokens_fn, num_buckets=100, min_size=0, max_size=1000,
50
+ max_tokens=None, max_sentences=None, drop_last=False):
51
+ """
52
+
53
+ :param sampler:
54
+ :param num_tokens_fn: 根据idx返回样本的长度的函数
55
+ :param num_buckets: 利用桶原理将相似长度的样本放在一个batchsize中,桶的数量
56
+ :param min_size: 最小长度的样本, 小于这个值的样本会被过滤掉。 依据这个值来创建样桶
57
+ :param max_size: 最大长度的样本
58
+ :param max_sentences: batch_size, 但是这里可以通过max_sentences 和 max_tokens 共同控制最终的大小
59
+ """
60
+ super(DynamicBatchSampler, self).__init__(sampler)
61
+ self.sampler = sampler
62
+ self.num_tokens_fn = num_tokens_fn
63
+ self.num_buckets = num_buckets
64
+
65
+ self.min_size = min_size
66
+ self.max_size = max_size
67
+
68
+ assert max_size <= max_tokens, "max_size should be smaller than max tokens"
69
+ assert max_tokens is not None or max_sentences is not None, \
70
+ "max_tokens and max_sentences should not be null at the same time, please specify one parameter at least"
71
+ self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
72
+ self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
73
+ self.drop_last = drop_last
74
+
75
+ def set_epoch(self, epoch):
76
+ self.sampler.set_epoch(epoch)
77
+ def is_batch_full(self, num_tokens, batch):
78
+ if len(batch) == 0:
79
+ return False
80
+ if len(batch) == self.max_sentences:
81
+ return True
82
+ if num_tokens > self.max_tokens:
83
+ return True
84
+ return False
85
+
86
+ def __iter__(self):
87
+ buckets = [[] for _ in range(self.num_buckets)]
88
+ sample_len = [0] * self.num_buckets
89
+
90
+ for idx in self.sampler:
91
+ idx_length = self.num_tokens_fn(idx)
92
+ if not (self.min_size <= idx_length <= self.max_size):
93
+ print("sentence at index {} of size {} exceeds max_tokens, the sentence is ignored".format(idx, idx_length))
94
+ continue
95
+
96
+ index_buckets = math.floor((idx_length - self.min_size) / (self.max_size - self.min_size + 1)
97
+ * self.num_buckets)
98
+ sample_len[index_buckets] = max(sample_len[index_buckets], idx_length)
99
+
100
+ num_tokens = (len(buckets[index_buckets]) + 1) * sample_len[index_buckets]
101
+ if self.is_batch_full(num_tokens, buckets[index_buckets]):
102
+ # yield this batch
103
+ yield buckets[index_buckets]
104
+ buckets[index_buckets] = []
105
+ sample_len[index_buckets] = 0
106
+
107
+ buckets[index_buckets].append(idx)
108
+
109
+ # process left-over
110
+ leftover_batch = []
111
+ leftover_sample_len = 0
112
+ leftover = [idx for bucket in buckets for idx in bucket]
113
+ for idx in leftover:
114
+ idx_length = self.num_tokens_fn(idx)
115
+ leftover_sample_len = max(leftover_sample_len, idx_length)
116
+ num_tokens = (len(leftover_batch) + 1) * leftover_sample_len
117
+ if self.is_batch_full(num_tokens, leftover_batch):
118
+ yield leftover_batch
119
+ leftover_batch = []
120
+ leftover_sample_len = 0
121
+ leftover_batch.append(idx)
122
+
123
+ if len(leftover_batch) > 0 and not self.drop_last:
124
+ yield leftover_batch
125
+
126
+ def __len__(self):
127
+ # we do not know the exactly batch size, so do not call len(dataloader)
128
+ pass
129
+
130
+
131
+ class AudioDataset(torch.utils.data.Dataset):
132
+ def __init__(self, h5_path, ann_path, tokenizer_path):
133
+ self.h5_path = h5_path
134
+ with open(ann_path, 'r', encoding='utf-8') as f:
135
+ lines = f.readlines()
136
+ ls = [l.split("|") for l in lines]
137
+ ls_T = list(zip(*ls))
138
+ del ls_T[-1]
139
+ self.h5_paths, self.durations, self.langs, self.texts = \
140
+ list(ls_T[0]), list(ls_T[1]), list(ls_T[2]), list(ls_T[3])
141
+ self.durations = [float(dur) for dur in self.durations]
142
+ self.tokenizer = Tokenizer.from_file(tokenizer_path)
143
+
144
+ self._archive = None
145
+
146
+ def __len__(self):
147
+ return len(self.h5_paths)
148
+
149
+ def get_dur(self, idx):
150
+ return self.durations[idx]
151
+
152
+ @property
153
+ def archive(self):
154
+ if self._archive is None: # lazy loading here!
155
+ self._archive = h5py.File(self.h5_path, "r")
156
+ return self._archive
157
+ def __getitem__(self, idx):
158
+ archive = self.archive
159
+ h5_path = self.h5_paths[idx]
160
+ sub = archive[h5_path]
161
+ audio_tokens = sub['audio'][()]
162
+ phone_tokens = sub['text'][()]
163
+ dur = self.durations[idx]
164
+ lang = self.langs[idx]
165
+ text = self.texts[idx]
166
+ # tokenization should be done within dataloader
167
+ phones = seq2phone(phone_tokens)
168
+ phones = phones.replace(" ", "_")
169
+ if not len(phones):
170
+ cptpho_tokens = self.tokenizer.encode(text).ids
171
+ else:
172
+ cptpho_tokens = self.tokenizer.encode(phones).ids
173
+ assert len(cptpho_tokens)
174
+ return {
175
+ 'utt_id': h5_path,
176
+ 'text': text,
177
+ 'audio': None,
178
+ 'audio_lens': None,
179
+ 'audio_features': audio_tokens,
180
+ 'audio_features_lens': len(audio_tokens.T),
181
+ 'text_tokens': np.array(cptpho_tokens),
182
+ 'text_tokens_lens': len(cptpho_tokens),
183
+ 'language': language_dict[lang],
184
+ }
185
+
186
+ def collate(batch):
187
+ utt_id_s = [b['utt_id'] for b in batch]
188
+ text_s = [b['text'] for b in batch]
189
+
190
+ audio_s = [b['audio'] for b in batch]
191
+ audio_lens_s = [b['audio_lens'] for b in batch]
192
+
193
+ audio_features_lens_s = [b['audio_features_lens'] for b in batch]
194
+ # create an empty tensor with maximum audio feature length
195
+ audio_features_s = torch.zeros([len(batch), max(audio_features_lens_s), 8], dtype=torch.int64) - 1 # audio pad with -1
196
+
197
+ text_tokens_lens_s = [b['text_tokens_lens'] for b in batch]
198
+ # create an empty tensor with maximum text tokens length
199
+ text_tokens_s = torch.zeros([len(batch), max(text_tokens_lens_s)], dtype=torch.int64) + 3 # [PAD] token id 3
200
+
201
+ language_s = [b['language'] for b in batch]
202
+
203
+ for i, b in enumerate(batch):
204
+ audio_features = b['audio_features']
205
+ audio_features_lens = b['audio_features_lens']
206
+ audio_features_s[i, :audio_features_lens, :] = torch.LongTensor(audio_features.T)
207
+
208
+ text_tokens = b['text_tokens']
209
+ text_tokens_lens = b['text_tokens_lens']
210
+ text_tokens_s[i, :text_tokens_lens] = torch.LongTensor(text_tokens)
211
+
212
+ batch = {
213
+ 'utt_id': utt_id_s,
214
+ 'text': text_s,
215
+ 'audio': audio_s,
216
+ 'audio_lens': audio_lens_s,
217
+ 'audio_features': audio_features_s,
218
+ 'audio_features_lens': torch.LongTensor(np.array(audio_features_lens_s)),
219
+ 'text_tokens': text_tokens_s,
220
+ 'text_tokens_lens': torch.LongTensor(np.array(text_tokens_lens_s)),
221
+ 'languages': torch.LongTensor(np.array(language_s)),
222
+ }
223
+ return batch
224
+
225
+ def create_dataloader(data_dir="/root/valle/egs/mix", n_gpus=1, rank=0, num_workers=0, num_buckets=10, max_duration=120):
226
+ train_dataset = AudioDataset(h5_path=f"{data_dir}/audio_sum.hdf5",
227
+ ann_path=f"{data_dir}/audio_ann_sum.txt",
228
+ tokenizer_path=f"{data_dir}/bpe_69.json")
229
+ ran_sampler = torch.utils.data.distributed.DistributedSampler(
230
+ train_dataset,
231
+ num_replicas=n_gpus,
232
+ rank=rank,
233
+ shuffle=True,
234
+ )
235
+ dynamic_sampler = DynamicBatchSampler(ran_sampler, train_dataset.get_dur, num_buckets=num_buckets, max_size=20,
236
+ max_tokens=max_duration)
237
+
238
+
239
+ train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=num_workers, collate_fn=collate,
240
+ batch_sampler=dynamic_sampler)
241
+
242
+ return train_loader
apps/audio_cloning/vallex/data/fbank.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 (authors: Feiteng Li)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ from dataclasses import asdict, dataclass
19
+ from typing import Any, Dict, Optional, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ # from lhotse.features.base import FeatureExtractor
24
+ # from lhotse.utils import EPSILON, Seconds, compute_num_frames
25
+ from librosa.filters import mel as librosa_mel_fn
26
+
27
+
28
+ @dataclass
29
+ class BigVGANFbankConfig:
30
+ # Spectogram-related part
31
+ # Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them
32
+ frame_length: Seconds = 1024 / 24000.0
33
+ frame_shift: Seconds = 256 / 24000.0
34
+ remove_dc_offset: bool = True
35
+ round_to_power_of_two: bool = True
36
+
37
+ # Fbank-related part
38
+ low_freq: float = 0.0
39
+ high_freq: float = 12000.0
40
+ num_mel_bins: int = 100
41
+ use_energy: bool = False
42
+
43
+ def to_dict(self) -> Dict[str, Any]:
44
+ return asdict(self)
45
+
46
+ @staticmethod
47
+ def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig":
48
+ return BigVGANFbankConfig(**data)
49
+
50
+
51
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
52
+ return torch.log(torch.clamp(x, min=clip_val) * C)
53
+
54
+
55
+ def spectral_normalize_torch(magnitudes):
56
+ output = dynamic_range_compression_torch(magnitudes)
57
+ return output
58
+
59
+
60
+ # https://github.com/NVIDIA/BigVGAN
61
+ # bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz
62
+ class BigVGANFbank(FeatureExtractor):
63
+ name = "fbank"
64
+ config_type = BigVGANFbankConfig
65
+
66
+ def __init__(self, config: Optional[Any] = None):
67
+ super(BigVGANFbank, self).__init__(config)
68
+ sampling_rate = 24000
69
+ self.mel_basis = torch.from_numpy(
70
+ librosa_mel_fn(
71
+ sampling_rate,
72
+ 1024,
73
+ self.config.num_mel_bins,
74
+ self.config.low_freq,
75
+ self.config.high_freq,
76
+ ).astype(np.float32)
77
+ )
78
+ self.hann_window = torch.hann_window(1024)
79
+
80
+ def _feature_fn(self, samples, **kwargs):
81
+ win_length, n_fft = 1024, 1024
82
+ hop_size = 256
83
+ if True:
84
+ sampling_rate = 24000
85
+ duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
86
+ expected_num_frames = compute_num_frames(
87
+ duration=duration,
88
+ frame_shift=self.frame_shift,
89
+ sampling_rate=sampling_rate,
90
+ )
91
+ pad_size = (
92
+ (expected_num_frames - 1) * hop_size
93
+ + win_length
94
+ - samples.shape[-1]
95
+ )
96
+ assert pad_size >= 0
97
+
98
+ y = torch.nn.functional.pad(
99
+ samples,
100
+ (0, pad_size),
101
+ mode="constant",
102
+ )
103
+ else:
104
+ y = torch.nn.functional.pad(
105
+ samples,
106
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
107
+ mode="reflect",
108
+ )
109
+
110
+ y = y.squeeze(1)
111
+
112
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
113
+ spec = torch.stft(
114
+ y,
115
+ n_fft,
116
+ hop_length=hop_size,
117
+ win_length=win_length,
118
+ window=self.hann_window,
119
+ center=False,
120
+ pad_mode="reflect",
121
+ normalized=False,
122
+ onesided=True,
123
+ return_complex=True,
124
+ )
125
+ spec = torch.view_as_real(spec)
126
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
127
+
128
+ spec = torch.matmul(self.mel_basis, spec)
129
+ spec = spectral_normalize_torch(spec)
130
+
131
+ return spec.transpose(2, 1).squeeze(0)
132
+
133
+ def extract(
134
+ self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
135
+ ) -> np.ndarray:
136
+ assert sampling_rate == 24000
137
+ params = asdict(self.config)
138
+ params.update({"sample_frequency": sampling_rate, "snip_edges": False})
139
+ params["frame_shift"] *= 1000.0
140
+ params["frame_length"] *= 1000.0
141
+ if not isinstance(samples, torch.Tensor):
142
+ samples = torch.from_numpy(samples)
143
+ # Torchaudio Kaldi feature extractors expect the channel dimension to be first.
144
+ if len(samples.shape) == 1:
145
+ samples = samples.unsqueeze(0)
146
+ features = self._feature_fn(samples, **params).to(torch.float32)
147
+ return features.numpy()
148
+
149
+ @property
150
+ def frame_shift(self) -> Seconds:
151
+ return self.config.frame_shift
152
+
153
+ def feature_dim(self, sampling_rate: int) -> int:
154
+ return self.config.num_mel_bins
155
+
156
+ @staticmethod
157
+ def mix(
158
+ features_a: np.ndarray,
159
+ features_b: np.ndarray,
160
+ energy_scaling_factor_b: float,
161
+ ) -> np.ndarray:
162
+ return np.log(
163
+ np.maximum(
164
+ # protection against log(0); max with EPSILON is adequate since these are energies (always >= 0)
165
+ EPSILON,
166
+ np.exp(features_a)
167
+ + energy_scaling_factor_b * np.exp(features_b),
168
+ )
169
+ )
170
+
171
+ @staticmethod
172
+ def compute_energy(features: np.ndarray) -> float:
173
+ return float(np.sum(np.exp(features)))
174
+
175
+
176
+ def get_fbank_extractor() -> BigVGANFbank:
177
+ return BigVGANFbank(BigVGANFbankConfig())
178
+
179
+
180
+ if __name__ == "__main__":
181
+ extractor = BigVGANFbank(BigVGANFbankConfig())
182
+
183
+ samples = torch.from_numpy(np.random.random([1000]).astype(np.float32))
184
+ samples = torch.clip(samples, -1.0, 1.0)
185
+ fbank = extractor.extract(samples, 24000.0)
186
+ print(f"fbank {fbank.shape}")
187
+
188
+ from scipy.io.wavfile import read
189
+
190
+ MAX_WAV_VALUE = 32768.0
191
+
192
+ sampling_rate, samples = read(
193
+ "egs/libritts/prompts/5639_40744_000000_000002.wav"
194
+ )
195
+ print(f"samples: [{samples.min()}, {samples.max()}]")
196
+ fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000)
197
+ print(f"fbank {fbank.shape}")
198
+
199
+ import matplotlib.pyplot as plt
200
+
201
+ _ = plt.figure(figsize=(18, 10))
202
+ plt.imshow(
203
+ X=fbank.transpose(1, 0),
204
+ cmap=plt.get_cmap("jet"),
205
+ aspect="auto",
206
+ interpolation="nearest",
207
+ )
208
+ plt.gca().invert_yaxis()
209
+ plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png")
210
+ plt.close()
211
+
212
+ print("fbank test PASS!")
apps/audio_cloning/vallex/data/input_strategies.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import defaultdict
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from typing import Tuple, Type
5
+
6
+ # from lhotse import CutSet
7
+ # from lhotse.dataset.collation import collate_features
8
+ # from lhotse.dataset.input_strategies import (
9
+ # ExecutorType,
10
+ # PrecomputedFeatures,
11
+ # _get_executor,
12
+ # )
13
+ # from lhotse.utils import fastcopy
14
+
15
+
16
+ class PromptedFeatures:
17
+ def __init__(self, prompts, features):
18
+ self.prompts = prompts
19
+ self.features = features
20
+
21
+ def to(self, device):
22
+ return PromptedFeatures(
23
+ self.prompts.to(device), self.features.to(device)
24
+ )
25
+
26
+ def sum(self):
27
+ return self.features.sum()
28
+
29
+ @property
30
+ def ndim(self):
31
+ return self.features.ndim
32
+
33
+ @property
34
+ def data(self):
35
+ return (self.prompts, self.features)
36
+
37
+
38
+ # class PromptedPrecomputedFeatures(PrecomputedFeatures):
39
+ # """
40
+ # :class:`InputStrategy` that reads pre-computed features, whose manifests
41
+ # are attached to cuts, from disk.
42
+ #
43
+ # It automatically pads the feature matrices with pre or post feature.
44
+ #
45
+ # .. automethod:: __call__
46
+ # """
47
+ #
48
+ # def __init__(
49
+ # self,
50
+ # dataset: str,
51
+ # cuts: CutSet,
52
+ # num_workers: int = 0,
53
+ # executor_type: Type[ExecutorType] = ThreadPoolExecutor,
54
+ # ) -> None:
55
+ # super(PromptedPrecomputedFeatures, self).__init__(
56
+ # num_workers, executor_type
57
+ # )
58
+ #
59
+ # self.utt2neighbors = defaultdict(lambda: [])
60
+ #
61
+ # if dataset.lower() == "libritts":
62
+ # # 909_131041_000013_000002
63
+ # # 909_131041_000013_000003
64
+ # speaker2utts = defaultdict(lambda: [])
65
+ #
66
+ # utt2cut = {}
67
+ # for cut in cuts:
68
+ # speaker = cut.supervisions[0].speaker
69
+ # speaker2utts[speaker].append(cut.id)
70
+ # utt2cut[cut.id] = cut
71
+ #
72
+ # for spk in speaker2utts:
73
+ # uttids = sorted(speaker2utts[spk])
74
+ # # Using the property of sorted keys to find previous utterance
75
+ # # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001
76
+ # if len(uttids) == 1:
77
+ # self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
78
+ # continue
79
+ #
80
+ # utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
81
+ # utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
82
+ #
83
+ # for utt in utt2prevutt:
84
+ # self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]])
85
+ #
86
+ # for utt in utt2postutt:
87
+ # self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]])
88
+ # elif dataset.lower() == "ljspeech":
89
+ # utt2cut = {}
90
+ # uttids = []
91
+ # for cut in cuts:
92
+ # uttids.append(cut.id)
93
+ # utt2cut[cut.id] = cut
94
+ #
95
+ # if len(uttids) == 1:
96
+ # self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
97
+ # else:
98
+ # # Using the property of sorted keys to find previous utterance
99
+ # # The keys has structure: LJ001-0010
100
+ # utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
101
+ # utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
102
+ #
103
+ # for utt in utt2postutt:
104
+ # postutt = utt2postutt[utt]
105
+ # if utt[:5] == postutt[:5]:
106
+ # self.utt2neighbors[utt].append(utt2cut[postutt])
107
+ #
108
+ # for utt in utt2prevutt:
109
+ # prevutt = utt2prevutt[utt]
110
+ # if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]:
111
+ # self.utt2neighbors[utt].append(utt2cut[prevutt])
112
+ # else:
113
+ # raise ValueError
114
+ #
115
+ # def __call__(
116
+ # self, cuts: CutSet
117
+ # ) -> Tuple[PromptedFeatures, PromptedFeatures]:
118
+ # """
119
+ # Reads the pre-computed features from disk/other storage.
120
+ # The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``.
121
+ #
122
+ # :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding.
123
+ # """
124
+ # features, features_lens = collate_features(
125
+ # cuts,
126
+ # executor=_get_executor(
127
+ # self.num_workers, executor_type=self._executor_type
128
+ # ),
129
+ # )
130
+ #
131
+ # prompts_cuts = []
132
+ # for k, cut in enumerate(cuts):
133
+ # prompts_cut = random.choice(self.utt2neighbors[cut.id])
134
+ # prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}"))
135
+ #
136
+ # mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0])
137
+ # # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate(
138
+ # # max_duration=mini_duration,
139
+ # # offset_type="random",
140
+ # # preserve_id=True,
141
+ # # )
142
+ # prompts_cuts = CutSet(
143
+ # cuts={k: cut for k, cut in enumerate(prompts_cuts)}
144
+ # ).truncate(
145
+ # max_duration=mini_duration,
146
+ # offset_type="random",
147
+ # preserve_id=False,
148
+ # )
149
+ #
150
+ # prompts, prompts_lens = collate_features(
151
+ # prompts_cuts,
152
+ # executor=_get_executor(
153
+ # self.num_workers, executor_type=self._executor_type
154
+ # ),
155
+ # )
156
+ #
157
+ # return PromptedFeatures(prompts, features), PromptedFeatures(
158
+ # prompts_lens, features_lens
159
+ # )
apps/audio_cloning/vallex/data/symbol_table.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
2
+ #
3
+ # See ../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from dataclasses import dataclass, field
18
+ from typing import Dict, Generic, List, Optional, TypeVar, Union
19
+
20
+ Symbol = TypeVar("Symbol")
21
+
22
+
23
+ # Disable __repr__ otherwise it could freeze e.g. Jupyter.
24
+ @dataclass(repr=False)
25
+ class SymbolTable(Generic[Symbol]):
26
+ """SymbolTable that maps symbol IDs, found on the FSA arcs to
27
+ actual objects. These objects can be arbitrary Python objects
28
+ that can serve as keys in a dictionary (i.e. they need to be
29
+ hashable and immutable).
30
+
31
+ The SymbolTable can only be read to/written from disk if the
32
+ symbols are strings.
33
+ """
34
+
35
+ _id2sym: Dict[int, Symbol] = field(default_factory=dict)
36
+ """Map an integer to a symbol.
37
+ """
38
+
39
+ _sym2id: Dict[Symbol, int] = field(default_factory=dict)
40
+ """Map a symbol to an integer.
41
+ """
42
+
43
+ _next_available_id: int = 1
44
+ """A helper internal field that helps adding new symbols
45
+ to the table efficiently.
46
+ """
47
+
48
+ eps: Symbol = "<eps>"
49
+ """Null symbol, always mapped to index 0.
50
+ """
51
+
52
+ def __post_init__(self):
53
+ for idx, sym in self._id2sym.items():
54
+ assert self._sym2id[sym] == idx
55
+ assert idx >= 0
56
+
57
+ for sym, idx in self._sym2id.items():
58
+ assert idx >= 0
59
+ assert self._id2sym[idx] == sym
60
+
61
+ if 0 not in self._id2sym:
62
+ self._id2sym[0] = self.eps
63
+ self._sym2id[self.eps] = 0
64
+ else:
65
+ assert self._id2sym[0] == self.eps
66
+ assert self._sym2id[self.eps] == 0
67
+
68
+ self._next_available_id = max(self._id2sym) + 1
69
+
70
+ @staticmethod
71
+ def from_str(s: str) -> "SymbolTable":
72
+ """Build a symbol table from a string.
73
+
74
+ The string consists of lines. Every line has two fields separated
75
+ by space(s), tab(s) or both. The first field is the symbol and the
76
+ second the integer id of the symbol.
77
+
78
+ Args:
79
+ s:
80
+ The input string with the format described above.
81
+ Returns:
82
+ An instance of :class:`SymbolTable`.
83
+ """
84
+ id2sym: Dict[int, str] = dict()
85
+ sym2id: Dict[str, int] = dict()
86
+
87
+ for line in s.split("\n"):
88
+ fields = line.split()
89
+ if len(fields) == 0:
90
+ continue # skip empty lines
91
+ assert len(fields) == 2, (
92
+ f"Expect a line with 2 fields. Given: {len(fields)}"
93
+ )
94
+ sym, idx = fields[0], int(fields[1])
95
+ assert sym not in sym2id, f"Duplicated symbol {sym}"
96
+ assert idx not in id2sym, f"Duplicated id {idx}"
97
+ id2sym[idx] = sym
98
+ sym2id[sym] = idx
99
+
100
+ eps = id2sym.get(0, "<eps>")
101
+
102
+ return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps)
103
+
104
+ @staticmethod
105
+ def from_file(filename: str) -> "SymbolTable":
106
+ """Build a symbol table from file.
107
+
108
+ Every line in the symbol table file has two fields separated by
109
+ space(s), tab(s) or both. The following is an example file:
110
+
111
+ .. code-block::
112
+
113
+ <eps> 0
114
+ a 1
115
+ b 2
116
+ c 3
117
+
118
+ Args:
119
+ filename:
120
+ Name of the symbol table file. Its format is documented above.
121
+
122
+ Returns:
123
+ An instance of :class:`SymbolTable`.
124
+
125
+ """
126
+ with open(filename, "r", encoding="utf-8") as f:
127
+ return SymbolTable.from_str(f.read().strip())
128
+
129
+ def to_str(self) -> str:
130
+ """
131
+ Returns:
132
+ Return a string representation of this object. You can pass
133
+ it to the method ``from_str`` to recreate an identical object.
134
+ """
135
+ s = ""
136
+ for idx, symbol in sorted(self._id2sym.items()):
137
+ s += f"{symbol} {idx}\n"
138
+ return s
139
+
140
+ def to_file(self, filename: str):
141
+ """Serialize the SymbolTable to a file.
142
+
143
+ Every line in the symbol table file has two fields separated by
144
+ space(s), tab(s) or both. The following is an example file:
145
+
146
+ .. code-block::
147
+
148
+ <eps> 0
149
+ a 1
150
+ b 2
151
+ c 3
152
+
153
+ Args:
154
+ filename:
155
+ Name of the symbol table file. Its format is documented above.
156
+ """
157
+ with open(filename, "w", encoding="utf-8") as f:
158
+ for idx, symbol in sorted(self._id2sym.items()):
159
+ print(symbol, idx, file=f)
160
+
161
+ def add(self, symbol: Symbol, index: Optional[int] = None) -> int:
162
+ """Add a new symbol to the SymbolTable.
163
+
164
+ Args:
165
+ symbol:
166
+ The symbol to be added.
167
+ index:
168
+ Optional int id to which the symbol should be assigned.
169
+ If it is not available, a ValueError will be raised.
170
+
171
+ Returns:
172
+ The int id to which the symbol has been assigned.
173
+ """
174
+ # Already in the table? Return its ID.
175
+ if symbol in self._sym2id:
176
+ return self._sym2id[symbol]
177
+ # Specific ID not provided - use next available.
178
+ if index is None:
179
+ index = self._next_available_id
180
+ # Specific ID provided but not available.
181
+ if index in self._id2sym:
182
+ raise ValueError(
183
+ f"Cannot assign id '{index}' to '{symbol}' - "
184
+ f"already occupied by {self._id2sym[index]}"
185
+ )
186
+ self._sym2id[symbol] = index
187
+ self._id2sym[index] = symbol
188
+
189
+ # Update next available ID if needed
190
+ if self._next_available_id <= index:
191
+ self._next_available_id = index + 1
192
+
193
+ return index
194
+
195
+ def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]:
196
+ """Get a symbol for an id or get an id for a symbol
197
+
198
+ Args:
199
+ k:
200
+ If it is an id, it tries to find the symbol corresponding
201
+ to the id; if it is a symbol, it tries to find the id
202
+ corresponding to the symbol.
203
+
204
+ Returns:
205
+ An id or a symbol depending on the given `k`.
206
+ """
207
+ if isinstance(k, int):
208
+ return self._id2sym[k]
209
+ else:
210
+ return self._sym2id[k]
211
+
212
+ def merge(self, other: "SymbolTable") -> "SymbolTable":
213
+ """Create a union of two SymbolTables.
214
+ Raises an AssertionError if the same IDs are occupied by
215
+ different symbols.
216
+
217
+ Args:
218
+ other:
219
+ A symbol table to merge with ``self``.
220
+
221
+ Returns:
222
+ A new symbol table.
223
+ """
224
+ self._check_compatible(other)
225
+
226
+ id2sym = {**self._id2sym, **other._id2sym}
227
+ sym2id = {**self._sym2id, **other._sym2id}
228
+
229
+ return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps)
230
+
231
+ def _check_compatible(self, other: "SymbolTable") -> None:
232
+ # Epsilon compatibility
233
+ assert self.eps == other.eps, (
234
+ f"Mismatched epsilon symbol: {self.eps} != {other.eps}"
235
+ )
236
+ # IDs compatibility
237
+ common_ids = set(self._id2sym).intersection(other._id2sym)
238
+ for idx in common_ids:
239
+ assert self[idx] == other[idx], (
240
+ f"ID conflict for id: {idx}, "
241
+ f'self[idx] = "{self[idx]}", '
242
+ f'other[idx] = "{other[idx]}"'
243
+ )
244
+ # Symbols compatibility
245
+ common_symbols = set(self._sym2id).intersection(other._sym2id)
246
+ for sym in common_symbols:
247
+ assert self[sym] == other[sym], (
248
+ f"ID conflict for id: {sym}, "
249
+ f'self[sym] = "{self[sym]}", '
250
+ f'other[sym] = "{other[sym]}"'
251
+ )
252
+
253
+ def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]:
254
+ return self.get(item)
255
+
256
+ def __contains__(self, item: Union[int, Symbol]) -> bool:
257
+ if isinstance(item, int):
258
+ return item in self._id2sym
259
+ else:
260
+ return item in self._sym2id
261
+
262
+ def __len__(self) -> int:
263
+ return len(self._id2sym)
264
+
265
+ def __eq__(self, other: "SymbolTable") -> bool:
266
+ if len(self) != len(other):
267
+ return False
268
+
269
+ for s in self.symbols:
270
+ if self[s] != other[s]:
271
+ return False
272
+
273
+ return True
274
+
275
+ @property
276
+ def ids(self) -> List[int]:
277
+ """Returns a list of integer IDs corresponding to the symbols."""
278
+ ans = list(self._id2sym.keys())
279
+ ans.sort()
280
+ return ans
281
+
282
+ @property
283
+ def symbols(self) -> List[Symbol]:
284
+ """Returns a list of symbols (e.g., strings) corresponding to
285
+ the integer IDs.
286
+ """
287
+ ans = list(self._sym2id.keys())
288
+ ans.sort()
289
+ return ans
apps/audio_cloning/vallex/data/tokenizer.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torchaudio
21
+ from encodec import EncodecModel
22
+ from encodec.utils import convert_audio
23
+
24
+ try:
25
+ pass
26
+ except Exception:
27
+ pass
28
+
29
+
30
+ def remove_encodec_weight_norm(model):
31
+ from encodec.modules import SConv1d
32
+ from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
33
+ from torch.nn.utils import remove_weight_norm
34
+
35
+ encoder = model.encoder.model
36
+ for key in encoder._modules:
37
+ if isinstance(encoder._modules[key], SEANetResnetBlock):
38
+ remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
39
+ block_modules = encoder._modules[key].block._modules
40
+ for skey in block_modules:
41
+ if isinstance(block_modules[skey], SConv1d):
42
+ remove_weight_norm(block_modules[skey].conv.conv)
43
+ elif isinstance(encoder._modules[key], SConv1d):
44
+ remove_weight_norm(encoder._modules[key].conv.conv)
45
+
46
+ decoder = model.decoder.model
47
+ for key in decoder._modules:
48
+ if isinstance(decoder._modules[key], SEANetResnetBlock):
49
+ remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
50
+ block_modules = decoder._modules[key].block._modules
51
+ for skey in block_modules:
52
+ if isinstance(block_modules[skey], SConv1d):
53
+ remove_weight_norm(block_modules[skey].conv.conv)
54
+ elif isinstance(decoder._modules[key], SConvTranspose1d):
55
+ remove_weight_norm(decoder._modules[key].convtr.convtr)
56
+ elif isinstance(decoder._modules[key], SConv1d):
57
+ remove_weight_norm(decoder._modules[key].conv.conv)
58
+
59
+
60
+ class AudioTokenizer:
61
+ """EnCodec audio."""
62
+
63
+ def __init__(
64
+ self,
65
+ device: Any = None,
66
+ ) -> None:
67
+ # Instantiate a pretrained EnCodec model
68
+ model = EncodecModel.encodec_model_24khz()
69
+ model.set_target_bandwidth(6.0)
70
+ remove_encodec_weight_norm(model)
71
+
72
+ if not device:
73
+ device = torch.device("cpu")
74
+ if torch.cuda.is_available():
75
+ device = torch.device("cuda:0")
76
+ if torch.backends.mps.is_available():
77
+ device = torch.device("mps")
78
+
79
+ self._device = device
80
+
81
+ self.codec = model.to(device)
82
+ self.sample_rate = model.sample_rate
83
+ self.channels = model.channels
84
+
85
+ @property
86
+ def device(self):
87
+ return self._device
88
+
89
+ def encode(self, wav: torch.Tensor) -> torch.Tensor:
90
+ return self.codec.encode(wav.to(self.device))
91
+
92
+ def decode(self, frames: torch.Tensor) -> torch.Tensor:
93
+ return self.codec.decode(frames)
94
+
95
+
96
+ def tokenize_audio(tokenizer: AudioTokenizer, audio):
97
+ # Load and pre-process the audio waveform
98
+ if isinstance(audio, str):
99
+ wav, sr = torchaudio.load(audio)
100
+ else:
101
+ wav, sr = audio
102
+ wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
103
+ wav = wav.unsqueeze(0)
104
+
105
+ # Extract discrete codes from EnCodec
106
+ with torch.no_grad():
107
+ encoded_frames = tokenizer.encode(wav)
108
+ return encoded_frames
109
+
110
+
111
+ if __name__ == "__main__":
112
+ model = EncodecModel.encodec_model_24khz()
113
+ model.set_target_bandwidth(6.0)
114
+
115
+ samples = torch.from_numpy(np.random.random([4, 1, 1600])).type(torch.float32)
116
+ codes_raw = model.encode(samples)
117
+
118
+ remove_encodec_weight_norm(model)
119
+ codes_norm = model.encode(samples)
120
+
121
+ assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
apps/audio_cloning/vallex/descriptions.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ top_md_org = """
2
+ # VALL-E X
3
+ VALL-E X can synthesize high-quality personalized speech with only a 3-second enrolled recording of
4
+ an unseen speaker as an acoustic prompt, even in another language for a monolingual speaker.<br>
5
+ This implementation supports zero-shot, mono-lingual/cross-lingual text-to-speech functionality of three languages (English, Chinese, Japanese)<br>
6
+ See this [demo](https://plachtaa.github.io/) page for more details.
7
+ """
8
+
9
+ top_ja_md = """
10
+ # VALL-E X
11
+
12
+ VALL-E X は、未学習の話者でも 3 秒間の音声プロンプトだけで高品質なパーソナライズ音声を合成できます。<br>
13
+ 単一言語話者であっても別の言語による音声合成が可能です。<br>
14
+ 本実装は英語・中国語・日本語のゼロショット単言語/クロス言語テキスト読み上げをサポートしています。
15
+
16
+ ## Reference
17
+
18
+ - [github.com/Plachtaa/VALL-E-X](https://github.com/Plachtaa/VALL-E-X/tree/master#readme)
19
+ - [github.com/gemelo-ai/vocos](https://github.com/gemelo-ai/vocos)
20
+ """
21
+
22
+ infer_from_audio_md_org = """
23
+ Upload a speech of 3~10 seconds as the audio prompt and type in the text you'd like to synthesize.<br>
24
+ The model will synthesize speech of given text with the same voice of your audio prompt.<br>
25
+ The model also tends to preserve the emotion & acoustic environment of your given speech.<br>
26
+ For faster inference, please use **"Make prompt"** to get a `.npz` file as the encoded audio prompt, and use it by **"Infer from prompt"**
27
+ """
28
+
29
+ infer_from_audio_ja_md = """
30
+ 3〜10 秒程度の音声をプロンプトとしてアップロードし、合成したいテキストを入力してください。<br>
31
+ モデルは、プロンプトと同じ声質でテキストを読み上げる音声を生成します。<br>
32
+ 元の音声に含まれる感情や音響環境も比較的保持されます。<br>
33
+ 推論を高速化したい場合は **"Make prompt"** で `.npz` ファイルを作成し、 **"Infer from prompt"** で利用してください。
34
+ """
apps/audio_cloning/vallex/examples.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _prompts_dir = "apps/audio_cloning/vallex/prompts"
2
+
3
+ infer_from_audio_examples = [
4
+ [
5
+ "私のクローンに騙されないでください。",
6
+ "日本語",
7
+ "no-accent",
8
+ f"{_prompts_dir}/ja-okuwaki.wav",
9
+ None,
10
+ "こんにちは、私の名前はオクワキヨウスケです。",
11
+ ],
12
+ [
13
+ "ぼくのクローンに騙されないでくれなのだ。",
14
+ "日本語",
15
+ "no-accent",
16
+ f"{_prompts_dir}/ja-zundamon.wav",
17
+ None,
18
+ "はじめまして、ずんだもんなのだ",
19
+ ],
20
+ [
21
+ "私のクローンに騙されないでください。",
22
+ "日本語",
23
+ "no-accent",
24
+ f"{_prompts_dir}/ja-okuwaki-long.wav",
25
+ None,
26
+ "こんにちは、私の名前はオクワキヨウスケです。これは音声クローニング用のサンプルです。",
27
+ ],
28
+ [
29
+ "私の声を真似するのはそんなに面白いですか?",
30
+ "日本語",
31
+ "no-accent",
32
+ f"{_prompts_dir}/ja-2.ogg",
33
+ None,
34
+ "初めまして、朝武よしのです。",
35
+ ],
36
+ [
37
+ "This is how this machine has taken my voice.",
38
+ "English",
39
+ "no-accent",
40
+ f"{_prompts_dir}/en-2.wav",
41
+ None,
42
+ "Wow, look at that! That's no ordinary Teddy bear!",
43
+ ],
44
+ [
45
+ "我喜欢抽电子烟,尤其是锐刻五代。",
46
+ "中文",
47
+ "no-accent",
48
+ f"{_prompts_dir}/zh-1.wav",
49
+ None,
50
+ "今天我很荣幸,",
51
+ ],
52
+ [
53
+ "你可以听得出来我有多困。",
54
+ "中文",
55
+ "no-accent",
56
+ f"{_prompts_dir}/en-1.wav",
57
+ None,
58
+ "",
59
+ ],
60
+ [
61
+ "この文は、クロスリンガル合成の例です。",
62
+ "日本語",
63
+ "no-accent",
64
+ f"{_prompts_dir}/zh-2.wav",
65
+ None,
66
+ "",
67
+ ],
68
+ [
69
+ "Actually, I can't speak English, but this machine helped me do it.",
70
+ "English",
71
+ "no-accent",
72
+ f"{_prompts_dir}/ja-1.wav",
73
+ None,
74
+ "",
75
+ ],
76
+ ]
77
+
78
+ make_npz_prompt_examples = [
79
+ [
80
+ "Gem-trader",
81
+ f"{_prompts_dir}/en-2.wav",
82
+ None,
83
+ "Wow, look at that! That's no ordinary Teddy bear!",
84
+ ],
85
+ ["Ding Zhen", f"{_prompts_dir}/zh-1.wav", None, "今天我很荣幸,"],
86
+ ["Yoshino", f"{_prompts_dir}/ja-2.ogg", None, "初めまして、朝武よしのです。"],
87
+ ["Sleepy-woman", f"{_prompts_dir}/en-1.wav", None, ""],
88
+ ["Yae", f"{_prompts_dir}/zh-2.wav", None, ""],
89
+ ["Cafe", f"{_prompts_dir}/ja-1.wav", None, ""],
90
+ ]
91
+
92
+ infer_from_prompt_examples = [
93
+ [
94
+ "A prompt contains voice, prosody and emotion information of a certain speaker.",
95
+ "English",
96
+ "no-accent",
97
+ f"{_prompts_dir}/vctk_1",
98
+ None,
99
+ ],
100
+ [
101
+ "This prompt is made with an audio of three seconds.",
102
+ "English",
103
+ "no-accent",
104
+ f"{_prompts_dir}/librispeech_1",
105
+ None,
106
+ ],
107
+ ["This prompt is made with Chinese speech", "English", "no-accent", "seel", None],
108
+ ]
apps/audio_cloning/vallex/g2p/__init__.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """from https://github.com/keithito/tacotron"""
2
+
3
+ # import utils.g2p.cleaners
4
+ from tokenizers import Tokenizer
5
+
6
+ import apps.audio_cloning.vallex.g2p.cleaners as cleaners
7
+
8
+ from .symbols import symbols
9
+
10
+ # Mappings from symbol to numeric ID and vice versa:
11
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
12
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
13
+
14
+ TOKENIZER_PATH = "./apps/audio_cloning/vallex/g2p/bpe_1024.json"
15
+
16
+
17
+ class PhonemeBpeTokenizer:
18
+ def __init__(self, tokenizer_path=TOKENIZER_PATH):
19
+ print(f"Initializing PhonemeBpeTokenizer with tokenizer path: {tokenizer_path}")
20
+ self.tokenizer = Tokenizer.from_file(tokenizer_path)
21
+
22
+ def tokenize(self, text):
23
+ # 1. convert text to phoneme
24
+ phonemes, langs = _clean_text(text, ["cje_cleaners"])
25
+ # 2. replace blank space " " with "_"
26
+ phonemes = phonemes.replace(" ", "_")
27
+ # 3. tokenize phonemes
28
+ phoneme_tokens = self.tokenizer.encode(phonemes).ids
29
+ assert len(phoneme_tokens) == len(langs)
30
+ if not len(phoneme_tokens):
31
+ raise ValueError("Empty text is given")
32
+ return phoneme_tokens, langs
33
+
34
+
35
+ def text_to_sequence(text, cleaner_names):
36
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
37
+ Args:
38
+ text: string to convert to a sequence
39
+ cleaner_names: names of the cleaner functions to run the text through
40
+ Returns:
41
+ List of integers corresponding to the symbols in the text
42
+ """
43
+ sequence = []
44
+ symbol_to_id = {s: i for i, s in enumerate(symbols)}
45
+ clean_text = _clean_text(text, cleaner_names)
46
+ for symbol in clean_text:
47
+ if symbol not in symbol_to_id.keys():
48
+ continue
49
+ symbol_id = symbol_to_id[symbol]
50
+ sequence += [symbol_id]
51
+ return sequence
52
+
53
+
54
+ def cleaned_text_to_sequence(cleaned_text):
55
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
56
+ Args:
57
+ text: string to convert to a sequence
58
+ Returns:
59
+ List of integers corresponding to the symbols in the text
60
+ """
61
+ sequence = [
62
+ _symbol_to_id[symbol]
63
+ for symbol in cleaned_text
64
+ if symbol in _symbol_to_id.keys()
65
+ ]
66
+ return sequence
67
+
68
+
69
+ def sequence_to_text(sequence):
70
+ """Converts a sequence of IDs back to a string"""
71
+ result = ""
72
+ for symbol_id in sequence:
73
+ s = _id_to_symbol[symbol_id]
74
+ result += s
75
+ return result
76
+
77
+
78
+ def _clean_text(text, cleaner_names):
79
+ for name in cleaner_names:
80
+ cleaner = getattr(cleaners, name)
81
+ if not cleaner:
82
+ raise Exception("Unknown cleaner: %s" % name)
83
+ text, langs = cleaner(text)
84
+ return text, langs
apps/audio_cloning/vallex/g2p/bpe_1024.json ADDED
@@ -0,0 +1,2049 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
5
+ "added_tokens": [
6
+ {
7
+ "id": 0,
8
+ "content": "[UNK]",
9
+ "single_word": false,
10
+ "lstrip": false,
11
+ "rstrip": false,
12
+ "normalized": false,
13
+ "special": true
14
+ },
15
+ {
16
+ "id": 1,
17
+ "content": "[CLS]",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false,
22
+ "special": true
23
+ },
24
+ {
25
+ "id": 2,
26
+ "content": "[SEP]",
27
+ "single_word": false,
28
+ "lstrip": false,
29
+ "rstrip": false,
30
+ "normalized": false,
31
+ "special": true
32
+ },
33
+ {
34
+ "id": 3,
35
+ "content": "[PAD]",
36
+ "single_word": false,
37
+ "lstrip": false,
38
+ "rstrip": false,
39
+ "normalized": false,
40
+ "special": true
41
+ },
42
+ {
43
+ "id": 4,
44
+ "content": "[MASK]",
45
+ "single_word": false,
46
+ "lstrip": false,
47
+ "rstrip": false,
48
+ "normalized": false,
49
+ "special": true
50
+ }
51
+ ],
52
+ "normalizer": null,
53
+ "pre_tokenizer": {
54
+ "type": "Whitespace"
55
+ },
56
+ "post_processor": null,
57
+ "decoder": null,
58
+ "model": {
59
+ "type": "BPE",
60
+ "dropout": null,
61
+ "unk_token": "[UNK]",
62
+ "continuing_subword_prefix": null,
63
+ "end_of_word_suffix": null,
64
+ "fuse_unk": false,
65
+ "byte_fallback": false,
66
+ "vocab": {
67
+ "[UNK]": 0,
68
+ "[CLS]": 1,
69
+ "[SEP]": 2,
70
+ "[PAD]": 3,
71
+ "[MASK]": 4,
72
+ "!": 5,
73
+ "#": 6,
74
+ "*": 7,
75
+ ",": 8,
76
+ "-": 9,
77
+ ".": 10,
78
+ "=": 11,
79
+ "?": 12,
80
+ "N": 13,
81
+ "Q": 14,
82
+ "^": 15,
83
+ "_": 16,
84
+ "`": 17,
85
+ "a": 18,
86
+ "b": 19,
87
+ "d": 20,
88
+ "e": 21,
89
+ "f": 22,
90
+ "g": 23,
91
+ "h": 24,
92
+ "i": 25,
93
+ "j": 26,
94
+ "k": 27,
95
+ "l": 28,
96
+ "m": 29,
97
+ "n": 30,
98
+ "o": 31,
99
+ "p": 32,
100
+ "s": 33,
101
+ "t": 34,
102
+ "u": 35,
103
+ "v": 36,
104
+ "w": 37,
105
+ "x": 38,
106
+ "y": 39,
107
+ "z": 40,
108
+ "~": 41,
109
+ "æ": 42,
110
+ "ç": 43,
111
+ "ð": 44,
112
+ "ŋ": 45,
113
+ "ɑ": 46,
114
+ "ɔ": 47,
115
+ "ə": 48,
116
+ "ɛ": 49,
117
+ "ɥ": 50,
118
+ "ɪ": 51,
119
+ "ɫ": 52,
120
+ "ɯ": 53,
121
+ "ɸ": 54,
122
+ "ɹ": 55,
123
+ "ɾ": 56,
124
+ "ʃ": 57,
125
+ "ʊ": 58,
126
+ "ʑ": 59,
127
+ "ʒ": 60,
128
+ "ʰ": 61,
129
+ "ˈ": 62,
130
+ "ˌ": 63,
131
+ "θ": 64,
132
+ "…": 65,
133
+ "⁼": 66,
134
+ "↑": 67,
135
+ "→": 68,
136
+ "↓": 69,
137
+ "_t": 70,
138
+ "↓↑": 71,
139
+ "_ˈ": 72,
140
+ "ən": 73,
141
+ "_s": 74,
142
+ "aɪ": 75,
143
+ "əɹ": 76,
144
+ "eɪ": 77,
145
+ "oʊ": 78,
146
+ "_k": 79,
147
+ "ʃi": 80,
148
+ "_w": 81,
149
+ "_ð": 82,
150
+ "ts": 83,
151
+ "tʃ": 84,
152
+ "_ts": 85,
153
+ "_h": 86,
154
+ "_ə": 87,
155
+ "_m": 88,
156
+ "an": 89,
157
+ "_n": 90,
158
+ "_ðə": 91,
159
+ "ɛn": 92,
160
+ "ɑʊ": 93,
161
+ "ɑŋ": 94,
162
+ "`⁼": 95,
163
+ "_p": 96,
164
+ "_i": 97,
165
+ "_ɪ": 98,
166
+ "_tʃ": 99,
167
+ "_l": 100,
168
+ "jɛn": 101,
169
+ "_d": 102,
170
+ "_f": 103,
171
+ "_j": 104,
172
+ "wo": 105,
173
+ "_b": 106,
174
+ "ta": 107,
175
+ "`↓": 108,
176
+ "te": 109,
177
+ "ənd": 110,
178
+ "_ʃi": 111,
179
+ "wa": 112,
180
+ "ka": 113,
181
+ "ɪŋ": 114,
182
+ "in": 115,
183
+ "st": 116,
184
+ "li": 117,
185
+ "ʊŋ": 118,
186
+ "_tɪ": 119,
187
+ "to": 120,
188
+ "weɪ": 121,
189
+ "_ənd": 122,
190
+ "ʰi": 123,
191
+ "_əv": 124,
192
+ "əŋ": 125,
193
+ "no": 126,
194
+ "_x": 127,
195
+ "ɾɯ": 128,
196
+ "na": 129,
197
+ "_a": 130,
198
+ "_ɹ": 131,
199
+ "ɪn": 132,
200
+ "ga": 133,
201
+ "de": 134,
202
+ "joʊ": 135,
203
+ "æn": 136,
204
+ "kɯ": 137,
205
+ "ɾe": 138,
206
+ "ma": 139,
207
+ "_ðə_ˈ": 140,
208
+ "ɾa": 141,
209
+ "ɛɹ": 142,
210
+ "mo": 143,
211
+ "ɔɹ": 144,
212
+ "əɫ": 145,
213
+ "_g": 146,
214
+ "da": 147,
215
+ "*↑": 148,
216
+ "ɪˈ": 149,
217
+ "_o": 150,
218
+ "_ʃ": 151,
219
+ "iŋ": 152,
220
+ "ja": 153,
221
+ "əm": 154,
222
+ "_ˌ": 155,
223
+ "aʊ": 156,
224
+ "_əˈ": 157,
225
+ "`↑": 158,
226
+ "ət": 159,
227
+ "_aɪ": 160,
228
+ "oo": 161,
229
+ "sɯ": 162,
230
+ "↓.": 163,
231
+ "_ɪn": 164,
232
+ "_hi": 165,
233
+ "_wɪ": 166,
234
+ "ɪz": 167,
235
+ "_na": 168,
236
+ "wan": 169,
237
+ "_ko": 170,
238
+ "_wo": 171,
239
+ "ɪd": 172,
240
+ "ɾi": 173,
241
+ "_ju": 174,
242
+ "mə": 175,
243
+ "_lə": 176,
244
+ "_hæ": 177,
245
+ "_ðət": 178,
246
+ "ɑɹ": 179,
247
+ "tʰ": 180,
248
+ "ki": 181,
249
+ "……": 182,
250
+ "ɑz": 183,
251
+ "_ɔ": 184,
252
+ "_mi": 185,
253
+ "_wɑz": 186,
254
+ "_ˈs": 187,
255
+ "↓,": 188,
256
+ "_tʰ": 189,
257
+ "əˈ": 190,
258
+ "dʑ": 191,
259
+ "ɪt": 192,
260
+ "_kʰ": 193,
261
+ "iɛ": 194,
262
+ "_ma": 195,
263
+ "ɪs": 196,
264
+ "tsɯ": 197,
265
+ "_ni": 198,
266
+ "_ɪt": 199,
267
+ "ke": 200,
268
+ "iɑʊ": 201,
269
+ "_ka": 202,
270
+ "_əɹ": 203,
271
+ "nd": 204,
272
+ "_ˈp": 205,
273
+ "ko": 206,
274
+ "jo": 207,
275
+ "ɹi": 208,
276
+ "mən": 209,
277
+ "ʊd": 210,
278
+ "_ˈm": 211,
279
+ "_fəɹ": 212,
280
+ "tʃʰi": 213,
281
+ "sa": 214,
282
+ "ʰɥ": 215,
283
+ "kʰ": 216,
284
+ "ˈs": 217,
285
+ "ɑt": 218,
286
+ "ɛd": 219,
287
+ "se": 220,
288
+ "tʃi": 221,
289
+ "ɛɫ": 222,
290
+ "_ˈk": 223,
291
+ "_joʊ": 224,
292
+ "təɹ": 225,
293
+ "ɛz": 226,
294
+ "--": 227,
295
+ "vəɹ": 228,
296
+ "`→": 229,
297
+ "ʃən": 230,
298
+ "_ɪz": 231,
299
+ "_meɪ": 232,
300
+ "_æ": 233,
301
+ "dʒ": 234,
302
+ "_ki": 235,
303
+ "_hɪz": 236,
304
+ "_bi": 237,
305
+ "uɑŋ": 238,
306
+ "_ˈf": 239,
307
+ "↓↑.": 240,
308
+ "_wɪθ": 241,
309
+ "ju": 242,
310
+ "iɑŋ": 243,
311
+ "→.": 244,
312
+ "_so": 245,
313
+ "_həɹ": 246,
314
+ "↑.": 247,
315
+ "ni": 248,
316
+ "_mo": 249,
317
+ "_maɪ": 250,
318
+ "laɪ": 251,
319
+ "ɥɛ": 252,
320
+ "_ta": 253,
321
+ "ənt": 254,
322
+ "_tʃʰi": 255,
323
+ "_sɯ": 256,
324
+ "_θ": 257,
325
+ "_ɛz": 258,
326
+ "wən": 259,
327
+ "me": 260,
328
+ "mi": 261,
329
+ "_hæd": 262,
330
+ "_ha": 263,
331
+ "əs": 264,
332
+ "_ˈl": 265,
333
+ "_st": 266,
334
+ "ðəɹ": 267,
335
+ "oʊn": 268,
336
+ "_wa": 269,
337
+ "ʰəŋ": 270,
338
+ "_nɑt": 271,
339
+ "*.": 272,
340
+ "kt": 273,
341
+ "_ˈh": 274,
342
+ "do": 275,
343
+ "ɥæn": 276,
344
+ "ne": 277,
345
+ "_to": 278,
346
+ "_wən": 279,
347
+ "_no": 280,
348
+ "_laɪ": 281,
349
+ "_wəɹ": 282,
350
+ "↑,": 283,
351
+ "→,": 284,
352
+ "ɛs": 285,
353
+ "↓↑,": 286,
354
+ "_ɔn": 287,
355
+ "ʰu": 288,
356
+ "so": 289,
357
+ "_ˈb": 290,
358
+ "ɫd": 291,
359
+ "ɪk": 292,
360
+ "ɪst": 293,
361
+ "_fɹ": 294,
362
+ "_ðɛɹ": 295,
363
+ "_weɪ": 296,
364
+ "kaɾa": 297,
365
+ "_ˈd": 298,
366
+ "_hæv": 299,
367
+ "tsʰ": 300,
368
+ "waɪ": 301,
369
+ "ɾo": 302,
370
+ "ɛm": 303,
371
+ "_æt": 304,
372
+ "ʊɹ": 305,
373
+ "_ˈw": 306,
374
+ "ba": 307,
375
+ "_noʊ": 308,
376
+ "ʰjɛn": 309,
377
+ "ɹeɪ": 310,
378
+ "_jo": 311,
379
+ "ɸɯ": 312,
380
+ "_sa": 313,
381
+ "_ɹɪˈ": 314,
382
+ "_ˈn": 315,
383
+ "ai": 316,
384
+ "_bət": 317,
385
+ "ɪɹ": 318,
386
+ "tʃʰɥ": 319,
387
+ "_dʑ": 320,
388
+ "əˌ": 321,
389
+ "_ðɪs": 322,
390
+ "..": 323,
391
+ "xwa": 324,
392
+ "_ɪm": 325,
393
+ "_dɪˈ": 326,
394
+ "_kən": 327,
395
+ "dʑi": 328,
396
+ "*,": 329,
397
+ "ɑn": 330,
398
+ "_ʃiɑŋ": 331,
399
+ "_kɯ": 332,
400
+ "ʃin": 333,
401
+ "_soʊ": 334,
402
+ "bi": 335,
403
+ "tʰjɛn": 336,
404
+ "te_i": 337,
405
+ "_tsʰ": 338,
406
+ "_ɯ": 339,
407
+ "aɪt": 340,
408
+ "ʰiŋ": 341,
409
+ "ðə": 342,
410
+ "_ɔɫ": 343,
411
+ "_ˈɹ": 344,
412
+ "nai": 345,
413
+ "əɹd": 346,
414
+ "_ˈt": 347,
415
+ "_ən": 348,
416
+ "_tʃʰɥ": 349,
417
+ "_iɛ": 350,
418
+ "leɪ": 351,
419
+ "ɛɹi": 352,
420
+ "ˈt": 353,
421
+ "ha": 354,
422
+ "ʃiŋ": 355,
423
+ "ɛvəɹ": 356,
424
+ "zɯ": 357,
425
+ "_wi": 358,
426
+ "_ja": 359,
427
+ "ɛk": 360,
428
+ "ʰɑŋ": 361,
429
+ "_tsɯ": 362,
430
+ "_əv_ðə": 363,
431
+ "taʃi": 364,
432
+ "_sɛd": 365,
433
+ "_xə": 366,
434
+ "_li": 367,
435
+ "_si": 368,
436
+ "desɯ": 369,
437
+ "_ˌɪn": 370,
438
+ "ʃjɛn": 371,
439
+ "_baɪ": 372,
440
+ "on": 373,
441
+ "_xɑʊ": 374,
442
+ "_ðeɪ": 375,
443
+ "_xaɪ": 376,
444
+ "`↓↑": 377,
445
+ "xweɪ": 378,
446
+ "hi": 379,
447
+ "_se": 380,
448
+ "ə_s": 381,
449
+ "_fɹəm": 382,
450
+ "ʊt": 383,
451
+ "di": 384,
452
+ "aʊt": 385,
453
+ "əb": 386,
454
+ "sɹ": 387,
455
+ "əz": 388,
456
+ "_xweɪ": 389,
457
+ "_kʰə": 390,
458
+ "ɹu": 391,
459
+ "_u": 392,
460
+ "_de": 393,
461
+ "aɪd": 394,
462
+ "ɪv": 395,
463
+ "bɯ": 396,
464
+ "_ho": 397,
465
+ "əɹz": 398,
466
+ "joo": 399,
467
+ "_bɪˈ": 400,
468
+ "_tʰa": 401,
469
+ "ɛt": 402,
470
+ "en": 403,
471
+ "ɛni": 404,
472
+ "əst": 405,
473
+ "æk": 406,
474
+ "ə_ts": 407,
475
+ "_ˈɪn": 408,
476
+ "ti": 409,
477
+ "ɥn": 410,
478
+ "_dʒ": 411,
479
+ "xɑʊ": 412,
480
+ "_ˈv": 413,
481
+ "ʃiɑŋ": 414,
482
+ "pʰ": 415,
483
+ "_wɪtʃ": 416,
484
+ "eɪm": 417,
485
+ "oʊz": 418,
486
+ "əðəɹ": 419,
487
+ "fɑŋ": 420,
488
+ "_ˈg": 421,
489
+ "_do": 422,
490
+ "_ʃiɑʊ": 423,
491
+ "_ˈæ": 424,
492
+ "_jʊɹ": 425,
493
+ "_ðɛm": 426,
494
+ "ɪm": 427,
495
+ "ɛst": 428,
496
+ "ænd": 429,
497
+ "_du": 430,
498
+ "ɯɯ": 431,
499
+ "kan": 432,
500
+ "_da": 433,
501
+ "ino": 434,
502
+ "_e": 435,
503
+ "_wʊd": 436,
504
+ "ɛnd": 437,
505
+ "meɪ": 438,
506
+ "θɪŋ": 439,
507
+ "_ʃjɛn": 440,
508
+ "iz": 441,
509
+ "aɪm": 442,
510
+ "_hu": 443,
511
+ "_əˈb": 444,
512
+ "əns": 445,
513
+ "_wɪɫ": 446,
514
+ "tʰi": 447,
515
+ "go": 448,
516
+ "ɛnt": 449,
517
+ "fu": 450,
518
+ "æp": 451,
519
+ "xoʊ": 452,
520
+ "eɪk": 453,
521
+ "ʊk": 454,
522
+ "əɹˈ": 455,
523
+ "_θɪŋ": 456,
524
+ "əl": 457,
525
+ "pɹ": 458,
526
+ "ətʃ": 459,
527
+ "nt": 460,
528
+ "_ɸɯ": 461,
529
+ "lu": 462,
530
+ "_ˈɔ": 463,
531
+ "_iɑʊ": 464,
532
+ "lə": 465,
533
+ "tu": 466,
534
+ "_dʑi": 467,
535
+ "eɪt": 468,
536
+ "_ʃin": 469,
537
+ "nna": 470,
538
+ "_ˈpɹ": 471,
539
+ "fən": 472,
540
+ "_əp": 473,
541
+ "njɛn": 474,
542
+ "_aʊt": 475,
543
+ "fɔɹ": 476,
544
+ "_tu": 477,
545
+ "eɪʃən": 478,
546
+ "ɪɫ": 479,
547
+ "_wət": 480,
548
+ "_ɪf": 481,
549
+ "_ɥ": 482,
550
+ "_fa": 483,
551
+ "ˈw": 484,
552
+ "tʃʰjɛn": 485,
553
+ "_wɪn": 486,
554
+ "oʊɫd": 487,
555
+ "_əˈp": 488,
556
+ "aʊnd": 489,
557
+ "san": 490,
558
+ "he": 491,
559
+ "_bɪn": 492,
560
+ "fa": 493,
561
+ "ɪf": 494,
562
+ "ɔŋ": 495,
563
+ "ge": 496,
564
+ "_ɪn_ðə": 497,
565
+ "miŋ": 498,
566
+ "_pɹ": 499,
567
+ "ina": 500,
568
+ "ano": 501,
569
+ "əbəɫ": 502,
570
+ "kˈs": 503,
571
+ "_ˈɛni": 504,
572
+ "nəŋ": 505,
573
+ "əd": 506,
574
+ "_əv_ðə_ˈ": 507,
575
+ "_waɪ": 508,
576
+ "_taɪm": 509,
577
+ "ˈsɛɫ": 510,
578
+ "ʃiɛ": 511,
579
+ "_kəm": 512,
580
+ "æst": 513,
581
+ "_goʊ": 514,
582
+ "mɯ": 515,
583
+ "ˈp": 516,
584
+ "_ˈst": 517,
585
+ "ə_t": 518,
586
+ "pt": 519,
587
+ "_pʰ": 520,
588
+ "ʰɹ": 521,
589
+ "ʃja": 522,
590
+ "iwa": 523,
591
+ "ɪl": 524,
592
+ "bət": 525,
593
+ "_fɑŋ": 526,
594
+ "ho": 527,
595
+ "iv": 528,
596
+ "loʊ": 529,
597
+ "be": 530,
598
+ "_laɪk": 531,
599
+ "ɪʃ": 532,
600
+ "_fu": 533,
601
+ "ze": 534,
602
+ "ə_tʃ": 535,
603
+ "ɑɹt": 536,
604
+ "ɔɹd": 537,
605
+ "tʃʰiŋ": 538,
606
+ "mp": 539,
607
+ "_ðə_s": 540,
608
+ "_əˈbaʊt": 541,
609
+ "_ˈoʊ": 542,
610
+ "kʰə": 543,
611
+ "d_tɪ": 544,
612
+ "ŋga": 545,
613
+ "əli": 546,
614
+ "_kʰan": 547,
615
+ "çi": 548,
616
+ "_ˈju": 549,
617
+ "_kʊd": 550,
618
+ "ɔɫ": 551,
619
+ "ɔt": 552,
620
+ "_ɪts": 553,
621
+ "_san": 554,
622
+ "tʃa": 555,
623
+ "i_na": 556,
624
+ "xə": 557,
625
+ "ɛkt": 558,
626
+ "_mɔɹ": 559,
627
+ "te_kɯ": 560,
628
+ "ɪdʒ": 561,
629
+ "jʊŋ": 562,
630
+ "_wan": 563,
631
+ "æt": 564,
632
+ "kat": 565,
633
+ "ˈsɛɫf": 566,
634
+ "_ke": 567,
635
+ "aɪnd": 568,
636
+ "it": 569,
637
+ "_ɑɹ": 570,
638
+ "sp": 571,
639
+ "oʊnt": 572,
640
+ "_tʃi": 573,
641
+ "tsʰɹ": 574,
642
+ "_xən": 575,
643
+ "_əˈg": 576,
644
+ "ə_k": 577,
645
+ "to_i": 578,
646
+ "_tʰi": 579,
647
+ "_iŋ": 580,
648
+ "aʊn": 581,
649
+ "gɯ": 582,
650
+ "_ɪkˈs": 583,
651
+ "ɛv": 584,
652
+ "gi": 585,
653
+ "ks": 586,
654
+ "_səm": 587,
655
+ "ana": 588,
656
+ "ɪtəɫ": 589,
657
+ "nan": 590,
658
+ "_ˈɪntu": 591,
659
+ "_hiɹ": 592,
660
+ "_te": 593,
661
+ "_naʊ": 594,
662
+ "ʃiɑʊ": 595,
663
+ "ʃo": 596,
664
+ "ɹe": 597,
665
+ "xaɪ": 598,
666
+ "_tʃʰiŋ": 599,
667
+ "_sɹ": 600,
668
+ "_haʊ": 601,
669
+ "?.": 602,
670
+ "_feɪ": 603,
671
+ "liŋ": 604,
672
+ "_ʃja": 605,
673
+ "_ˈdʒ": 606,
674
+ "_seɪ": 607,
675
+ "ˈn": 608,
676
+ "soʊ": 609,
677
+ "tʰʊŋ": 610,
678
+ "_ljoʊ": 611,
679
+ "maɪ": 612,
680
+ "_bɹ": 613,
681
+ "ɹeɪt": 614,
682
+ "_nəŋ": 615,
683
+ "ʰə": 616,
684
+ "æns": 617,
685
+ "_ˈɔl": 618,
686
+ "tatʃi": 619,
687
+ "nto": 620,
688
+ "_ˌɪnˈ": 621,
689
+ "le": 622,
690
+ "nde": 623,
691
+ "_ˈvɛɹi": 624,
692
+ "mənt": 625,
693
+ "ɾima": 626,
694
+ "_ðɛn": 627,
695
+ "_həz": 628,
696
+ "_ɹi": 629,
697
+ "ftəɹ": 630,
698
+ "_sp": 631,
699
+ "ɾewa": 632,
700
+ "ga_a": 633,
701
+ "z_əv": 634,
702
+ "_miŋ": 635,
703
+ "_tɪ_ðə": 636,
704
+ "ɹaɪ": 637,
705
+ "ɛl": 638,
706
+ "ɹæ": 639,
707
+ "_hoʊ": 640,
708
+ "xu": 641,
709
+ "oʊnli": 642,
710
+ "ŋk": 643,
711
+ "i_i": 644,
712
+ "_dɪd": 645,
713
+ "_dʒɪst": 646,
714
+ "ing": 647,
715
+ "kai": 648,
716
+ "_mæn": 649,
717
+ "_in": 650,
718
+ "zo": 651,
719
+ "əf": 652,
720
+ "dake": 653,
721
+ "_ˈsəm": 654,
722
+ "ɾɯ_no": 655,
723
+ "_go": 656,
724
+ "tʃəɹ": 657,
725
+ "ite": 658,
726
+ "`↓.": 659,
727
+ "_kʰaɪ": 660,
728
+ "sk": 661,
729
+ "ɔɹs": 662,
730
+ "_tʰiŋ": 663,
731
+ "_nə": 664,
732
+ "pəɫ": 665,
733
+ "_tɪ_bi": 666,
734
+ "ˈfɔɹ": 667,
735
+ "mu": 668,
736
+ "su": 669,
737
+ "aa": 670,
738
+ "ɪstəɹ": 671,
739
+ "ʰan": 672,
740
+ "pəɹ": 673,
741
+ "ə_p": 674,
742
+ "liɑŋ": 675,
743
+ "_v": 676,
744
+ "oʊst": 677,
745
+ "_əˈgɛn": 678,
746
+ "ənz": 679,
747
+ "No": 680,
748
+ "ɔɹt": 681,
749
+ "_səˈ": 682,
750
+ "_mɯ": 683,
751
+ "tʃʰ": 684,
752
+ "_ˈlɪtəɫ": 685,
753
+ "_xwo": 686,
754
+ "_ˌbi": 687,
755
+ "_ˈoʊvəɹ": 688,
756
+ "_çi": 689,
757
+ "_deɪ": 690,
758
+ "aɪn": 691,
759
+ "_ʃiŋ": 692,
760
+ "i_ʃi": 693,
761
+ "_tsʰaɪ": 694,
762
+ "ʃoo": 695,
763
+ "ɾoo": 696,
764
+ "bəɹ": 697,
765
+ "ʰa": 698,
766
+ "ˈɛs": 699,
767
+ "_ɪn_ðə_ˈ": 700,
768
+ "Nwa": 701,
769
+ "_ðən": 702,
770
+ "saɪ": 703,
771
+ "_ˈjuˈɛs": 704,
772
+ "nda": 705,
773
+ "_pleɪ": 706,
774
+ "ɪŋ_tɪ": 707,
775
+ "ɪti": 708,
776
+ "_me": 709,
777
+ "_ʃʊd": 710,
778
+ "_nu": 711,
779
+ "_ðə_k": 712,
780
+ "za": 713,
781
+ "_ˈɛvəɹ": 714,
782
+ "əɹn": 715,
783
+ "æd": 716,
784
+ "ˈm": 717,
785
+ "_doʊnt": 718,
786
+ "_məst": 719,
787
+ "jɯɯ": 720,
788
+ "ɑɹd": 721,
789
+ "_jɛn": 722,
790
+ "ʃɥ": 723,
791
+ "_ˈoʊnli": 724,
792
+ "_ʃo": 725,
793
+ "_liŋ": 726,
794
+ "ss": 727,
795
+ "ɑl": 728,
796
+ "dea": 729,
797
+ "ɾeta": 730,
798
+ "mjɛn": 731,
799
+ "_gʊd": 732,
800
+ "_wɔ": 733,
801
+ "imo": 734,
802
+ "no_ko": 735,
803
+ "_ɥæn": 736,
804
+ "ndʒ": 737,
805
+ "ɪʃən": 738,
806
+ "o_ʃi": 739,
807
+ "_θɪŋk": 740,
808
+ "_nan": 741,
809
+ "to_o": 742,
810
+ "_tʰʊŋ": 743,
811
+ "ljoʊ": 744,
812
+ "tai": 745,
813
+ "mə_s": 746,
814
+ "_jɯ": 747,
815
+ "_uɑŋ": 748,
816
+ "_ˌbiˈfɔɹ": 749,
817
+ "æs": 750,
818
+ "_tʃʰjɛn": 751,
819
+ "ik": 752,
820
+ "_bæk": 753,
821
+ "_ˈiv": 754,
822
+ "eɪn": 755,
823
+ "un": 756,
824
+ "la": 757,
825
+ "ˈk": 758,
826
+ "_daʊn": 759,
827
+ "anai": 760,
828
+ "_lɛ": 761,
829
+ "əɹt": 762,
830
+ "ðɛɹ": 763,
831
+ "_ˈæftəɹ": 764,
832
+ "dat": 765,
833
+ "fan": 766,
834
+ "bəɫ": 767,
835
+ "temo": 768,
836
+ "tʰa": 769,
837
+ "ɾɯ_ko": 770,
838
+ "ˈv": 771,
839
+ "feɪ": 772,
840
+ "_mətʃ": 773,
841
+ "xwo": 774,
842
+ "ɹoʊ": 775,
843
+ "_ba": 776,
844
+ "_ˈnɛvəɹ": 777,
845
+ "_meɪd": 778,
846
+ "_jʊŋ": 779,
847
+ "_əˈpɑn": 780,
848
+ "!?": 781,
849
+ "_ˈʃ": 782,
850
+ "_ðə_ˈk": 783,
851
+ "ft": 784,
852
+ "_bo": 785,
853
+ "_ɪn_ə": 786,
854
+ "tʃʰɥæn": 787,
855
+ "ˈz": 788,
856
+ "`↓,": 789,
857
+ "_bɪˈk": 790,
858
+ "ɪg": 791,
859
+ "kin": 792,
860
+ "_kl": 793,
861
+ "ɾɯ_n": 794,
862
+ "_lɑʊ": 795,
863
+ "----": 796,
864
+ "ika": 797,
865
+ "_ɹaɪt": 798,
866
+ "zd": 799,
867
+ "z_ənd": 800,
868
+ "_kjo": 801,
869
+ "xwan": 802,
870
+ "too": 803,
871
+ "_gɪt": 804,
872
+ "_liɑŋ": 805,
873
+ "ta_n": 806,
874
+ "_keɪm": 807,
875
+ "_ˈəðəɹ": 808,
876
+ "_wɛɫ": 809,
877
+ "teki": 810,
878
+ "see": 811,
879
+ "jɯ": 812,
880
+ "i_o": 813,
881
+ "to_ʃi": 814,
882
+ "fəɫ": 815,
883
+ "bo": 816,
884
+ "ˌt": 817,
885
+ "ɪp": 818,
886
+ "ane": 819,
887
+ "_tʰjɛn": 820,
888
+ "_tʃo": 821,
889
+ "ɾjo": 822,
890
+ "ɪns": 823,
891
+ "_he": 824,
892
+ "ŋka": 825,
893
+ "ʃɥɛ": 826,
894
+ "dʑa": 827,
895
+ "vd": 828,
896
+ "ʰwan": 829,
897
+ "_gɹeɪt": 830,
898
+ "_əv_ə": 831,
899
+ "əndəɹ": 832,
900
+ "kedo": 833,
901
+ "_ðə_b": 834,
902
+ "ək": 835,
903
+ "_teɪk": 836,
904
+ "kʰan": 837,
905
+ "_ˈɔlˌ": 838,
906
+ "swo": 839,
907
+ "_ɪt_wɑz": 840,
908
+ "_ʃɥ": 841,
909
+ "_sim": 842,
910
+ "_ˈfɑ": 843,
911
+ "min": 844,
912
+ "i_a": 845,
913
+ "soo": 846,
914
+ "ɛns": 847,
915
+ "_sətʃ": 848,
916
+ "tʰaɪ": 849,
917
+ "_ga": 850,
918
+ "i_ka": 851,
919
+ "koo": 852,
920
+ "_fəɹst": 853,
921
+ "_ˈtʃ": 854,
922
+ "nno": 855,
923
+ "ə_ɹ": 856,
924
+ "taɾa": 857,
925
+ "tʃʰjoʊ": 858,
926
+ "_æm": 859,
927
+ "_mu": 860,
928
+ "_meɪk": 861,
929
+ "↓…": 862,
930
+ "ɪˈθ": 863,
931
+ "ɑb": 864,
932
+ "ɹa": 865,
933
+ "_wɛɹ": 866,
934
+ "_ðə_ˈs": 867,
935
+ "_əˈl": 868,
936
+ "_oʊɫd": 869,
937
+ "æl": 870,
938
+ "_ˈpi": 871,
939
+ "_lɔŋ": 872,
940
+ "dʑo": 873,
941
+ "_tʰaɪ": 874,
942
+ "ɔɹn": 875,
943
+ "əɫz": 876,
944
+ "_təˈ": 877,
945
+ "_əˈweɪ": 878,
946
+ "pa": 879,
947
+ "_ðiz": 880,
948
+ "_ˈsp": 881,
949
+ "nn": 882,
950
+ "mae": 883,
951
+ "towa": 884,
952
+ "ta_no": 885,
953
+ "_an": 886,
954
+ "kʰaɪ": 887,
955
+ "ɾaɾe": 888,
956
+ "eɪs": 889,
957
+ "ɑd": 890,
958
+ "_wɪˈθ": 891,
959
+ "_ˈivɪn": 892,
960
+ "_lu": 893,
961
+ "ɔɪ": 894,
962
+ "lɪŋ": 895,
963
+ "əti": 896,
964
+ "_ðə_f": 897,
965
+ "oʃi": 898,
966
+ "_la": 899,
967
+ "si": 900,
968
+ "tɪd": 901,
969
+ "haʊ": 902,
970
+ "pʰin": 903,
971
+ "ˈst": 904,
972
+ "_ˈpəɹ": 905,
973
+ "eɹ": 906,
974
+ "*!": 907,
975
+ "_ˈmɪstəɹ": 908,
976
+ "ʃa": 909,
977
+ "_ˌɪm": 910,
978
+ "ˌθɪŋ": 911,
979
+ "_neɪ": 912,
980
+ "_nɥ": 913,
981
+ "ɑk": 914,
982
+ "_ɹu": 915,
983
+ "_ʃɯ": 916,
984
+ "_ðə_ˈm": 917,
985
+ "demo": 918,
986
+ "_dɹ": 919,
987
+ "dʑoo": 920,
988
+ "_stɪɫ": 921,
989
+ "_pʰiŋ": 922,
990
+ "ə_i": 923,
991
+ "_ɪkˈsp": 924,
992
+ "_wɛnt": 925,
993
+ "ɪɹi": 926,
994
+ "əˈm": 927,
995
+ "o_ka": 928,
996
+ "_əˈk": 929,
997
+ "ɔk": 930,
998
+ "_ɥɛ": 931,
999
+ "_lʊk": 932,
1000
+ "ˈd": 933,
1001
+ "kaʃi": 934,
1002
+ "_wɪθ_ə": 935,
1003
+ "ljɛn": 936,
1004
+ "ɔn": 937,
1005
+ "_ljɛn": 938,
1006
+ "_hɛɫ": 939,
1007
+ "uɹ": 940,
1008
+ "_tʰoʊ": 941,
1009
+ "_tʃʰɥæn": 942,
1010
+ "_sk": 943,
1011
+ "tsʰaɪ": 944,
1012
+ "ɛtəɹ": 945,
1013
+ "_min": 946,
1014
+ "noʊ": 947,
1015
+ "ʃɯ": 948,
1016
+ "_θɹu": 949,
1017
+ "_θɔt": 950,
1018
+ "dajo": 951,
1019
+ "wi": 952,
1020
+ "i_ko": 953,
1021
+ "_tɹ": 954,
1022
+ "_fan": 955,
1023
+ "ɹɛ": 956,
1024
+ "saN": 957,
1025
+ "_hi_wɑz": 958,
1026
+ "_ɾe": 959,
1027
+ "_əm": 960,
1028
+ "te_ki": 961,
1029
+ "_xoʊ": 962,
1030
+ "ˈl": 963,
1031
+ "ˈg": 964,
1032
+ "ga_i": 965,
1033
+ "_ɔn_ðə": 966,
1034
+ "_xwa": 967,
1035
+ "vɪŋ": 968,
1036
+ "man": 969,
1037
+ "fəɹ": 970,
1038
+ "_oʊn": 971,
1039
+ "ˈɹ": 972,
1040
+ "_kɹ": 973,
1041
+ "te_o": 974,
1042
+ "ɪli": 975,
1043
+ "_ʃɥɛ": 976,
1044
+ "_fəŋ": 977,
1045
+ "æɫ": 978,
1046
+ "ɑp": 979,
1047
+ "_ˈɛv": 980,
1048
+ "eɪndʒ": 981,
1049
+ "iɫ": 982,
1050
+ "wət": 983,
1051
+ "ɛðəɹ": 984,
1052
+ "_fən": 985,
1053
+ "ɾee": 986,
1054
+ "_hi_hæd": 987,
1055
+ "_maɪt": 988,
1056
+ "_ge": 989,
1057
+ "ækt": 990,
1058
+ "ɪts": 991,
1059
+ "_hɪm": 992,
1060
+ "_ze": 993,
1061
+ "ii": 994,
1062
+ "_N": 995,
1063
+ "_əv_hɪz": 996,
1064
+ "_gɹ": 997,
1065
+ "ænt": 998,
1066
+ "ɪˌ": 999,
1067
+ "_hɪmˈsɛɫf": 1000,
1068
+ "wa_na": 1001,
1069
+ "aɪəɹ": 1002,
1070
+ "dʑanai": 1003,
1071
+ "kana": 1004,
1072
+ "aɪz": 1005,
1073
+ "_ɪt_ɪz": 1006,
1074
+ "mase": 1007,
1075
+ "wɪn": 1008,
1076
+ "əθɪŋ": 1009,
1077
+ "_pɹəˈ": 1010,
1078
+ "kɯn": 1011,
1079
+ "ˈju": 1012,
1080
+ "_fɔɹ": 1013,
1081
+ "pʰi": 1014,
1082
+ "pʰiŋ": 1015,
1083
+ "o_i": 1016,
1084
+ "vz": 1017,
1085
+ "ɔɪn": 1018,
1086
+ "tʰiŋ": 1019,
1087
+ "_ne": 1020,
1088
+ "gəɹ": 1021,
1089
+ "æts": 1022,
1090
+ "_ˈɹi": 1023
1091
+ },
1092
+ "merges": [
1093
+ "_ t",
1094
+ "↓ ↑",
1095
+ "_ ˈ",
1096
+ "ə n",
1097
+ "_ s",
1098
+ "a ɪ",
1099
+ "ə ɹ",
1100
+ "e ɪ",
1101
+ "o ʊ",
1102
+ "_ k",
1103
+ "ʃ i",
1104
+ "_ w",
1105
+ "_ ð",
1106
+ "t s",
1107
+ "t ʃ",
1108
+ "_t s",
1109
+ "_ h",
1110
+ "_ ə",
1111
+ "_ m",
1112
+ "a n",
1113
+ "_ n",
1114
+ "_ð ə",
1115
+ "ɛ n",
1116
+ "ɑ ʊ",
1117
+ "ɑ ŋ",
1118
+ "` ⁼",
1119
+ "_ p",
1120
+ "_ i",
1121
+ "_ ɪ",
1122
+ "_t ʃ",
1123
+ "_ l",
1124
+ "j ɛn",
1125
+ "_ d",
1126
+ "_ f",
1127
+ "_ j",
1128
+ "w o",
1129
+ "_ b",
1130
+ "t a",
1131
+ "` ↓",
1132
+ "t e",
1133
+ "ən d",
1134
+ "_ ʃi",
1135
+ "w a",
1136
+ "k a",
1137
+ "ɪ ŋ",
1138
+ "i n",
1139
+ "s t",
1140
+ "l i",
1141
+ "ʊ ŋ",
1142
+ "_t ɪ",
1143
+ "t o",
1144
+ "w eɪ",
1145
+ "_ ənd",
1146
+ "ʰ i",
1147
+ "_ə v",
1148
+ "ə ŋ",
1149
+ "n o",
1150
+ "_ x",
1151
+ "ɾ ɯ",
1152
+ "n a",
1153
+ "_ a",
1154
+ "_ ɹ",
1155
+ "ɪ n",
1156
+ "g a",
1157
+ "d e",
1158
+ "j oʊ",
1159
+ "æ n",
1160
+ "k ɯ",
1161
+ "ɾ e",
1162
+ "m a",
1163
+ "_ðə _ˈ",
1164
+ "ɾ a",
1165
+ "ɛ ɹ",
1166
+ "m o",
1167
+ "ɔ ɹ",
1168
+ "ə ɫ",
1169
+ "_ g",
1170
+ "d a",
1171
+ "* ↑",
1172
+ "ɪ ˈ",
1173
+ "_ o",
1174
+ "_ ʃ",
1175
+ "i ŋ",
1176
+ "j a",
1177
+ "ə m",
1178
+ "_ ˌ",
1179
+ "a ʊ",
1180
+ "_ə ˈ",
1181
+ "` ↑",
1182
+ "ə t",
1183
+ "_ aɪ",
1184
+ "o o",
1185
+ "s ɯ",
1186
+ "↓ .",
1187
+ "_ɪ n",
1188
+ "_h i",
1189
+ "_w ɪ",
1190
+ "ɪ z",
1191
+ "_n a",
1192
+ "w an",
1193
+ "_k o",
1194
+ "_w o",
1195
+ "ɪ d",
1196
+ "ɾ i",
1197
+ "_j u",
1198
+ "m ə",
1199
+ "_l ə",
1200
+ "_h æ",
1201
+ "_ðə t",
1202
+ "ɑ ɹ",
1203
+ "t ʰ",
1204
+ "k i",
1205
+ "… …",
1206
+ "ɑ z",
1207
+ "_ ɔ",
1208
+ "_m i",
1209
+ "_w ɑz",
1210
+ "_ˈ s",
1211
+ "↓ ,",
1212
+ "_t ʰ",
1213
+ "ə ˈ",
1214
+ "d ʑ",
1215
+ "ɪ t",
1216
+ "_k ʰ",
1217
+ "i ɛ",
1218
+ "_m a",
1219
+ "ɪ s",
1220
+ "ts ɯ",
1221
+ "_n i",
1222
+ "_ɪ t",
1223
+ "k e",
1224
+ "i ɑʊ",
1225
+ "_k a",
1226
+ "_ əɹ",
1227
+ "n d",
1228
+ "_ˈ p",
1229
+ "k o",
1230
+ "j o",
1231
+ "ɹ i",
1232
+ "m ən",
1233
+ "ʊ d",
1234
+ "_ˈ m",
1235
+ "_f əɹ",
1236
+ "tʃ ʰi",
1237
+ "s a",
1238
+ "ʰ ɥ",
1239
+ "k ʰ",
1240
+ "ˈ s",
1241
+ "ɑ t",
1242
+ "ɛ d",
1243
+ "s e",
1244
+ "t ʃi",
1245
+ "ɛ ɫ",
1246
+ "_ˈ k",
1247
+ "_j oʊ",
1248
+ "t əɹ",
1249
+ "ɛ z",
1250
+ "- -",
1251
+ "v əɹ",
1252
+ "` →",
1253
+ "ʃ ən",
1254
+ "_ɪ z",
1255
+ "_m eɪ",
1256
+ "_ æ",
1257
+ "d ʒ",
1258
+ "_k i",
1259
+ "_h ɪz",
1260
+ "_b i",
1261
+ "u ɑŋ",
1262
+ "_ˈ f",
1263
+ "↓↑ .",
1264
+ "_wɪ θ",
1265
+ "j u",
1266
+ "i ɑŋ",
1267
+ "→ .",
1268
+ "_s o",
1269
+ "_h əɹ",
1270
+ "↑ .",
1271
+ "n i",
1272
+ "_m o",
1273
+ "_m aɪ",
1274
+ "l aɪ",
1275
+ "ɥ ɛ",
1276
+ "_t a",
1277
+ "ən t",
1278
+ "_tʃ ʰi",
1279
+ "_s ɯ",
1280
+ "_ θ",
1281
+ "_ ɛz",
1282
+ "w ən",
1283
+ "m e",
1284
+ "m i",
1285
+ "_hæ d",
1286
+ "_h a",
1287
+ "ə s",
1288
+ "_ˈ l",
1289
+ "_s t",
1290
+ "ð əɹ",
1291
+ "oʊ n",
1292
+ "_w a",
1293
+ "ʰ əŋ",
1294
+ "_n ɑt",
1295
+ "* .",
1296
+ "k t",
1297
+ "_ˈ h",
1298
+ "d o",
1299
+ "ɥ æn",
1300
+ "n e",
1301
+ "_t o",
1302
+ "_w ən",
1303
+ "_n o",
1304
+ "_l aɪ",
1305
+ "_w əɹ",
1306
+ "↑ ,",
1307
+ "→ ,",
1308
+ "ɛ s",
1309
+ "↓↑ ,",
1310
+ "_ɔ n",
1311
+ "ʰ u",
1312
+ "s o",
1313
+ "_ˈ b",
1314
+ "ɫ d",
1315
+ "ɪ k",
1316
+ "ɪ st",
1317
+ "_f ɹ",
1318
+ "_ð ɛɹ",
1319
+ "_w eɪ",
1320
+ "ka ɾa",
1321
+ "_ˈ d",
1322
+ "_hæ v",
1323
+ "ts ʰ",
1324
+ "w aɪ",
1325
+ "ɾ o",
1326
+ "ɛ m",
1327
+ "_æ t",
1328
+ "ʊ ɹ",
1329
+ "_ˈ w",
1330
+ "b a",
1331
+ "_n oʊ",
1332
+ "ʰ jɛn",
1333
+ "ɹ eɪ",
1334
+ "_j o",
1335
+ "ɸ ɯ",
1336
+ "_s a",
1337
+ "_ɹ ɪˈ",
1338
+ "_ˈ n",
1339
+ "a i",
1340
+ "_b ət",
1341
+ "ɪ ɹ",
1342
+ "tʃ ʰɥ",
1343
+ "_d ʑ",
1344
+ "ə ˌ",
1345
+ "_ð ɪs",
1346
+ ". .",
1347
+ "x wa",
1348
+ "_ɪ m",
1349
+ "_d ɪˈ",
1350
+ "_k ən",
1351
+ "dʑ i",
1352
+ "* ,",
1353
+ "ɑ n",
1354
+ "_ʃi ɑŋ",
1355
+ "_k ɯ",
1356
+ "ʃi n",
1357
+ "_s oʊ",
1358
+ "b i",
1359
+ "tʰ jɛn",
1360
+ "te _i",
1361
+ "_ts ʰ",
1362
+ "_ ɯ",
1363
+ "aɪ t",
1364
+ "ʰi ŋ",
1365
+ "ð ə",
1366
+ "_ɔ ɫ",
1367
+ "_ˈ ɹ",
1368
+ "na i",
1369
+ "əɹ d",
1370
+ "_ˈ t",
1371
+ "_ ən",
1372
+ "_tʃ ʰɥ",
1373
+ "_i ɛ",
1374
+ "l eɪ",
1375
+ "ɛɹ i",
1376
+ "ˈ t",
1377
+ "h a",
1378
+ "ʃi ŋ",
1379
+ "ɛ vəɹ",
1380
+ "z ɯ",
1381
+ "_w i",
1382
+ "_j a",
1383
+ "ɛ k",
1384
+ "ʰ ɑŋ",
1385
+ "_ts ɯ",
1386
+ "_əv _ðə",
1387
+ "ta ʃi",
1388
+ "_s ɛd",
1389
+ "_x ə",
1390
+ "_l i",
1391
+ "_s i",
1392
+ "de sɯ",
1393
+ "_ˌ ɪn",
1394
+ "ʃ jɛn",
1395
+ "_b aɪ",
1396
+ "o n",
1397
+ "_x ɑʊ",
1398
+ "_ð eɪ",
1399
+ "_x aɪ",
1400
+ "` ↓↑",
1401
+ "x weɪ",
1402
+ "h i",
1403
+ "_s e",
1404
+ "ə _s",
1405
+ "_fɹ əm",
1406
+ "ʊ t",
1407
+ "d i",
1408
+ "aʊ t",
1409
+ "ə b",
1410
+ "s ɹ",
1411
+ "ə z",
1412
+ "_x weɪ",
1413
+ "_kʰ ə",
1414
+ "ɹ u",
1415
+ "_ u",
1416
+ "_d e",
1417
+ "aɪ d",
1418
+ "ɪ v",
1419
+ "b ɯ",
1420
+ "_h o",
1421
+ "əɹ z",
1422
+ "j oo",
1423
+ "_b ɪˈ",
1424
+ "_tʰ a",
1425
+ "ɛ t",
1426
+ "e n",
1427
+ "ɛn i",
1428
+ "ə st",
1429
+ "æ k",
1430
+ "ə _ts",
1431
+ "_ˈ ɪn",
1432
+ "t i",
1433
+ "ɥ n",
1434
+ "_d ʒ",
1435
+ "x ɑʊ",
1436
+ "_ˈ v",
1437
+ "ʃi ɑŋ",
1438
+ "p ʰ",
1439
+ "_wɪ tʃ",
1440
+ "eɪ m",
1441
+ "oʊ z",
1442
+ "ə ðəɹ",
1443
+ "f ɑŋ",
1444
+ "_ˈ g",
1445
+ "_d o",
1446
+ "_ʃi ɑʊ",
1447
+ "_ˈ æ",
1448
+ "_j ʊɹ",
1449
+ "_ð ɛm",
1450
+ "ɪ m",
1451
+ "ɛ st",
1452
+ "æn d",
1453
+ "_d u",
1454
+ "ɯ ɯ",
1455
+ "k an",
1456
+ "_d a",
1457
+ "in o",
1458
+ "_ e",
1459
+ "_w ʊd",
1460
+ "ɛn d",
1461
+ "m eɪ",
1462
+ "θ ɪŋ",
1463
+ "_ʃ jɛn",
1464
+ "i z",
1465
+ "aɪ m",
1466
+ "_h u",
1467
+ "_əˈ b",
1468
+ "ən s",
1469
+ "_wɪ ɫ",
1470
+ "t ʰi",
1471
+ "g o",
1472
+ "ɛn t",
1473
+ "f u",
1474
+ "æ p",
1475
+ "x oʊ",
1476
+ "eɪ k",
1477
+ "ʊ k",
1478
+ "əɹ ˈ",
1479
+ "_θ ɪŋ",
1480
+ "ə l",
1481
+ "p ɹ",
1482
+ "ə tʃ",
1483
+ "n t",
1484
+ "_ ɸɯ",
1485
+ "l u",
1486
+ "_ˈ ɔ",
1487
+ "_i ɑʊ",
1488
+ "l ə",
1489
+ "t u",
1490
+ "_dʑ i",
1491
+ "eɪ t",
1492
+ "_ʃi n",
1493
+ "n na",
1494
+ "_ˈp ɹ",
1495
+ "f ən",
1496
+ "_ə p",
1497
+ "n jɛn",
1498
+ "_a ʊt",
1499
+ "f ɔɹ",
1500
+ "_t u",
1501
+ "eɪ ʃən",
1502
+ "ɪ ɫ",
1503
+ "_w ət",
1504
+ "_ɪ f",
1505
+ "_ ɥ",
1506
+ "_f a",
1507
+ "ˈ w",
1508
+ "tʃ ʰjɛn",
1509
+ "_w ɪn",
1510
+ "oʊ ɫd",
1511
+ "_əˈ p",
1512
+ "aʊ nd",
1513
+ "s an",
1514
+ "h e",
1515
+ "_b ɪn",
1516
+ "f a",
1517
+ "ɪ f",
1518
+ "ɔ ŋ",
1519
+ "g e",
1520
+ "_ɪn _ðə",
1521
+ "m iŋ",
1522
+ "_p ɹ",
1523
+ "in a",
1524
+ "an o",
1525
+ "əb əɫ",
1526
+ "k ˈs",
1527
+ "_ˈ ɛni",
1528
+ "n əŋ",
1529
+ "ə d",
1530
+ "_əv _ðə_ˈ",
1531
+ "_w aɪ",
1532
+ "_t aɪm",
1533
+ "ˈs ɛɫ",
1534
+ "ʃi ɛ",
1535
+ "_k əm",
1536
+ "æ st",
1537
+ "_g oʊ",
1538
+ "m ɯ",
1539
+ "ˈ p",
1540
+ "_ˈ st",
1541
+ "ə _t",
1542
+ "p t",
1543
+ "_p ʰ",
1544
+ "ʰ ɹ",
1545
+ "ʃ ja",
1546
+ "i wa",
1547
+ "ɪ l",
1548
+ "b ət",
1549
+ "_f ɑŋ",
1550
+ "h o",
1551
+ "i v",
1552
+ "l oʊ",
1553
+ "b e",
1554
+ "_laɪ k",
1555
+ "ɪ ʃ",
1556
+ "_f u",
1557
+ "z e",
1558
+ "ə _tʃ",
1559
+ "ɑɹ t",
1560
+ "ɔɹ d",
1561
+ "tʃʰi ŋ",
1562
+ "m p",
1563
+ "_ðə _s",
1564
+ "_əˈb aʊt",
1565
+ "_ˈ oʊ",
1566
+ "kʰ ə",
1567
+ "d _tɪ",
1568
+ "ŋ ga",
1569
+ "ə li",
1570
+ "_kʰ an",
1571
+ "ç i",
1572
+ "_ˈ ju",
1573
+ "_k ʊd",
1574
+ "ɔ ɫ",
1575
+ "ɔ t",
1576
+ "_ɪ ts",
1577
+ "_s an",
1578
+ "tʃ a",
1579
+ "i _na",
1580
+ "x ə",
1581
+ "ɛ kt",
1582
+ "_m ɔɹ",
1583
+ "te _kɯ",
1584
+ "ɪd ʒ",
1585
+ "j ʊŋ",
1586
+ "_w an",
1587
+ "æ t",
1588
+ "ka t",
1589
+ "ˈsɛɫ f",
1590
+ "_k e",
1591
+ "aɪ nd",
1592
+ "i t",
1593
+ "_ ɑɹ",
1594
+ "s p",
1595
+ "oʊn t",
1596
+ "_t ʃi",
1597
+ "tsʰ ɹ",
1598
+ "_x ən",
1599
+ "_əˈ g",
1600
+ "ə _k",
1601
+ "to _i",
1602
+ "_t ʰi",
1603
+ "_i ŋ",
1604
+ "aʊ n",
1605
+ "g ɯ",
1606
+ "_ɪ kˈs",
1607
+ "ɛ v",
1608
+ "g i",
1609
+ "k s",
1610
+ "_s əm",
1611
+ "an a",
1612
+ "ɪt əɫ",
1613
+ "n an",
1614
+ "_ˈɪn tu",
1615
+ "_hi ɹ",
1616
+ "_t e",
1617
+ "_n aʊ",
1618
+ "ʃi ɑʊ",
1619
+ "ʃ o",
1620
+ "ɹ e",
1621
+ "x aɪ",
1622
+ "_tʃʰi ŋ",
1623
+ "_s ɹ",
1624
+ "_h aʊ",
1625
+ "? .",
1626
+ "_f eɪ",
1627
+ "li ŋ",
1628
+ "_ʃ ja",
1629
+ "_ˈ dʒ",
1630
+ "_s eɪ",
1631
+ "ˈ n",
1632
+ "s oʊ",
1633
+ "tʰ ʊŋ",
1634
+ "_l joʊ",
1635
+ "m aɪ",
1636
+ "_b ɹ",
1637
+ "ɹeɪ t",
1638
+ "_n əŋ",
1639
+ "ʰ ə",
1640
+ "æn s",
1641
+ "_ˈɔ l",
1642
+ "ta tʃi",
1643
+ "n to",
1644
+ "_ˌɪn ˈ",
1645
+ "l e",
1646
+ "n de",
1647
+ "_ˈv ɛɹi",
1648
+ "mən t",
1649
+ "ɾi ma",
1650
+ "_ð ɛn",
1651
+ "_h əz",
1652
+ "_ɹ i",
1653
+ "f təɹ",
1654
+ "_s p",
1655
+ "ɾe wa",
1656
+ "ga _a",
1657
+ "z _əv",
1658
+ "_m iŋ",
1659
+ "_tɪ _ðə",
1660
+ "ɹ aɪ",
1661
+ "ɛ l",
1662
+ "ɹ æ",
1663
+ "_h oʊ",
1664
+ "x u",
1665
+ "oʊn li",
1666
+ "ŋ k",
1667
+ "i _i",
1668
+ "_d ɪd",
1669
+ "_dʒ ɪst",
1670
+ "in g",
1671
+ "ka i",
1672
+ "_m æn",
1673
+ "_i n",
1674
+ "z o",
1675
+ "ə f",
1676
+ "da ke",
1677
+ "_ˈs əm",
1678
+ "ɾɯ _no",
1679
+ "_g o",
1680
+ "tʃ əɹ",
1681
+ "i te",
1682
+ "`↓ .",
1683
+ "_kʰ aɪ",
1684
+ "s k",
1685
+ "ɔɹ s",
1686
+ "_t ʰiŋ",
1687
+ "_n ə",
1688
+ "p əɫ",
1689
+ "_tɪ _bi",
1690
+ "ˈ fɔɹ",
1691
+ "m u",
1692
+ "s u",
1693
+ "a a",
1694
+ "ɪst əɹ",
1695
+ "ʰ an",
1696
+ "p əɹ",
1697
+ "ə _p",
1698
+ "li ɑŋ",
1699
+ "_ v",
1700
+ "oʊ st",
1701
+ "_əˈg ɛn",
1702
+ "ən z",
1703
+ "N o",
1704
+ "ɔɹ t",
1705
+ "_s əˈ",
1706
+ "_m ɯ",
1707
+ "tʃ ʰ",
1708
+ "_ˈl ɪtəɫ",
1709
+ "_x wo",
1710
+ "_ˌ bi",
1711
+ "_ˈoʊ vəɹ",
1712
+ "_ çi",
1713
+ "_d eɪ",
1714
+ "aɪ n",
1715
+ "_ʃi ŋ",
1716
+ "i _ʃi",
1717
+ "_tsʰ aɪ",
1718
+ "ʃ oo",
1719
+ "ɾ oo",
1720
+ "b əɹ",
1721
+ "ʰ a",
1722
+ "ˈ ɛs",
1723
+ "_ɪn _ðə_ˈ",
1724
+ "N wa",
1725
+ "_ð ən",
1726
+ "s aɪ",
1727
+ "_ˈju ˈɛs",
1728
+ "n da",
1729
+ "_p leɪ",
1730
+ "ɪŋ _tɪ",
1731
+ "ɪt i",
1732
+ "_m e",
1733
+ "_ʃ ʊd",
1734
+ "_n u",
1735
+ "_ðə _k",
1736
+ "z a",
1737
+ "_ˈ ɛvəɹ",
1738
+ "əɹ n",
1739
+ "æ d",
1740
+ "ˈ m",
1741
+ "_d oʊnt",
1742
+ "_m əst",
1743
+ "j ɯɯ",
1744
+ "ɑɹ d",
1745
+ "_ jɛn",
1746
+ "ʃ ɥ",
1747
+ "_ˈ oʊnli",
1748
+ "_ʃ o",
1749
+ "_l iŋ",
1750
+ "s s",
1751
+ "ɑ l",
1752
+ "de a",
1753
+ "ɾe ta",
1754
+ "m jɛn",
1755
+ "_g ʊd",
1756
+ "_w ɔ",
1757
+ "i mo",
1758
+ "no _ko",
1759
+ "_ ɥæn",
1760
+ "nd ʒ",
1761
+ "ɪ ʃən",
1762
+ "o _ʃi",
1763
+ "_θɪŋ k",
1764
+ "_n an",
1765
+ "to _o",
1766
+ "_tʰ ʊŋ",
1767
+ "l joʊ",
1768
+ "ta i",
1769
+ "mə _s",
1770
+ "_j ɯ",
1771
+ "_ uɑŋ",
1772
+ "_ˌbi ˈfɔɹ",
1773
+ "æ s",
1774
+ "_tʃ ʰjɛn",
1775
+ "i k",
1776
+ "_b æk",
1777
+ "_ˈ iv",
1778
+ "eɪ n",
1779
+ "u n",
1780
+ "l a",
1781
+ "ˈ k",
1782
+ "_d aʊn",
1783
+ "an ai",
1784
+ "_l ɛ",
1785
+ "əɹ t",
1786
+ "ð ɛɹ",
1787
+ "_ˈæ ftəɹ",
1788
+ "da t",
1789
+ "f an",
1790
+ "b əɫ",
1791
+ "te mo",
1792
+ "tʰ a",
1793
+ "ɾɯ _ko",
1794
+ "ˈ v",
1795
+ "f eɪ",
1796
+ "_m ətʃ",
1797
+ "x wo",
1798
+ "ɹ oʊ",
1799
+ "_b a",
1800
+ "_ˈn ɛvəɹ",
1801
+ "_meɪ d",
1802
+ "_j ʊŋ",
1803
+ "_əˈp ɑn",
1804
+ "! ?",
1805
+ "_ˈ ʃ",
1806
+ "_ðə_ˈ k",
1807
+ "f t",
1808
+ "_b o",
1809
+ "_ɪn _ə",
1810
+ "tʃʰɥ æn",
1811
+ "ˈ z",
1812
+ "`↓ ,",
1813
+ "_bɪˈ k",
1814
+ "ɪ g",
1815
+ "k in",
1816
+ "_k l",
1817
+ "ɾɯ _n",
1818
+ "_l ɑʊ",
1819
+ "-- --",
1820
+ "i ka",
1821
+ "_ɹ aɪt",
1822
+ "z d",
1823
+ "z _ənd",
1824
+ "_k jo",
1825
+ "x wan",
1826
+ "to o",
1827
+ "_g ɪt",
1828
+ "_l iɑŋ",
1829
+ "ta _n",
1830
+ "_k eɪm",
1831
+ "_ˈ əðəɹ",
1832
+ "_w ɛɫ",
1833
+ "te ki",
1834
+ "se e",
1835
+ "j ɯ",
1836
+ "i _o",
1837
+ "to _ʃi",
1838
+ "f əɫ",
1839
+ "b o",
1840
+ "ˌ t",
1841
+ "ɪ p",
1842
+ "an e",
1843
+ "_tʰ jɛn",
1844
+ "_tʃ o",
1845
+ "ɾ jo",
1846
+ "ɪn s",
1847
+ "_h e",
1848
+ "ŋ ka",
1849
+ "ʃ ɥɛ",
1850
+ "dʑ a",
1851
+ "v d",
1852
+ "ʰ wan",
1853
+ "_g ɹeɪt",
1854
+ "_əv _ə",
1855
+ "ənd əɹ",
1856
+ "ke do",
1857
+ "_ðə _b",
1858
+ "ə k",
1859
+ "_t eɪk",
1860
+ "kʰ an",
1861
+ "_ˈɔl ˌ",
1862
+ "s wo",
1863
+ "_ɪt _wɑz",
1864
+ "_ʃ ɥ",
1865
+ "_si m",
1866
+ "_ˈf ɑ",
1867
+ "m in",
1868
+ "i _a",
1869
+ "s oo",
1870
+ "ɛn s",
1871
+ "_s ətʃ",
1872
+ "tʰ aɪ",
1873
+ "_ ga",
1874
+ "i _ka",
1875
+ "k oo",
1876
+ "_fəɹ st",
1877
+ "_ˈ tʃ",
1878
+ "n no",
1879
+ "ə _ɹ",
1880
+ "ta ɾa",
1881
+ "tʃʰ joʊ",
1882
+ "_æ m",
1883
+ "_m u",
1884
+ "_meɪ k",
1885
+ "↓ …",
1886
+ "ɪˈ θ",
1887
+ "ɑ b",
1888
+ "ɹ a",
1889
+ "_w ɛɹ",
1890
+ "_ðə_ˈ s",
1891
+ "_əˈ l",
1892
+ "_ oʊɫd",
1893
+ "æ l",
1894
+ "_ˈp i",
1895
+ "_l ɔŋ",
1896
+ "dʑ o",
1897
+ "_tʰ aɪ",
1898
+ "ɔɹ n",
1899
+ "əɫ z",
1900
+ "_t əˈ",
1901
+ "_əˈ weɪ",
1902
+ "p a",
1903
+ "_ð iz",
1904
+ "_ˈs p",
1905
+ "n n",
1906
+ "ma e",
1907
+ "to wa",
1908
+ "ta _no",
1909
+ "_ an",
1910
+ "kʰ aɪ",
1911
+ "ɾa ɾe",
1912
+ "eɪ s",
1913
+ "ɑ d",
1914
+ "_w ɪˈθ",
1915
+ "_ˈiv ɪn",
1916
+ "_l u",
1917
+ "ɔ ɪ",
1918
+ "l ɪŋ",
1919
+ "ət i",
1920
+ "_ðə _f",
1921
+ "o ʃi",
1922
+ "_l a",
1923
+ "s i",
1924
+ "t ɪd",
1925
+ "h aʊ",
1926
+ "pʰ in",
1927
+ "ˈ st",
1928
+ "_ˈp əɹ",
1929
+ "e ɹ",
1930
+ "* !",
1931
+ "_ˈm ɪstəɹ",
1932
+ "ʃ a",
1933
+ "_ˌ ɪm",
1934
+ "ˌ θɪŋ",
1935
+ "_n eɪ",
1936
+ "_n ɥ",
1937
+ "ɑ k",
1938
+ "_ɹ u",
1939
+ "_ʃ ɯ",
1940
+ "_ðə_ˈ m",
1941
+ "de mo",
1942
+ "_d ɹ",
1943
+ "dʑ oo",
1944
+ "_st ɪɫ",
1945
+ "_p ʰiŋ",
1946
+ "ə _i",
1947
+ "_ɪkˈs p",
1948
+ "_w ɛnt",
1949
+ "ɪ ɹi",
1950
+ "əˈ m",
1951
+ "o _ka",
1952
+ "_əˈ k",
1953
+ "ɔ k",
1954
+ "_ ɥɛ",
1955
+ "_l ʊk",
1956
+ "ˈ d",
1957
+ "ka ʃi",
1958
+ "_wɪθ _ə",
1959
+ "l jɛn",
1960
+ "ɔ n",
1961
+ "_l jɛn",
1962
+ "_h ɛɫ",
1963
+ "u ɹ",
1964
+ "_tʰ oʊ",
1965
+ "_tʃʰɥ æn",
1966
+ "_s k",
1967
+ "tsʰ aɪ",
1968
+ "ɛ təɹ",
1969
+ "_m in",
1970
+ "n oʊ",
1971
+ "ʃ ɯ",
1972
+ "_θ ɹu",
1973
+ "_θ ɔt",
1974
+ "da jo",
1975
+ "w i",
1976
+ "i _ko",
1977
+ "_t ɹ",
1978
+ "_f an",
1979
+ "ɹ ɛ",
1980
+ "sa N",
1981
+ "_hi _wɑz",
1982
+ "_ ɾe",
1983
+ "_ə m",
1984
+ "te _ki",
1985
+ "_x oʊ",
1986
+ "ˈ l",
1987
+ "ˈ g",
1988
+ "ga _i",
1989
+ "_ɔn _ðə",
1990
+ "_x wa",
1991
+ "v ɪŋ",
1992
+ "m an",
1993
+ "f əɹ",
1994
+ "_ oʊn",
1995
+ "ˈ ɹ",
1996
+ "_k ɹ",
1997
+ "te _o",
1998
+ "ɪ li",
1999
+ "_ʃ ɥɛ",
2000
+ "_f əŋ",
2001
+ "æ ɫ",
2002
+ "ɑ p",
2003
+ "_ˈ ɛv",
2004
+ "eɪ ndʒ",
2005
+ "i ɫ",
2006
+ "w ət",
2007
+ "ɛ ðəɹ",
2008
+ "_f ən",
2009
+ "ɾe e",
2010
+ "_hi _hæd",
2011
+ "_maɪ t",
2012
+ "_g e",
2013
+ "æ kt",
2014
+ "ɪ ts",
2015
+ "_h ɪm",
2016
+ "_ ze",
2017
+ "i i",
2018
+ "_ N",
2019
+ "_əv _hɪz",
2020
+ "_g ɹ",
2021
+ "æn t",
2022
+ "ɪ ˌ",
2023
+ "_hɪm ˈsɛɫf",
2024
+ "wa _na",
2025
+ "aɪ əɹ",
2026
+ "dʑ anai",
2027
+ "kan a",
2028
+ "aɪ z",
2029
+ "_ɪt _ɪz",
2030
+ "ma se",
2031
+ "w ɪn",
2032
+ "ə θɪŋ",
2033
+ "_pɹ əˈ",
2034
+ "kɯ n",
2035
+ "ˈ ju",
2036
+ "_f ɔɹ",
2037
+ "p ʰi",
2038
+ "p ʰiŋ",
2039
+ "o _i",
2040
+ "v z",
2041
+ "ɔ ɪn",
2042
+ "t ʰiŋ",
2043
+ "_n e",
2044
+ "g əɹ",
2045
+ "æ ts",
2046
+ "_ˈ ɹi"
2047
+ ]
2048
+ }
2049
+ }
apps/audio_cloning/vallex/g2p/bpe_69.json ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
5
+ "added_tokens": [
6
+ {
7
+ "id": 0,
8
+ "content": "[UNK]",
9
+ "single_word": false,
10
+ "lstrip": false,
11
+ "rstrip": false,
12
+ "normalized": false,
13
+ "special": true
14
+ },
15
+ {
16
+ "id": 1,
17
+ "content": "[CLS]",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false,
22
+ "special": true
23
+ },
24
+ {
25
+ "id": 2,
26
+ "content": "[SEP]",
27
+ "single_word": false,
28
+ "lstrip": false,
29
+ "rstrip": false,
30
+ "normalized": false,
31
+ "special": true
32
+ },
33
+ {
34
+ "id": 3,
35
+ "content": "[PAD]",
36
+ "single_word": false,
37
+ "lstrip": false,
38
+ "rstrip": false,
39
+ "normalized": false,
40
+ "special": true
41
+ },
42
+ {
43
+ "id": 4,
44
+ "content": "[MASK]",
45
+ "single_word": false,
46
+ "lstrip": false,
47
+ "rstrip": false,
48
+ "normalized": false,
49
+ "special": true
50
+ }
51
+ ],
52
+ "normalizer": null,
53
+ "pre_tokenizer": {
54
+ "type": "Whitespace"
55
+ },
56
+ "post_processor": null,
57
+ "decoder": null,
58
+ "model": {
59
+ "type": "BPE",
60
+ "dropout": null,
61
+ "unk_token": "[UNK]",
62
+ "continuing_subword_prefix": null,
63
+ "end_of_word_suffix": null,
64
+ "fuse_unk": false,
65
+ "byte_fallback": false,
66
+ "vocab": {
67
+ "[UNK]": 0,
68
+ "[CLS]": 1,
69
+ "[SEP]": 2,
70
+ "[PAD]": 3,
71
+ "[MASK]": 4,
72
+ "!": 5,
73
+ "#": 6,
74
+ "*": 7,
75
+ ",": 8,
76
+ "-": 9,
77
+ ".": 10,
78
+ "=": 11,
79
+ "?": 12,
80
+ "N": 13,
81
+ "Q": 14,
82
+ "^": 15,
83
+ "_": 16,
84
+ "`": 17,
85
+ "a": 18,
86
+ "b": 19,
87
+ "d": 20,
88
+ "e": 21,
89
+ "f": 22,
90
+ "g": 23,
91
+ "h": 24,
92
+ "i": 25,
93
+ "j": 26,
94
+ "k": 27,
95
+ "l": 28,
96
+ "m": 29,
97
+ "n": 30,
98
+ "o": 31,
99
+ "p": 32,
100
+ "s": 33,
101
+ "t": 34,
102
+ "u": 35,
103
+ "v": 36,
104
+ "w": 37,
105
+ "x": 38,
106
+ "y": 39,
107
+ "z": 40,
108
+ "~": 41,
109
+ "æ": 42,
110
+ "ç": 43,
111
+ "ð": 44,
112
+ "ŋ": 45,
113
+ "ɑ": 46,
114
+ "ɔ": 47,
115
+ "ə": 48,
116
+ "ɛ": 49,
117
+ "ɥ": 50,
118
+ "ɪ": 51,
119
+ "ɫ": 52,
120
+ "ɯ": 53,
121
+ "ɸ": 54,
122
+ "ɹ": 55,
123
+ "ɾ": 56,
124
+ "ʃ": 57,
125
+ "ʊ": 58,
126
+ "ʑ": 59,
127
+ "ʒ": 60,
128
+ "ʰ": 61,
129
+ "ˈ": 62,
130
+ "ˌ": 63,
131
+ "θ": 64,
132
+ "…": 65,
133
+ "⁼": 66,
134
+ "↑": 67,
135
+ "→": 68,
136
+ "↓": 69
137
+ },
138
+ "merges": [
139
+ ]
140
+ }
141
+ }
apps/audio_cloning/vallex/g2p/cleaners.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from .english import english_to_ipa2
4
+ from .japanese import japanese_to_ipa2, japanese_to_romaji_with_accent
5
+ from .mandarin import (
6
+ chinese_to_bopomofo,
7
+ chinese_to_ipa,
8
+ latin_to_bopomofo,
9
+ number_to_chinese,
10
+ )
11
+
12
+ patterns = [r"\[EN\](.*?)\[EN\]", r"\[ZH\](.*?)\[ZH\]", r"\[JA\](.*?)\[JA\]"]
13
+
14
+
15
+ def japanese_cleaners(text):
16
+ text = japanese_to_romaji_with_accent(text)
17
+ text = re.sub(r"([A-Za-z])$", r"\1.", text)
18
+ return text
19
+
20
+
21
+ def japanese_cleaners2(text):
22
+ return japanese_cleaners(text).replace("ts", "ʦ").replace("...", "…")
23
+
24
+
25
+ def chinese_cleaners(text):
26
+ """Pipeline for Chinese text"""
27
+ text = number_to_chinese(text)
28
+ text = chinese_to_bopomofo(text)
29
+ text = latin_to_bopomofo(text)
30
+ text = re.sub(r"([ˉˊˇˋ˙])$", r"\1。", text)
31
+ return text
32
+
33
+
34
+ def cje_cleaners(text):
35
+ matches = []
36
+ for pattern in patterns:
37
+ matches.extend(re.finditer(pattern, text))
38
+
39
+ matches.sort(key=lambda x: x.start()) # Sort matches by their start positions
40
+
41
+ outputs = ""
42
+ output_langs = []
43
+
44
+ for match in matches:
45
+ text_segment = text[match.start() : match.end()]
46
+ phon = clean_one(text_segment)
47
+ if "[EN]" in text_segment:
48
+ lang = "en"
49
+ elif "[ZH]" in text_segment:
50
+ lang = "zh"
51
+ elif "[JA]" in text_segment:
52
+ lang = "ja"
53
+ else:
54
+ raise ValueError("If you see this error, please report this bug to issues.")
55
+ outputs += phon
56
+ output_langs += [lang] * len(phon)
57
+ assert len(outputs) == len(output_langs)
58
+ return outputs, output_langs
59
+
60
+
61
+ def clean_one(text):
62
+ if text.find("[ZH]") != -1:
63
+ text = re.sub(
64
+ r"\[ZH\](.*?)\[ZH\]", lambda x: chinese_to_ipa(x.group(1)) + " ", text
65
+ )
66
+ if text.find("[JA]") != -1:
67
+ text = re.sub(
68
+ r"\[JA\](.*?)\[JA\]", lambda x: japanese_to_ipa2(x.group(1)) + " ", text
69
+ )
70
+ if text.find("[EN]") != -1:
71
+ text = re.sub(
72
+ r"\[EN\](.*?)\[EN\]", lambda x: english_to_ipa2(x.group(1)) + " ", text
73
+ )
74
+ text = re.sub(r"\s+$", "", text)
75
+ text = re.sub(r"([^\.,!\?\-…~])$", r"\1.", text)
76
+ return text
apps/audio_cloning/vallex/g2p/english.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import inflect
4
+ from unidecode import unidecode
5
+
6
+ """from https://github.com/keithito/tacotron"""
7
+
8
+ """
9
+ Cleaners are transformations that run over the input text at both training and eval time.
10
+
11
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
12
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
13
+ 1. "english_cleaners" for English text
14
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
15
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
16
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
17
+ the symbols in symbols.py to match your data).
18
+ """
19
+
20
+
21
+ _inflect = inflect.engine()
22
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
23
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
24
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
25
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
26
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
27
+ _number_re = re.compile(r"[0-9]+")
28
+
29
+ # List of (regular expression, replacement) pairs for abbreviations:
30
+ _abbreviations = [
31
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
32
+ for x in [
33
+ ("mrs", "misess"),
34
+ ("mr", "mister"),
35
+ ("dr", "doctor"),
36
+ ("st", "saint"),
37
+ ("co", "company"),
38
+ ("jr", "junior"),
39
+ ("maj", "major"),
40
+ ("gen", "general"),
41
+ ("drs", "doctors"),
42
+ ("rev", "reverend"),
43
+ ("lt", "lieutenant"),
44
+ ("hon", "honorable"),
45
+ ("sgt", "sergeant"),
46
+ ("capt", "captain"),
47
+ ("esq", "esquire"),
48
+ ("ltd", "limited"),
49
+ ("col", "colonel"),
50
+ ("ft", "fort"),
51
+ ]
52
+ ]
53
+
54
+
55
+ # List of (ipa, lazy ipa) pairs:
56
+ _lazy_ipa = [
57
+ (re.compile("%s" % x[0]), x[1])
58
+ for x in [
59
+ ("r", "ɹ"),
60
+ ("æ", "e"),
61
+ ("ɑ", "a"),
62
+ ("ɔ", "o"),
63
+ ("ð", "z"),
64
+ ("θ", "s"),
65
+ ("ɛ", "e"),
66
+ ("ɪ", "i"),
67
+ ("ʊ", "u"),
68
+ ("ʒ", "ʥ"),
69
+ ("ʤ", "ʥ"),
70
+ ("ˈ", "↓"),
71
+ ]
72
+ ]
73
+
74
+ # List of (ipa, lazy ipa2) pairs:
75
+ _lazy_ipa2 = [
76
+ (re.compile("%s" % x[0]), x[1])
77
+ for x in [
78
+ ("r", "ɹ"),
79
+ ("ð", "z"),
80
+ ("θ", "s"),
81
+ ("ʒ", "ʑ"),
82
+ ("ʤ", "dʑ"),
83
+ ("ˈ", "↓"),
84
+ ]
85
+ ]
86
+
87
+ # List of (ipa, ipa2) pairs
88
+ _ipa_to_ipa2 = [
89
+ (re.compile("%s" % x[0]), x[1]) for x in [("r", "ɹ"), ("ʤ", "dʒ"), ("ʧ", "tʃ")]
90
+ ]
91
+
92
+
93
+ def expand_abbreviations(text):
94
+ for regex, replacement in _abbreviations:
95
+ text = re.sub(regex, replacement, text)
96
+ return text
97
+
98
+
99
+ def collapse_whitespace(text):
100
+ return re.sub(r"\s+", " ", text)
101
+
102
+
103
+ def _remove_commas(m):
104
+ return m.group(1).replace(",", "")
105
+
106
+
107
+ def _expand_decimal_point(m):
108
+ return m.group(1).replace(".", " point ")
109
+
110
+
111
+ def _expand_dollars(m):
112
+ match = m.group(1)
113
+ parts = match.split(".")
114
+ if len(parts) > 2:
115
+ return match + " dollars" # Unexpected format
116
+ dollars = int(parts[0]) if parts[0] else 0
117
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
118
+ if dollars and cents:
119
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
120
+ cent_unit = "cent" if cents == 1 else "cents"
121
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
122
+ elif dollars:
123
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
124
+ return "%s %s" % (dollars, dollar_unit)
125
+ elif cents:
126
+ cent_unit = "cent" if cents == 1 else "cents"
127
+ return "%s %s" % (cents, cent_unit)
128
+ else:
129
+ return "zero dollars"
130
+
131
+
132
+ def _expand_ordinal(m):
133
+ return _inflect.number_to_words(m.group(0))
134
+
135
+
136
+ def _expand_number(m):
137
+ num = int(m.group(0))
138
+ if num > 1000 and num < 3000:
139
+ if num == 2000:
140
+ return "two thousand"
141
+ elif num > 2000 and num < 2010:
142
+ return "two thousand " + _inflect.number_to_words(num % 100)
143
+ elif num % 100 == 0:
144
+ return _inflect.number_to_words(num // 100) + " hundred"
145
+ else:
146
+ return _inflect.number_to_words(
147
+ num, andword="", zero="oh", group=2
148
+ ).replace(", ", " ")
149
+ else:
150
+ return _inflect.number_to_words(num, andword="")
151
+
152
+
153
+ def normalize_numbers(text):
154
+ text = re.sub(_comma_number_re, _remove_commas, text)
155
+ text = re.sub(_pounds_re, r"\1 pounds", text)
156
+ text = re.sub(_dollars_re, _expand_dollars, text)
157
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
158
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
159
+ text = re.sub(_number_re, _expand_number, text)
160
+ return text
161
+
162
+
163
+ def mark_dark_l(text):
164
+ return re.sub(r"l([^aeiouæɑɔəɛɪʊ ]*(?: |$))", lambda x: "ɫ" + x.group(1), text)
165
+
166
+
167
+ def english_to_ipa(text):
168
+ import eng_to_ipa as ipa
169
+
170
+ text = unidecode(text).lower()
171
+ text = expand_abbreviations(text)
172
+ text = normalize_numbers(text)
173
+ phonemes = ipa.convert(text)
174
+ phonemes = collapse_whitespace(phonemes)
175
+ return phonemes
176
+
177
+
178
+ def english_to_lazy_ipa(text):
179
+ text = english_to_ipa(text)
180
+ for regex, replacement in _lazy_ipa:
181
+ text = re.sub(regex, replacement, text)
182
+ return text
183
+
184
+
185
+ def english_to_ipa2(text):
186
+ text = english_to_ipa(text)
187
+ text = mark_dark_l(text)
188
+ for regex, replacement in _ipa_to_ipa2:
189
+ text = re.sub(regex, replacement, text)
190
+ return text.replace("...", "…")
191
+
192
+
193
+ def english_to_lazy_ipa2(text):
194
+ text = english_to_ipa(text)
195
+ for regex, replacement in _lazy_ipa2:
196
+ text = re.sub(regex, replacement, text)
197
+ return text
apps/audio_cloning/vallex/g2p/japanese.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from unidecode import unidecode
4
+
5
+ # Regular expression matching Japanese without punctuation marks:
6
+ _japanese_characters = re.compile(
7
+ r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
8
+ )
9
+
10
+ # Regular expression matching non-Japanese characters or punctuation marks:
11
+ _japanese_marks = re.compile(
12
+ r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
13
+ )
14
+
15
+ # List of (symbol, Japanese) pairs for marks:
16
+ _symbols_to_japanese = [(re.compile("%s" % x[0]), x[1]) for x in [("%", "パーセント")]]
17
+
18
+ # List of (romaji, ipa) pairs for marks:
19
+ _romaji_to_ipa = [
20
+ (re.compile("%s" % x[0]), x[1])
21
+ for x in [
22
+ ("ts", "ʦ"),
23
+ ("u", "ɯ"),
24
+ ("j", "ʥ"),
25
+ ("y", "j"),
26
+ ("ni", "n^i"),
27
+ ("nj", "n^"),
28
+ ("hi", "çi"),
29
+ ("hj", "ç"),
30
+ ("f", "ɸ"),
31
+ ("I", "i*"),
32
+ ("U", "ɯ*"),
33
+ ("r", "ɾ"),
34
+ ]
35
+ ]
36
+
37
+ # List of (romaji, ipa2) pairs for marks:
38
+ _romaji_to_ipa2 = [
39
+ (re.compile("%s" % x[0]), x[1])
40
+ for x in [
41
+ ("u", "ɯ"),
42
+ ("ʧ", "tʃ"),
43
+ ("j", "dʑ"),
44
+ ("y", "j"),
45
+ ("ni", "n^i"),
46
+ ("nj", "n^"),
47
+ ("hi", "çi"),
48
+ ("hj", "ç"),
49
+ ("f", "ɸ"),
50
+ ("I", "i*"),
51
+ ("U", "ɯ*"),
52
+ ("r", "ɾ"),
53
+ ]
54
+ ]
55
+
56
+ # List of (consonant, sokuon) pairs:
57
+ _real_sokuon = [
58
+ (re.compile("%s" % x[0]), x[1])
59
+ for x in [
60
+ (r"Q([↑↓]*[kg])", r"k#\1"),
61
+ (r"Q([↑↓]*[tdjʧ])", r"t#\1"),
62
+ (r"Q([↑↓]*[sʃ])", r"s\1"),
63
+ (r"Q([↑↓]*[pb])", r"p#\1"),
64
+ ]
65
+ ]
66
+
67
+ # List of (consonant, hatsuon) pairs:
68
+ _real_hatsuon = [
69
+ (re.compile("%s" % x[0]), x[1])
70
+ for x in [
71
+ (r"N([↑↓]*[pbm])", r"m\1"),
72
+ (r"N([↑↓]*[ʧʥj])", r"n^\1"),
73
+ (r"N([↑↓]*[tdn])", r"n\1"),
74
+ (r"N([↑↓]*[kg])", r"ŋ\1"),
75
+ ]
76
+ ]
77
+
78
+
79
+ def symbols_to_japanese(text):
80
+ for regex, replacement in _symbols_to_japanese:
81
+ text = re.sub(regex, replacement, text)
82
+ return text
83
+
84
+
85
+ def japanese_to_romaji_with_accent(text):
86
+ """Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html"""
87
+ import pyopenjtalk
88
+
89
+ text = symbols_to_japanese(text)
90
+ sentences = re.split(_japanese_marks, text)
91
+ marks = re.findall(_japanese_marks, text)
92
+ text = ""
93
+ for i, sentence in enumerate(sentences):
94
+ if re.match(_japanese_characters, sentence):
95
+ if text != "":
96
+ text += " "
97
+ labels = pyopenjtalk.extract_fullcontext(sentence)
98
+ for n, label in enumerate(labels):
99
+ phoneme = re.search(r"\-([^\+]*)\+", label).group(1)
100
+ if phoneme not in ["sil", "pau"]:
101
+ text += (
102
+ phoneme.replace("ch", "ʧ").replace("sh", "ʃ").replace("cl", "Q")
103
+ )
104
+ else:
105
+ continue
106
+ # n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
107
+ a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
108
+ a2 = int(re.search(r"\+(\d+)\+", label).group(1))
109
+ a3 = int(re.search(r"\+(\d+)/", label).group(1))
110
+ if re.search(r"\-([^\+]*)\+", labels[n + 1]).group(1) in ["sil", "pau"]:
111
+ a2_next = -1
112
+ else:
113
+ a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
114
+ # Accent phrase boundary
115
+ if a3 == 1 and a2_next == 1:
116
+ text += " "
117
+ # Falling
118
+ elif a1 == 0 and a2_next == a2 + 1:
119
+ text += "↓"
120
+ # Rising
121
+ elif a2 == 1 and a2_next == 2:
122
+ text += "↑"
123
+ if i < len(marks):
124
+ text += unidecode(marks[i]).replace(" ", "")
125
+ return text
126
+
127
+
128
+ def get_real_sokuon(text):
129
+ for regex, replacement in _real_sokuon:
130
+ text = re.sub(regex, replacement, text)
131
+ return text
132
+
133
+
134
+ def get_real_hatsuon(text):
135
+ for regex, replacement in _real_hatsuon:
136
+ text = re.sub(regex, replacement, text)
137
+ return text
138
+
139
+
140
+ def japanese_to_ipa(text):
141
+ text = japanese_to_romaji_with_accent(text).replace("...", "…")
142
+ text = re.sub(
143
+ r"([aiueo])\1+", lambda x: x.group(0)[0] + "ː" * (len(x.group(0)) - 1), text
144
+ )
145
+ text = get_real_sokuon(text)
146
+ text = get_real_hatsuon(text)
147
+ for regex, replacement in _romaji_to_ipa:
148
+ text = re.sub(regex, replacement, text)
149
+ return text
150
+
151
+
152
+ def japanese_to_ipa2(text):
153
+ text = japanese_to_romaji_with_accent(text).replace("...", "…")
154
+ text = get_real_sokuon(text)
155
+ text = get_real_hatsuon(text)
156
+ for regex, replacement in _romaji_to_ipa2:
157
+ text = re.sub(regex, replacement, text)
158
+ return text
159
+
160
+
161
+ def japanese_to_ipa3(text):
162
+ text = (
163
+ japanese_to_ipa2(text)
164
+ .replace("n^", "ȵ")
165
+ .replace("ʃ", "ɕ")
166
+ .replace("*", "\u0325")
167
+ .replace("#", "\u031a")
168
+ )
169
+ text = re.sub(
170
+ r"([aiɯeo])\1+", lambda x: x.group(0)[0] + "ː" * (len(x.group(0)) - 1), text
171
+ )
172
+ text = re.sub(r"((?:^|\s)(?:ts|tɕ|[kpt]))", r"\1ʰ", text)
173
+ return text
apps/audio_cloning/vallex/g2p/mandarin.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import cn2an
4
+ import jieba
5
+
6
+ # List of (Latin alphabet, bopomofo) pairs:
7
+ _latin_to_bopomofo = [
8
+ (re.compile("%s" % x[0], re.IGNORECASE), x[1])
9
+ for x in [
10
+ ("a", "ㄟˉ"),
11
+ ("b", "ㄅㄧˋ"),
12
+ ("c", "ㄙㄧˉ"),
13
+ ("d", "ㄉㄧˋ"),
14
+ ("e", "ㄧˋ"),
15
+ ("f", "ㄝˊㄈㄨˋ"),
16
+ ("g", "ㄐㄧˋ"),
17
+ ("h", "ㄝˇㄑㄩˋ"),
18
+ ("i", "ㄞˋ"),
19
+ ("j", "ㄐㄟˋ"),
20
+ ("k", "ㄎㄟˋ"),
21
+ ("l", "ㄝˊㄛˋ"),
22
+ ("m", "ㄝˊㄇㄨˋ"),
23
+ ("n", "ㄣˉ"),
24
+ ("o", "ㄡˉ"),
25
+ ("p", "ㄆㄧˉ"),
26
+ ("q", "ㄎㄧㄡˉ"),
27
+ ("r", "ㄚˋ"),
28
+ ("s", "ㄝˊㄙˋ"),
29
+ ("t", "ㄊㄧˋ"),
30
+ ("u", "ㄧㄡˉ"),
31
+ ("v", "ㄨㄧˉ"),
32
+ ("w", "ㄉㄚˋㄅㄨˋㄌㄧㄡˋ"),
33
+ ("x", "ㄝˉㄎㄨˋㄙˋ"),
34
+ ("y", "ㄨㄞˋ"),
35
+ ("z", "ㄗㄟˋ"),
36
+ ]
37
+ ]
38
+
39
+ # List of (bopomofo, romaji) pairs:
40
+ _bopomofo_to_romaji = [
41
+ (re.compile("%s" % x[0]), x[1])
42
+ for x in [
43
+ ("ㄅㄛ", "p⁼wo"),
44
+ ("ㄆㄛ", "pʰwo"),
45
+ ("ㄇㄛ", "mwo"),
46
+ ("ㄈㄛ", "fwo"),
47
+ ("ㄅ", "p⁼"),
48
+ ("ㄆ", "pʰ"),
49
+ ("ㄇ", "m"),
50
+ ("ㄈ", "f"),
51
+ ("ㄉ", "t⁼"),
52
+ ("ㄊ", "tʰ"),
53
+ ("ㄋ", "n"),
54
+ ("ㄌ", "l"),
55
+ ("ㄍ", "k⁼"),
56
+ ("ㄎ", "kʰ"),
57
+ ("ㄏ", "h"),
58
+ ("ㄐ", "ʧ⁼"),
59
+ ("ㄑ", "ʧʰ"),
60
+ ("ㄒ", "ʃ"),
61
+ ("ㄓ", "ʦ`⁼"),
62
+ ("ㄔ", "ʦ`ʰ"),
63
+ ("ㄕ", "s`"),
64
+ ("ㄖ", "ɹ`"),
65
+ ("ㄗ", "ʦ⁼"),
66
+ ("ㄘ", "ʦʰ"),
67
+ ("ㄙ", "s"),
68
+ ("ㄚ", "a"),
69
+ ("ㄛ", "o"),
70
+ ("ㄜ", "ə"),
71
+ ("ㄝ", "e"),
72
+ ("ㄞ", "ai"),
73
+ ("ㄟ", "ei"),
74
+ ("ㄠ", "au"),
75
+ ("ㄡ", "ou"),
76
+ ("ㄧㄢ", "yeNN"),
77
+ ("ㄢ", "aNN"),
78
+ ("ㄧㄣ", "iNN"),
79
+ ("ㄣ", "əNN"),
80
+ ("ㄤ", "aNg"),
81
+ ("ㄧㄥ", "iNg"),
82
+ ("ㄨㄥ", "uNg"),
83
+ ("ㄩㄥ", "yuNg"),
84
+ ("ㄥ", "əNg"),
85
+ ("ㄦ", "əɻ"),
86
+ ("ㄧ", "i"),
87
+ ("ㄨ", "u"),
88
+ ("ㄩ", "ɥ"),
89
+ ("ˉ", "→"),
90
+ ("ˊ", "↑"),
91
+ ("ˇ", "↓↑"),
92
+ ("ˋ", "↓"),
93
+ ("˙", ""),
94
+ (",", ","),
95
+ ("。", "."),
96
+ ("!", "!"),
97
+ ("?", "?"),
98
+ ("—", "-"),
99
+ ]
100
+ ]
101
+
102
+ # List of (romaji, ipa) pairs:
103
+ _romaji_to_ipa = [
104
+ (re.compile("%s" % x[0], re.IGNORECASE), x[1])
105
+ for x in [
106
+ ("ʃy", "ʃ"),
107
+ ("ʧʰy", "ʧʰ"),
108
+ ("ʧ⁼y", "ʧ⁼"),
109
+ ("NN", "n"),
110
+ ("Ng", "ŋ"),
111
+ ("y", "j"),
112
+ ("h", "x"),
113
+ ]
114
+ ]
115
+
116
+ # List of (bopomofo, ipa) pairs:
117
+ _bopomofo_to_ipa = [
118
+ (re.compile("%s" % x[0]), x[1])
119
+ for x in [
120
+ ("ㄅㄛ", "p⁼wo"),
121
+ ("ㄆㄛ", "pʰwo"),
122
+ ("ㄇㄛ", "mwo"),
123
+ ("ㄈㄛ", "fwo"),
124
+ ("ㄅ", "p⁼"),
125
+ ("ㄆ", "pʰ"),
126
+ ("ㄇ", "m"),
127
+ ("ㄈ", "f"),
128
+ ("ㄉ", "t⁼"),
129
+ ("ㄊ", "tʰ"),
130
+ ("ㄋ", "n"),
131
+ ("ㄌ", "l"),
132
+ ("ㄍ", "k⁼"),
133
+ ("ㄎ", "kʰ"),
134
+ ("ㄏ", "x"),
135
+ ("ㄐ", "tʃ⁼"),
136
+ ("ㄑ", "tʃʰ"),
137
+ ("ㄒ", "ʃ"),
138
+ ("ㄓ", "ts`⁼"),
139
+ ("ㄔ", "ts`ʰ"),
140
+ ("ㄕ", "s`"),
141
+ ("ㄖ", "ɹ`"),
142
+ ("ㄗ", "ts⁼"),
143
+ ("ㄘ", "tsʰ"),
144
+ ("ㄙ", "s"),
145
+ ("ㄚ", "a"),
146
+ ("ㄛ", "o"),
147
+ ("ㄜ", "ə"),
148
+ ("ㄝ", "ɛ"),
149
+ ("ㄞ", "aɪ"),
150
+ ("ㄟ", "eɪ"),
151
+ ("ㄠ", "ɑʊ"),
152
+ ("ㄡ", "oʊ"),
153
+ ("ㄧㄢ", "jɛn"),
154
+ ("ㄩㄢ", "ɥæn"),
155
+ ("ㄢ", "an"),
156
+ ("ㄧㄣ", "in"),
157
+ ("ㄩㄣ", "ɥn"),
158
+ ("ㄣ", "ən"),
159
+ ("ㄤ", "ɑŋ"),
160
+ ("ㄧㄥ", "iŋ"),
161
+ ("ㄨㄥ", "ʊŋ"),
162
+ ("ㄩㄥ", "jʊŋ"),
163
+ ("ㄥ", "əŋ"),
164
+ ("ㄦ", "əɻ"),
165
+ ("ㄧ", "i"),
166
+ ("ㄨ", "u"),
167
+ ("ㄩ", "ɥ"),
168
+ ("ˉ", "→"),
169
+ ("ˊ", "↑"),
170
+ ("ˇ", "↓↑"),
171
+ ("ˋ", "↓"),
172
+ ("˙", ""),
173
+ (",", ","),
174
+ ("。", "."),
175
+ ("!", "!"),
176
+ ("?", "?"),
177
+ ("—", "-"),
178
+ ]
179
+ ]
180
+
181
+ # List of (bopomofo, ipa2) pairs:
182
+ _bopomofo_to_ipa2 = [
183
+ (re.compile("%s" % x[0]), x[1])
184
+ for x in [
185
+ ("ㄅㄛ", "pwo"),
186
+ ("ㄆㄛ", "pʰwo"),
187
+ ("ㄇㄛ", "mwo"),
188
+ ("ㄈㄛ", "fwo"),
189
+ ("ㄅ", "p"),
190
+ ("ㄆ", "pʰ"),
191
+ ("ㄇ", "m"),
192
+ ("ㄈ", "f"),
193
+ ("ㄉ", "t"),
194
+ ("ㄊ", "tʰ"),
195
+ ("ㄋ", "n"),
196
+ ("ㄌ", "l"),
197
+ ("ㄍ", "k"),
198
+ ("ㄎ", "kʰ"),
199
+ ("ㄏ", "h"),
200
+ ("ㄐ", "tɕ"),
201
+ ("ㄑ", "tɕʰ"),
202
+ ("ㄒ", "ɕ"),
203
+ ("ㄓ", "tʂ"),
204
+ ("ㄔ", "tʂʰ"),
205
+ ("ㄕ", "ʂ"),
206
+ ("ㄖ", "ɻ"),
207
+ ("ㄗ", "ts"),
208
+ ("ㄘ", "tsʰ"),
209
+ ("���", "s"),
210
+ ("ㄚ", "a"),
211
+ ("ㄛ", "o"),
212
+ ("ㄜ", "ɤ"),
213
+ ("ㄝ", "ɛ"),
214
+ ("ㄞ", "aɪ"),
215
+ ("ㄟ", "eɪ"),
216
+ ("ㄠ", "ɑʊ"),
217
+ ("ㄡ", "oʊ"),
218
+ ("ㄧㄢ", "jɛn"),
219
+ ("ㄩㄢ", "yæn"),
220
+ ("ㄢ", "an"),
221
+ ("ㄧㄣ", "in"),
222
+ ("ㄩㄣ", "yn"),
223
+ ("ㄣ", "ən"),
224
+ ("ㄤ", "ɑŋ"),
225
+ ("ㄧㄥ", "iŋ"),
226
+ ("ㄨㄥ", "ʊŋ"),
227
+ ("ㄩㄥ", "jʊŋ"),
228
+ ("ㄥ", "ɤŋ"),
229
+ ("ㄦ", "əɻ"),
230
+ ("ㄧ", "i"),
231
+ ("ㄨ", "u"),
232
+ ("ㄩ", "y"),
233
+ ("ˉ", "˥"),
234
+ ("ˊ", "˧˥"),
235
+ ("ˇ", "˨˩˦"),
236
+ ("ˋ", "˥˩"),
237
+ ("˙", ""),
238
+ (",", ","),
239
+ ("。", "."),
240
+ ("!", "!"),
241
+ ("?", "?"),
242
+ ("—", "-"),
243
+ ]
244
+ ]
245
+
246
+
247
+ def number_to_chinese(text):
248
+ numbers = re.findall(r"\d+(?:\.?\d+)?", text)
249
+ for number in numbers:
250
+ text = text.replace(number, cn2an.an2cn(number), 1)
251
+ return text
252
+
253
+
254
+ def chinese_to_bopomofo(text):
255
+ from pypinyin import BOPOMOFO, lazy_pinyin
256
+
257
+ text = text.replace("、", ",").replace(";", ",").replace(":", ",")
258
+ words = jieba.lcut(text, cut_all=False)
259
+ text = ""
260
+ for word in words:
261
+ bopomofos = lazy_pinyin(word, BOPOMOFO)
262
+ if not re.search("[\u4e00-\u9fff]", word):
263
+ text += word
264
+ continue
265
+ for i in range(len(bopomofos)):
266
+ bopomofos[i] = re.sub(r"([\u3105-\u3129])$", r"\1ˉ", bopomofos[i])
267
+ if text != "":
268
+ text += " "
269
+ text += "".join(bopomofos)
270
+ return text
271
+
272
+
273
+ def latin_to_bopomofo(text):
274
+ for regex, replacement in _latin_to_bopomofo:
275
+ text = re.sub(regex, replacement, text)
276
+ return text
277
+
278
+
279
+ def bopomofo_to_romaji(text):
280
+ for regex, replacement in _bopomofo_to_romaji:
281
+ text = re.sub(regex, replacement, text)
282
+ return text
283
+
284
+
285
+ def bopomofo_to_ipa(text):
286
+ for regex, replacement in _bopomofo_to_ipa:
287
+ text = re.sub(regex, replacement, text)
288
+ return text
289
+
290
+
291
+ def bopomofo_to_ipa2(text):
292
+ for regex, replacement in _bopomofo_to_ipa2:
293
+ text = re.sub(regex, replacement, text)
294
+ return text
295
+
296
+
297
+ def chinese_to_romaji(text):
298
+ text = number_to_chinese(text)
299
+ text = chinese_to_bopomofo(text)
300
+ text = latin_to_bopomofo(text)
301
+ text = bopomofo_to_romaji(text)
302
+ text = re.sub("i([aoe])", r"y\1", text)
303
+ text = re.sub("u([aoəe])", r"w\1", text)
304
+ text = re.sub("([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ`\2", text).replace("ɻ", "ɹ`")
305
+ text = re.sub("([ʦs][⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
306
+ return text
307
+
308
+
309
+ def chinese_to_lazy_ipa(text):
310
+ text = chinese_to_romaji(text)
311
+ for regex, replacement in _romaji_to_ipa:
312
+ text = re.sub(regex, replacement, text)
313
+ return text
314
+
315
+
316
+ def chinese_to_ipa(text):
317
+ text = number_to_chinese(text)
318
+ text = chinese_to_bopomofo(text)
319
+ text = latin_to_bopomofo(text)
320
+ text = bopomofo_to_ipa(text)
321
+ text = re.sub("i([aoe])", r"j\1", text)
322
+ text = re.sub("u([aoəe])", r"w\1", text)
323
+ text = re.sub("([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ`\2", text).replace("ɻ", "ɹ`")
324
+ text = re.sub("([s][⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
325
+ return text
326
+
327
+
328
+ def chinese_to_ipa2(text):
329
+ text = number_to_chinese(text)
330
+ text = chinese_to_bopomofo(text)
331
+ text = latin_to_bopomofo(text)
332
+ text = bopomofo_to_ipa2(text)
333
+ text = re.sub(r"i([aoe])", r"j\1", text)
334
+ text = re.sub(r"u([aoəe])", r"w\1", text)
335
+ text = re.sub(r"([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)", r"\1ʅ\2", text)
336
+ text = re.sub(r"(sʰ?)([˩˨˧˦˥ ]+|$)", r"\1ɿ\2", text)
337
+ return text
apps/audio_cloning/vallex/g2p/symbols.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Defines the set of symbols used in text input to the model.
3
+ '''
4
+
5
+ # japanese_cleaners
6
+ # _pad = '_'
7
+ # _punctuation = ',.!?-'
8
+ # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
9
+
10
+
11
+ '''# japanese_cleaners2
12
+ _pad = '_'
13
+ _punctuation = ',.!?-~…'
14
+ _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
15
+ '''
16
+
17
+
18
+ '''# korean_cleaners
19
+ _pad = '_'
20
+ _punctuation = ',.!?…~'
21
+ _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
22
+ '''
23
+
24
+ '''# chinese_cleaners
25
+ _pad = '_'
26
+ _punctuation = ',。!?—…'
27
+ _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
28
+ '''
29
+
30
+ # # zh_ja_mixture_cleaners
31
+ # _pad = '_'
32
+ # _punctuation = ',.!?-~…'
33
+ # _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
34
+
35
+
36
+ '''# sanskrit_cleaners
37
+ _pad = '_'
38
+ _punctuation = '।'
39
+ _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
40
+ '''
41
+
42
+ '''# cjks_cleaners
43
+ _pad = '_'
44
+ _punctuation = ',.!?-~…'
45
+ _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
46
+ '''
47
+
48
+ '''# thai_cleaners
49
+ _pad = '_'
50
+ _punctuation = '.!? '
51
+ _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
52
+ '''
53
+
54
+ # # cjke_cleaners2
55
+ _pad = '_'
56
+ _punctuation = ',.!?-~…'
57
+ _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
58
+
59
+
60
+ '''# shanghainese_cleaners
61
+ _pad = '_'
62
+ _punctuation = ',.!?…'
63
+ _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
64
+ '''
65
+
66
+ '''# chinese_dialect_cleaners
67
+ _pad = '_'
68
+ _punctuation = ',.!?~…─'
69
+ _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ '
70
+ '''
71
+
72
+ # Export all symbols:
73
+ symbols = [_pad] + list(_punctuation) + list(_letters)
74
+
75
+ # Special symbol ids
76
+ SPACE_ID = symbols.index(" ")
apps/audio_cloning/vallex/macros.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NUM_LAYERS = 12
2
+ NUM_HEAD = 16
3
+ N_DIM = 1024
4
+ PREFIX_MODE = 1
5
+ NUM_QUANTIZERS = 8
6
+ SAMPLE_RATE = 24000
7
+
8
+ lang2token = {
9
+ "zh": "[ZH]",
10
+ "ja": "[JA]",
11
+ "en": "[EN]",
12
+ "mix": "",
13
+ }
14
+
15
+ lang2code = {
16
+ "zh": 0,
17
+ "ja": 1,
18
+ "en": 2,
19
+ }
20
+
21
+ token2lang = {"[ZH]": "zh", "[JA]": "ja", "[EN]": "en", "": "mix"}
22
+
23
+ code2lang = {
24
+ 0: "zh",
25
+ 1: "ja",
26
+ 2: "en",
27
+ }
28
+
29
+ langdropdown2token = {
30
+ "English": "[EN]",
31
+ "中文": "[ZH]",
32
+ "日本語": "[JA]",
33
+ "Mix": "",
34
+ }
apps/audio_cloning/vallex/main.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import multiprocessing
3
+ import os
4
+ import pathlib
5
+ import platform
6
+ import sys
7
+ import tempfile
8
+ import time
9
+
10
+ import gradio as gr
11
+ import langid
12
+ import nltk
13
+ import numpy as np
14
+ import torch
15
+ import torchaudio
16
+ import whisper
17
+ from vocos import Vocos
18
+
19
+ from .data.collation import get_text_token_collater
20
+ from .data.tokenizer import (
21
+ AudioTokenizer,
22
+ tokenize_audio,
23
+ )
24
+ from .descriptions import infer_from_audio_ja_md, top_ja_md
25
+ from .examples import infer_from_audio_examples
26
+ from .g2p import PhonemeBpeTokenizer
27
+ from .macros import (
28
+ N_DIM,
29
+ NUM_HEAD,
30
+ NUM_LAYERS,
31
+ NUM_QUANTIZERS,
32
+ PREFIX_MODE,
33
+ lang2code,
34
+ lang2token,
35
+ langdropdown2token,
36
+ token2lang,
37
+ )
38
+ from .models.vallex import VALLE
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+ # set languages
43
+ langid.set_languages(["en", "zh", "ja"])
44
+
45
+ # set nltk data path
46
+ nltk.data.path = nltk.data.path + [os.path.join(os.getcwd(), "nltk_data")]
47
+ logger.info("nltk_data path: %s", nltk.data.path)
48
+
49
+ # get encoding
50
+ logger.info(
51
+ "default encoding is %s,file system encoding is %s",
52
+ sys.getdefaultencoding(),
53
+ sys.getfilesystemencoding(),
54
+ )
55
+
56
+ # check python version
57
+ logger.info("You are using Python version %s", platform.python_version())
58
+ if sys.version_info[0] < 3 or sys.version_info[1] < 7:
59
+ logger.warning("The Python version is too low and may cause problems")
60
+ if platform.system().lower() == "windows":
61
+ temp = pathlib.PosixPath
62
+ pathlib.PosixPath = pathlib.WindowsPath
63
+ else:
64
+ temp = pathlib.WindowsPath
65
+ pathlib.WindowsPath = pathlib.PosixPath
66
+ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
67
+
68
+ # set torch threads (guarded for hot-reload)
69
+ thread_count = multiprocessing.cpu_count()
70
+ logger.info("Use %d cpu cores for computing", thread_count)
71
+ if not getattr(torch, "_vallex_threads_configured", False):
72
+ torch.set_num_threads(thread_count)
73
+ try:
74
+ torch.set_num_interop_threads(thread_count)
75
+ except RuntimeError as err:
76
+ logger.warning("Skipping set_num_interop_threads: %s", err)
77
+ torch._C._jit_set_profiling_executor(False)
78
+ torch._C._jit_set_profiling_mode(False)
79
+ torch._C._set_graph_executor_optimize(False)
80
+
81
+ # gradio のリロード時に torch.set_num_iterop_threads を実行するとエラーになるので、設定済みのフラグをセット
82
+ setattr(torch, "_vallex_threads_configured", True)
83
+ else:
84
+ logger.info("Torch threads already configured; skipping reconfiguration")
85
+
86
+ # set text tokenizer and collater
87
+ logger.info("Setting text tokenizer and collater...")
88
+ tokenizer_path = "./apps/audio_cloning/vallex/g2p/bpe_69.json"
89
+ text_tokenizer = PhonemeBpeTokenizer(tokenizer_path=tokenizer_path)
90
+ text_collater = get_text_token_collater()
91
+
92
+ # set device
93
+ logger.info("Setting device...")
94
+ device = torch.device("cpu")
95
+ if torch.cuda.is_available():
96
+ device = torch.device("cuda", 0)
97
+ # if torch.backends.mps.is_available():
98
+ # device = torch.device("mps")
99
+ logger.info("Device set to %s", device)
100
+
101
+ # Download VALL-E-X model weights if not exists
102
+ OUTPUT_DIR_CHECKPOINTS = "./models/checkpoints"
103
+ if platform.system().lower() == "linux":
104
+ # docker(linux)環境では /app/models/checkpoints にする
105
+ OUTPUT_DIR_CHECKPOINTS = "/app/models/checkpoints"
106
+
107
+ OUTPUT_FILENAME_CHECKPOINTS = "vallex-checkpoint.pt"
108
+ OUTPUT_PATH_CHECKPOINTS = os.path.join(
109
+ OUTPUT_DIR_CHECKPOINTS, OUTPUT_FILENAME_CHECKPOINTS
110
+ )
111
+ if not os.path.exists(OUTPUT_DIR_CHECKPOINTS):
112
+ os.makedirs(OUTPUT_DIR_CHECKPOINTS, exist_ok=True)
113
+ if not os.path.exists(OUTPUT_PATH_CHECKPOINTS):
114
+ import wget
115
+
116
+ logging.info(
117
+ "Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ..."
118
+ )
119
+ wget.download(
120
+ "https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt",
121
+ out=OUTPUT_PATH_CHECKPOINTS,
122
+ bar=wget.bar_adaptive,
123
+ )
124
+ raise Exception(
125
+ "\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'"
126
+ "\n manually download model weights and put it to {} .".format(
127
+ os.getcwd() + f"{OUTPUT_DIR_CHECKPOINTS}"
128
+ )
129
+ )
130
+
131
+ # initialize VALL-E-X model
132
+ model = VALLE(
133
+ N_DIM,
134
+ NUM_HEAD,
135
+ NUM_LAYERS,
136
+ norm_first=True,
137
+ add_prenet=False,
138
+ prefix_mode=PREFIX_MODE,
139
+ share_embedding=True,
140
+ nar_scale_factor=1.0,
141
+ prepend_bos=True,
142
+ num_quantizers=NUM_QUANTIZERS,
143
+ )
144
+ checkpoint = torch.load(OUTPUT_PATH_CHECKPOINTS, map_location="cpu", weights_only=False)
145
+ missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model"], strict=True)
146
+ assert not missing_keys
147
+ model.eval()
148
+
149
+ # Encodec-based tokenizer: converts reference audio into discrete conditioning tokens for VALLE
150
+ logger.info("Initializing Encodec-based tokenizer...")
151
+ audio_tokenizer = AudioTokenizer(device)
152
+
153
+ # Vocos vocoder: decodes VALLE's discrete acoustic codes back into a 24 kHz waveform
154
+ vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(device)
155
+
156
+ # initialize ASR model
157
+ OUTPUT_DIR_WHISPER = "./models/whisper"
158
+ if platform.system().lower() == "linux":
159
+ OUTPUT_DIR_WHISPER = "/app/models/whisper"
160
+
161
+ if not os.path.exists(OUTPUT_DIR_WHISPER):
162
+ os.makedirs(OUTPUT_DIR_WHISPER, exist_ok=True)
163
+ try:
164
+ logger.info("Loading Whisper model...")
165
+ model_name = "tiny"
166
+ whisper_model = whisper.load_model(
167
+ model_name, download_root=OUTPUT_DIR_WHISPER
168
+ ).cpu()
169
+ logger.info("Whisper model loaded successfully")
170
+ except Exception as e:
171
+ logging.info(e)
172
+ raise Exception(
173
+ "\n Whisper download failed or damaged, please go to "
174
+ "'https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt'"
175
+ "\n manually download model and put it to {} .".format(os.getcwd() + "/whisper")
176
+ )
177
+
178
+ # Initialize Voice Presets
179
+ logger.info("Initializing Voice Presets...")
180
+ PRESETS_DIR = "apps/audio_cloning/vallex/presets"
181
+ preset_list = os.walk(PRESETS_DIR).__next__()[2]
182
+ preset_list = [preset[:-4] for preset in preset_list if preset.endswith(".npz")]
183
+
184
+
185
+ def clear_prompts():
186
+ try:
187
+ path = tempfile.gettempdir()
188
+ for eachfile in os.listdir(path):
189
+ filename = os.path.join(path, eachfile)
190
+ if os.path.isfile(filename) and filename.endswith(".npz"):
191
+ lastmodifytime = os.stat(filename).st_mtime
192
+ endfiletime = time.time() - 60
193
+ if endfiletime > lastmodifytime:
194
+ os.remove(filename)
195
+ except Exception as e:
196
+ logger.error("Error clearing prompts: %s", e)
197
+ return
198
+
199
+
200
+ def transcribe_one(model, audio_path):
201
+ # load audio and pad/trim it to fit 30 seconds
202
+ audio = whisper.load_audio(audio_path)
203
+ audio = whisper.pad_or_trim(audio)
204
+
205
+ # make log-Mel spectrogram and move to the same device as the model
206
+ mel = whisper.log_mel_spectrogram(audio).to(model.device)
207
+
208
+ # detect the spoken language
209
+ _, probs = model.detect_language(mel)
210
+ print(f"Detected language: {max(probs, key=probs.get)}")
211
+ lang = max(probs, key=probs.get)
212
+ # decode the audio
213
+ options = whisper.DecodingOptions(
214
+ temperature=1.0,
215
+ best_of=5,
216
+ fp16=False if device == torch.device("cpu") else True,
217
+ sample_len=150,
218
+ )
219
+ result = whisper.decode(model, mel, options)
220
+
221
+ # print the recognized text
222
+ print(result.text)
223
+
224
+ text_pr = result.text
225
+ if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
226
+ text_pr += "."
227
+ return lang, text_pr
228
+
229
+
230
+ def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
231
+ global model, text_collater, text_tokenizer, audio_tokenizer
232
+ clear_prompts()
233
+ audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
234
+ sr, wav_pr = audio_prompt
235
+ if not isinstance(wav_pr, torch.FloatTensor):
236
+ wav_pr = torch.FloatTensor(wav_pr)
237
+ if wav_pr.abs().max() > 1:
238
+ wav_pr /= wav_pr.abs().max()
239
+ if wav_pr.size(-1) == 2:
240
+ wav_pr = wav_pr[:, 0]
241
+ if wav_pr.ndim == 1:
242
+ wav_pr = wav_pr.unsqueeze(0)
243
+ assert wav_pr.ndim and wav_pr.size(0) == 1
244
+
245
+ if transcript_content == "":
246
+ text_pr, lang_pr = make_prompt(name, wav_pr, sr, save=False)
247
+ else:
248
+ lang_pr = langid.classify(str(transcript_content))[0]
249
+ lang_token = lang2token[lang_pr]
250
+ text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
251
+ # tokenize audio
252
+ encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
253
+ audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
254
+
255
+ # tokenize text
256
+ phonemes, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
257
+ text_tokens, enroll_x_lens = text_collater([phonemes])
258
+
259
+ message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
260
+
261
+ # save as npz file
262
+ np.savez(
263
+ os.path.join(tempfile.gettempdir(), f"{name}.npz"),
264
+ audio_tokens=audio_tokens,
265
+ text_tokens=text_tokens,
266
+ lang_code=lang2code[lang_pr],
267
+ )
268
+ return message, os.path.join(tempfile.gettempdir(), f"{name}.npz")
269
+
270
+
271
+ def make_prompt(name, wav, sr, save=True):
272
+ global whisper_model
273
+ whisper_model.to(device)
274
+ if not isinstance(wav, torch.FloatTensor):
275
+ wav = torch.tensor(wav)
276
+ if wav.abs().max() > 1:
277
+ wav /= wav.abs().max()
278
+ if wav.size(-1) == 2:
279
+ wav = wav.mean(-1, keepdim=False)
280
+ if wav.ndim == 1:
281
+ wav = wav.unsqueeze(0)
282
+ assert wav.ndim and wav.size(0) == 1
283
+ torchaudio.save(f"./prompts/{name}.wav", wav, sr)
284
+ lang, text = transcribe_one(whisper_model, f"./prompts/{name}.wav")
285
+ lang_token = lang2token[lang]
286
+ text = lang_token + text + lang_token
287
+ with open(f"./prompts/{name}.txt", "w", encoding="utf-8") as f:
288
+ f.write(text)
289
+ if not save:
290
+ os.remove(f"./prompts/{name}.wav")
291
+ os.remove(f"./prompts/{name}.txt")
292
+
293
+ whisper_model.cpu()
294
+ torch.cuda.empty_cache()
295
+ return text, lang
296
+
297
+
298
+ @torch.no_grad()
299
+ def infer_from_audio(
300
+ text, language, accent, audio_prompt, record_audio_prompt, transcript_content
301
+ ):
302
+ global model, text_collater, text_tokenizer, audio_tokenizer
303
+ audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
304
+ sr, wav_pr = audio_prompt
305
+ if not isinstance(wav_pr, torch.FloatTensor):
306
+ wav_pr = torch.FloatTensor(wav_pr)
307
+ if wav_pr.abs().max() > 1:
308
+ wav_pr /= wav_pr.abs().max()
309
+ if wav_pr.size(-1) == 2:
310
+ wav_pr = wav_pr[:, 0]
311
+ if wav_pr.ndim == 1:
312
+ wav_pr = wav_pr.unsqueeze(0)
313
+ assert wav_pr.ndim and wav_pr.size(0) == 1
314
+
315
+ if transcript_content == "":
316
+ text_pr, lang_pr = make_prompt("dummy", wav_pr, sr, save=False)
317
+ else:
318
+ lang_pr = langid.classify(str(transcript_content))[0]
319
+ lang_token = lang2token[lang_pr]
320
+ text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
321
+
322
+ if language == "auto-detect":
323
+ lang_token = lang2token[langid.classify(text)[0]]
324
+ else:
325
+ lang_token = langdropdown2token[language]
326
+ lang = token2lang[lang_token]
327
+ text = lang_token + text + lang_token
328
+
329
+ # onload model
330
+ model.to(device)
331
+
332
+ # tokenize audio
333
+ encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
334
+ audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
335
+
336
+ # tokenize text
337
+ logging.info(f"synthesize text: {text}")
338
+ phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
339
+ text_tokens, text_tokens_lens = text_collater([phone_tokens])
340
+
341
+ enroll_x_lens = None
342
+ if text_pr:
343
+ text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
344
+ text_prompts, enroll_x_lens = text_collater([text_prompts])
345
+ text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
346
+ text_tokens_lens += enroll_x_lens
347
+ lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
348
+ encoded_frames = model.inference(
349
+ text_tokens.to(device),
350
+ text_tokens_lens.to(device),
351
+ audio_prompts,
352
+ enroll_x_lens=enroll_x_lens,
353
+ top_k=-100,
354
+ temperature=1,
355
+ prompt_language=lang_pr,
356
+ text_language=langs if accent == "no-accent" else lang,
357
+ best_of=5,
358
+ )
359
+ # Decode with Vocos
360
+ frames = encoded_frames.permute(2, 0, 1)
361
+ features = vocos.codes_to_features(frames)
362
+ samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
363
+
364
+ # offload model
365
+ model.to("cpu")
366
+ torch.cuda.empty_cache()
367
+
368
+ message = f"text prompt: {text_pr}\nsythesized text: {text}"
369
+ return message, (24000, samples.squeeze(0).cpu().numpy())
370
+
371
+
372
+ def main():
373
+ app = gr.Blocks(title="VALL-E X")
374
+ with app:
375
+ gr.Markdown(top_ja_md)
376
+ with gr.Tab("Infer from audio"):
377
+ gr.Markdown(infer_from_audio_ja_md)
378
+ with gr.Row():
379
+ with gr.Column():
380
+ textbox = gr.TextArea(
381
+ label="音声合成で喋らせたいテキスト",
382
+ # placeholder="Type your sentence here",
383
+ placeholder="ここに音声合成で喋らせたいテキストを入力してください。",
384
+ value="Welcome back, Master. What can I do for you today?",
385
+ elem_id="tts-input",
386
+ )
387
+ language_dropdown = gr.Dropdown(
388
+ choices=["auto-detect", "English", "中文", "日本語"],
389
+ value="auto-detect",
390
+ label="language",
391
+ )
392
+ accent_dropdown = gr.Dropdown(
393
+ choices=["no-accent", "English", "中文", "日本語"],
394
+ value="no-accent",
395
+ label="accent",
396
+ )
397
+ textbox_transcript = gr.TextArea(
398
+ label="Transcript",
399
+ # placeholder="Write transcript here. (leave empty to use whisper)",
400
+ placeholder="アップロードした音声、または録音した音声のテキストを入力してください。(whisper を使用する場合は空のままにしてください。)",
401
+ value="",
402
+ elem_id="prompt-name",
403
+ )
404
+ upload_audio_prompt = gr.Audio(
405
+ label="音声アップロード",
406
+ sources=["upload"],
407
+ interactive=True,
408
+ )
409
+ record_audio_prompt = gr.Audio(
410
+ label="音声を録音する",
411
+ sources=["microphone"],
412
+ interactive=True,
413
+ )
414
+ with gr.Column():
415
+ text_output = gr.Textbox(label="Message")
416
+ audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio")
417
+ btn = gr.Button("音声合成を開始する")
418
+ btn.click(
419
+ infer_from_audio,
420
+ inputs=[
421
+ textbox,
422
+ language_dropdown,
423
+ accent_dropdown,
424
+ upload_audio_prompt,
425
+ record_audio_prompt,
426
+ textbox_transcript,
427
+ ],
428
+ outputs=[text_output, audio_output],
429
+ )
430
+ textbox_mp = gr.TextArea(
431
+ label="Prompt name",
432
+ placeholder="Name your prompt here",
433
+ value="prompt_1",
434
+ elem_id="prompt-name",
435
+ )
436
+ btn_mp = gr.Button("Make prompt!")
437
+ prompt_output = gr.File(interactive=False)
438
+ btn_mp.click(
439
+ make_npz_prompt,
440
+ inputs=[
441
+ textbox_mp,
442
+ upload_audio_prompt,
443
+ record_audio_prompt,
444
+ textbox_transcript,
445
+ ],
446
+ outputs=[text_output, prompt_output],
447
+ )
448
+ gr.Examples(
449
+ examples=infer_from_audio_examples,
450
+ inputs=[
451
+ textbox,
452
+ language_dropdown,
453
+ accent_dropdown,
454
+ upload_audio_prompt,
455
+ record_audio_prompt,
456
+ textbox_transcript,
457
+ ],
458
+ outputs=[text_output, audio_output],
459
+ fn=infer_from_audio,
460
+ cache_examples=False,
461
+ )
apps/audio_cloning/vallex/models/__init__.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch.nn as nn
4
+
5
+ from .transformer import Transformer
6
+ from .vallex import VALLE, VALLF
7
+
8
+
9
+ def add_model_arguments(parser: argparse.ArgumentParser):
10
+ parser.add_argument(
11
+ "--model-name",
12
+ type=str,
13
+ default="VALL-E",
14
+ help="VALL-E, VALL-F, Transformer.",
15
+ )
16
+ parser.add_argument(
17
+ "--decoder-dim",
18
+ type=int,
19
+ default=1024,
20
+ help="Embedding dimension in the decoder model.",
21
+ )
22
+ parser.add_argument(
23
+ "--nhead",
24
+ type=int,
25
+ default=16,
26
+ help="Number of attention heads in the Decoder layers.",
27
+ )
28
+ parser.add_argument(
29
+ "--num-decoder-layers",
30
+ type=int,
31
+ default=12,
32
+ help="Number of Decoder layers.",
33
+ )
34
+ parser.add_argument(
35
+ "--scale-factor",
36
+ type=float,
37
+ default=1.0,
38
+ help="Model scale factor which will be assigned different meanings in different models.",
39
+ )
40
+ parser.add_argument(
41
+ "--norm-first",
42
+ type=bool,
43
+ default=True,
44
+ help="Pre or Post Normalization.",
45
+ )
46
+ parser.add_argument(
47
+ "--add-prenet",
48
+ type=bool,
49
+ default=False,
50
+ help="Whether add PreNet after Inputs.",
51
+ )
52
+
53
+ # VALL-E & F
54
+ parser.add_argument(
55
+ "--prefix-mode",
56
+ type=int,
57
+ default=1,
58
+ help="The mode for how to prefix VALL-E NAR Decoder, "
59
+ "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.",
60
+ )
61
+ parser.add_argument(
62
+ "--share-embedding",
63
+ type=bool,
64
+ default=True,
65
+ help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.",
66
+ )
67
+ parser.add_argument(
68
+ "--prepend-bos",
69
+ type=bool,
70
+ default=False,
71
+ help="Whether prepend <BOS> to the acoustic tokens -> AR Decoder inputs.",
72
+ )
73
+ parser.add_argument(
74
+ "--num-quantizers",
75
+ type=int,
76
+ default=8,
77
+ help="Number of Audio/Semantic quantization layers.",
78
+ )
79
+
80
+ # Transformer
81
+ parser.add_argument(
82
+ "--scaling-xformers",
83
+ type=bool,
84
+ default=False,
85
+ help="Apply Reworked Conformer scaling on Transformers.",
86
+ )
87
+
88
+
89
+ def get_model(params) -> nn.Module:
90
+ if params.model_name.lower() in ["vall-f", "vallf"]:
91
+ model = VALLF(
92
+ params.decoder_dim,
93
+ params.nhead,
94
+ params.num_decoder_layers,
95
+ norm_first=params.norm_first,
96
+ add_prenet=params.add_prenet,
97
+ prefix_mode=params.prefix_mode,
98
+ share_embedding=params.share_embedding,
99
+ nar_scale_factor=params.scale_factor,
100
+ prepend_bos=params.prepend_bos,
101
+ num_quantizers=params.num_quantizers,
102
+ )
103
+ elif params.model_name.lower() in ["vall-e", "valle"]:
104
+ model = VALLE(
105
+ params.decoder_dim,
106
+ params.nhead,
107
+ params.num_decoder_layers,
108
+ norm_first=params.norm_first,
109
+ add_prenet=params.add_prenet,
110
+ prefix_mode=params.prefix_mode,
111
+ share_embedding=params.share_embedding,
112
+ nar_scale_factor=params.scale_factor,
113
+ prepend_bos=params.prepend_bos,
114
+ num_quantizers=params.num_quantizers,
115
+ )
116
+ else:
117
+ assert params.model_name in ["Transformer"]
118
+ model = Transformer(
119
+ params.decoder_dim,
120
+ params.nhead,
121
+ params.num_decoder_layers,
122
+ norm_first=params.norm_first,
123
+ add_prenet=params.add_prenet,
124
+ scaling_xformers=params.scaling_xformers,
125
+ )
126
+
127
+ return model
apps/audio_cloning/vallex/models/macros.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Text
2
+ NUM_TEXT_TOKENS = 2048
3
+
4
+ # Audio
5
+ NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins
6
+ NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band
7
+
8
+
9
+ # Speaker
10
+ NUM_SPEAKER_CLASSES = 4096
11
+ SPEAKER_EMBEDDING_DIM = 64
apps/audio_cloning/vallex/models/transformer.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 (authors: Feiteng Li)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from functools import partial
16
+ from typing import Any, Dict, List, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from ..modules.embedding import SinePositionalEmbedding, TokenEmbedding
23
+ from ..modules.scaling import BalancedDoubleSwish, ScaledLinear
24
+ from ..modules.transformer import (
25
+ BalancedBasicNorm,
26
+ IdentityNorm,
27
+ TransformerDecoderLayer,
28
+ TransformerEncoder,
29
+ TransformerEncoderLayer,
30
+ )
31
+ from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS
32
+
33
+ # from icefall.utils import make_pad_mask
34
+ # from torchmetrics.classification import BinaryAccuracy
35
+ from .vallex import Transpose
36
+ from .visualizer import visualize
37
+
38
+ IdentityNorm = IdentityNorm
39
+
40
+
41
+ class Transformer(nn.Module):
42
+ """It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding)
43
+ Neural Speech Synthesis with Transformer Network
44
+ https://arxiv.org/abs/1809.08895
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ d_model: int,
50
+ nhead: int,
51
+ num_layers: int,
52
+ norm_first: bool = True,
53
+ add_prenet: bool = False,
54
+ scaling_xformers: bool = False,
55
+ ):
56
+ """
57
+ Args:
58
+ d_model:
59
+ The number of expected features in the input (required).
60
+ nhead:
61
+ The number of heads in the multiheadattention models (required).
62
+ num_layers:
63
+ The number of sub-decoder-layers in the decoder (required).
64
+ """
65
+ super().__init__()
66
+ self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
67
+
68
+ if add_prenet:
69
+ self.encoder_prenet = nn.Sequential(
70
+ Transpose(),
71
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
72
+ nn.BatchNorm1d(d_model),
73
+ nn.ReLU(),
74
+ nn.Dropout(0.5),
75
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
76
+ nn.BatchNorm1d(d_model),
77
+ nn.ReLU(),
78
+ nn.Dropout(0.5),
79
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
80
+ nn.BatchNorm1d(d_model),
81
+ nn.ReLU(),
82
+ nn.Dropout(0.5),
83
+ Transpose(),
84
+ nn.Linear(d_model, d_model),
85
+ )
86
+
87
+ self.decoder_prenet = nn.Sequential(
88
+ nn.Linear(NUM_MEL_BINS, 256),
89
+ nn.ReLU(),
90
+ nn.Dropout(0.5),
91
+ nn.Linear(256, 256),
92
+ nn.ReLU(),
93
+ nn.Dropout(0.5),
94
+ nn.Linear(256, d_model),
95
+ )
96
+
97
+ assert scaling_xformers is False # TODO: update this block
98
+ else:
99
+ self.encoder_prenet = nn.Identity()
100
+ if scaling_xformers:
101
+ self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model)
102
+ else:
103
+ self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model)
104
+
105
+ self.encoder_position = SinePositionalEmbedding(
106
+ d_model,
107
+ dropout=0.1,
108
+ scale=False,
109
+ )
110
+ self.decoder_position = SinePositionalEmbedding(
111
+ d_model, dropout=0.1, scale=False
112
+ )
113
+
114
+ if scaling_xformers:
115
+ self.encoder = TransformerEncoder(
116
+ TransformerEncoderLayer(
117
+ d_model,
118
+ nhead,
119
+ dim_feedforward=d_model * 4,
120
+ dropout=0.1,
121
+ batch_first=True,
122
+ norm_first=norm_first,
123
+ linear1_self_attention_cls=ScaledLinear,
124
+ linear2_self_attention_cls=partial(
125
+ ScaledLinear, initial_scale=0.01
126
+ ),
127
+ linear1_feedforward_cls=ScaledLinear,
128
+ linear2_feedforward_cls=partial(ScaledLinear, initial_scale=0.01),
129
+ activation=partial(
130
+ BalancedDoubleSwish,
131
+ channel_dim=-1,
132
+ max_abs=10.0,
133
+ min_prob=0.25,
134
+ ),
135
+ layer_norm_cls=IdentityNorm,
136
+ ),
137
+ num_layers=num_layers,
138
+ norm=BalancedBasicNorm(d_model) if norm_first else None,
139
+ )
140
+
141
+ self.decoder = nn.TransformerDecoder(
142
+ TransformerDecoderLayer(
143
+ d_model,
144
+ nhead,
145
+ dim_feedforward=d_model * 4,
146
+ dropout=0.1,
147
+ batch_first=True,
148
+ norm_first=norm_first,
149
+ linear1_self_attention_cls=ScaledLinear,
150
+ linear2_self_attention_cls=partial(
151
+ ScaledLinear, initial_scale=0.01
152
+ ),
153
+ linear1_feedforward_cls=ScaledLinear,
154
+ linear2_feedforward_cls=partial(ScaledLinear, initial_scale=0.01),
155
+ activation=partial(
156
+ BalancedDoubleSwish,
157
+ channel_dim=-1,
158
+ max_abs=10.0,
159
+ min_prob=0.25,
160
+ ),
161
+ layer_norm_cls=IdentityNorm,
162
+ ),
163
+ num_layers=num_layers,
164
+ norm=BalancedBasicNorm(d_model) if norm_first else None,
165
+ )
166
+
167
+ self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS)
168
+ self.stop_layer = nn.Linear(d_model, 1)
169
+ else:
170
+ self.encoder = nn.TransformerEncoder(
171
+ nn.TransformerEncoderLayer(
172
+ d_model,
173
+ nhead,
174
+ dim_feedforward=d_model * 4,
175
+ activation=F.relu,
176
+ dropout=0.1,
177
+ batch_first=True,
178
+ norm_first=norm_first,
179
+ ),
180
+ num_layers=num_layers,
181
+ norm=nn.LayerNorm(d_model) if norm_first else None,
182
+ )
183
+
184
+ self.decoder = nn.TransformerDecoder(
185
+ nn.TransformerDecoderLayer(
186
+ d_model,
187
+ nhead,
188
+ dim_feedforward=d_model * 4,
189
+ activation=F.relu,
190
+ dropout=0.1,
191
+ batch_first=True,
192
+ norm_first=norm_first,
193
+ ),
194
+ num_layers=num_layers,
195
+ norm=nn.LayerNorm(d_model) if norm_first else None,
196
+ )
197
+
198
+ self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS)
199
+ self.stop_layer = nn.Linear(d_model, 1)
200
+
201
+ self.stop_accuracy_metric = BinaryAccuracy(
202
+ threshold=0.5, multidim_average="global"
203
+ )
204
+
205
+ # self.apply(self._init_weights)
206
+
207
+ # def _init_weights(self, module):
208
+ # if isinstance(module, (nn.Linear)):
209
+ # module.weight.data.normal_(mean=0.0, std=0.02)
210
+ # if isinstance(module, nn.Linear) and module.bias is not None:
211
+ # module.bias.data.zero_()
212
+ # elif isinstance(module, nn.LayerNorm):
213
+ # module.bias.data.zero_()
214
+ # module.weight.data.fill_(1.0)
215
+ # elif isinstance(module, nn.Embedding):
216
+ # module.weight.data.normal_(mean=0.0, std=0.02)
217
+
218
+ def forward(
219
+ self,
220
+ x: torch.Tensor,
221
+ x_lens: torch.Tensor,
222
+ y: torch.Tensor,
223
+ y_lens: torch.Tensor,
224
+ reduction: str = "sum",
225
+ train_stage: int = 0,
226
+ **kwargs,
227
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
228
+ """
229
+ Args:
230
+ x:
231
+ A 2-D tensor of shape (N, S).
232
+ x_lens:
233
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
234
+ before padding.
235
+ y:
236
+ A 3-D tensor of shape (N, T, 8).
237
+ y_lens:
238
+ A 1-D tensor of shape (N,). It contains the number of tokens in `x`
239
+ before padding.
240
+ train_stage:
241
+ Not used in this model.
242
+ Returns:
243
+ Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
244
+ """
245
+ del train_stage
246
+
247
+ assert x.ndim == 2, x.shape
248
+ assert x_lens.ndim == 1, x_lens.shape
249
+ assert y.ndim == 3, y.shape
250
+ assert y_lens.ndim == 1, y_lens.shape
251
+
252
+ assert torch.all(x_lens > 0)
253
+
254
+ # NOTE: x has been padded in TextTokenCollater
255
+ x_mask = make_pad_mask(x_lens).to(x.device)
256
+
257
+ x = self.text_embedding(x)
258
+ x = self.encoder_prenet(x)
259
+ x = self.encoder_position(x)
260
+ x = self.encoder(x, src_key_padding_mask=x_mask)
261
+
262
+ total_loss, metrics = 0.0, {}
263
+
264
+ y_mask = make_pad_mask(y_lens).to(y.device)
265
+ y_mask_float = y_mask.type(torch.float32)
266
+ data_mask = 1.0 - y_mask_float.unsqueeze(-1)
267
+
268
+ # Training
269
+ # AR Decoder
270
+ def pad_y(y):
271
+ y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach()
272
+ # inputs, targets
273
+ return y[:, :-1], y[:, 1:]
274
+
275
+ y, targets = pad_y(y * data_mask) # mask padding as zeros
276
+
277
+ y_emb = self.decoder_prenet(y)
278
+ y_pos = self.decoder_position(y_emb)
279
+
280
+ y_len = y_lens.max()
281
+ tgt_mask = torch.triu(
282
+ torch.ones(y_len, y_len, device=y.device, dtype=torch.bool),
283
+ diagonal=1,
284
+ )
285
+ y_dec = self.decoder(
286
+ y_pos,
287
+ x,
288
+ tgt_mask=tgt_mask,
289
+ memory_key_padding_mask=x_mask,
290
+ )
291
+
292
+ predict = self.predict_layer(y_dec)
293
+ # loss
294
+ total_loss = F.mse_loss(predict, targets, reduction=reduction)
295
+
296
+ logits = self.stop_layer(y_dec).squeeze(-1)
297
+ stop_loss = F.binary_cross_entropy_with_logits(
298
+ logits,
299
+ y_mask_float.detach(),
300
+ weight=1.0 + y_mask_float.detach() * 4.0,
301
+ reduction=reduction,
302
+ )
303
+ metrics["stop_loss"] = stop_loss.detach()
304
+
305
+ stop_accuracy = self.stop_accuracy_metric(
306
+ (torch.sigmoid(logits) >= 0.5).type(torch.int64),
307
+ y_mask.type(torch.int64),
308
+ )
309
+ # icefall MetricsTracker.norm_items()
310
+ metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type(
311
+ torch.float32
312
+ )
313
+
314
+ return ((x, predict), total_loss + 100.0 * stop_loss, metrics)
315
+
316
+ def inference(
317
+ self,
318
+ x: torch.Tensor,
319
+ x_lens: torch.Tensor,
320
+ y: Any = None,
321
+ **kwargs,
322
+ ) -> torch.Tensor:
323
+ """
324
+ Args:
325
+ x:
326
+ A 2-D tensor of shape (1, S).
327
+ x_lens:
328
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
329
+ before padding.
330
+ Returns:
331
+ Return the predicted audio code matrix and cross-entropy loss.
332
+ """
333
+ assert x.ndim == 2, x.shape
334
+ assert x_lens.ndim == 1, x_lens.shape
335
+
336
+ assert torch.all(x_lens > 0)
337
+
338
+ x_mask = make_pad_mask(x_lens).to(x.device)
339
+
340
+ x = self.text_embedding(x)
341
+ x = self.encoder_prenet(x)
342
+ x = self.encoder_position(x)
343
+ x = self.encoder(x, src_key_padding_mask=x_mask)
344
+
345
+ x_mask = make_pad_mask(x_lens).to(x.device)
346
+
347
+ # AR Decoder
348
+ # TODO: Managing decoder steps avoid repetitive computation
349
+ y = torch.zeros(
350
+ [x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device
351
+ )
352
+ while True:
353
+ y_emb = self.decoder_prenet(y)
354
+ y_pos = self.decoder_position(y_emb)
355
+
356
+ tgt_mask = torch.triu(
357
+ torch.ones(y.shape[1], y.shape[1], device=y.device, dtype=torch.bool),
358
+ diagonal=1,
359
+ )
360
+
361
+ y_dec = self.decoder(
362
+ y_pos,
363
+ x,
364
+ tgt_mask=tgt_mask,
365
+ memory_mask=None,
366
+ memory_key_padding_mask=x_mask,
367
+ )
368
+ predict = self.predict_layer(y_dec[:, -1:])
369
+
370
+ logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5
371
+ if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()):
372
+ print(f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]")
373
+ break
374
+
375
+ y = torch.concat([y, predict], dim=1)
376
+
377
+ return y[:, 1:]
378
+
379
+ def visualize(
380
+ self,
381
+ predicts: Tuple[torch.Tensor],
382
+ batch: Dict[str, Union[List, torch.Tensor]],
383
+ output_dir: str,
384
+ limit: int = 4,
385
+ ) -> None:
386
+ visualize(predicts, batch, output_dir, limit=limit)
apps/audio_cloning/vallex/models/vallex.py ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 (authors: Feiteng Li)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ from typing import Dict, Iterator, List, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+
23
+ # from icefall.utils import make_pad_mask
24
+ # from torchmetrics.classification import MulticlassAccuracy
25
+ from ..data.input_strategies import PromptedFeatures
26
+ from ..modules.embedding import SinePositionalEmbedding, TokenEmbedding
27
+ from ..modules.transformer import (
28
+ AdaptiveLayerNorm,
29
+ LayerNorm,
30
+ TransformerDecoderLayer,
31
+ TransformerEncoder,
32
+ TransformerEncoderLayer,
33
+ )
34
+ from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
35
+
36
+
37
+ class Transpose(nn.Identity):
38
+ """(N, T, D) -> (N, D, T)"""
39
+
40
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
41
+ return input.transpose(1, 2)
42
+
43
+
44
+ # NOTE: There are two ways to implement the model
45
+ # 1) [VALL-F] standard TransformerDecoder, use x as memory
46
+ # 2) [VALL-E] modified TransformerDecoder like GPT-x(e.g. causal TransformerEncoder),
47
+ # use x as the prefix of decoder inputs
48
+ class VALLF(nn.Module):
49
+ """It implements https://arxiv.org/abs/2301.02111
50
+ "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ d_model: int,
56
+ nhead: int,
57
+ num_layers: int,
58
+ norm_first: bool = True,
59
+ add_prenet: bool = False,
60
+ decoder_cls: Union[
61
+ nn.TransformerDecoder, nn.TransformerEncoder
62
+ ] = nn.TransformerDecoder,
63
+ decoder_layer_cls: Union[
64
+ TransformerDecoderLayer, TransformerEncoderLayer
65
+ ] = TransformerDecoderLayer,
66
+ prefix_mode: int = 0,
67
+ share_embedding: bool = True,
68
+ nar_scale_factor: float = 1.0,
69
+ prepend_bos: bool = True,
70
+ num_quantizers: int = 8,
71
+ ):
72
+ """
73
+ Args:
74
+ d_model:
75
+ The number of expected features in the input (required).
76
+ nhead:
77
+ The number of heads in the multiheadattention models (required).
78
+ num_layers:
79
+ The number of sub-decoder-layers in the decoder (required).
80
+ """
81
+ super().__init__()
82
+ nar_d_model = int(d_model * nar_scale_factor)
83
+
84
+ self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
85
+ self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS)
86
+
87
+ # ID NUM_AUDIO_TOKENS -> PAD
88
+ # ID NUM_AUDIO_TOKENS + 1 -> BOS
89
+ self.ar_audio_prepend_bos = prepend_bos
90
+ self.ar_audio_embedding = TokenEmbedding(
91
+ d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos)
92
+ )
93
+
94
+ # PreNet
95
+ if add_prenet:
96
+ self.ar_text_prenet = nn.Sequential(
97
+ Transpose(),
98
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
99
+ nn.BatchNorm1d(d_model),
100
+ nn.ReLU(),
101
+ nn.Dropout(0.5),
102
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
103
+ nn.BatchNorm1d(d_model),
104
+ nn.ReLU(),
105
+ nn.Dropout(0.5),
106
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
107
+ nn.BatchNorm1d(d_model),
108
+ nn.ReLU(),
109
+ nn.Dropout(0.5),
110
+ Transpose(),
111
+ nn.Linear(d_model, d_model),
112
+ )
113
+
114
+ self.ar_audio_prenet = nn.Sequential(
115
+ nn.Linear(d_model, 256),
116
+ nn.ReLU(),
117
+ nn.Dropout(0.25),
118
+ nn.Linear(256, 256),
119
+ nn.ReLU(),
120
+ nn.Dropout(0.25),
121
+ nn.Linear(256, d_model),
122
+ )
123
+ else:
124
+ self.ar_text_prenet = nn.Identity()
125
+ self.ar_audio_prenet = nn.Identity()
126
+
127
+ self.ar_text_position = SinePositionalEmbedding(
128
+ d_model,
129
+ dropout=0.1,
130
+ scale=False,
131
+ alpha=True,
132
+ )
133
+ self.ar_audio_position = SinePositionalEmbedding(
134
+ d_model,
135
+ dropout=0.1,
136
+ scale=False,
137
+ alpha=True,
138
+ )
139
+
140
+ self.ar_decoder = decoder_cls(
141
+ decoder_layer_cls(
142
+ d_model,
143
+ nhead,
144
+ dim_feedforward=d_model * 4,
145
+ dropout=0.1,
146
+ batch_first=True,
147
+ norm_first=norm_first,
148
+ ),
149
+ num_layers=num_layers,
150
+ norm=LayerNorm(d_model) if norm_first else None,
151
+ )
152
+ self.ar_predict_layer = nn.Linear(d_model, NUM_AUDIO_TOKENS + 1, bias=False)
153
+
154
+ self.rng = random.Random(0)
155
+ self.num_heads = nhead
156
+ self.prefix_mode = prefix_mode
157
+ self.num_quantizers = num_quantizers
158
+
159
+ assert num_quantizers >= 1
160
+ if num_quantizers > 1:
161
+ self.nar_audio_embeddings = nn.ModuleList(
162
+ [TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)]
163
+ + [
164
+ TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS)
165
+ for i in range(num_quantizers - 1)
166
+ ]
167
+ ) # W_a
168
+
169
+ # PreNet
170
+ if add_prenet:
171
+ self.nar_text_prenet = nn.Sequential(
172
+ Transpose(),
173
+ nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
174
+ nn.BatchNorm1d(nar_d_model),
175
+ nn.ReLU(),
176
+ nn.Dropout(0.5),
177
+ nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
178
+ nn.BatchNorm1d(nar_d_model),
179
+ nn.ReLU(),
180
+ nn.Dropout(0.5),
181
+ nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
182
+ nn.BatchNorm1d(nar_d_model),
183
+ nn.ReLU(),
184
+ nn.Dropout(0.5),
185
+ Transpose(),
186
+ nn.Linear(nar_d_model, nar_d_model),
187
+ )
188
+ self.nar_audio_prenet = nn.Sequential(
189
+ nn.Linear(nar_d_model, 256),
190
+ nn.ReLU(),
191
+ nn.Dropout(0.25),
192
+ nn.Linear(256, 256),
193
+ nn.ReLU(),
194
+ nn.Dropout(0.25),
195
+ nn.Linear(256, nar_d_model),
196
+ )
197
+ else:
198
+ self.nar_text_prenet = nn.Identity()
199
+ self.nar_audio_prenet = nn.Identity()
200
+
201
+ self.nar_text_position = SinePositionalEmbedding(
202
+ nar_d_model,
203
+ dropout=0.0,
204
+ scale=False,
205
+ alpha=False,
206
+ )
207
+ self.nar_audio_position = SinePositionalEmbedding(
208
+ nar_d_model,
209
+ dropout=0.1,
210
+ scale=False,
211
+ alpha=False,
212
+ )
213
+
214
+ self.nar_decoder = decoder_cls(
215
+ decoder_layer_cls(
216
+ nar_d_model,
217
+ int(nhead * nar_scale_factor),
218
+ dim_feedforward=nar_d_model * 4,
219
+ dropout=0.1,
220
+ batch_first=True,
221
+ norm_first=norm_first,
222
+ adaptive_layer_norm=True,
223
+ ),
224
+ num_layers=int(num_layers * nar_scale_factor),
225
+ norm=AdaptiveLayerNorm(nar_d_model, norm=nn.LayerNorm(nar_d_model))
226
+ if norm_first
227
+ else None,
228
+ )
229
+ self.nar_predict_layers = nn.ModuleList(
230
+ [
231
+ nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False)
232
+ for i in range(num_quantizers - 1)
233
+ ]
234
+ )
235
+ self.nar_stage_embeddings = nn.ModuleList(
236
+ [TokenEmbedding(nar_d_model, 1) for i in range(num_quantizers - 1)]
237
+ )
238
+
239
+ if share_embedding:
240
+ # We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa
241
+ # NOTE(Feiteng): In the experiment, this undermines accuracy
242
+ # self.ar_predict_layer.weight = self.ar_audio_embedding.weight
243
+
244
+ # We also share the parameters of the acoustic embedding layer and the output prediction layer,
245
+ # which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
246
+ for j in range(0, num_quantizers - 2):
247
+ self.nar_predict_layers[j].weight = self.nar_audio_embeddings[
248
+ j + 2
249
+ ].weight
250
+
251
+ def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
252
+ assert stage > 0
253
+ if stage == 1:
254
+ for name, param in self.named_parameters():
255
+ if name.startswith("ar_"):
256
+ print(f" AR parameter: {name}")
257
+ yield param
258
+
259
+ if stage == 2:
260
+ for name, param in self.named_parameters():
261
+ if name.startswith("nar_"):
262
+ print(f"NAR parameter: {name}")
263
+ yield param
264
+
265
+ def stage_named_parameters(
266
+ self, stage: int = 1
267
+ ) -> Iterator[Tuple[str, nn.Parameter]]:
268
+ assert stage > 0
269
+ if stage == 1:
270
+ for pair in self.named_parameters():
271
+ if pair[0].startswith("ar_"):
272
+ yield pair
273
+
274
+ if stage == 2:
275
+ for pair in self.named_parameters():
276
+ if pair[0].startswith("nar_"):
277
+ yield pair
278
+
279
+ def pad_y_eos(self, y, y_mask_int, eos_id):
280
+ targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
281
+ y_mask_int, (0, 1), value=1
282
+ )
283
+ # inputs, targets
284
+ if self.ar_audio_prepend_bos:
285
+ return (
286
+ F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1),
287
+ targets,
288
+ )
289
+
290
+ return targets[:, :-1], targets[:, 1:]
291
+
292
+ def _prepare_prompts(
293
+ self, y, y_lens, codes, nar_stage, y_prompts_codes, prefix_mode
294
+ ):
295
+ # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
296
+ # from the same utterance.
297
+ # We implement this differently.
298
+ if prefix_mode == 0:
299
+ # no prefix
300
+ prefix_len = 0
301
+ y_emb = self.nar_audio_embeddings[0](y)
302
+ for j in range(1, nar_stage):
303
+ # Formula (4) (5)
304
+ y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
305
+ elif prefix_mode == 1:
306
+ # prefix at begining
307
+ int_low = (0.25 * y_lens.min()).type(torch.int64).item()
308
+ prefix_len = torch.randint(0, int_low * 2, size=()).item()
309
+ prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames
310
+
311
+ y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
312
+ y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
313
+ for j in range(1, self.num_quantizers):
314
+ y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j])
315
+ if j < nar_stage:
316
+ y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j])
317
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
318
+ elif prefix_mode in [2, 4]:
319
+ if prefix_mode == 2:
320
+ # random prefix
321
+ prefix_len = min(225, int(0.25 * y_lens.min().item()))
322
+
323
+ y_prompts_codes = []
324
+ for b in range(codes.shape[0]):
325
+ start = self.rng.randint(0, y_lens[b].item() - prefix_len)
326
+ y_prompts_codes.append(
327
+ torch.clone(codes[b, start : start + prefix_len])
328
+ )
329
+ codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS
330
+ y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
331
+ else:
332
+ prefix_len = y_prompts_codes.shape[1]
333
+
334
+ y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
335
+ y_emb = self.nar_audio_embeddings[0](y)
336
+ for j in range(1, self.num_quantizers):
337
+ y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j])
338
+ if j < nar_stage:
339
+ y_emb += self.nar_audio_embeddings[j](codes[..., j])
340
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
341
+ else:
342
+ raise ValueError
343
+
344
+ return y_emb, prefix_len
345
+
346
+ def forward(
347
+ self,
348
+ x: torch.Tensor,
349
+ x_lens: torch.Tensor,
350
+ y: Union[torch.Tensor, PromptedFeatures],
351
+ y_lens: Union[torch.Tensor, PromptedFeatures],
352
+ reduction: str = "sum",
353
+ train_stage: int = 0,
354
+ **kwargs,
355
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
356
+ raise NotImplementedError
357
+
358
+ def inference(
359
+ self,
360
+ x: torch.Tensor,
361
+ x_lens: torch.Tensor,
362
+ y: torch.Tensor,
363
+ enroll_x_lens: Union[torch.Tensor, None] = None,
364
+ top_k: int = -100,
365
+ temperature: float = 1.0,
366
+ ) -> torch.Tensor:
367
+ raise NotImplementedError
368
+
369
+ def visualize(
370
+ self,
371
+ predicts: Tuple[torch.Tensor],
372
+ batch: Dict[str, Union[List, torch.Tensor]],
373
+ output_dir: str,
374
+ limit: int = 4,
375
+ ) -> None:
376
+ raise NotImplementedError
377
+
378
+
379
+ class VALLE(VALLF):
380
+ """It implements https://arxiv.org/abs/2301.02111
381
+ "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
382
+ """
383
+
384
+ def __init__(
385
+ self,
386
+ d_model: int,
387
+ nhead: int,
388
+ num_layers: int,
389
+ norm_first: bool = True,
390
+ add_prenet: bool = False,
391
+ prefix_mode: int = 0,
392
+ share_embedding: bool = True,
393
+ nar_scale_factor: float = 1.0,
394
+ **kwargs,
395
+ ):
396
+ """
397
+ Args:
398
+ d_model:
399
+ The number of expected features in the input (required).
400
+ nhead:
401
+ The number of heads in the multiheadattention models (required).
402
+ num_layers:
403
+ The number of sub-decoder-layers in the decoder (required).
404
+ """
405
+ super(VALLE, self).__init__(
406
+ d_model,
407
+ nhead,
408
+ num_layers,
409
+ norm_first=norm_first,
410
+ add_prenet=add_prenet,
411
+ decoder_cls=TransformerEncoder,
412
+ decoder_layer_cls=TransformerEncoderLayer,
413
+ prefix_mode=prefix_mode,
414
+ share_embedding=share_embedding,
415
+ nar_scale_factor=nar_scale_factor,
416
+ **kwargs,
417
+ )
418
+ self.language_ID = {
419
+ "en": 0,
420
+ "zh": 1,
421
+ "ja": 2,
422
+ }
423
+ self.ar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
424
+ self.nar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
425
+
426
+ def forward(
427
+ self,
428
+ x: torch.Tensor,
429
+ x_lens: torch.Tensor,
430
+ y: Union[torch.Tensor, PromptedFeatures],
431
+ y_lens: Union[torch.Tensor, PromptedFeatures],
432
+ reduction: str = "sum",
433
+ train_stage: int = 0,
434
+ **kwargs,
435
+ ):
436
+ raise NotImplementedError
437
+
438
+ def inference(
439
+ self,
440
+ x: torch.Tensor,
441
+ x_lens: torch.Tensor,
442
+ y: torch.Tensor,
443
+ enroll_x_lens: torch.Tensor,
444
+ top_k: int = -100,
445
+ temperature: float = 1.0,
446
+ prompt_language: str = None,
447
+ text_language: str = None,
448
+ best_of: int = 1,
449
+ length_penalty: float = 1.0,
450
+ return_worst: bool = False,
451
+ ) -> torch.Tensor:
452
+ """
453
+ Args:
454
+ x:
455
+ A 2-D tensor of shape (1, S).
456
+ x_lens:
457
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
458
+ before padding.
459
+ y:
460
+ A 3-D tensor of shape (1, T, 8).
461
+ top_k: (`optional`) int
462
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
463
+ temperature: (`optional`) float
464
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
465
+ Returns:
466
+ Return the predicted audio code matrix.
467
+ """
468
+ assert x.ndim == 2, x.shape
469
+ assert x_lens.ndim == 1, x_lens.shape
470
+ assert y.ndim == 3, y.shape
471
+ assert y.shape[0] == 1, y.shape
472
+
473
+ assert torch.all(x_lens > 0)
474
+
475
+ # NOTE: x has been padded in TextTokenCollater
476
+ text = x
477
+ x = self.ar_text_embedding(text)
478
+ # Add language embedding
479
+ prompt_language_id = torch.LongTensor(
480
+ np.array([self.language_ID[prompt_language]])
481
+ ).to(x.device)
482
+ if isinstance(text_language, str):
483
+ text_language_id = torch.LongTensor(
484
+ np.array([self.language_ID[text_language]])
485
+ ).to(x.device)
486
+ elif isinstance(text_language, List):
487
+ text_language_id = torch.LongTensor(
488
+ np.array([self.language_ID[tl] for tl in text_language])
489
+ ).to(x.device)
490
+ x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
491
+ x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
492
+ x = self.ar_text_prenet(x)
493
+ x = self.ar_text_position(x)
494
+
495
+ text_len = x_lens.max()
496
+ prompts = y
497
+ prefix_len = y.shape[1]
498
+
499
+ # AR Decoder
500
+ # TODO: Managing decoder steps avoid repetitive computation
501
+ y = prompts[..., 0]
502
+ if self.ar_audio_prepend_bos:
503
+ y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1)
504
+
505
+ x_len = x_lens.max()
506
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
507
+
508
+ kv_cache = None
509
+ use_kv_caching = True
510
+
511
+ sum_logprobs = torch.zeros(
512
+ best_of, device=y.device
513
+ ) # implement batch decoding here
514
+ x = x.repeat(best_of, 1, 1)
515
+ y = y.repeat(best_of, 1)
516
+ while True:
517
+ y_emb = self.ar_audio_embedding(y)
518
+ y_emb = self.ar_audio_prenet(y_emb)
519
+ y_pos = self.ar_audio_position(y_emb)
520
+ xy_pos = torch.concat([x, y_pos], dim=1)
521
+
522
+ y_len = y.shape[1]
523
+ x_attn_mask_pad = F.pad(
524
+ x_attn_mask,
525
+ (0, y_len),
526
+ value=True,
527
+ )
528
+ y_attn_mask = F.pad(
529
+ torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
530
+ (x_len, 0),
531
+ value=False,
532
+ )
533
+ xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
534
+ y.device
535
+ )
536
+
537
+ if use_kv_caching and kv_cache is not None:
538
+ xy_pos = xy_pos[:, [-1]]
539
+ else:
540
+ pass
541
+
542
+ xy_dec, kv_cache = self.ar_decoder.infer(
543
+ xy_pos,
544
+ mask=xy_attn_mask,
545
+ past_kv=kv_cache,
546
+ use_cache=use_kv_caching,
547
+ )
548
+ # xy_dec, _ = self.ar_decoder(
549
+ # (xy_pos, None),
550
+ # mask=xy_attn_mask,
551
+ # )
552
+
553
+ logits = self.ar_predict_layer(xy_dec[:, -1])
554
+ samples, current_logprobs = topk_sampling(
555
+ logits, top_k=top_k, top_p=1, temperature=temperature
556
+ )
557
+ sum_logprobs += current_logprobs * (y[:, -1] != NUM_AUDIO_TOKENS)
558
+ samples[y[:, -1] == NUM_AUDIO_TOKENS] = NUM_AUDIO_TOKENS
559
+ completed = (samples[:, -1] == NUM_AUDIO_TOKENS).all()
560
+ if completed or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16:
561
+ if prompts.shape[1] == y.shape[1]:
562
+ raise SyntaxError("well trained model shouldn't reach here.")
563
+ lengths = torch.sum(y != NUM_AUDIO_TOKENS, dim=1)
564
+ avg_logprobs = sum_logprobs / lengths**length_penalty
565
+ # choose the best beam according to sum_logprobs
566
+ best_beam = y[torch.argmax(avg_logprobs), :]
567
+ worst_beam = y[torch.argmin(avg_logprobs), :]
568
+ # strip all eos tokens
569
+ best_beam = best_beam[best_beam != NUM_AUDIO_TOKENS]
570
+ worst_beam = worst_beam[worst_beam != NUM_AUDIO_TOKENS]
571
+ if return_worst:
572
+ y = worst_beam.unsqueeze(0)
573
+ else:
574
+ y = best_beam.unsqueeze(0)
575
+ print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
576
+ break
577
+
578
+ y = torch.concat([y, samples], dim=1)
579
+
580
+ codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
581
+ if self.num_quantizers == 1:
582
+ return torch.stack(codes, dim=-1)
583
+
584
+ # Non-AR Decoders
585
+ y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :])
586
+
587
+ if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
588
+ enrolled_len = enroll_x_lens.max().item()
589
+ # SOS + Synthesis Text + EOS
590
+ text = torch.concat(
591
+ [
592
+ text[:, :1],
593
+ text[:, enrolled_len - 1 :],
594
+ ],
595
+ dim=1,
596
+ )
597
+ text_len = text_len - (enrolled_len - 2)
598
+ assert text.shape[0] == 1
599
+
600
+ x = self.nar_text_embedding(text)
601
+ # Add language embedding
602
+ prompt_language_id = torch.LongTensor(
603
+ np.array([self.language_ID[prompt_language]])
604
+ ).to(x.device)
605
+ if isinstance(text_language, str):
606
+ text_language_id = torch.LongTensor(
607
+ np.array([self.language_ID[text_language]])
608
+ ).to(x.device)
609
+ elif isinstance(text_language, List):
610
+ text_language_id = torch.LongTensor(
611
+ np.array([self.language_ID[tl] for tl in text_language])
612
+ ).to(x.device)
613
+ x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
614
+ x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
615
+ x = self.nar_text_prenet(x)
616
+ x = self.nar_text_position(x)
617
+
618
+ if self.prefix_mode == 0:
619
+ for i, (predict_layer, embedding_layer) in enumerate(
620
+ zip(
621
+ self.nar_predict_layers,
622
+ self.nar_audio_embeddings[1:],
623
+ )
624
+ ):
625
+ y_pos = self.nar_audio_prenet(y_emb)
626
+ y_pos = self.nar_audio_position(y_pos)
627
+ xy_pos = torch.concat([x, y_pos], dim=1)
628
+
629
+ xy_dec, _ = self.nar_decoder(
630
+ (xy_pos, self.nar_stage_embeddings[i].weight)
631
+ )
632
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
633
+
634
+ samples = torch.argmax(logits, dim=-1)
635
+ codes.append(samples)
636
+
637
+ if i < self.num_quantizers - 2:
638
+ y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
639
+ y_emb[:, prefix_len:] += embedding_layer(samples)
640
+ else:
641
+ for j in range(1, self.num_quantizers):
642
+ y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
643
+
644
+ for i, (predict_layer, embedding_layer) in enumerate(
645
+ zip(
646
+ self.nar_predict_layers,
647
+ self.nar_audio_embeddings[1:],
648
+ )
649
+ ):
650
+ y_pos = self.nar_audio_prenet(y_emb)
651
+ y_pos = self.nar_audio_position(y_pos)
652
+ xy_pos = torch.concat([x, y_pos], dim=1)
653
+
654
+ xy_dec, _ = self.nar_decoder(
655
+ (xy_pos, self.nar_stage_embeddings[i].weight)
656
+ )
657
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
658
+
659
+ samples = torch.argmax(logits, dim=-1)
660
+ codes.append(samples)
661
+
662
+ if i < self.num_quantizers - 2:
663
+ y_emb[:, prefix_len:] += embedding_layer(samples)
664
+
665
+ assert len(codes) == self.num_quantizers
666
+ return torch.stack(codes, dim=-1)
667
+
668
+ def continual(
669
+ self,
670
+ x: torch.Tensor,
671
+ x_lens: torch.Tensor,
672
+ y: torch.Tensor,
673
+ ) -> torch.Tensor:
674
+ """
675
+ Args:
676
+ x:
677
+ A 2-D tensor of shape (1, S).
678
+ x_lens:
679
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
680
+ before padding.
681
+ y:
682
+ A 3-D tensor of shape (1, T, 8).
683
+ Returns:
684
+ Return the predicted audio code matrix.
685
+ """
686
+ assert x.ndim == 2, x.shape
687
+ assert x_lens.ndim == 1, x_lens.shape
688
+ assert y.ndim == 3, y.shape
689
+ assert y.shape[0] == 1, y.shape
690
+
691
+ assert torch.all(x_lens > 0)
692
+ assert self.num_quantizers == 8
693
+
694
+ # NOTE: x has been padded in TextTokenCollater
695
+ text = x
696
+ x = self.ar_text_embedding(text)
697
+ x = self.ar_text_prenet(x)
698
+ x = self.ar_text_position(x)
699
+
700
+ text_len = x_lens.max()
701
+
702
+ prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
703
+
704
+ # AR Decoder
705
+ prompts = y[:, :prefix_len]
706
+
707
+ codes = [y[:, prefix_len:, 0]]
708
+ # Non-AR Decoders
709
+ x = self.nar_text_embedding(text)
710
+ x = self.nar_text_prenet(x)
711
+ x = self.nar_text_position(x)
712
+
713
+ y_emb = self.nar_audio_embeddings[0](y[..., 0])
714
+
715
+ if self.prefix_mode == 0:
716
+ for i, (predict_layer, embedding_layer) in enumerate(
717
+ zip(
718
+ self.nar_predict_layers,
719
+ self.nar_audio_embeddings[1:],
720
+ )
721
+ ):
722
+ y_pos = self.nar_audio_position(y_emb)
723
+ y_pos = self.nar_audio_prenet(y_pos)
724
+ xy_pos = torch.concat([x, y_pos], dim=1)
725
+
726
+ xy_dec, _ = self.nar_decoder(
727
+ (xy_pos, self.nar_stage_embeddings[i].weight)
728
+ )
729
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
730
+
731
+ samples = torch.argmax(logits, dim=-1)
732
+ codes.append(samples)
733
+
734
+ if i < 6:
735
+ y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
736
+ y_emb[:, prefix_len:] += embedding_layer(samples)
737
+ else:
738
+ for j in range(1, 8):
739
+ y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
740
+
741
+ for i, (predict_layer, embedding_layer) in enumerate(
742
+ zip(
743
+ self.nar_predict_layers,
744
+ self.nar_audio_embeddings[1:],
745
+ )
746
+ ):
747
+ y_pos = self.nar_audio_prenet(y_emb)
748
+ y_pos = self.nar_audio_position(y_pos)
749
+ xy_pos = torch.concat([x, y_pos], dim=1)
750
+
751
+ xy_dec, _ = self.nar_decoder(
752
+ (xy_pos, self.nar_stage_embeddings[i].weight)
753
+ )
754
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
755
+
756
+ samples = torch.argmax(logits, dim=-1)
757
+ codes.append(samples)
758
+
759
+ if i < 6:
760
+ y_emb[:, prefix_len:] += embedding_layer(samples)
761
+
762
+ assert len(codes) == 8
763
+ return torch.stack(codes, dim=-1)
764
+
765
+
766
+ # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
767
+ def top_k_top_p_filtering(
768
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
769
+ ):
770
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
771
+ Args:
772
+ logits: logits distribution shape (batch size, vocabulary size)
773
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
774
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
775
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
776
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
777
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
778
+ """
779
+ if top_k > 0:
780
+ top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
781
+ # Remove all tokens with a probability less than the last token of the top-k
782
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
783
+ logits[indices_to_remove] = filter_value
784
+
785
+ if top_p < 1.0:
786
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
787
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
788
+
789
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
790
+ sorted_indices_to_remove = cumulative_probs > top_p
791
+ if min_tokens_to_keep > 1:
792
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
793
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
794
+ # Shift the indices to the right to keep also the first token above the threshold
795
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
796
+ sorted_indices_to_remove[..., 0] = 0
797
+
798
+ # scatter sorted tensors to original indexing
799
+ indices_to_remove = sorted_indices_to_remove.scatter(
800
+ 1, sorted_indices, sorted_indices_to_remove
801
+ )
802
+ logits[indices_to_remove] = filter_value
803
+ return logits
804
+
805
+
806
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
807
+ # temperature: (`optional`) float
808
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
809
+ # top_k: (`optional`) int
810
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
811
+ # top_p: (`optional`) float
812
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
813
+
814
+ # Temperature (higher temperature => more likely to sample low probability tokens)
815
+ if temperature != 1.0:
816
+ logits = logits / temperature
817
+ # Top-p/top-k filtering
818
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
819
+ # Sample
820
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
821
+ logprobs = F.log_softmax(logits.float(), dim=-1)
822
+ current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)]
823
+ return token, current_logprobs
apps/audio_cloning/vallex/models/visualizer.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ from typing import Dict, List, Tuple, Union
20
+
21
+ import matplotlib.pyplot as plt
22
+ import numpy as np
23
+ import torch
24
+
25
+
26
+ def visualize(
27
+ predicts: Tuple[torch.Tensor],
28
+ batch: Dict[str, Union[List, torch.Tensor]],
29
+ output_dir: str,
30
+ limit: int = 4,
31
+ ) -> None:
32
+ text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
33
+ text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
34
+ audio_features = batch["audio_features"].to("cpu").detach().numpy()
35
+ audio_features_lens = batch["audio_features_lens"].to("cpu").detach().numpy()
36
+ assert text_tokens.ndim == 2
37
+
38
+ utt_ids, texts = batch["utt_id"], batch["text"]
39
+
40
+ encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
41
+ decoder_outputs = predicts[1]
42
+ if isinstance(decoder_outputs, list):
43
+ decoder_outputs = decoder_outputs[-1]
44
+ decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
45
+
46
+ vmin, vmax = 0, 1024 # Encodec
47
+ if decoder_outputs.dtype == np.float32:
48
+ vmin, vmax = -6, 0 # Fbank
49
+
50
+ num_figures = 3
51
+ for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
52
+ _ = plt.figure(figsize=(14, 8 * num_figures))
53
+
54
+ S = text_tokens_lens[b]
55
+ T = audio_features_lens[b]
56
+
57
+ # encoder
58
+ plt.subplot(num_figures, 1, 1)
59
+ plt.title(f"Text: {text}")
60
+ plt.imshow(
61
+ X=np.transpose(encoder_outputs[b]),
62
+ cmap=plt.get_cmap("jet"),
63
+ aspect="auto",
64
+ interpolation="nearest",
65
+ )
66
+ plt.gca().invert_yaxis()
67
+ plt.axvline(x=S - 0.4, linewidth=2, color="r")
68
+ plt.xlabel("Encoder Output")
69
+ plt.colorbar()
70
+
71
+ # decoder
72
+ plt.subplot(num_figures, 1, 2)
73
+ plt.imshow(
74
+ X=np.transpose(decoder_outputs[b]),
75
+ cmap=plt.get_cmap("jet"),
76
+ aspect="auto",
77
+ interpolation="nearest",
78
+ vmin=vmin,
79
+ vmax=vmax,
80
+ )
81
+ plt.gca().invert_yaxis()
82
+ plt.axvline(x=T - 0.4, linewidth=2, color="r")
83
+ plt.xlabel("Decoder Output")
84
+ plt.colorbar()
85
+
86
+ # target
87
+ plt.subplot(num_figures, 1, 3)
88
+ plt.imshow(
89
+ X=np.transpose(audio_features[b]),
90
+ cmap=plt.get_cmap("jet"),
91
+ aspect="auto",
92
+ interpolation="nearest",
93
+ vmin=vmin,
94
+ vmax=vmax,
95
+ )
96
+ plt.gca().invert_yaxis()
97
+ plt.axvline(x=T - 0.4, linewidth=2, color="r")
98
+ plt.xlabel("Decoder Target")
99
+ plt.colorbar()
100
+
101
+ plt.savefig(f"{output_dir}/{utt_id}.png")
102
+ plt.close()
apps/audio_cloning/vallex/modules/__init__.py ADDED
File without changes
apps/audio_cloning/vallex/modules/activation.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, List
2
+ import math
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from torch.nn import Linear, Module
7
+ from torch.nn import functional as F
8
+ from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
9
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
10
+ from torch.nn.parameter import Parameter
11
+
12
+ def _in_projection_packed(
13
+ q: Tensor,
14
+ k: Tensor,
15
+ v: Tensor,
16
+ w: Tensor,
17
+ b: Optional[Tensor] = None,
18
+ ) -> List[Tensor]:
19
+ r"""
20
+ Performs the in-projection step of the attention operation, using packed weights.
21
+ Output is a triple containing projection tensors for query, key and value.
22
+
23
+ Args:
24
+ q, k, v: query, key and value tensors to be projected. For self-attention,
25
+ these are typically the same tensor; for encoder-decoder attention,
26
+ k and v are typically the same tensor. (We take advantage of these
27
+ identities for performance if they are present.) Regardless, q, k and v
28
+ must share a common embedding dimension; otherwise their shapes may vary.
29
+ w: projection weights for q, k and v, packed into a single tensor. Weights
30
+ are packed along dimension 0, in q, k, v order.
31
+ b: optional projection biases for q, k and v, packed into a single tensor
32
+ in q, k, v order.
33
+
34
+ Shape:
35
+ Inputs:
36
+ - q: :math:`(..., E)` where E is the embedding dimension
37
+ - k: :math:`(..., E)` where E is the embedding dimension
38
+ - v: :math:`(..., E)` where E is the embedding dimension
39
+ - w: :math:`(E * 3, E)` where E is the embedding dimension
40
+ - b: :math:`E * 3` where E is the embedding dimension
41
+
42
+ Output:
43
+ - in output list :math:`[q', k', v']`, each output tensor will have the
44
+ same shape as the corresponding input tensor.
45
+ """
46
+ E = q.size(-1)
47
+ if k is v:
48
+ if q is k:
49
+ # self-attention
50
+ return F.linear(q, w, b).chunk(3, dim=-1)
51
+ else:
52
+ # encoder-decoder attention
53
+ w_q, w_kv = w.split([E, E * 2])
54
+ if b is None:
55
+ b_q = b_kv = None
56
+ else:
57
+ b_q, b_kv = b.split([E, E * 2])
58
+ return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
59
+ else:
60
+ w_q, w_k, w_v = w.chunk(3)
61
+ if b is None:
62
+ b_q = b_k = b_v = None
63
+ else:
64
+ b_q, b_k, b_v = b.chunk(3)
65
+ return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
66
+
67
+ def _scaled_dot_product_attention(
68
+ q: Tensor,
69
+ k: Tensor,
70
+ v: Tensor,
71
+ attn_mask: Optional[Tensor] = None,
72
+ dropout_p: float = 0.0,
73
+ ) -> Tuple[Tensor, Tensor]:
74
+ r"""
75
+ Computes scaled dot product attention on query, key and value tensors, using
76
+ an optional attention mask if passed, and applying dropout if a probability
77
+ greater than 0.0 is specified.
78
+ Returns a tensor pair containing attended values and attention weights.
79
+
80
+ Args:
81
+ q, k, v: query, key and value tensors. See Shape section for shape details.
82
+ attn_mask: optional tensor containing mask values to be added to calculated
83
+ attention. May be 2D or 3D; see Shape section for details.
84
+ dropout_p: dropout probability. If greater than 0.0, dropout is applied.
85
+
86
+ Shape:
87
+ - q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
88
+ and E is embedding dimension.
89
+ - key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
90
+ and E is embedding dimension.
91
+ - value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
92
+ and E is embedding dimension.
93
+ - attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
94
+ shape :math:`(Nt, Ns)`.
95
+
96
+ - Output: attention values have shape :math:`(B, Nt, E)`; attention weights
97
+ have shape :math:`(B, Nt, Ns)`
98
+ """
99
+ B, Nt, E = q.shape
100
+ q = q / math.sqrt(E)
101
+ # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
102
+ if attn_mask is not None:
103
+ attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
104
+ else:
105
+ attn = torch.bmm(q, k.transpose(-2, -1))
106
+
107
+ attn = F.softmax(attn, dim=-1)
108
+ if dropout_p > 0.0:
109
+ attn = F.dropout(attn, p=dropout_p)
110
+ # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
111
+ output = torch.bmm(attn, v)
112
+ return output, attn
113
+
114
+ def multi_head_attention_forward(
115
+ x,
116
+ ipw,
117
+ ipb,
118
+ opw,
119
+ opb,
120
+ n_head,
121
+ attn_mask,
122
+ past_kv=None,
123
+ use_cache=False,
124
+ ):
125
+ # x = x.transpose(1, 0)
126
+ # tgt_len, bsz, embed_dim = x.shape
127
+ # head_dim = embed_dim // n_head
128
+ # q, k, v = _in_projection_packed(x, x, x, ipw, ipb)
129
+ # q = q.contiguous().view(tgt_len, bsz * n_head, head_dim).transpose(0, 1)
130
+ # k = k.contiguous().view(k.shape[0], bsz * n_head, head_dim).transpose(0, 1)
131
+ # v = v.contiguous().view(v.shape[0], bsz * n_head, head_dim).transpose(0, 1)
132
+
133
+ # new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
134
+ # new_attn_mask.masked_fill_(attn_mask, float("-inf"))
135
+ # attn_mask = new_attn_mask
136
+ #
137
+ # attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, 0.0)
138
+ # attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
139
+ # attn_output = torch._C._nn.linear(attn_output, opw, opb)
140
+ # attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
141
+
142
+ B, T, C = x.size()
143
+
144
+ q, k, v = torch._C._nn.linear(x, ipw, ipb).chunk(3, dim=-1)
145
+ k = k.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
146
+ q = q.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
147
+ v = v.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
148
+ if past_kv is not None:
149
+ past_key = past_kv[0]
150
+ past_value = past_kv[1]
151
+ k = torch.cat((past_key, k), dim=-2)
152
+ v = torch.cat((past_value, v), dim=-2)
153
+
154
+ FULL_T = k.shape[-2]
155
+
156
+ if use_cache is True:
157
+ present = (k, v)
158
+ else:
159
+ present = None
160
+
161
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
162
+ att = att.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf'))
163
+ att = F.softmax(att, dim=-1)
164
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
165
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
166
+ y = torch._C._nn.linear(y, opw, opb)
167
+ return (y, present)
168
+
169
+
170
+ class MultiheadAttention(Module):
171
+ r"""Allows the model to jointly attend to information
172
+ from different representation subspaces as described in the paper:
173
+ `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
174
+
175
+ Multi-Head Attention is defined as:
176
+
177
+ .. math::
178
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
179
+
180
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
181
+
182
+ ``forward()`` will use a special optimized implementation if all of the following
183
+ conditions are met:
184
+
185
+ - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
186
+ restriction will be loosened in the future.)
187
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
188
+ - training is disabled (using ``.eval()``)
189
+ - dropout is 0
190
+ - ``add_bias_kv`` is ``False``
191
+ - ``add_zero_attn`` is ``False``
192
+ - ``batch_first`` is ``True`` and the input is batched
193
+ - ``kdim`` and ``vdim`` are equal to ``embed_dim``
194
+ - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
195
+ - if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
196
+ nor ``attn_mask`` is passed
197
+
198
+ If the optimized implementation is in use, a
199
+ `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
200
+ ``query``/``key``/``value`` to represent padding more efficiently than using a
201
+ padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
202
+ will be returned, and an additional speedup proportional to the fraction of the input
203
+ that is padding can be expected.
204
+
205
+ Args:
206
+ embed_dim: Total dimension of the model.
207
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
208
+ across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
209
+ dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
210
+ bias: If specified, adds bias to input / output projection layers. Default: ``True``.
211
+ add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
212
+ add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
213
+ Default: ``False``.
214
+ kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
215
+ vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
216
+ batch_first: If ``True``, then the input and output tensors are provided
217
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
218
+
219
+ Examples::
220
+
221
+ >>> # xdoctest: +SKIP
222
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
223
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
224
+
225
+ """
226
+ __constants__ = ["batch_first"]
227
+ bias_k: Optional[torch.Tensor]
228
+ bias_v: Optional[torch.Tensor]
229
+
230
+ def __init__(
231
+ self,
232
+ embed_dim,
233
+ num_heads,
234
+ dropout=0.0,
235
+ bias=True,
236
+ add_bias_kv=False,
237
+ add_zero_attn=False,
238
+ kdim=None,
239
+ vdim=None,
240
+ batch_first=False,
241
+ linear1_cls=Linear,
242
+ linear2_cls=Linear,
243
+ device=None,
244
+ dtype=None,
245
+ ) -> None:
246
+ factory_kwargs = {"device": device, "dtype": dtype}
247
+ super(MultiheadAttention, self).__init__()
248
+ self.embed_dim = embed_dim
249
+ self.kdim = kdim if kdim is not None else embed_dim
250
+ self.vdim = vdim if vdim is not None else embed_dim
251
+ self._qkv_same_embed_dim = (
252
+ self.kdim == embed_dim and self.vdim == embed_dim
253
+ )
254
+
255
+ self.num_heads = num_heads
256
+ self.dropout = dropout
257
+ self.batch_first = batch_first
258
+ self.head_dim = embed_dim // num_heads
259
+ assert (
260
+ self.head_dim * num_heads == self.embed_dim
261
+ ), "embed_dim must be divisible by num_heads"
262
+
263
+ if add_bias_kv:
264
+ self.bias_k = Parameter(
265
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
266
+ )
267
+ self.bias_v = Parameter(
268
+ torch.empty((1, 1, embed_dim), **factory_kwargs)
269
+ )
270
+ else:
271
+ self.bias_k = self.bias_v = None
272
+
273
+ if linear1_cls == Linear:
274
+ if not self._qkv_same_embed_dim:
275
+ self.q_proj_weight = Parameter(
276
+ torch.empty((embed_dim, embed_dim), **factory_kwargs)
277
+ )
278
+ self.k_proj_weight = Parameter(
279
+ torch.empty((embed_dim, self.kdim), **factory_kwargs)
280
+ )
281
+ self.v_proj_weight = Parameter(
282
+ torch.empty((embed_dim, self.vdim), **factory_kwargs)
283
+ )
284
+ self.register_parameter("in_proj_weight", None)
285
+ else:
286
+ self.in_proj_weight = Parameter(
287
+ torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
288
+ )
289
+ self.register_parameter("q_proj_weight", None)
290
+ self.register_parameter("k_proj_weight", None)
291
+ self.register_parameter("v_proj_weight", None)
292
+
293
+ if bias:
294
+ self.in_proj_bias = Parameter(
295
+ torch.empty(3 * embed_dim, **factory_kwargs)
296
+ )
297
+ else:
298
+ self.register_parameter("in_proj_bias", None)
299
+ self.out_proj = NonDynamicallyQuantizableLinear(
300
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
301
+ )
302
+
303
+ self._reset_parameters()
304
+ else:
305
+ if not self._qkv_same_embed_dim:
306
+ raise NotImplementedError
307
+ else:
308
+ self.in_proj_linear = linear1_cls(
309
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
310
+ )
311
+ self.in_proj_weight = self.in_proj_linear.weight
312
+
313
+ self.register_parameter("q_proj_weight", None)
314
+ self.register_parameter("k_proj_weight", None)
315
+ self.register_parameter("v_proj_weight", None)
316
+
317
+ if bias:
318
+ self.in_proj_bias = self.in_proj_linear.bias
319
+ else:
320
+ self.register_parameter("in_proj_bias", None)
321
+
322
+ self.out_proj = linear2_cls(
323
+ embed_dim, embed_dim, bias=bias, **factory_kwargs
324
+ )
325
+
326
+ if self.bias_k is not None:
327
+ xavier_normal_(self.bias_k)
328
+ if self.bias_v is not None:
329
+ xavier_normal_(self.bias_v)
330
+
331
+ self.add_zero_attn = add_zero_attn
332
+
333
+ def _reset_parameters(self):
334
+ if self._qkv_same_embed_dim:
335
+ xavier_uniform_(self.in_proj_weight)
336
+ else:
337
+ xavier_uniform_(self.q_proj_weight)
338
+ xavier_uniform_(self.k_proj_weight)
339
+ xavier_uniform_(self.v_proj_weight)
340
+
341
+ if self.in_proj_bias is not None:
342
+ constant_(self.in_proj_bias, 0.0)
343
+ constant_(self.out_proj.bias, 0.0)
344
+
345
+ if self.bias_k is not None:
346
+ xavier_normal_(self.bias_k)
347
+ if self.bias_v is not None:
348
+ xavier_normal_(self.bias_v)
349
+
350
+ def __setstate__(self, state):
351
+ # Support loading old MultiheadAttention checkpoints generated by v1.1.0
352
+ if "_qkv_same_embed_dim" not in state:
353
+ state["_qkv_same_embed_dim"] = True
354
+
355
+ super(MultiheadAttention, self).__setstate__(state)
356
+
357
+ def forward(
358
+ self,
359
+ query: Tensor,
360
+ key: Tensor,
361
+ value: Tensor,
362
+ key_padding_mask: Optional[Tensor] = None,
363
+ need_weights: bool = True,
364
+ attn_mask: Optional[Tensor] = None,
365
+ average_attn_weights: bool = True,
366
+ ) -> Tuple[Tensor, Optional[Tensor]]:
367
+ r"""
368
+ Args:
369
+ query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
370
+ or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
371
+ :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
372
+ Queries are compared against key-value pairs to produce the output.
373
+ See "Attention Is All You Need" for more details.
374
+ key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
375
+ or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
376
+ :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
377
+ See "Attention Is All You Need" for more details.
378
+ value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
379
+ ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
380
+ sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
381
+ See "Attention Is All You Need" for more details.
382
+ key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
383
+ to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
384
+ Binary and byte masks are supported.
385
+ For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
386
+ the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
387
+ need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
388
+ Default: ``True``.
389
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
390
+ :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
391
+ :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
392
+ broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
393
+ Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
394
+ corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
395
+ corresponding position is not allowed to attend. For a float mask, the mask values will be added to
396
+ the attention weight.
397
+ average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
398
+ heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
399
+ effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
400
+
401
+ Outputs:
402
+ - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
403
+ :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
404
+ where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
405
+ embedding dimension ``embed_dim``.
406
+ - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
407
+ returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
408
+ :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
409
+ :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
410
+ head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
411
+
412
+ .. note::
413
+ `batch_first` argument is ignored for unbatched inputs.
414
+ """
415
+ is_batched = query.dim() == 3
416
+ if key_padding_mask is not None:
417
+ _kpm_dtype = key_padding_mask.dtype
418
+ if _kpm_dtype != torch.bool and not torch.is_floating_point(
419
+ key_padding_mask
420
+ ):
421
+ raise AssertionError(
422
+ "only bool and floating types of key_padding_mask are supported"
423
+ )
424
+ why_not_fast_path = ""
425
+ if not is_batched:
426
+ why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
427
+ elif query is not key or key is not value:
428
+ # When lifting this restriction, don't forget to either
429
+ # enforce that the dtypes all match or test cases where
430
+ # they don't!
431
+ why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
432
+ elif (
433
+ self.in_proj_bias is not None
434
+ and query.dtype != self.in_proj_bias.dtype
435
+ ):
436
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
437
+ elif (
438
+ self.in_proj_weight is not None
439
+ and query.dtype != self.in_proj_weight.dtype
440
+ ):
441
+ # this case will fail anyway, but at least they'll get a useful error message.
442
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
443
+ elif self.training:
444
+ why_not_fast_path = "training is enabled"
445
+ elif not self.batch_first:
446
+ why_not_fast_path = "batch_first was not True"
447
+ elif self.bias_k is not None:
448
+ why_not_fast_path = "self.bias_k was not None"
449
+ elif self.bias_v is not None:
450
+ why_not_fast_path = "self.bias_v was not None"
451
+ elif self.dropout:
452
+ why_not_fast_path = f"dropout was {self.dropout}, required zero"
453
+ elif self.add_zero_attn:
454
+ why_not_fast_path = "add_zero_attn was enabled"
455
+ elif not self._qkv_same_embed_dim:
456
+ why_not_fast_path = "_qkv_same_embed_dim was not True"
457
+ elif attn_mask is not None:
458
+ why_not_fast_path = "attn_mask was not None"
459
+ elif query.is_nested and key_padding_mask is not None:
460
+ why_not_fast_path = (
461
+ "key_padding_mask is not supported with NestedTensor input"
462
+ )
463
+ elif self.num_heads % 2 == 1:
464
+ why_not_fast_path = "num_heads is odd"
465
+ elif torch.is_autocast_enabled():
466
+ why_not_fast_path = "autocast is enabled"
467
+
468
+ if not why_not_fast_path:
469
+ tensor_args = (
470
+ query,
471
+ key,
472
+ value,
473
+ self.in_proj_weight,
474
+ self.in_proj_bias,
475
+ self.out_proj.weight,
476
+ self.out_proj.bias,
477
+ )
478
+ # We have to use list comprehensions below because TorchScript does not support
479
+ # generator expressions.
480
+ if torch.overrides.has_torch_function(tensor_args):
481
+ why_not_fast_path = "some Tensor argument has_torch_function"
482
+ elif not all(
483
+ [
484
+ (x is None or x.is_cuda or "cpu" in str(x.device))
485
+ for x in tensor_args
486
+ ]
487
+ ):
488
+ why_not_fast_path = (
489
+ "some Tensor argument is neither CUDA nor CPU"
490
+ )
491
+ elif torch.is_grad_enabled() and any(
492
+ [x is not None and x.requires_grad for x in tensor_args]
493
+ ):
494
+ why_not_fast_path = (
495
+ "grad is enabled and at least one of query or the "
496
+ "input/output projection weights or biases requires_grad"
497
+ )
498
+ if not why_not_fast_path:
499
+ return torch._native_multi_head_attention(
500
+ query,
501
+ key,
502
+ value,
503
+ self.embed_dim,
504
+ self.num_heads,
505
+ self.in_proj_weight,
506
+ self.in_proj_bias,
507
+ self.out_proj.weight,
508
+ self.out_proj.bias,
509
+ key_padding_mask
510
+ if key_padding_mask is not None
511
+ else attn_mask,
512
+ need_weights,
513
+ average_attn_weights,
514
+ 1
515
+ if key_padding_mask is not None
516
+ else 0
517
+ if attn_mask is not None
518
+ else None,
519
+ )
520
+
521
+ any_nested = query.is_nested or key.is_nested or value.is_nested
522
+ assert not any_nested, (
523
+ "MultiheadAttention does not support NestedTensor outside of its fast path. "
524
+ + f"The fast path was not hit because {why_not_fast_path}"
525
+ )
526
+
527
+ if self.batch_first and is_batched:
528
+ # make sure that the transpose op does not affect the "is" property
529
+ if key is value:
530
+ if query is key:
531
+ query = key = value = query.transpose(1, 0)
532
+ else:
533
+ query, key = [x.transpose(1, 0) for x in (query, key)]
534
+ value = key
535
+ else:
536
+ query, key, value = [
537
+ x.transpose(1, 0) for x in (query, key, value)
538
+ ]
539
+
540
+ if not self._qkv_same_embed_dim:
541
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
542
+ query,
543
+ key,
544
+ value,
545
+ self.embed_dim,
546
+ self.num_heads,
547
+ self.in_proj_weight,
548
+ self.in_proj_bias,
549
+ self.bias_k,
550
+ self.bias_v,
551
+ self.add_zero_attn,
552
+ self.dropout,
553
+ self.out_proj.weight,
554
+ self.out_proj.bias,
555
+ training=self.training,
556
+ key_padding_mask=key_padding_mask,
557
+ need_weights=need_weights,
558
+ attn_mask=attn_mask,
559
+ use_separate_proj_weight=True,
560
+ q_proj_weight=self.q_proj_weight,
561
+ k_proj_weight=self.k_proj_weight,
562
+ v_proj_weight=self.v_proj_weight,
563
+ average_attn_weights=average_attn_weights,
564
+ )
565
+ else:
566
+ attn_output, attn_output_weights = F.multi_head_attention_forward(
567
+ query,
568
+ key,
569
+ value,
570
+ self.embed_dim,
571
+ self.num_heads,
572
+ self.in_proj_weight,
573
+ self.in_proj_bias,
574
+ self.bias_k,
575
+ self.bias_v,
576
+ self.add_zero_attn,
577
+ self.dropout,
578
+ self.out_proj.weight,
579
+ self.out_proj.bias,
580
+ training=self.training,
581
+ key_padding_mask=key_padding_mask,
582
+ need_weights=need_weights,
583
+ attn_mask=attn_mask,
584
+ average_attn_weights=average_attn_weights,
585
+ )
586
+ if self.batch_first and is_batched:
587
+ return attn_output.transpose(1, 0), attn_output_weights
588
+ else:
589
+ return attn_output, attn_output_weights
590
+
591
+ def infer(self,
592
+ x: Tensor,
593
+ key_padding_mask: Optional[Tensor] = None,
594
+ need_weights: bool = True,
595
+ attn_mask: Optional[Tensor] = None,
596
+ average_attn_weights: bool = True,
597
+ past_kv = None,
598
+ use_cache = False
599
+ ):
600
+ # x = x.transpose(1, 0)
601
+ y, kv = multi_head_attention_forward(
602
+ x=x,
603
+ ipw=self.in_proj_weight,
604
+ ipb=self.in_proj_bias,
605
+ opw=self.out_proj.weight,
606
+ opb=self.out_proj.bias,
607
+ n_head=self.num_heads,
608
+ attn_mask=attn_mask,
609
+ past_kv=past_kv,
610
+ use_cache=use_cache,
611
+ )
612
+ return (y, kv)
apps/audio_cloning/vallex/modules/embedding.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 (authors: Feiteng Li)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+
21
+ class TokenEmbedding(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dim_model: int,
25
+ vocab_size: int,
26
+ dropout: float = 0.0,
27
+ ):
28
+ super().__init__()
29
+
30
+ self.vocab_size = vocab_size
31
+ self.dim_model = dim_model
32
+
33
+ self.dropout = torch.nn.Dropout(p=dropout)
34
+ self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
35
+
36
+ @property
37
+ def weight(self) -> torch.Tensor:
38
+ return self.word_embeddings.weight
39
+
40
+ def embedding(self, index: int) -> torch.Tensor:
41
+ return self.word_embeddings.weight[index : index + 1]
42
+
43
+ def forward(self, x: torch.Tensor):
44
+ X = self.word_embeddings(x)
45
+ X = self.dropout(X)
46
+
47
+ return X
48
+
49
+
50
+ class SinePositionalEmbedding(nn.Module):
51
+ def __init__(
52
+ self,
53
+ dim_model: int,
54
+ dropout: float = 0.0,
55
+ scale: bool = False,
56
+ alpha: bool = False,
57
+ ):
58
+ super().__init__()
59
+ self.dim_model = dim_model
60
+ self.x_scale = math.sqrt(dim_model) if scale else 1.0
61
+ self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
62
+ self.dropout = torch.nn.Dropout(p=dropout)
63
+
64
+ self.reverse = False
65
+ self.pe = None
66
+ self.extend_pe(torch.tensor(0.0).expand(1, 4000))
67
+
68
+ def extend_pe(self, x):
69
+ """Reset the positional encodings."""
70
+ if self.pe is not None:
71
+ if self.pe.size(1) >= x.size(1):
72
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
73
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
74
+ return
75
+ pe = torch.zeros(x.size(1), self.dim_model)
76
+ if self.reverse:
77
+ position = torch.arange(
78
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
79
+ ).unsqueeze(1)
80
+ else:
81
+ position = torch.arange(
82
+ 0, x.size(1), dtype=torch.float32
83
+ ).unsqueeze(1)
84
+ div_term = torch.exp(
85
+ torch.arange(0, self.dim_model, 2, dtype=torch.float32)
86
+ * -(math.log(10000.0) / self.dim_model)
87
+ )
88
+ pe[:, 0::2] = torch.sin(position * div_term)
89
+ pe[:, 1::2] = torch.cos(position * div_term)
90
+ pe = pe.unsqueeze(0)
91
+ self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ self.extend_pe(x)
95
+ output = x.unsqueeze(-1) if x.ndim == 2 else x
96
+ output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
97
+ return self.dropout(output)
apps/audio_cloning/vallex/modules/optim.py ADDED
@@ -0,0 +1,1105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import contextlib
18
+ import logging
19
+ import random
20
+ from collections import defaultdict
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ from lhotse.utils import fix_random_seed
25
+ from torch import Tensor
26
+ from torch.optim import Optimizer
27
+
28
+
29
+ class BatchedOptimizer(Optimizer):
30
+ """
31
+ This class adds to class Optimizer the capability to optimize parameters in batches:
32
+ it will stack the parameters and their grads for you so the optimizer can work
33
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
34
+ as it reduces the number of kernels launched in the optimizer.
35
+
36
+ Args:
37
+ params:
38
+ """
39
+
40
+ def __init__(self, params, defaults):
41
+ super(BatchedOptimizer, self).__init__(params, defaults)
42
+
43
+ @contextlib.contextmanager
44
+ def batched_params(self, param_group, group_params_names):
45
+ """
46
+ This function returns (technically, yields) a list of
47
+ of tuples (p, state), where
48
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
49
+ that share the same shape, and its gradient is also stacked;
50
+ `state` is the state corresponding to this batch of parameters
51
+ (it will be physically located in the "state" for one of the real
52
+ parameters, the last one that has any particular shape and dtype).
53
+
54
+ This function is decorated as a context manager so that it can
55
+ write parameters back to their "real" locations.
56
+
57
+ The idea is, instead of doing:
58
+ <code>
59
+ for p in group["params"]:
60
+ state = self.state[p]
61
+ ...
62
+ </code>
63
+ you can do:
64
+ <code>
65
+ with self.batched_params(group["params"]) as batches:
66
+ for p, state, p_names in batches:
67
+ ...
68
+ </code>
69
+
70
+ Args:
71
+ group: a parameter group, which is a list of parameters; should be
72
+ one of self.param_groups.
73
+ group_params_names: name for each parameter in group,
74
+ which is List[str].
75
+ """
76
+ batches = defaultdict(
77
+ list
78
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
79
+ batches_names = defaultdict(
80
+ list
81
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
82
+
83
+ assert len(param_group) == len(group_params_names)
84
+ for p, named_p in zip(param_group, group_params_names):
85
+ key = (str(p.dtype), *p.shape)
86
+ batches[key].append(p)
87
+ batches_names[key].append(named_p)
88
+
89
+ batches_names_keys = list(batches_names.keys())
90
+ sorted_idx = sorted(
91
+ range(len(batches_names)), key=lambda i: batches_names_keys[i]
92
+ )
93
+ batches_names = [
94
+ batches_names[batches_names_keys[idx]] for idx in sorted_idx
95
+ ]
96
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
97
+
98
+ stacked_params_dict = dict()
99
+
100
+ # turn batches into a list, in deterministic order.
101
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
102
+ # one for each batch in `batches`.
103
+ tuples = []
104
+
105
+ for batch, batch_names in zip(batches, batches_names):
106
+ p = batch[0]
107
+ # we arbitrarily store the state in the
108
+ # state corresponding to the 1st parameter in the
109
+ # group. class Optimizer will take care of saving/loading state.
110
+ state = self.state[p]
111
+ p_stacked = torch.stack(batch)
112
+ grad = torch.stack(
113
+ [
114
+ torch.zeros_like(p) if p.grad is None else p.grad
115
+ for p in batch
116
+ ]
117
+ )
118
+ p_stacked.grad = grad
119
+ stacked_params_dict[key] = p_stacked
120
+ tuples.append((p_stacked, state, batch_names))
121
+
122
+ yield tuples # <-- calling code will do the actual optimization here!
123
+
124
+ for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
125
+ for i, p in enumerate(batch): # batch is list of Parameter
126
+ p.copy_(stacked_params[i])
127
+
128
+
129
+ class ScaledAdam(BatchedOptimizer):
130
+ """
131
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
132
+ proportional to the norm of that parameter; and also learn the scale of the parameter,
133
+ in log space, subject to upper and lower limits (as if we had factored each parameter as
134
+ param = underlying_param * log_scale.exp())
135
+
136
+
137
+ Args:
138
+ params: The parameters or param_groups to optimize (like other Optimizer subclasses)
139
+ lr: The learning rate. We will typically use a learning rate schedule that starts
140
+ at 0.03 and decreases over time, i.e. much higher than other common
141
+ optimizers.
142
+ clipping_scale: (e.g. 2.0)
143
+ A scale for gradient-clipping: if specified, the normalized gradients
144
+ over the whole model will be clipped to have 2-norm equal to
145
+ `clipping_scale` times the median 2-norm over the most recent period
146
+ of `clipping_update_period` minibatches. By "normalized gradients",
147
+ we mean after multiplying by the rms parameter value for this tensor
148
+ [for non-scalars]; this is appropriate because our update is scaled
149
+ by this quantity.
150
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
151
+ Must satisfy 0 < beta <= beta2 < 1.
152
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
153
+ scale of each parameter tensor and scalar parameters of the mode..
154
+ If each parameter were decomposed
155
+ as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
156
+ would be a the scaling factor on the learning rate of p_scale.
157
+ eps: A general-purpose epsilon to prevent division by zero
158
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
159
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
160
+ parameter tensor to be >= this value)
161
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
162
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
163
+ parameter tensor to be <= this value)
164
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
165
+ model has any parameters with numel() == 1).
166
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
167
+ of the parameter tensor. This is provided to save a little time
168
+ in the update.
169
+ clipping_update_period: if clipping_scale is specified, this is the period
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ params,
175
+ lr=3e-02,
176
+ clipping_scale=None,
177
+ betas=(0.9, 0.98),
178
+ scalar_lr_scale=0.1,
179
+ eps=1.0e-08,
180
+ param_min_rms=1.0e-05,
181
+ param_max_rms=3.0,
182
+ scalar_max=10.0,
183
+ size_update_period=4,
184
+ clipping_update_period=100,
185
+ parameters_names=None,
186
+ show_dominant_parameters=True,
187
+ ):
188
+
189
+ assert parameters_names is not None, (
190
+ "Please prepare parameters_names,"
191
+ "which is a List[List[str]]. Each List[str] is for a group"
192
+ "and each str is for a parameter"
193
+ )
194
+ defaults = dict(
195
+ lr=lr,
196
+ clipping_scale=clipping_scale,
197
+ betas=betas,
198
+ scalar_lr_scale=scalar_lr_scale,
199
+ eps=eps,
200
+ param_min_rms=param_min_rms,
201
+ param_max_rms=param_max_rms,
202
+ scalar_max=scalar_max,
203
+ size_update_period=size_update_period,
204
+ clipping_update_period=clipping_update_period,
205
+ )
206
+
207
+ super(ScaledAdam, self).__init__(params, defaults)
208
+ assert len(self.param_groups) == len(parameters_names)
209
+ self.parameters_names = parameters_names
210
+ self.show_dominant_parameters = show_dominant_parameters
211
+
212
+ def __setstate__(self, state):
213
+ super(ScaledAdam, self).__setstate__(state)
214
+
215
+ @torch.no_grad()
216
+ def step(self, closure=None):
217
+ """Performs a single optimization step.
218
+
219
+ Arguments:
220
+ closure (callable, optional): A closure that reevaluates the model
221
+ and returns the loss.
222
+ """
223
+ loss = None
224
+ if closure is not None:
225
+ with torch.enable_grad():
226
+ loss = closure()
227
+
228
+ batch = True
229
+
230
+ for group, group_params_names in zip(
231
+ self.param_groups, self.parameters_names
232
+ ):
233
+
234
+ with self.batched_params(
235
+ group["params"], group_params_names
236
+ ) as batches:
237
+
238
+ # batches is list of pairs (stacked_param, state). stacked_param is like
239
+ # a regular parameter, and will have a .grad, but the 1st dim corresponds to
240
+ # a stacking dim, it is not a real dim.
241
+
242
+ if (
243
+ len(batches[0][1]) == 0
244
+ ): # if len(first state) == 0: not yet initialized
245
+ clipping_scale = 1
246
+ else:
247
+ clipping_scale = self._get_clipping_scale(group, batches)
248
+
249
+ for p, state, _ in batches:
250
+ # Perform optimization step.
251
+ # grad is not going to be None, we handled that when creating the batches.
252
+ grad = p.grad
253
+ if grad.is_sparse:
254
+ raise RuntimeError(
255
+ "ScaledAdam optimizer does not support sparse gradients"
256
+ )
257
+ # State initialization
258
+ if len(state) == 0:
259
+ self._init_state(group, p, state)
260
+
261
+ self._step_one_batch(group, p, state, clipping_scale)
262
+
263
+ return loss
264
+
265
+ def _init_state(self, group: dict, p: Tensor, state: dict):
266
+ """
267
+ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
268
+ is actually the batch dimension, corresponding to batched-together
269
+ parameters of a given shape.
270
+
271
+
272
+ Args:
273
+ group: Dict to look up configuration values.
274
+ p: The parameter that we are initializing the state for
275
+ state: Dict from string to whatever state we are initializing
276
+ """
277
+ size_update_period = group["size_update_period"]
278
+
279
+ state["step"] = 0
280
+
281
+ kwargs = {"device": p.device, "dtype": p.dtype}
282
+
283
+ # 'delta' implements conventional momentum. There are
284
+ # several different kinds of update going on, so rather than
285
+ # compute "exp_avg" like in Adam, we store and decay a
286
+ # parameter-change "delta", which combines all forms of
287
+ # update. this is equivalent to how it's done in Adam,
288
+ # except for the first few steps.
289
+ state["delta"] = torch.zeros_like(
290
+ p, memory_format=torch.preserve_format
291
+ )
292
+
293
+ batch_size = p.shape[0]
294
+ numel = p.numel() // batch_size
295
+ numel = p.numel()
296
+
297
+ if numel > 1:
298
+ # "param_rms" just periodically records the scalar root-mean-square value of
299
+ # the parameter tensor.
300
+ # it has a shape like (batch_size, 1, 1, 1, 1)
301
+ param_rms = (
302
+ (p ** 2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
303
+ )
304
+ state["param_rms"] = param_rms
305
+
306
+ state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
307
+ state["scale_grads"] = torch.zeros(
308
+ size_update_period, *param_rms.shape, **kwargs
309
+ )
310
+
311
+ # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
312
+ state["exp_avg_sq"] = torch.zeros_like(
313
+ p, memory_format=torch.preserve_format
314
+ )
315
+
316
+ def _get_clipping_scale(
317
+ self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
318
+ ) -> float:
319
+ """
320
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
321
+ by this amount before applying the rest of the update.
322
+
323
+ Args:
324
+ group: the parameter group, an item in self.param_groups
325
+ tuples: a list of tuples of (param, state, param_names)
326
+ where param is a batched set of parameters,
327
+ with a .grad (1st dim is batch dim)
328
+ and state is the state-dict where optimization parameters are kept.
329
+ param_names is a List[str] while each str is name for a parameter
330
+ in batched set of parameters "param".
331
+ """
332
+ assert len(tuples) >= 1
333
+ clipping_scale = group["clipping_scale"]
334
+ (first_p, first_state, _) = tuples[0]
335
+ step = first_state["step"]
336
+ if clipping_scale is None or step == 0:
337
+ # no clipping. return early on step == 0 because the other
338
+ # parameters' state won't have been initialized yet.
339
+ return 1.0
340
+ clipping_update_period = group["clipping_update_period"]
341
+
342
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
343
+ for (p, state, param_names) in tuples:
344
+ grad = p.grad
345
+ if grad.is_sparse:
346
+ raise RuntimeError(
347
+ "ScaledAdam optimizer does not support sparse gradients"
348
+ )
349
+ if p.numel() == p.shape[0]: # a batch of scalars
350
+ tot_sumsq += (
351
+ grad ** 2
352
+ ).sum() # sum() to change shape [1] to []
353
+ else:
354
+ tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
355
+
356
+ tot_norm = tot_sumsq.sqrt()
357
+ if "model_norms" not in first_state:
358
+ first_state["model_norms"] = torch.zeros(
359
+ clipping_update_period, device=p.device
360
+ )
361
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
362
+
363
+ if step % clipping_update_period == 0:
364
+ # Print some stats.
365
+ # We don't reach here if step == 0 because we would have returned
366
+ # above.
367
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
368
+ quartiles = []
369
+ for n in range(0, 5):
370
+ index = min(
371
+ clipping_update_period - 1,
372
+ (clipping_update_period // 4) * n,
373
+ )
374
+ quartiles.append(sorted_norms[index].item())
375
+
376
+ median = quartiles[2]
377
+ threshold = clipping_scale * median
378
+ first_state["model_norm_threshold"] = threshold
379
+ percent_clipped = (
380
+ first_state["num_clipped"] * 100.0 / clipping_update_period
381
+ if "num_clipped" in first_state
382
+ else 0.0
383
+ )
384
+ first_state["num_clipped"] = 0
385
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
386
+ logging.info(
387
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
388
+ f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
389
+ )
390
+
391
+ if step < clipping_update_period:
392
+ return 1.0 # We have not yet estimated a norm to clip to.
393
+ else:
394
+ try:
395
+ model_norm_threshold = first_state["model_norm_threshold"]
396
+ except KeyError:
397
+ logging.info(
398
+ "Warning: model_norm_threshold not in state: possibly "
399
+ "you changed config when restarting, adding clipping_scale option?"
400
+ )
401
+ return 1.0
402
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
403
+ if ans < 1.0:
404
+ first_state["num_clipped"] += 1
405
+ if ans < 0.1:
406
+ logging.warn(
407
+ f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
408
+ )
409
+ if self.show_dominant_parameters:
410
+ assert p.shape[0] == len(param_names)
411
+ self._show_gradient_dominating_parameter(tuples, tot_sumsq)
412
+ return ans
413
+
414
+ def _show_gradient_dominating_parameter(
415
+ self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
416
+ ):
417
+ """
418
+ Show information of parameter wihch dominanting tot_sumsq.
419
+
420
+ Args:
421
+ tuples: a list of tuples of (param, state, param_names)
422
+ where param is a batched set of parameters,
423
+ with a .grad (1st dim is batch dim)
424
+ and state is the state-dict where optimization parameters are kept.
425
+ param_names is a List[str] while each str is name for a parameter
426
+ in batched set of parameters "param".
427
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
428
+ from tuples, we still pass it to save some time.
429
+ """
430
+ all_sumsq_orig = {}
431
+ for (p, state, batch_param_names) in tuples:
432
+ # p is a stacked batch parameters.
433
+ batch_grad = p.grad
434
+ if p.numel() == p.shape[0]: # a batch of scalars
435
+ batch_sumsq_orig = batch_grad ** 2
436
+ # Dummpy values used by following `zip` statement.
437
+ batch_rms_orig = torch.ones(p.shape[0])
438
+ else:
439
+ batch_rms_orig = state["param_rms"]
440
+ batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
441
+ dim=list(range(1, batch_grad.ndim))
442
+ )
443
+
444
+ for name, sumsq_orig, rms, grad in zip(
445
+ batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
446
+ ):
447
+
448
+ proportion_orig = sumsq_orig / tot_sumsq
449
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
450
+
451
+ assert torch.isclose(
452
+ sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
453
+ torch.tensor(1.0),
454
+ )
455
+ sorted_by_proportion = {
456
+ k: v
457
+ for k, v in sorted(
458
+ all_sumsq_orig.items(),
459
+ key=lambda item: item[1][0],
460
+ reverse=True,
461
+ )
462
+ }
463
+ dominant_param_name = next(iter(sorted_by_proportion))
464
+ (
465
+ dominant_proportion,
466
+ dominant_sumsq,
467
+ dominant_rms,
468
+ dominant_grad,
469
+ ) = sorted_by_proportion[dominant_param_name]
470
+ logging.info(
471
+ f"Parameter Dominanting tot_sumsq {dominant_param_name}"
472
+ f" with proportion {dominant_proportion:.2f},"
473
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
474
+ f"={dominant_sumsq:.3e},"
475
+ f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
476
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}"
477
+ )
478
+
479
+ def _step_one_batch(
480
+ self, group: dict, p: Tensor, state: dict, clipping_scale: float
481
+ ):
482
+ """
483
+ Do the step for one parameter, which is actually going to be a batch of
484
+ `real` parameters, with dim 0 as the batch dim.
485
+ Args:
486
+ group: dict to look up configuration values
487
+ p: parameter to update (actually multiple parameters stacked together
488
+ as a batch)
489
+ state: state-dict for p, to look up the optimizer state
490
+ """
491
+ lr = group["lr"]
492
+ size_update_period = group["size_update_period"]
493
+ beta1 = group["betas"][0]
494
+
495
+ grad = p.grad
496
+ if clipping_scale != 1.0:
497
+ grad = grad * clipping_scale
498
+ step = state["step"]
499
+ delta = state["delta"]
500
+
501
+ delta.mul_(beta1)
502
+ batch_size = p.shape[0]
503
+ numel = p.numel() // batch_size
504
+ if numel > 1:
505
+ # Update the size/scale of p, and set param_rms
506
+ scale_grads = state["scale_grads"]
507
+ scale_grads[step % size_update_period] = (p * grad).sum(
508
+ dim=list(range(1, p.ndim)), keepdim=True
509
+ )
510
+ if step % size_update_period == size_update_period - 1:
511
+ param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
512
+ param_rms.copy_(
513
+ (p ** 2)
514
+ .mean(dim=list(range(1, p.ndim)), keepdim=True)
515
+ .sqrt()
516
+ )
517
+ if step > 0:
518
+ # self._size_update() learns the overall scale on the
519
+ # parameter, by shrinking or expanding it.
520
+ self._size_update(group, scale_grads, p, state)
521
+
522
+ if numel == 1:
523
+ # For parameters with 1 element we just use regular Adam.
524
+ # Updates delta.
525
+ self._step_scalar(group, p, state)
526
+ else:
527
+ self._step(group, p, state)
528
+
529
+ state["step"] = step + 1
530
+
531
+ def _size_update(
532
+ self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
533
+ ) -> None:
534
+ """
535
+ Called only where p.numel() > 1, this updates the scale of the parameter.
536
+ If we imagine: p = underlying_param * scale.exp(), and we are doing
537
+ gradient descent on underlying param and on scale, this function does the update
538
+ on `scale`.
539
+
540
+ Args:
541
+ group: dict to look up configuration values
542
+ scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
543
+ grads w.r.t. the scales.
544
+ p: The parameter to update
545
+ state: The state-dict of p
546
+ """
547
+
548
+ param_rms = state["param_rms"]
549
+ beta1, beta2 = group["betas"]
550
+ size_lr = group["lr"] * group["scalar_lr_scale"]
551
+ param_min_rms = group["param_min_rms"]
552
+ param_max_rms = group["param_max_rms"]
553
+ eps = group["eps"]
554
+ step = state["step"]
555
+ batch_size = p.shape[0]
556
+
557
+ size_update_period = scale_grads.shape[0]
558
+ # correct beta2 for the size update period: we will have
559
+ # faster decay at this level.
560
+ beta2_corr = beta2 ** size_update_period
561
+
562
+ scale_exp_avg_sq = state[
563
+ "scale_exp_avg_sq"
564
+ ] # shape: (batch_size, 1, 1, ..)
565
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
566
+ (scale_grads ** 2).mean(
567
+ dim=0
568
+ ), # mean over dim `size_update_period`
569
+ alpha=1 - beta2_corr,
570
+ ) # shape is (batch_size, 1, 1, ...)
571
+
572
+ # The 1st time we reach here is when size_step == 1.
573
+ size_step = (step + 1) // size_update_period
574
+ bias_correction2 = 1 - beta2_corr ** size_step
575
+ # we don't bother with bias_correction1; this will help prevent divergence
576
+ # at the start of training.
577
+
578
+ denom = scale_exp_avg_sq.sqrt() + eps
579
+
580
+ scale_step = (
581
+ -size_lr
582
+ * (bias_correction2 ** 0.5)
583
+ * scale_grads.sum(dim=0)
584
+ / denom
585
+ )
586
+
587
+ is_too_small = param_rms < param_min_rms
588
+ is_too_large = param_rms > param_max_rms
589
+
590
+ # when the param gets too small, just don't shrink it any further.
591
+ scale_step.masked_fill_(is_too_small, 0.0)
592
+ # when it gets too large, stop it from getting any larger.
593
+ scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
594
+ delta = state["delta"]
595
+ # the factor of (1-beta1) relates to momentum.
596
+ delta.add_(p * scale_step, alpha=(1 - beta1))
597
+
598
+ def _step(self, group: dict, p: Tensor, state: dict):
599
+ """
600
+ This function does the core update of self.step(), in the case where the members of
601
+ the batch have more than 1 element.
602
+
603
+ Args:
604
+ group: A dict which will be used to look up configuration values
605
+ p: The parameter to be updated
606
+ grad: The grad of p
607
+ state: The state-dict corresponding to parameter p
608
+
609
+ This function modifies p.
610
+ """
611
+ grad = p.grad
612
+ lr = group["lr"]
613
+ beta1, beta2 = group["betas"]
614
+ eps = group["eps"]
615
+ param_min_rms = group["param_min_rms"]
616
+ step = state["step"]
617
+
618
+ exp_avg_sq = state["exp_avg_sq"]
619
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
620
+
621
+ this_step = state["step"] - (
622
+ state["zero_step"] if "zero_step" in state else 0
623
+ )
624
+ bias_correction2 = 1 - beta2 ** (this_step + 1)
625
+ if bias_correction2 < 0.99:
626
+ # note: not in-place.
627
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
628
+
629
+ denom = exp_avg_sq.sqrt()
630
+ denom += eps
631
+ grad = grad / denom
632
+
633
+ alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
634
+
635
+ delta = state["delta"]
636
+ delta.add_(grad * alpha)
637
+ p.add_(delta)
638
+
639
+ def _step_scalar(self, group: dict, p: Tensor, state: dict):
640
+ """
641
+ A simplified form of the core update for scalar tensors, where we cannot get a good
642
+ estimate of the parameter rms.
643
+ """
644
+ beta1, beta2 = group["betas"]
645
+ scalar_max = group["scalar_max"]
646
+ eps = group["eps"]
647
+ lr = group["lr"] * group["scalar_lr_scale"]
648
+ grad = p.grad
649
+
650
+ exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
651
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
652
+
653
+ # bias_correction2 is like in Adam. Don't bother with bias_correction1;
654
+ # slower update at the start will help stability anyway.
655
+ bias_correction2 = 1 - beta2 ** (state["step"] + 1)
656
+ denom = (exp_avg_sq / bias_correction2).sqrt() + eps
657
+
658
+ delta = state["delta"]
659
+ delta.add_(grad / denom, alpha=-lr * (1 - beta1))
660
+ p.clamp_(min=-scalar_max, max=scalar_max)
661
+ p.add_(delta)
662
+
663
+
664
+ class LRScheduler(object):
665
+ """
666
+ Base-class for learning rate schedulers where the learning-rate depends on both the
667
+ batch and the epoch.
668
+ """
669
+
670
+ def __init__(self, optimizer: Optimizer, verbose: bool = False):
671
+ # Attach optimizer
672
+ if not isinstance(optimizer, Optimizer):
673
+ raise TypeError(
674
+ "{} is not an Optimizer".format(type(optimizer).__name__)
675
+ )
676
+ self.optimizer = optimizer
677
+ self.verbose = verbose
678
+
679
+ for group in optimizer.param_groups:
680
+ group.setdefault("base_lr", group["lr"])
681
+
682
+ self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
683
+
684
+ self.epoch = 0
685
+ self.batch = 0
686
+
687
+ def state_dict(self):
688
+ """Returns the state of the scheduler as a :class:`dict`.
689
+
690
+ It contains an entry for every variable in self.__dict__ which
691
+ is not the optimizer.
692
+ """
693
+ return {
694
+ "base_lrs": self.base_lrs,
695
+ "epoch": self.epoch,
696
+ "batch": self.batch,
697
+ }
698
+
699
+ def load_state_dict(self, state_dict):
700
+ """Loads the schedulers state.
701
+
702
+ Args:
703
+ state_dict (dict): scheduler state. Should be an object returned
704
+ from a call to :meth:`state_dict`.
705
+ """
706
+ self.__dict__.update(state_dict)
707
+
708
+ def get_last_lr(self) -> List[float]:
709
+ """Return last computed learning rate by current scheduler. Will be a list of float."""
710
+ return self._last_lr
711
+
712
+ def get_lr(self):
713
+ # Compute list of learning rates from self.epoch and self.batch and
714
+ # self.base_lrs; this must be overloaded by the user.
715
+ # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
716
+ raise NotImplementedError
717
+
718
+ def step_batch(self, batch: Optional[int] = None) -> None:
719
+ # Step the batch index, or just set it. If `batch` is specified, it
720
+ # must be the batch index from the start of training, i.e. summed over
721
+ # all epochs.
722
+ # You can call this in any order; if you don't provide 'batch', it should
723
+ # of course be called once per batch.
724
+ if batch is not None:
725
+ self.batch = batch
726
+ else:
727
+ self.batch = self.batch + 1
728
+ self._set_lrs()
729
+
730
+ def step_epoch(self, epoch: Optional[int] = None):
731
+ # Step the epoch index, or just set it. If you provide the 'epoch' arg,
732
+ # you should call this at the start of the epoch; if you don't provide the 'epoch'
733
+ # arg, you should call it at the end of the epoch.
734
+ if epoch is not None:
735
+ self.epoch = epoch
736
+ else:
737
+ self.epoch = self.epoch + 1
738
+ self._set_lrs()
739
+
740
+ def _set_lrs(self):
741
+ values = self.get_lr()
742
+ assert len(values) == len(self.optimizer.param_groups)
743
+
744
+ for i, data in enumerate(zip(self.optimizer.param_groups, values)):
745
+ param_group, lr = data
746
+ param_group["lr"] = lr
747
+ self.print_lr(self.verbose, i, lr)
748
+ self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
749
+
750
+ def print_lr(self, is_verbose, group, lr):
751
+ """Display the current learning rate."""
752
+ if is_verbose:
753
+ logging.info(
754
+ f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
755
+ f" of group {group} to {lr:.4e}."
756
+ )
757
+
758
+
759
+ class Eden(LRScheduler):
760
+ """
761
+ Eden scheduler.
762
+ The basic formula (before warmup) is:
763
+ lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
764
+ (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
765
+ where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
766
+ and then stays constant at 1.
767
+
768
+
769
+ E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
770
+
771
+ Args:
772
+ optimizer: the optimizer to change the learning rates on
773
+ lr_batches: the number of batches after which we start significantly
774
+ decreasing the learning rate, suggest 5000.
775
+ lr_epochs: the number of epochs after which we start significantly
776
+ decreasing the learning rate, suggest 6 if you plan to do e.g.
777
+ 20 to 40 epochs, but may need smaller number if dataset is huge
778
+ and you will do few epochs.
779
+ """
780
+
781
+ def __init__(
782
+ self,
783
+ optimizer: Optimizer,
784
+ lr_batches: Union[int, float],
785
+ lr_epochs: Union[int, float],
786
+ warmup_batches: Union[int, float] = 500.0,
787
+ verbose: bool = False,
788
+ ):
789
+ super(Eden, self).__init__(optimizer, verbose)
790
+ self.lr_batches = lr_batches
791
+ self.lr_epochs = lr_epochs
792
+ self.warmup_batches = warmup_batches
793
+
794
+ def get_lr(self):
795
+ factor = (
796
+ (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
797
+ ) ** -0.25 * (
798
+ ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
799
+ ** -0.25
800
+ )
801
+ warmup_factor = (
802
+ 1.0
803
+ if self.batch >= self.warmup_batches
804
+ else 0.5 + 0.5 * (self.batch / self.warmup_batches)
805
+ )
806
+
807
+ return [x * factor * warmup_factor for x in self.base_lrs]
808
+
809
+
810
+ def _test_eden():
811
+ m = torch.nn.Linear(100, 100)
812
+ optim = ScaledAdam(m.parameters(), lr=0.03)
813
+
814
+ scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
815
+
816
+ for epoch in range(10):
817
+ scheduler.step_epoch(epoch) # sets epoch to `epoch`
818
+
819
+ for step in range(20):
820
+ x = torch.randn(200, 100).detach()
821
+ x.requires_grad = True
822
+ y = m(x)
823
+ dy = torch.randn(200, 100).detach()
824
+ f = (y * dy).sum()
825
+ f.backward()
826
+
827
+ optim.step()
828
+ scheduler.step_batch()
829
+ optim.zero_grad()
830
+
831
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
832
+ logging.info(f"state dict = {scheduler.state_dict()}")
833
+
834
+
835
+ # This is included mostly as a baseline for ScaledAdam.
836
+ class Eve(Optimizer):
837
+ """
838
+ Implements Eve algorithm. This is a modified version of AdamW with a special
839
+ way of setting the weight-decay / shrinkage-factor, which is designed to make the
840
+ rms of the parameters approach a particular target_rms (default: 0.1). This is
841
+ for use with networks with 'scaled' versions of modules (see scaling.py), which
842
+ will be close to invariant to the absolute scale on the parameter matrix.
843
+
844
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
845
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
846
+ Eve is unpublished so far.
847
+
848
+ Arguments:
849
+ params (iterable): iterable of parameters to optimize or dicts defining
850
+ parameter groups
851
+ lr (float, optional): learning rate (default: 1e-3)
852
+ betas (Tuple[float, float], optional): coefficients used for computing
853
+ running averages of gradient and its square (default: (0.9, 0.999))
854
+ eps (float, optional): term added to the denominator to improve
855
+ numerical stability (default: 1e-8)
856
+ weight_decay (float, optional): weight decay coefficient (default: 3e-4;
857
+ this value means that the weight would decay significantly after
858
+ about 3k minibatches. Is not multiplied by learning rate, but
859
+ is conditional on RMS-value of parameter being > target_rms.
860
+ target_rms (float, optional): target root-mean-square value of
861
+ parameters, if they fall below this we will stop applying weight decay.
862
+
863
+
864
+ .. _Adam: A Method for Stochastic Optimization:
865
+ https://arxiv.org/abs/1412.6980
866
+ .. _Decoupled Weight Decay Regularization:
867
+ https://arxiv.org/abs/1711.05101
868
+ .. _On the Convergence of Adam and Beyond:
869
+ https://openreview.net/forum?id=ryQu7f-RZ
870
+ """
871
+
872
+ def __init__(
873
+ self,
874
+ params,
875
+ lr=1e-3,
876
+ betas=(0.9, 0.98),
877
+ eps=1e-8,
878
+ weight_decay=1e-3,
879
+ target_rms=0.1,
880
+ ):
881
+ if not 0.0 <= lr:
882
+ raise ValueError("Invalid learning rate: {}".format(lr))
883
+ if not 0.0 <= eps:
884
+ raise ValueError("Invalid epsilon value: {}".format(eps))
885
+ if not 0.0 <= betas[0] < 1.0:
886
+ raise ValueError(
887
+ "Invalid beta parameter at index 0: {}".format(betas[0])
888
+ )
889
+ if not 0.0 <= betas[1] < 1.0:
890
+ raise ValueError(
891
+ "Invalid beta parameter at index 1: {}".format(betas[1])
892
+ )
893
+ if not 0 <= weight_decay <= 0.1:
894
+ raise ValueError(
895
+ "Invalid weight_decay value: {}".format(weight_decay)
896
+ )
897
+ if not 0 < target_rms <= 10.0:
898
+ raise ValueError("Invalid target_rms value: {}".format(target_rms))
899
+ defaults = dict(
900
+ lr=lr,
901
+ betas=betas,
902
+ eps=eps,
903
+ weight_decay=weight_decay,
904
+ target_rms=target_rms,
905
+ )
906
+ super(Eve, self).__init__(params, defaults)
907
+
908
+ def __setstate__(self, state):
909
+ super(Eve, self).__setstate__(state)
910
+
911
+ @torch.no_grad()
912
+ def step(self, closure=None):
913
+ """Performs a single optimization step.
914
+
915
+ Arguments:
916
+ closure (callable, optional): A closure that reevaluates the model
917
+ and returns the loss.
918
+ """
919
+ loss = None
920
+ if closure is not None:
921
+ with torch.enable_grad():
922
+ loss = closure()
923
+
924
+ for group in self.param_groups:
925
+ for p in group["params"]:
926
+ if p.grad is None:
927
+ continue
928
+
929
+ # Perform optimization step
930
+ grad = p.grad
931
+ if grad.is_sparse:
932
+ raise RuntimeError(
933
+ "AdamW does not support sparse gradients"
934
+ )
935
+
936
+ state = self.state[p]
937
+
938
+ # State initialization
939
+ if len(state) == 0:
940
+ state["step"] = 0
941
+ # Exponential moving average of gradient values
942
+ state["exp_avg"] = torch.zeros_like(
943
+ p, memory_format=torch.preserve_format
944
+ )
945
+ # Exponential moving average of squared gradient values
946
+ state["exp_avg_sq"] = torch.zeros_like(
947
+ p, memory_format=torch.preserve_format
948
+ )
949
+
950
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
951
+
952
+ beta1, beta2 = group["betas"]
953
+
954
+ state["step"] += 1
955
+ bias_correction1 = 1 - beta1 ** state["step"]
956
+ bias_correction2 = 1 - beta2 ** state["step"]
957
+
958
+ # Decay the first and second moment running average coefficient
959
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
960
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
961
+ denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
962
+ group["eps"]
963
+ )
964
+
965
+ step_size = group["lr"] / bias_correction1
966
+ target_rms = group["target_rms"]
967
+ weight_decay = group["weight_decay"]
968
+
969
+ if p.numel() > 1:
970
+ # avoid applying this weight-decay on "scaling factors"
971
+ # (which are scalar).
972
+ is_above_target_rms = p.norm() > (
973
+ target_rms * (p.numel() ** 0.5)
974
+ )
975
+ p.mul_(1 - (weight_decay * is_above_target_rms))
976
+
977
+ p.addcdiv_(exp_avg, denom, value=-step_size)
978
+
979
+ # if random.random() < 0.0005:
980
+ # step = (exp_avg / denom) * step_size
981
+ # logging.info(
982
+ # f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
983
+ # )
984
+
985
+ return loss
986
+
987
+
988
+ def _test_scaled_adam(hidden_dim: int):
989
+ import timeit
990
+
991
+ from scaling import ScaledLinear
992
+
993
+ E = 100
994
+ B = 4
995
+ T = 2
996
+ logging.info("in test_eve_cain")
997
+ # device = torch.device('cuda')
998
+ device = torch.device("cpu")
999
+ dtype = torch.float32
1000
+
1001
+ fix_random_seed(42)
1002
+ # these input_magnitudes and output_magnitudes are to test that
1003
+ # Abel is working as we expect and is able to adjust scales of
1004
+ # different dims differently.
1005
+ input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
1006
+ output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
1007
+
1008
+ for iter in [1, 0]:
1009
+ fix_random_seed(42)
1010
+ Linear = torch.nn.Linear if iter == 0 else ScaledLinear
1011
+
1012
+ m = torch.nn.Sequential(
1013
+ Linear(E, hidden_dim),
1014
+ torch.nn.PReLU(),
1015
+ Linear(hidden_dim, hidden_dim),
1016
+ torch.nn.PReLU(),
1017
+ Linear(hidden_dim, E),
1018
+ ).to(device)
1019
+
1020
+ train_pairs = [
1021
+ (
1022
+ 100.0
1023
+ * torch.randn(B, T, E, device=device, dtype=dtype)
1024
+ * input_magnitudes,
1025
+ torch.randn(B, T, E, device=device, dtype=dtype)
1026
+ * output_magnitudes,
1027
+ )
1028
+ for _ in range(20)
1029
+ ]
1030
+
1031
+ if iter == 0:
1032
+ optim = Eve(m.parameters(), lr=0.003)
1033
+ elif iter == 1:
1034
+ optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
1035
+ scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
1036
+
1037
+ start = timeit.default_timer()
1038
+ avg_loss = 0.0
1039
+ for epoch in range(180):
1040
+ scheduler.step_epoch()
1041
+ # if epoch == 100 and iter in [2,3]:
1042
+ # optim.reset_speedup() # check it doesn't crash.
1043
+
1044
+ # if epoch == 130:
1045
+ # opts = diagnostics.TensorDiagnosticOptions(
1046
+ # 2 ** 22
1047
+ # ) # allow 4 megabytes per sub-module
1048
+ # diagnostic = diagnostics.attach_diagnostics(m, opts)
1049
+
1050
+ for n, (x, y) in enumerate(train_pairs):
1051
+ y_out = m(x)
1052
+ loss = ((y_out - y) ** 2).mean() * 100.0
1053
+ if epoch == 0 and n == 0:
1054
+ avg_loss = loss.item()
1055
+ else:
1056
+ avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
1057
+ if n == 0 and epoch % 5 == 0:
1058
+ # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
1059
+ # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
1060
+ # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
1061
+ # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
1062
+ # scale1 = '%.2e' % (m[0].weight_scale.exp().item())
1063
+ # scale1b = '%.2e' % (m[0].bias_scale.exp().item())
1064
+ # scale2 = '%.2e' % (m[2].weight_scale.exp().item())
1065
+ # scale2b = '%.2e' % (m[2].bias_scale.exp().item())
1066
+ lr = scheduler.get_last_lr()[0]
1067
+ logging.info(
1068
+ f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
1069
+ ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
1070
+ loss.log().backward()
1071
+ optim.step()
1072
+ optim.zero_grad()
1073
+ scheduler.step_batch()
1074
+
1075
+ # diagnostic.print_diagnostics()
1076
+
1077
+ stop = timeit.default_timer()
1078
+ logging.info(f"Iter={iter}, Time taken: {stop - start}")
1079
+
1080
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
1081
+ # logging.info("state dict = ", scheduler.state_dict())
1082
+ # logging.info("optim state_dict = ", optim.state_dict())
1083
+ logging.info(f"input_magnitudes = {input_magnitudes}")
1084
+ logging.info(f"output_magnitudes = {output_magnitudes}")
1085
+
1086
+
1087
+ if __name__ == "__main__":
1088
+ torch.set_num_threads(1)
1089
+ torch.set_num_interop_threads(1)
1090
+ logging.getLogger().setLevel(logging.INFO)
1091
+ import subprocess
1092
+
1093
+ s = subprocess.check_output(
1094
+ "git status -uno .; git log -1; git diff HEAD .", shell=True
1095
+ )
1096
+ logging.info(s)
1097
+ import sys
1098
+
1099
+ if len(sys.argv) > 1:
1100
+ hidden_dim = int(sys.argv[1])
1101
+ else:
1102
+ hidden_dim = 200
1103
+
1104
+ _test_scaled_adam(hidden_dim)
1105
+ _test_eden()
apps/audio_cloning/vallex/modules/scaling.py ADDED
@@ -0,0 +1,1369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
2
+ #
3
+ # See ../../../../LICENSE for clarification regarding multiple authors
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import logging
19
+ import math
20
+ import random
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ from torch import Tensor
26
+
27
+
28
+ class Transpose(nn.Identity):
29
+ """(N, T, D) -> (N, D, T)"""
30
+
31
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
32
+ return input.transpose(1, 2)
33
+
34
+
35
+ class ActivationBalancerFunction(torch.autograd.Function):
36
+ @staticmethod
37
+ def forward(
38
+ ctx,
39
+ x: Tensor,
40
+ scale_factor: Tensor,
41
+ sign_factor: Optional[Tensor],
42
+ channel_dim: int,
43
+ ) -> Tensor:
44
+ if channel_dim < 0:
45
+ channel_dim += x.ndim
46
+ ctx.channel_dim = channel_dim
47
+ xgt0 = x > 0
48
+ if sign_factor is None:
49
+ ctx.save_for_backward(xgt0, scale_factor)
50
+ else:
51
+ ctx.save_for_backward(xgt0, scale_factor, sign_factor)
52
+ return x
53
+
54
+ @staticmethod
55
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
56
+ if len(ctx.saved_tensors) == 3:
57
+ xgt0, scale_factor, sign_factor = ctx.saved_tensors
58
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
59
+ scale_factor = scale_factor.unsqueeze(-1)
60
+ sign_factor = sign_factor.unsqueeze(-1)
61
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
62
+ else:
63
+ xgt0, scale_factor = ctx.saved_tensors
64
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
65
+ scale_factor = scale_factor.unsqueeze(-1)
66
+ factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
67
+ neg_delta_grad = x_grad.abs() * factor
68
+ return (
69
+ x_grad - neg_delta_grad,
70
+ None,
71
+ None,
72
+ None,
73
+ )
74
+
75
+
76
+ def _compute_scale_factor(
77
+ x: Tensor,
78
+ channel_dim: int,
79
+ min_abs: float,
80
+ max_abs: float,
81
+ gain_factor: float,
82
+ max_factor: float,
83
+ ) -> Tensor:
84
+ if channel_dim < 0:
85
+ channel_dim += x.ndim
86
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
87
+ x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
88
+
89
+ if min_abs == 0.0:
90
+ below_threshold = 0.0
91
+ else:
92
+ # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
93
+ # x_abs)_mean , min_abs.
94
+ below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
95
+ min=0, max=max_factor
96
+ )
97
+
98
+ above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
99
+ min=0, max=max_factor
100
+ )
101
+
102
+ return below_threshold - above_threshold
103
+
104
+
105
+ def _compute_sign_factor(
106
+ x: Tensor,
107
+ channel_dim: int,
108
+ min_positive: float,
109
+ max_positive: float,
110
+ gain_factor: float,
111
+ max_factor: float,
112
+ ) -> Tensor:
113
+ if channel_dim < 0:
114
+ channel_dim += x.ndim
115
+ sum_dims = [d for d in range(x.ndim) if d != channel_dim]
116
+ proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
117
+ if min_positive == 0.0:
118
+ factor1 = 0.0
119
+ else:
120
+ # 0 if proportion_positive >= min_positive, else can be
121
+ # as large as max_factor.
122
+ factor1 = (
123
+ (min_positive - proportion_positive) * (gain_factor / min_positive)
124
+ ).clamp_(min=0, max=max_factor)
125
+
126
+ if max_positive == 1.0:
127
+ factor2 = 0.0
128
+ else:
129
+ # 0 if self.proportion_positive <= max_positive, else can be
130
+ # as large as -max_factor.
131
+ factor2 = (
132
+ (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
133
+ ).clamp_(min=0, max=max_factor)
134
+ sign_factor = factor1 - factor2
135
+ # require min_positive != 0 or max_positive != 1:
136
+ assert not isinstance(sign_factor, float)
137
+ return sign_factor
138
+
139
+
140
+ class ActivationScaleBalancerFunction(torch.autograd.Function):
141
+ """
142
+ This object is used in class ActivationBalancer when the user specified
143
+ min_positive=0, max_positive=1, so there are no constraints on the signs
144
+ of the activations and only the absolute value has a constraint.
145
+ """
146
+
147
+ @staticmethod
148
+ def forward(
149
+ ctx,
150
+ x: Tensor,
151
+ sign_factor: Tensor,
152
+ scale_factor: Tensor,
153
+ channel_dim: int,
154
+ ) -> Tensor:
155
+ if channel_dim < 0:
156
+ channel_dim += x.ndim
157
+ ctx.channel_dim = channel_dim
158
+ xgt0 = x > 0
159
+ ctx.save_for_backward(xgt0, sign_factor, scale_factor)
160
+ return x
161
+
162
+ @staticmethod
163
+ def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
164
+ xgt0, sign_factor, scale_factor = ctx.saved_tensors
165
+ for _ in range(ctx.channel_dim, x_grad.ndim - 1):
166
+ sign_factor = sign_factor.unsqueeze(-1)
167
+ scale_factor = scale_factor.unsqueeze(-1)
168
+
169
+ factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
170
+ neg_delta_grad = x_grad.abs() * factor
171
+ return (
172
+ x_grad - neg_delta_grad,
173
+ None,
174
+ None,
175
+ None,
176
+ )
177
+
178
+
179
+ class RandomClampFunction(torch.autograd.Function):
180
+ @staticmethod
181
+ def forward(
182
+ ctx,
183
+ x: Tensor,
184
+ min: Optional[float],
185
+ max: Optional[float],
186
+ prob: float,
187
+ reflect: float,
188
+ ) -> Tensor:
189
+ x_clamped = torch.clamp(x, min=min, max=max)
190
+ mask = torch.rand_like(x) < prob
191
+ ans = torch.where(mask, x_clamped, x)
192
+ if x.requires_grad:
193
+ ctx.save_for_backward(ans == x)
194
+ ctx.reflect = reflect
195
+ if reflect != 0.0:
196
+ ans = ans * (1.0 + reflect) - (x * reflect)
197
+ return ans
198
+
199
+ @staticmethod
200
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
201
+ (is_same,) = ctx.saved_tensors
202
+ x_grad = ans_grad * is_same.to(ans_grad.dtype)
203
+ reflect = ctx.reflect
204
+ if reflect != 0.0:
205
+ x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
206
+ return x_grad, None, None, None, None
207
+
208
+
209
+ def random_clamp(
210
+ x: Tensor,
211
+ min: Optional[float] = None,
212
+ max: Optional[float] = None,
213
+ prob: float = 0.5,
214
+ reflect: float = 0.0,
215
+ ):
216
+ return RandomClampFunction.apply(x, min, max, prob, reflect)
217
+
218
+
219
+ def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
220
+ """
221
+ A randomized way of casting a floating point value to half precision.
222
+ """
223
+ if x.dtype == torch.float16:
224
+ return x
225
+ x_abs = x.abs()
226
+ is_too_small = x_abs < min_abs
227
+ # for elements where is_too_small is true, random_val will contain +-min_abs with
228
+ # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
229
+ # for those elements].
230
+ random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
231
+ return torch.where(is_too_small, random_val, x).to(torch.float16)
232
+
233
+
234
+ class RandomGradFunction(torch.autograd.Function):
235
+ """
236
+ Does nothing in forward pass; in backward pass, gets rid of very small grads using
237
+ randomized approach that preserves expectations (intended to reduce roundoff).
238
+ """
239
+
240
+ @staticmethod
241
+ def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
242
+ ctx.min_abs = min_abs
243
+ return x
244
+
245
+ @staticmethod
246
+ def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
247
+ if ans_grad.dtype == torch.float16:
248
+ return (
249
+ random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs),
250
+ None,
251
+ )
252
+ else:
253
+ return ans_grad, None
254
+
255
+
256
+ class RandomGrad(torch.nn.Module):
257
+ """
258
+ Gets rid of very small gradients using an expectation-preserving method, intended to increase
259
+ accuracy of training when using amp (automatic mixed precision)
260
+ """
261
+
262
+ def __init__(self, min_abs: float = 5.0e-06):
263
+ super(RandomGrad, self).__init__()
264
+ self.min_abs = min_abs
265
+
266
+ def forward(self, x: Tensor):
267
+ if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
268
+ return x
269
+ else:
270
+ return RandomGradFunction.apply(x, self.min_abs)
271
+
272
+
273
+ class SoftmaxFunction(torch.autograd.Function):
274
+ """
275
+ Tries to handle half-precision derivatives in a randomized way that should
276
+ be more accurate for training than the default behavior.
277
+ """
278
+
279
+ @staticmethod
280
+ def forward(ctx, x: Tensor, dim: int):
281
+ ans = x.softmax(dim=dim)
282
+ # if x dtype is float16, x.softmax() returns a float32 because
283
+ # (presumably) that op does not support float16, and autocast
284
+ # is enabled.
285
+ if torch.is_autocast_enabled():
286
+ ans = ans.to(torch.float16)
287
+ ctx.save_for_backward(ans)
288
+ ctx.x_dtype = x.dtype
289
+ ctx.dim = dim
290
+ return ans
291
+
292
+ @staticmethod
293
+ def backward(ctx, ans_grad: Tensor):
294
+ (ans,) = ctx.saved_tensors
295
+ with torch.cuda.amp.autocast(enabled=False):
296
+ ans_grad = ans_grad.to(torch.float32)
297
+ ans = ans.to(torch.float32)
298
+ x_grad = ans_grad * ans
299
+ x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
300
+ return x_grad, None
301
+
302
+
303
+ def softmax(x: Tensor, dim: int):
304
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
305
+ return x.softmax(dim)
306
+
307
+ return SoftmaxFunction.apply(x, dim)
308
+
309
+
310
+ class MaxEigLimiterFunction(torch.autograd.Function):
311
+ @staticmethod
312
+ def forward(
313
+ ctx,
314
+ x: Tensor,
315
+ coeffs: Tensor,
316
+ direction: Tensor,
317
+ channel_dim: int,
318
+ grad_scale: float,
319
+ ) -> Tensor:
320
+ ctx.channel_dim = channel_dim
321
+ ctx.grad_scale = grad_scale
322
+ ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
323
+ return x
324
+
325
+ @staticmethod
326
+ def backward(ctx, x_grad, *args):
327
+ with torch.enable_grad():
328
+ (x_orig, coeffs, new_direction) = ctx.saved_tensors
329
+ x_orig.requires_grad = True
330
+ num_channels = x_orig.shape[ctx.channel_dim]
331
+ x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
332
+ new_direction.requires_grad = False
333
+ x = x - x.mean(dim=0)
334
+ x_var = (x**2).mean()
335
+ x_residual = x - coeffs * new_direction
336
+ x_residual_var = (x_residual**2).mean()
337
+ # `variance_proportion` is the proportion of the variance accounted for
338
+ # by the top eigen-direction. This is to be minimized.
339
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
340
+ variance_proportion.backward()
341
+ x_orig_grad = x_orig.grad
342
+ x_extra_grad = (
343
+ x_orig.grad
344
+ * ctx.grad_scale
345
+ * x_grad.norm()
346
+ / (x_orig_grad.norm() + 1.0e-20)
347
+ )
348
+ return x_grad + x_extra_grad.detach(), None, None, None, None
349
+
350
+
351
+ class BasicNorm(torch.nn.Module):
352
+ """
353
+ This is intended to be a simpler, and hopefully cheaper, replacement for
354
+ LayerNorm. The observation this is based on, is that Transformer-type
355
+ networks, especially with pre-norm, sometimes seem to set one of the
356
+ feature dimensions to a large constant value (e.g. 50), which "defeats"
357
+ the LayerNorm because the output magnitude is then not strongly dependent
358
+ on the other (useful) features. Presumably the weight and bias of the
359
+ LayerNorm are required to allow it to do this.
360
+
361
+ So the idea is to introduce this large constant value as an explicit
362
+ parameter, that takes the role of the "eps" in LayerNorm, so the network
363
+ doesn't have to do this trick. We make the "eps" learnable.
364
+
365
+ Args:
366
+ num_channels: the number of channels, e.g. 512.
367
+ channel_dim: the axis/dimension corresponding to the channel,
368
+ interprted as an offset from the input's ndim if negative.
369
+ shis is NOT the num_channels; it should typically be one of
370
+ {-2, -1, 0, 1, 2, 3}.
371
+ eps: the initial "epsilon" that we add as ballast in:
372
+ scale = ((input_vec**2).mean() + epsilon)**-0.5
373
+ Note: our epsilon is actually large, but we keep the name
374
+ to indicate the connection with conventional LayerNorm.
375
+ learn_eps: if true, we learn epsilon; if false, we keep it
376
+ at the initial value.
377
+ eps_min: float
378
+ eps_max: float
379
+ """
380
+
381
+ def __init__(
382
+ self,
383
+ num_channels: int,
384
+ channel_dim: int = -1, # CAUTION: see documentation.
385
+ eps: float = 0.25,
386
+ learn_eps: bool = True,
387
+ eps_min: float = -3.0,
388
+ eps_max: float = 3.0,
389
+ ) -> None:
390
+ super(BasicNorm, self).__init__()
391
+ self.num_channels = num_channels
392
+ self.channel_dim = channel_dim
393
+ if learn_eps:
394
+ self.eps = nn.Parameter(torch.tensor(eps).log().detach())
395
+ else:
396
+ self.register_buffer("eps", torch.tensor(eps).log().detach())
397
+ self.eps_min = eps_min
398
+ self.eps_max = eps_max
399
+
400
+ def forward(self, x: Tensor) -> Tensor:
401
+ assert x.shape[self.channel_dim] == self.num_channels
402
+ eps = self.eps
403
+ if self.training and random.random() < 0.25:
404
+ # with probability 0.25, in training mode, clamp eps between the min
405
+ # and max; this will encourage it to learn parameters within the
406
+ # allowed range by making parameters that are outside the allowed
407
+ # range noisy.
408
+
409
+ # gradients to allow the parameter to get back into the allowed
410
+ # region if it happens to exit it.
411
+ eps = eps.clamp(min=self.eps_min, max=self.eps_max)
412
+ scales = (
413
+ torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp()
414
+ ) ** -0.5
415
+ return x * scales
416
+
417
+
418
+ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
419
+ """
420
+ Behaves like a constructor of a modified version of nn.Linear
421
+ that gives an easy way to set the default initial parameter scale.
422
+
423
+ Args:
424
+ Accepts the standard args and kwargs that nn.Linear accepts
425
+ e.g. in_features, out_features, bias=False.
426
+
427
+ initial_scale: you can override this if you want to increase
428
+ or decrease the initial magnitude of the module's output
429
+ (affects the initialization of weight_scale and bias_scale).
430
+ Another option, if you want to do something like this, is
431
+ to re-initialize the parameters.
432
+ """
433
+ ans = nn.Linear(*args, **kwargs)
434
+ with torch.no_grad():
435
+ ans.weight[:] *= initial_scale
436
+ if ans.bias is not None:
437
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
438
+ return ans
439
+
440
+
441
+ def ScaledConv1d(
442
+ *args,
443
+ initial_scale: float = 1.0,
444
+ kernel_size: int = 3,
445
+ padding: str = "same",
446
+ **kwargs,
447
+ ) -> nn.Conv1d:
448
+ """
449
+ Behaves like a constructor of a modified version of nn.Conv1d
450
+ that gives an easy way to set the default initial parameter scale.
451
+
452
+ Args:
453
+ Accepts the standard args and kwargs that nn.Linear accepts
454
+ e.g. in_features, out_features, bias=False.
455
+
456
+ initial_scale: you can override this if you want to increase
457
+ or decrease the initial magnitude of the module's output
458
+ (affects the initialization of weight_scale and bias_scale).
459
+ Another option, if you want to do something like this, is
460
+ to re-initialize the parameters.
461
+ """
462
+ ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
463
+ with torch.no_grad():
464
+ ans.weight[:] *= initial_scale
465
+ if ans.bias is not None:
466
+ torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
467
+ return ans
468
+
469
+
470
+ def TransposeScaledConv1d(
471
+ *args,
472
+ initial_scale: float = 1.0,
473
+ kernel_size: int = 3,
474
+ padding: str = "same",
475
+ **kwargs,
476
+ ) -> nn.Sequential:
477
+ """
478
+ Transpose -> ScaledConv1d
479
+ """
480
+ return nn.Sequential(
481
+ Transpose(),
482
+ ScaledConv1d(
483
+ *args,
484
+ initial_scale=initial_scale,
485
+ kernel_size=kernel_size,
486
+ padding=padding,
487
+ **kwargs,
488
+ ),
489
+ )
490
+
491
+
492
+ def ScaledConv1dTranspose(
493
+ *args,
494
+ initial_scale: float = 1.0,
495
+ kernel_size: int = 3,
496
+ padding: str = "same",
497
+ **kwargs,
498
+ ) -> nn.Sequential:
499
+ """
500
+ Transpose -> ScaledConv1d
501
+ """
502
+ return nn.Sequential(
503
+ ScaledConv1d(
504
+ *args,
505
+ initial_scale=initial_scale,
506
+ kernel_size=kernel_size,
507
+ padding=padding,
508
+ **kwargs,
509
+ ),
510
+ Transpose(),
511
+ )
512
+
513
+
514
+ def TransposeConv1d(
515
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
516
+ ) -> nn.Sequential:
517
+ """
518
+ Transpose -> Conv1d
519
+ """
520
+ return nn.Sequential(
521
+ Transpose(),
522
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
523
+ )
524
+
525
+
526
+ def Conv1dTranspose(
527
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
528
+ ) -> nn.Sequential:
529
+ """
530
+ ScaledConv1d -> Transpose
531
+ """
532
+ return nn.Sequential(
533
+ nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
534
+ Transpose(),
535
+ )
536
+
537
+
538
+ class SRLinear(nn.Linear):
539
+ """https://arxiv.org/abs/2303.06296
540
+ Stabilizing Transformer Training by Preventing Attention Entropy Collapse
541
+ """
542
+
543
+ def __init__(self, in_features, out_features, bias=True, **kwargs):
544
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
545
+ self.register_buffer(
546
+ "u", nn.functional.normalize(torch.randn(in_features), dim=0)
547
+ )
548
+ with torch.no_grad():
549
+ sigma = self.get_sigma()
550
+ self.register_buffer("spectral_norm", sigma)
551
+ self.sigma = nn.Parameter(torch.ones(1))
552
+
553
+ def get_sigma(self):
554
+ with torch.no_grad():
555
+ u = self.u
556
+ v = self.weight.mv(u)
557
+ v = nn.functional.normalize(v, dim=0)
558
+ u = self.weight.T.mv(v)
559
+ u = nn.functional.normalize(u, dim=0)
560
+ self.u.data.copy_(u)
561
+ return torch.einsum("c,cd,d->", v, self.weight, u)
562
+
563
+ def get_weight(self):
564
+ sigma = self.get_sigma()
565
+ if self.training:
566
+ self.spectral_norm.data.copy_(sigma)
567
+ weight = (self.sigma / sigma) * self.weight
568
+ return weight
569
+
570
+ def forward(self, x):
571
+ return nn.functional.linear(x, self.get_weight(), self.bias)
572
+
573
+
574
+ class SRConv1d(SRLinear):
575
+ def __init__(
576
+ self,
577
+ in_features,
578
+ out_features,
579
+ kernel_size,
580
+ stride: int = 1,
581
+ padding: str = "same",
582
+ bias: bool = True,
583
+ **kwargs,
584
+ ):
585
+ in_features = in_features * kernel_size
586
+ super().__init__(in_features, out_features, bias=bias, **kwargs)
587
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
588
+ self.kernel_size = kernel_size
589
+ self.stride = stride
590
+ self.padding = padding
591
+
592
+ def forward(self, x):
593
+ in_features = self.in_features // self.kernel_size
594
+ weight = self.get_weight().view(
595
+ self.out_features, in_features, self.kernel_size
596
+ )
597
+ return nn.functional.conv1d(
598
+ x, weight, bias=self.bias, stride=self.stride, padding=self.padding
599
+ )
600
+
601
+
602
+ def TransposeSRConv1d(
603
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
604
+ ) -> nn.Sequential:
605
+ """
606
+ Transpose -> SRConv1d
607
+ """
608
+ return nn.Sequential(
609
+ Transpose(),
610
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
611
+ )
612
+
613
+
614
+ def SRConv1dTranspose(
615
+ *args, kernel_size: int = 3, padding: str = "same", **kwargs
616
+ ) -> nn.Sequential:
617
+ """
618
+ SRConv1d -> Transpose
619
+ """
620
+ return nn.Sequential(
621
+ SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
622
+ Transpose(),
623
+ )
624
+
625
+
626
+ class ActivationBalancer(torch.nn.Module):
627
+ """
628
+ Modifies the backpropped derivatives of a function to try to encourage, for
629
+ each channel, that it is positive at least a proportion `threshold` of the
630
+ time. It does this by multiplying negative derivative values by up to
631
+ (1+max_factor), and positive derivative values by up to (1-max_factor),
632
+ interpolated from 1 at the threshold to those extremal values when none
633
+ of the inputs are positive.
634
+
635
+ Args:
636
+ num_channels: the number of channels
637
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
638
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
639
+ min_positive: the minimum, per channel, of the proportion of the time
640
+ that (x > 0), below which we start to modify the derivatives.
641
+ max_positive: the maximum, per channel, of the proportion of the time
642
+ that (x > 0), above which we start to modify the derivatives.
643
+ max_factor: the maximum factor by which we modify the derivatives for
644
+ either the sign constraint or the magnitude constraint;
645
+ e.g. with max_factor=0.02, the the derivatives would be multiplied by
646
+ values in the range [0.98..1.02].
647
+ sign_gain_factor: determines the 'gain' with which we increase the
648
+ change in gradient once the constraints on min_positive and max_positive
649
+ are violated.
650
+ scale_gain_factor: determines the 'gain' with which we increase the
651
+ change in gradient once the constraints on min_abs and max_abs
652
+ are violated.
653
+ min_abs: the minimum average-absolute-value difference from the mean
654
+ value per channel, which we allow, before we start to modify
655
+ the derivatives to prevent this.
656
+ max_abs: the maximum average-absolute-value difference from the mean
657
+ value per channel, which we allow, before we start to modify
658
+ the derivatives to prevent this.
659
+ min_prob: determines the minimum probability with which we modify the
660
+ gradients for the {min,max}_positive and {min,max}_abs constraints,
661
+ on each forward(). This is done randomly to prevent all layers
662
+ from doing it at the same time. Early in training we may use
663
+ higher probabilities than this; it will decay to this value.
664
+ """
665
+
666
+ def __init__(
667
+ self,
668
+ num_channels: int,
669
+ channel_dim: int,
670
+ min_positive: float = 0.05,
671
+ max_positive: float = 0.95,
672
+ max_factor: float = 0.04,
673
+ sign_gain_factor: float = 0.01,
674
+ scale_gain_factor: float = 0.02,
675
+ min_abs: float = 0.2,
676
+ max_abs: float = 100.0,
677
+ min_prob: float = 0.1,
678
+ ):
679
+ super(ActivationBalancer, self).__init__()
680
+ self.num_channels = num_channels
681
+ self.channel_dim = channel_dim
682
+ self.min_positive = min_positive
683
+ self.max_positive = max_positive
684
+ self.max_factor = max_factor
685
+ self.min_abs = min_abs
686
+ self.max_abs = max_abs
687
+ self.min_prob = min_prob
688
+ self.sign_gain_factor = sign_gain_factor
689
+ self.scale_gain_factor = scale_gain_factor
690
+
691
+ # count measures how many times the forward() function has been called.
692
+ # We occasionally sync this to a tensor called `count`, that exists to
693
+ # make sure it is synced to disk when we load and save the model.
694
+ self.cpu_count = 0
695
+ self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
696
+
697
+ def forward(self, x: Tensor) -> Tensor:
698
+ if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
699
+ return _no_op(x)
700
+
701
+ count = self.cpu_count
702
+ self.cpu_count += 1
703
+
704
+ if random.random() < 0.01:
705
+ # Occasionally sync self.cpu_count with self.count.
706
+ # count affects the decay of 'prob'. don't do this on every iter,
707
+ # because syncing with the GPU is slow.
708
+ self.cpu_count = max(self.cpu_count, self.count.item())
709
+ self.count.fill_(self.cpu_count)
710
+
711
+ # the prob of doing some work exponentially decreases from 0.5 till it hits
712
+ # a floor at min_prob (==0.1, by default)
713
+ prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
714
+
715
+ if random.random() < prob:
716
+ sign_gain_factor = 0.5
717
+ if self.min_positive != 0.0 or self.max_positive != 1.0:
718
+ sign_factor = _compute_sign_factor(
719
+ x,
720
+ self.channel_dim,
721
+ self.min_positive,
722
+ self.max_positive,
723
+ gain_factor=self.sign_gain_factor / prob,
724
+ max_factor=self.max_factor,
725
+ )
726
+ else:
727
+ sign_factor = None
728
+
729
+ scale_factor = _compute_scale_factor(
730
+ x.detach(),
731
+ self.channel_dim,
732
+ min_abs=self.min_abs,
733
+ max_abs=self.max_abs,
734
+ gain_factor=self.scale_gain_factor / prob,
735
+ max_factor=self.max_factor,
736
+ )
737
+ return ActivationBalancerFunction.apply(
738
+ x,
739
+ scale_factor,
740
+ sign_factor,
741
+ self.channel_dim,
742
+ )
743
+ else:
744
+ return _no_op(x)
745
+
746
+
747
+ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
748
+ """
749
+ Returns x unmodified, but in backprop will put a penalty for the excess of
750
+ the absolute values of elements of x over the limit "limit". E.g. if
751
+ limit == 10.0, then if x has any values over 10 it will get a penalty.
752
+
753
+ Caution: the value of this penalty will be affected by grad scaling used
754
+ in automatic mixed precision training. For this reasons we use this,
755
+ it shouldn't really matter, or may even be helpful; we just use this
756
+ to disallow really implausible values of scores to be given to softmax.
757
+ """
758
+ x_sign = x.sign()
759
+ over_limit = (x.abs() - limit) > 0
760
+ # The following is a memory efficient way to penalize the absolute values of
761
+ # x that's over the limit. (The memory efficiency comes when you think
762
+ # about which items torch needs to cache for the autograd, and which ones it
763
+ # can throw away). The numerical value of aux_loss as computed here will
764
+ # actually be larger than it should be, by limit * over_limit.sum(), but it
765
+ # has the same derivative as the real aux_loss which is penalty * (x.abs() -
766
+ # limit).relu().
767
+ aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
768
+ # note: we don't do sum() here on aux)_loss, but it's as if we had done
769
+ # sum() due to how with_loss() works.
770
+ x = with_loss(x, aux_loss)
771
+ # you must use x for something, or this will be ineffective.
772
+ return x
773
+
774
+
775
+ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
776
+ if x.ndim == 2:
777
+ return x.diag()
778
+ else:
779
+ (batch, dim, dim) = x.shape
780
+ x = x.reshape(batch, dim * dim)
781
+ x = x[:, :: dim + 1]
782
+ assert x.shape == (batch, dim)
783
+ return x
784
+
785
+
786
+ def _whitening_metric(x: Tensor, num_groups: int):
787
+ """
788
+ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
789
+ of the centered feature covariance are the same within each group's covariance matrix
790
+ and also between groups.
791
+ Args:
792
+ x: a Tensor of shape (*, num_channels)
793
+ num_groups: the number of groups of channels, a number >=1 that divides num_channels
794
+ Returns:
795
+ Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
796
+ greater than 1.0 otherwise.
797
+ """
798
+ assert x.dtype != torch.float16
799
+ x = x.reshape(-1, x.shape[-1])
800
+ (num_frames, num_channels) = x.shape
801
+ assert num_channels % num_groups == 0
802
+ channels_per_group = num_channels // num_groups
803
+ x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
804
+ # x now has shape (num_groups, num_frames, channels_per_group)
805
+ # subtract the mean so we use the centered, not uncentered, covariance.
806
+ # My experience has been that when we "mess with the gradients" like this,
807
+ # it's better not do anything that tries to move the mean around, because
808
+ # that can easily cause instability.
809
+ x = x - x.mean(dim=1, keepdim=True)
810
+ # x_covar: (num_groups, channels_per_group, channels_per_group)
811
+ x_covar = torch.matmul(x.transpose(1, 2), x)
812
+ x_covar_mean_diag = _diag(x_covar).mean()
813
+ # the following expression is what we'd get if we took the matrix product
814
+ # of each covariance and measured the mean of its trace, i.e.
815
+ # the same as _diag(torch.matmul(x_covar, x_covar)).mean().
816
+ x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
817
+ # this metric will be >= 1.0; the larger it is, the less 'white' the data was.
818
+ metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
819
+ return metric
820
+
821
+
822
+ class WhiteningPenaltyFunction(torch.autograd.Function):
823
+ @staticmethod
824
+ def forward(
825
+ ctx,
826
+ x: Tensor,
827
+ num_groups: int,
828
+ whitening_limit: float,
829
+ grad_scale: float,
830
+ ) -> Tensor:
831
+ ctx.save_for_backward(x)
832
+ ctx.num_groups = num_groups
833
+ ctx.whitening_limit = whitening_limit
834
+ ctx.grad_scale = grad_scale
835
+ return x
836
+
837
+ @staticmethod
838
+ def backward(ctx, x_grad: Tensor):
839
+ (x_orig,) = ctx.saved_tensors
840
+ with torch.enable_grad():
841
+ with torch.cuda.amp.autocast(enabled=False):
842
+ x_detached = x_orig.to(torch.float32).detach()
843
+ x_detached.requires_grad = True
844
+
845
+ metric = _whitening_metric(x_detached, ctx.num_groups)
846
+
847
+ if random.random() < 0.005 or __name__ == "__main__":
848
+ logging.info(
849
+ f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
850
+ f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
851
+ )
852
+
853
+ (metric - ctx.whitening_limit).relu().backward()
854
+ penalty_grad = x_detached.grad
855
+ scale = ctx.grad_scale * (
856
+ x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20)
857
+ )
858
+ penalty_grad = penalty_grad * scale
859
+ return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
860
+
861
+
862
+ class Whiten(nn.Module):
863
+ def __init__(
864
+ self,
865
+ num_groups: int,
866
+ whitening_limit: float,
867
+ prob: Union[float, Tuple[float, float]],
868
+ grad_scale: float,
869
+ ):
870
+ """
871
+ Args:
872
+ num_groups: the number of groups to divide the channel dim into before
873
+ whitening. We will attempt to make the feature covariance
874
+ within each group, after mean subtraction, as "white" as possible,
875
+ while having the same trace across all groups.
876
+ whitening_limit: a value greater than 1.0, that dictates how much
877
+ freedom we have to violate the constraints. 1.0 would mean perfectly
878
+ white, with exactly the same trace across groups; larger values
879
+ give more freedom. E.g. 2.0.
880
+ prob: the probability with which we apply the gradient modification
881
+ (also affects the grad scale). May be supplied as a float,
882
+ or as a pair (min_prob, max_prob)
883
+
884
+ grad_scale: determines the scale on the gradient term from this object,
885
+ relative to the rest of the gradient on the attention weights.
886
+ E.g. 0.02 (you may want to use smaller values than this if prob is large)
887
+ """
888
+ super(Whiten, self).__init__()
889
+ assert num_groups >= 1
890
+ assert whitening_limit >= 1
891
+ assert grad_scale >= 0
892
+ self.num_groups = num_groups
893
+ self.whitening_limit = whitening_limit
894
+ if isinstance(prob, float):
895
+ assert 0 < prob <= 1
896
+ self.prob = prob
897
+ else:
898
+ (self.min_prob, self.max_prob) = prob
899
+ assert 0 < self.min_prob < self.max_prob <= 1
900
+ self.prob = self.max_prob
901
+
902
+ self.grad_scale = grad_scale
903
+
904
+ def forward(self, x: Tensor) -> Tensor:
905
+ """
906
+ In the forward pass, this function just returns the input unmodified.
907
+ In the backward pass, it will modify the gradients to ensure that the
908
+ distribution in each group has close to (lambda times I) as the covariance
909
+ after mean subtraction, with the same lambda across groups.
910
+ For whitening_limit > 1, there will be more freedom to violate this
911
+ constraint.
912
+
913
+ Args:
914
+ x: the input of shape (*, num_channels)
915
+
916
+ Returns:
917
+ x, unmodified. You should make sure
918
+ you use the returned value, or the graph will be freed
919
+ and nothing will happen in backprop.
920
+ """
921
+ if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
922
+ return _no_op(x)
923
+ else:
924
+ if hasattr(self, "min_prob") and random.random() < 0.25:
925
+ # occasionally switch between min_prob and max_prob, based on whether
926
+ # we are above or below the threshold.
927
+ if (
928
+ _whitening_metric(x.to(torch.float32), self.num_groups)
929
+ > self.whitening_limit
930
+ ):
931
+ # there would be a change to the grad.
932
+ self.prob = self.max_prob
933
+ else:
934
+ self.prob = self.min_prob
935
+
936
+ return WhiteningPenaltyFunction.apply(
937
+ x, self.num_groups, self.whitening_limit, self.grad_scale
938
+ )
939
+
940
+
941
+ class WithLoss(torch.autograd.Function):
942
+ @staticmethod
943
+ def forward(ctx, x: Tensor, y: Tensor):
944
+ ctx.y_shape = y.shape
945
+ return x
946
+
947
+ @staticmethod
948
+ def backward(ctx, ans_grad: Tensor):
949
+ return ans_grad, torch.ones(
950
+ ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
951
+ )
952
+
953
+
954
+ def with_loss(x, y):
955
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
956
+ return x
957
+ # returns x but adds y.sum() to the loss function.
958
+ return WithLoss.apply(x, y)
959
+
960
+
961
+ def _no_op(x: Tensor) -> Tensor:
962
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
963
+ return x
964
+ else:
965
+ # a no-op function that will have a node in the autograd graph,
966
+ # to avoid certain bugs relating to backward hooks
967
+ return x.chunk(1, dim=-1)[0]
968
+
969
+
970
+ class Identity(torch.nn.Module):
971
+ def __init__(self):
972
+ super(Identity, self).__init__()
973
+
974
+ def forward(self, x):
975
+ return _no_op(x)
976
+
977
+
978
+ class MaxEig(torch.nn.Module):
979
+ """
980
+ Modifies the backpropped derivatives of a function to try to discourage
981
+ that any given direction in activation space accounts for more than
982
+ a specified proportion of the covariance (e.g. 0.2).
983
+
984
+
985
+ Args:
986
+ num_channels: the number of channels
987
+ channel_dim: the dimension/axis corresponding to the channel, e.g.
988
+ -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
989
+ max_var_per_eig: the maximum proportion of the variance of the
990
+ features/channels, after mean subtraction, that can come from
991
+ any given eigenvalue.
992
+ min_prob: the minimum probability with which we apply this during any invocation
993
+ of forward(), assuming last time we applied the constraint it was
994
+ not active; supplied for speed.
995
+ scale: determines the scale with which we modify the gradients, relative
996
+ to the existing / unmodified gradients
997
+ """
998
+
999
+ def __init__(
1000
+ self,
1001
+ num_channels: int,
1002
+ channel_dim: int,
1003
+ max_var_per_eig: float = 0.2,
1004
+ min_prob: float = 0.01,
1005
+ scale: float = 0.01,
1006
+ ):
1007
+ super(MaxEig, self).__init__()
1008
+ self.num_channels = num_channels
1009
+ self.channel_dim = channel_dim
1010
+ self.scale = scale
1011
+ assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
1012
+ self.max_var_per_eig = max_var_per_eig
1013
+
1014
+ # we figure out the dominant direction using the power method: starting with
1015
+ # a random vector, keep multiplying by the covariance and renormalizing.
1016
+ with torch.no_grad():
1017
+ # arbitrary.. would use randn() but want to leave the rest of the model's
1018
+ # random parameters unchanged for comparison
1019
+ direction = torch.arange(num_channels).to(torch.float)
1020
+ direction = direction / direction.norm()
1021
+ self.register_buffer("max_eig_direction", direction)
1022
+
1023
+ self.min_prob = min_prob
1024
+ # cur_prob is the current probability we'll use to apply the ActivationBalancer.
1025
+ # We'll regress this towards prob, each time we try to apply it and it is not
1026
+ # active.
1027
+ self.cur_prob = 1.0
1028
+
1029
+ def forward(self, x: Tensor) -> Tensor:
1030
+ if (
1031
+ torch.jit.is_scripting()
1032
+ or self.max_var_per_eig <= 0
1033
+ or random.random() > self.cur_prob
1034
+ or torch.jit.is_tracing()
1035
+ ):
1036
+ return _no_op(x)
1037
+
1038
+ with torch.cuda.amp.autocast(enabled=False):
1039
+ eps = 1.0e-20
1040
+ orig_x = x
1041
+ x = x.to(torch.float32)
1042
+ with torch.no_grad():
1043
+ x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels)
1044
+ x = x - x.mean(dim=0)
1045
+ new_direction, coeffs = self._find_direction_coeffs(
1046
+ x, self.max_eig_direction
1047
+ )
1048
+ x_var = (x**2).mean()
1049
+ x_residual = x - coeffs * new_direction
1050
+ x_residual_var = (x_residual**2).mean()
1051
+
1052
+ # `variance_proportion` is the proportion of the variance accounted for
1053
+ # by the top eigen-direction.
1054
+ variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
1055
+
1056
+ # ensure new direction is nonzero even if x == 0, by including `direction`.
1057
+ self._set_direction(0.1 * self.max_eig_direction + new_direction)
1058
+
1059
+ if random.random() < 0.01 or __name__ == "__main__":
1060
+ logging.info(
1061
+ f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
1062
+ )
1063
+
1064
+ if variance_proportion >= self.max_var_per_eig:
1065
+ # The constraint is active. Note, we should quite rarely
1066
+ # reach here, only near the beginning of training if we are
1067
+ # starting to diverge, should this constraint be active.
1068
+ cur_prob = self.cur_prob
1069
+ self.cur_prob = 1.0 # next time, do the update with probability 1.0.
1070
+ return MaxEigLimiterFunction.apply(
1071
+ orig_x, coeffs, new_direction, self.channel_dim, self.scale
1072
+ )
1073
+ else:
1074
+ # let self.cur_prob exponentially approach self.min_prob, as
1075
+ # long as the constraint is inactive.
1076
+ self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
1077
+ return orig_x
1078
+
1079
+ def _set_direction(self, direction: Tensor):
1080
+ """
1081
+ Sets self.max_eig_direction to a normalized version of `direction`
1082
+ """
1083
+ direction = direction.detach()
1084
+ direction = direction / direction.norm()
1085
+ direction_sum = direction.sum().item()
1086
+ if direction_sum - direction_sum == 0: # no inf/nan
1087
+ self.max_eig_direction[:] = direction
1088
+ else:
1089
+ logging.info(
1090
+ f"Warning: sum of direction in MaxEig is {direction_sum}, "
1091
+ "num_channels={self.num_channels}, channel_dim={self.channel_dim}"
1092
+ )
1093
+
1094
+ def _find_direction_coeffs(
1095
+ self, x: Tensor, prev_direction: Tensor
1096
+ ) -> Tuple[Tensor, Tensor, Tensor]:
1097
+ """
1098
+ Figure out (an approximation to) the proportion of the variance of a set of
1099
+ feature vectors that can be attributed to the top eigen-direction.
1100
+ Args:
1101
+ x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
1102
+ prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
1103
+ of the top eigen-direction, or a random direction if this is the first
1104
+ iteration. Does not have to be normalized, but should be nonzero.
1105
+
1106
+ Returns: (cur_direction, coeffs), where:
1107
+ cur_direction: a Tensor of shape (num_channels,) that is the current
1108
+ estimate of the top eigen-direction.
1109
+ coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
1110
+ approximately minimizes, (x - coeffs * cur_direction).norm()
1111
+ """
1112
+ (num_frames, num_channels) = x.shape
1113
+ assert num_channels > 1 and num_frames > 1
1114
+ assert prev_direction.shape == (num_channels,)
1115
+ # `coeffs` are the coefficients of `prev_direction` in x.
1116
+ # actually represent the coeffs up to a constant positive factor.
1117
+ coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
1118
+ cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20)
1119
+ return cur_direction, coeffs
1120
+
1121
+
1122
+ class DoubleSwishFunction(torch.autograd.Function):
1123
+ """
1124
+ double_swish(x) = x * torch.sigmoid(x-1)
1125
+ This is a definition, originally motivated by its close numerical
1126
+ similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
1127
+
1128
+ Memory-efficient derivative computation:
1129
+ double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
1130
+ double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
1131
+ Now, s'(x) = s(x) * (1-s(x)).
1132
+ double_swish'(x) = x * s'(x) + s(x).
1133
+ = x * s(x) * (1-s(x)) + s(x).
1134
+ = double_swish(x) * (1-s(x)) + s(x)
1135
+ ... so we just need to remember s(x) but not x itself.
1136
+ """
1137
+
1138
+ @staticmethod
1139
+ def forward(ctx, x: Tensor) -> Tensor:
1140
+ requires_grad = x.requires_grad
1141
+ x_dtype = x.dtype
1142
+ if x.dtype == torch.float16:
1143
+ x = x.to(torch.float32)
1144
+
1145
+ s = torch.sigmoid(x - 1.0)
1146
+ y = x * s
1147
+
1148
+ if requires_grad:
1149
+ deriv = y * (1 - s) + s
1150
+ # notes on derivative of x * sigmoid(x - 1):
1151
+ # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
1152
+ # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
1153
+ # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
1154
+ # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
1155
+ # floors), should be expectation-preserving.
1156
+ floor = -0.043637
1157
+ ceil = 1.2
1158
+ d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
1159
+ deriv
1160
+ )
1161
+ if __name__ == "__main__":
1162
+ # for self-testing only.
1163
+ assert d_scaled.min() >= 0.0
1164
+ assert d_scaled.max() < 256.0
1165
+ d_int = d_scaled.to(torch.uint8)
1166
+ ctx.save_for_backward(d_int)
1167
+ if x.dtype == torch.float16 or torch.is_autocast_enabled():
1168
+ y = y.to(torch.float16)
1169
+ return y
1170
+
1171
+ @staticmethod
1172
+ def backward(ctx, y_grad: Tensor) -> Tensor:
1173
+ (d,) = ctx.saved_tensors
1174
+ # the same constants as used in forward pass.
1175
+ floor = -0.043637
1176
+ ceil = 1.2
1177
+ d = d * ((ceil - floor) / 255.0) + floor
1178
+ return y_grad * d
1179
+
1180
+
1181
+ class DoubleSwish(torch.nn.Module):
1182
+ def forward(self, x: Tensor) -> Tensor:
1183
+ """Return double-swish activation function which is an approximation to Swish(Swish(x)),
1184
+ that we approximate closely with x * sigmoid(x-1).
1185
+ """
1186
+ if torch.jit.is_scripting() or torch.jit.is_tracing():
1187
+ return x * torch.sigmoid(x - 1.0)
1188
+ return DoubleSwishFunction.apply(x)
1189
+
1190
+
1191
+ def BalancedDoubleSwish(
1192
+ d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
1193
+ ) -> nn.Sequential:
1194
+ """
1195
+ ActivationBalancer -> DoubleSwish
1196
+ """
1197
+ balancer = ActivationBalancer(
1198
+ d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
1199
+ )
1200
+ return nn.Sequential(
1201
+ balancer,
1202
+ DoubleSwish(),
1203
+ )
1204
+
1205
+
1206
+ def _test_max_eig():
1207
+ for proportion in [0.1, 0.5, 10.0]:
1208
+ logging.info(f"proportion = {proportion}")
1209
+ x = torch.randn(100, 128)
1210
+ direction = torch.randn(128)
1211
+ coeffs = torch.randn(100, 1)
1212
+ x += proportion * direction * coeffs
1213
+
1214
+ x.requires_grad = True
1215
+
1216
+ num_channels = 128
1217
+ m = MaxEig(
1218
+ num_channels,
1219
+ 1,
1220
+ 0.5,
1221
+ scale=0.1, # channel_dim # max_var_per_eig
1222
+ ) # grad_scale
1223
+
1224
+ for _ in range(4):
1225
+ y = m(x)
1226
+
1227
+ y_grad = torch.randn_like(x)
1228
+ y.backward(gradient=y_grad)
1229
+
1230
+ if proportion < 0.2:
1231
+ assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
1232
+ elif proportion > 1.0:
1233
+ assert not torch.allclose(x.grad, y_grad)
1234
+
1235
+
1236
+ def _test_whiten():
1237
+ for proportion in [0.1, 0.5, 10.0]:
1238
+ logging.info(f"_test_whiten(): proportion = {proportion}")
1239
+ x = torch.randn(100, 128)
1240
+ direction = torch.randn(128)
1241
+ coeffs = torch.randn(100, 1)
1242
+ x += proportion * direction * coeffs
1243
+
1244
+ x.requires_grad = True
1245
+
1246
+ num_channels = 128
1247
+ m = Whiten(
1248
+ 1,
1249
+ 5.0,
1250
+ prob=1.0,
1251
+ grad_scale=0.1, # num_groups # whitening_limit,
1252
+ ) # grad_scale
1253
+
1254
+ for _ in range(4):
1255
+ y = m(x)
1256
+
1257
+ y_grad = torch.randn_like(x)
1258
+ y.backward(gradient=y_grad)
1259
+
1260
+ if proportion < 0.2:
1261
+ assert torch.allclose(x.grad, y_grad)
1262
+ elif proportion > 1.0:
1263
+ assert not torch.allclose(x.grad, y_grad)
1264
+
1265
+
1266
+ def _test_activation_balancer_sign():
1267
+ probs = torch.arange(0, 1, 0.01)
1268
+ N = 1000
1269
+ x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
1270
+ x = x.detach()
1271
+ x.requires_grad = True
1272
+ m = ActivationBalancer(
1273
+ probs.numel(),
1274
+ channel_dim=0,
1275
+ min_positive=0.05,
1276
+ max_positive=0.95,
1277
+ max_factor=0.2,
1278
+ min_abs=0.0,
1279
+ )
1280
+
1281
+ y_grad = torch.sign(torch.randn(probs.numel(), N))
1282
+
1283
+ y = m(x)
1284
+ y.backward(gradient=y_grad)
1285
+ print("_test_activation_balancer_sign: x = ", x)
1286
+ print("_test_activation_balancer_sign: y grad = ", y_grad)
1287
+ print("_test_activation_balancer_sign: x grad = ", x.grad)
1288
+
1289
+
1290
+ def _test_activation_balancer_magnitude():
1291
+ magnitudes = torch.arange(0, 1, 0.01)
1292
+ N = 1000
1293
+ x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
1294
+ x = x.detach()
1295
+ x.requires_grad = True
1296
+ m = ActivationBalancer(
1297
+ magnitudes.numel(),
1298
+ channel_dim=0,
1299
+ min_positive=0.0,
1300
+ max_positive=1.0,
1301
+ max_factor=0.2,
1302
+ min_abs=0.2,
1303
+ max_abs=0.8,
1304
+ min_prob=1.0,
1305
+ )
1306
+
1307
+ y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
1308
+
1309
+ y = m(x)
1310
+ y.backward(gradient=y_grad)
1311
+ print("_test_activation_balancer_magnitude: x = ", x)
1312
+ print("_test_activation_balancer_magnitude: y grad = ", y_grad)
1313
+ print("_test_activation_balancer_magnitude: x grad = ", x.grad)
1314
+
1315
+
1316
+ def _test_basic_norm():
1317
+ num_channels = 128
1318
+ m = BasicNorm(num_channels=num_channels, channel_dim=1)
1319
+
1320
+ x = torch.randn(500, num_channels)
1321
+
1322
+ y = m(x)
1323
+
1324
+ assert y.shape == x.shape
1325
+ x_rms = (x**2).mean().sqrt()
1326
+ y_rms = (y**2).mean().sqrt()
1327
+ print("x rms = ", x_rms)
1328
+ print("y rms = ", y_rms)
1329
+ assert y_rms < x_rms
1330
+ assert y_rms > 0.5 * x_rms
1331
+
1332
+
1333
+ def _test_double_swish_deriv():
1334
+ x = torch.randn(10, 12, dtype=torch.double) * 3.0
1335
+ x.requires_grad = True
1336
+ m = DoubleSwish()
1337
+
1338
+ tol = (1.2 - (-0.043637)) / 255.0
1339
+ torch.autograd.gradcheck(m, x, atol=tol)
1340
+
1341
+ # for self-test.
1342
+ x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
1343
+ x.requires_grad = True
1344
+ y = m(x)
1345
+
1346
+
1347
+ def _test_softmax():
1348
+ a = torch.randn(2, 10, dtype=torch.float64)
1349
+ b = a.clone()
1350
+ a.requires_grad = True
1351
+ b.requires_grad = True
1352
+ a.softmax(dim=1)[:, 0].sum().backward()
1353
+ print("a grad = ", a.grad)
1354
+ softmax(b, dim=1)[:, 0].sum().backward()
1355
+ print("b grad = ", b.grad)
1356
+ assert torch.allclose(a.grad, b.grad)
1357
+
1358
+
1359
+ if __name__ == "__main__":
1360
+ logging.getLogger().setLevel(logging.INFO)
1361
+ torch.set_num_threads(1)
1362
+ torch.set_num_interop_threads(1)
1363
+ _test_softmax()
1364
+ _test_whiten()
1365
+ _test_max_eig()
1366
+ _test_activation_balancer_sign()
1367
+ _test_activation_balancer_magnitude()
1368
+ _test_basic_norm()
1369
+ _test_double_swish_deriv()
apps/audio_cloning/vallex/modules/scheduler.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright 2023 (authors: Feiteng Li)
3
+ #
4
+ # See ../../../../LICENSE for clarification regarding multiple authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ import torch
20
+
21
+ from modules.optim import Eden
22
+
23
+
24
+ def calc_lr(step, dim_embed, warmup_steps):
25
+ return dim_embed ** (-0.5) * min(
26
+ step ** (-0.5), step * warmup_steps ** (-1.5)
27
+ )
28
+
29
+
30
+ class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
31
+ def __init__(
32
+ self,
33
+ base_lr: float,
34
+ optimizer: torch.optim.Optimizer,
35
+ dim_embed: int,
36
+ warmup_steps: int,
37
+ last_epoch: int = -1,
38
+ verbose: bool = False,
39
+ ) -> None:
40
+
41
+ self.dim_embed = dim_embed
42
+ self.base_lr = base_lr
43
+ self.warmup_steps = warmup_steps
44
+ self.num_param_groups = len(optimizer.param_groups)
45
+
46
+ super().__init__(optimizer, last_epoch, verbose)
47
+
48
+ def get_lr(self) -> float:
49
+ lr = self.base_lr * calc_lr(
50
+ self._step_count, self.dim_embed, self.warmup_steps
51
+ )
52
+ return [lr] * self.num_param_groups
53
+
54
+ def set_step(self, step: int):
55
+ self._step_count = step
56
+
57
+
58
+ def get_scheduler(params, optimizer):
59
+ if params.scheduler_name.lower() == "eden":
60
+ scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
61
+ elif params.scheduler_name.lower() == "noam":
62
+ scheduler = NoamScheduler(
63
+ params.base_lr,
64
+ optimizer,
65
+ params.decoder_dim,
66
+ warmup_steps=params.warmup_steps,
67
+ )
68
+ # scheduler.set_step(params.start_batch or params.batch_idx_train)
69
+ elif params.scheduler_name.lower() == "cosine":
70
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
71
+ params.warmup_steps,
72
+ optimizer,
73
+ eta_min=params.base_lr,
74
+ )
75
+ else:
76
+ raise NotImplementedError(f"{params.scheduler_name}")
77
+
78
+ return scheduler
apps/audio_cloning/vallex/modules/transformer.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import numbers
3
+ from functools import partial
4
+ from typing import Any, Callable, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ from torch import Tensor, nn
8
+ from torch.nn import functional as F
9
+
10
+ from .activation import MultiheadAttention
11
+ from .scaling import ActivationBalancer, BalancedDoubleSwish
12
+ from .scaling import BasicNorm as _BasicNorm
13
+
14
+ _shape_t = Union[int, List[int], torch.Size]
15
+
16
+
17
+ class LayerNorm(nn.Module):
18
+ __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
19
+ normalized_shape: Tuple[int, ...]
20
+ eps: float
21
+ elementwise_affine: bool
22
+
23
+ def __init__(
24
+ self,
25
+ normalized_shape: _shape_t,
26
+ eps: float = 1e-5,
27
+ elementwise_affine: bool = True,
28
+ device=None,
29
+ dtype=None,
30
+ ) -> None:
31
+ factory_kwargs = {"device": device, "dtype": dtype}
32
+ super(LayerNorm, self).__init__()
33
+ if isinstance(normalized_shape, numbers.Integral):
34
+ # mypy error: incompatible types in assignment
35
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
36
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
37
+ self.eps = eps
38
+ self.elementwise_affine = elementwise_affine
39
+ if self.elementwise_affine:
40
+ self.weight = nn.Parameter(
41
+ torch.empty(self.normalized_shape, **factory_kwargs)
42
+ )
43
+ self.bias = nn.Parameter(
44
+ torch.empty(self.normalized_shape, **factory_kwargs)
45
+ )
46
+ else:
47
+ self.register_parameter("weight", None)
48
+ self.register_parameter("bias", None)
49
+
50
+ self.reset_parameters()
51
+
52
+ def reset_parameters(self) -> None:
53
+ if self.elementwise_affine:
54
+ nn.init.ones_(self.weight)
55
+ nn.init.zeros_(self.bias)
56
+
57
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
58
+ if isinstance(input, tuple):
59
+ input, embedding = input
60
+ return (
61
+ F.layer_norm(
62
+ input,
63
+ self.normalized_shape,
64
+ self.weight,
65
+ self.bias,
66
+ self.eps,
67
+ ),
68
+ embedding,
69
+ )
70
+
71
+ assert embedding is None
72
+ return F.layer_norm(
73
+ input, self.normalized_shape, self.weight, self.bias, self.eps
74
+ )
75
+
76
+ def extra_repr(self) -> str:
77
+ return (
78
+ "{normalized_shape}, eps={eps}, "
79
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
80
+ )
81
+
82
+
83
+ class AdaptiveLayerNorm(nn.Module):
84
+ r"""Adaptive Layer Normalization"""
85
+
86
+ def __init__(self, d_model, norm) -> None:
87
+ super(AdaptiveLayerNorm, self).__init__()
88
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
89
+ self.norm = norm
90
+ self.d_model = d_model
91
+ self.eps = self.norm.eps
92
+
93
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
94
+ if isinstance(input, tuple):
95
+ input, embedding = input
96
+ weight, bias = torch.split(
97
+ self.project_layer(embedding),
98
+ split_size_or_sections=self.d_model,
99
+ dim=-1,
100
+ )
101
+ return (weight * self.norm(input) + bias, embedding)
102
+
103
+ weight, bias = torch.split(
104
+ self.project_layer(embedding),
105
+ split_size_or_sections=self.d_model,
106
+ dim=-1,
107
+ )
108
+ return weight * self.norm(input) + bias
109
+
110
+
111
+ class BasicNorm(_BasicNorm):
112
+ def __init__(
113
+ self,
114
+ d_model: int,
115
+ eps: float = 1e-5,
116
+ device=None,
117
+ dtype=None,
118
+ ):
119
+ super(BasicNorm, self).__init__(d_model, eps=eps)
120
+
121
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
122
+ if isinstance(input, tuple):
123
+ input, embedding = input
124
+ return (
125
+ super(BasicNorm, self).forward(input),
126
+ embedding,
127
+ )
128
+
129
+ assert embedding is None
130
+ return super(BasicNorm, self).forward(input)
131
+
132
+
133
+ class BalancedBasicNorm(nn.Module):
134
+ def __init__(
135
+ self,
136
+ d_model: int,
137
+ eps: float = 1e-5,
138
+ device=None,
139
+ dtype=None,
140
+ ):
141
+ super(BalancedBasicNorm, self).__init__()
142
+ self.balancer = ActivationBalancer(
143
+ d_model,
144
+ channel_dim=-1,
145
+ min_positive=0.45,
146
+ max_positive=0.55,
147
+ max_abs=6.0,
148
+ )
149
+ self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
150
+
151
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
152
+ if isinstance(input, tuple):
153
+ input, embedding = input
154
+ return self.norm((self.balancer(input), embedding))
155
+
156
+ assert embedding is None
157
+ return self.norm(self.balancer(input))
158
+
159
+
160
+ class IdentityNorm(nn.Module):
161
+ def __init__(
162
+ self,
163
+ d_model: int,
164
+ eps: float = 1e-5,
165
+ device=None,
166
+ dtype=None,
167
+ ) -> None:
168
+ super(IdentityNorm, self).__init__()
169
+
170
+ def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
171
+ if isinstance(input, tuple):
172
+ return input
173
+
174
+ assert embedding is None
175
+ return input
176
+
177
+
178
+ class TransformerEncoderLayer(nn.Module):
179
+ __constants__ = ["batch_first", "norm_first"]
180
+
181
+ def __init__(
182
+ self,
183
+ d_model: int,
184
+ nhead: int,
185
+ dim_feedforward: int = 2048,
186
+ dropout: float = 0.1,
187
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
188
+ batch_first: bool = False,
189
+ norm_first: bool = False,
190
+ device=None,
191
+ dtype=None,
192
+ linear1_self_attention_cls: nn.Module = nn.Linear,
193
+ linear2_self_attention_cls: nn.Module = nn.Linear,
194
+ linear1_feedforward_cls: nn.Module = nn.Linear,
195
+ linear2_feedforward_cls: nn.Module = nn.Linear,
196
+ layer_norm_cls: nn.Module = LayerNorm,
197
+ layer_norm_eps: float = 1e-5,
198
+ adaptive_layer_norm=False,
199
+ ) -> None:
200
+ factory_kwargs = {"device": device, "dtype": dtype}
201
+ super(TransformerEncoderLayer, self).__init__()
202
+ self.self_attn = MultiheadAttention(
203
+ d_model,
204
+ nhead,
205
+ dropout=dropout,
206
+ batch_first=batch_first,
207
+ linear1_cls=linear1_self_attention_cls,
208
+ linear2_cls=linear2_self_attention_cls,
209
+ **factory_kwargs,
210
+ )
211
+
212
+ # Implementation of Feedforward model
213
+ self.linear1 = linear1_feedforward_cls(
214
+ d_model, dim_feedforward, **factory_kwargs
215
+ )
216
+ self.dropout = nn.Dropout(dropout)
217
+ self.linear2 = linear2_feedforward_cls(
218
+ dim_feedforward, d_model, **factory_kwargs
219
+ )
220
+
221
+ self.norm_first = norm_first
222
+ self.dropout1 = nn.Dropout(dropout)
223
+ self.dropout2 = nn.Dropout(dropout)
224
+
225
+ # Legacy string support for activation function.
226
+ if isinstance(activation, str):
227
+ activation = _get_activation_fn(activation)
228
+ elif isinstance(activation, partial):
229
+ activation = activation(d_model)
230
+ elif activation == BalancedDoubleSwish:
231
+ activation = BalancedDoubleSwish(d_model)
232
+
233
+ # # We can't test self.activation in forward() in TorchScript,
234
+ # # so stash some information about it instead.
235
+ # if activation is F.relu or isinstance(activation, torch.nn.ReLU):
236
+ # self.activation_relu_or_gelu = 1
237
+ # elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
238
+ # self.activation_relu_or_gelu = 2
239
+ # else:
240
+ # self.activation_relu_or_gelu = 0
241
+ self.activation = activation
242
+
243
+ norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
244
+ if layer_norm_cls == IdentityNorm:
245
+ norm2 = BalancedBasicNorm(
246
+ d_model, eps=layer_norm_eps, **factory_kwargs
247
+ )
248
+ else:
249
+ norm2 = layer_norm_cls(
250
+ d_model, eps=layer_norm_eps, **factory_kwargs
251
+ )
252
+
253
+ if adaptive_layer_norm:
254
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
255
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
256
+ else:
257
+ self.norm1 = norm1
258
+ self.norm2 = norm2
259
+
260
+ def __setstate__(self, state):
261
+ super(TransformerEncoderLayer, self).__setstate__(state)
262
+ if not hasattr(self, "activation"):
263
+ self.activation = F.relu
264
+
265
+ def forward(
266
+ self,
267
+ src: Tensor,
268
+ src_mask: Optional[Tensor] = None,
269
+ src_key_padding_mask: Optional[Tensor] = None,
270
+ ) -> Tensor:
271
+ r"""Pass the input through the encoder layer.
272
+
273
+ Args:
274
+ src: the sequence to the encoder layer (required).
275
+ src_mask: the mask for the src sequence (optional).
276
+ src_key_padding_mask: the mask for the src keys per batch (optional).
277
+
278
+ Shape:
279
+ see the docs in Transformer class.
280
+ """
281
+ x, stage_embedding = src, None
282
+ is_src_tuple = False
283
+ if isinstance(src, tuple):
284
+ x, stage_embedding = src
285
+ is_src_tuple = True
286
+
287
+ if src_key_padding_mask is not None:
288
+ _skpm_dtype = src_key_padding_mask.dtype
289
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
290
+ src_key_padding_mask
291
+ ):
292
+ raise AssertionError(
293
+ "only bool and floating types of key_padding_mask are supported"
294
+ )
295
+
296
+ if self.norm_first:
297
+ x = x + self._sa_block(
298
+ self.norm1(x, stage_embedding),
299
+ src_mask,
300
+ src_key_padding_mask,
301
+ )
302
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
303
+ else:
304
+ x = self.norm1(
305
+ x + self._sa_block(x, src_mask, src_key_padding_mask),
306
+ stage_embedding,
307
+ )
308
+ x = self.norm2(x + self._ff_block(x), stage_embedding)
309
+
310
+ if is_src_tuple:
311
+ return (x, stage_embedding)
312
+ return x
313
+
314
+ def infer(
315
+ self,
316
+ src: Tensor,
317
+ src_mask: Optional[Tensor] = None,
318
+ src_key_padding_mask: Optional[Tensor] = None,
319
+ past_kv: Optional[Tensor] = None,
320
+ use_cache: bool = False,
321
+ ):
322
+ x, stage_embedding = src, None
323
+ is_src_tuple = False
324
+ if isinstance(src, tuple):
325
+ x, stage_embedding = src
326
+ is_src_tuple = True
327
+
328
+ if src_key_padding_mask is not None:
329
+ _skpm_dtype = src_key_padding_mask.dtype
330
+ if _skpm_dtype != torch.bool and not torch.is_floating_point(
331
+ src_key_padding_mask
332
+ ):
333
+ raise AssertionError(
334
+ "only bool and floating types of key_padding_mask are supported"
335
+ )
336
+
337
+ if self.norm_first:
338
+ x_attn_out, kv = self.self_attn.infer(
339
+ self.norm1(x, stage_embedding),
340
+ attn_mask=src_mask,
341
+ key_padding_mask=src_key_padding_mask,
342
+ need_weights=False,
343
+ past_kv=past_kv,
344
+ use_cache=use_cache,
345
+ )
346
+ x = x + x_attn_out
347
+ x = x + self._ff_block(self.norm2(x, stage_embedding))
348
+
349
+ if is_src_tuple:
350
+ return (x, stage_embedding)
351
+ return (x, kv)
352
+
353
+ # self-attention block
354
+ def _sa_block(
355
+ self,
356
+ x: Tensor,
357
+ attn_mask: Optional[Tensor],
358
+ key_padding_mask: Optional[Tensor],
359
+ ) -> Tensor:
360
+ x = self.self_attn(
361
+ x,
362
+ x,
363
+ x,
364
+ attn_mask=attn_mask,
365
+ key_padding_mask=key_padding_mask,
366
+ need_weights=False,
367
+ )[0]
368
+ return self.dropout1(x)
369
+
370
+ # feed forward block
371
+ def _ff_block(self, x: Tensor) -> Tensor:
372
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
373
+ return self.dropout2(x)
374
+
375
+
376
+ class TransformerEncoder(nn.Module):
377
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
378
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
379
+
380
+ Args:
381
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
382
+ num_layers: the number of sub-encoder-layers in the encoder (required).
383
+ norm: the layer normalization component (optional).
384
+ enable_nested_tensor: if True, input will automatically convert to nested tensor
385
+ (and convert back on output). This will improve the overall performance of
386
+ TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
387
+
388
+ Examples::
389
+ >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
390
+ >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
391
+ >>> src = torch.rand(10, 32, 512)
392
+ >>> out = transformer_encoder(src)
393
+ """
394
+ __constants__ = ["norm"]
395
+
396
+ def __init__(self, encoder_layer, num_layers, norm=None):
397
+ super(TransformerEncoder, self).__init__()
398
+ self.layers = _get_clones(encoder_layer, num_layers)
399
+ self.num_layers = num_layers
400
+ self.norm = norm
401
+
402
+ def forward(
403
+ self,
404
+ src: Tensor,
405
+ mask: Optional[Tensor] = None,
406
+ src_key_padding_mask: Optional[Tensor] = None,
407
+ return_layer_states: bool = False,
408
+ ) -> Tensor:
409
+ r"""Pass the input through the encoder layers in turn.
410
+
411
+ Args:
412
+ src: the sequence to the encoder (required).
413
+ mask: the mask for the src sequence (optional).
414
+ src_key_padding_mask: the mask for the src keys per batch (optional).
415
+ return_layer_states: return layers' state (optional).
416
+
417
+ Shape:
418
+ see the docs in Transformer class.
419
+ """
420
+ if return_layer_states:
421
+ layer_states = [] # layers' output
422
+ output = src
423
+ for mod in self.layers:
424
+ output = mod(
425
+ output,
426
+ src_mask=mask,
427
+ src_key_padding_mask=src_key_padding_mask,
428
+ )
429
+ layer_states.append(output[0])
430
+
431
+ if self.norm is not None:
432
+ output = self.norm(output)
433
+
434
+ return layer_states, output
435
+
436
+ output = src
437
+ for mod in self.layers:
438
+ output = mod(
439
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
440
+ )
441
+
442
+ if self.norm is not None:
443
+ output = self.norm(output)
444
+
445
+ return output
446
+
447
+ def infer(
448
+ self,
449
+ src: Tensor,
450
+ mask: Optional[Tensor] = None,
451
+ src_key_padding_mask: Optional[Tensor] = None,
452
+ return_layer_states: bool = False,
453
+ past_kv: Optional[Tensor] = None,
454
+ use_cache: bool = False,
455
+ ):
456
+ if past_kv is None:
457
+ past_length = 0
458
+ past_kv = tuple([None] * self.num_layers)
459
+ else:
460
+ past_length = past_kv[0][0].size(-2)
461
+ new_kv = () if use_cache else None
462
+ output = src
463
+ for mod, past_layer_kv in zip(self.layers, past_kv):
464
+ output, kv = mod.infer(
465
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache
466
+ )
467
+ if use_cache:
468
+ new_kv = new_kv + (kv,)
469
+
470
+ if self.norm is not None:
471
+ output = self.norm(output)
472
+
473
+ return output, new_kv
474
+
475
+
476
+ class TransformerDecoderLayer(nn.Module):
477
+ __constants__ = ["batch_first", "norm_first"]
478
+
479
+ def __init__(
480
+ self,
481
+ d_model: int,
482
+ nhead: int,
483
+ dim_feedforward: int = 2048,
484
+ dropout: float = 0.1,
485
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
486
+ linear1_self_attention_cls: nn.Module = nn.Linear,
487
+ linear2_self_attention_cls: nn.Module = nn.Linear,
488
+ linear1_feedforward_cls: nn.Module = nn.Linear,
489
+ linear2_feedforward_cls: nn.Module = nn.Linear,
490
+ batch_first: bool = False,
491
+ norm_first: bool = False,
492
+ device=None,
493
+ dtype=None,
494
+ layer_norm_cls: nn.Module = LayerNorm,
495
+ layer_norm_eps: float = 1e-5,
496
+ adaptive_layer_norm=False,
497
+ ) -> None:
498
+ factory_kwargs = {"device": device, "dtype": dtype}
499
+ super(TransformerDecoderLayer, self).__init__()
500
+ self.self_attn = MultiheadAttention(
501
+ d_model,
502
+ nhead,
503
+ dropout=dropout,
504
+ batch_first=batch_first,
505
+ linear1_cls=linear1_self_attention_cls,
506
+ linear2_cls=linear2_self_attention_cls,
507
+ **factory_kwargs,
508
+ )
509
+ self.multihead_attn = MultiheadAttention(
510
+ d_model,
511
+ nhead,
512
+ dropout=dropout,
513
+ batch_first=batch_first,
514
+ linear1_cls=linear1_self_attention_cls,
515
+ linear2_cls=linear2_self_attention_cls,
516
+ **factory_kwargs,
517
+ )
518
+ # Implementation of Feedforward model
519
+ self.linear1 = linear1_feedforward_cls(
520
+ d_model, dim_feedforward, **factory_kwargs
521
+ )
522
+ self.dropout = nn.Dropout(dropout)
523
+ self.linear2 = linear2_feedforward_cls(
524
+ dim_feedforward, d_model, **factory_kwargs
525
+ )
526
+
527
+ self.norm_first = norm_first
528
+ self.dropout1 = nn.Dropout(dropout)
529
+ self.dropout2 = nn.Dropout(dropout)
530
+ self.dropout3 = nn.Dropout(dropout)
531
+
532
+ # Legacy string support for activation function.
533
+ if isinstance(activation, str):
534
+ self.activation = _get_activation_fn(activation)
535
+ elif isinstance(activation, partial):
536
+ self.activation = activation(d_model)
537
+ elif activation == BalancedDoubleSwish:
538
+ self.activation = BalancedDoubleSwish(d_model)
539
+ else:
540
+ self.activation = activation
541
+
542
+ if adaptive_layer_norm:
543
+ norm1 = layer_norm_cls(
544
+ d_model, eps=layer_norm_eps, **factory_kwargs
545
+ )
546
+ norm2 = layer_norm_cls(
547
+ d_model, eps=layer_norm_eps, **factory_kwargs
548
+ )
549
+ norm3 = layer_norm_cls(
550
+ d_model, eps=layer_norm_eps, **factory_kwargs
551
+ )
552
+
553
+ self.norm1 = AdaptiveLayerNorm(d_model, norm1)
554
+ self.norm2 = AdaptiveLayerNorm(d_model, norm2)
555
+ self.norm3 = AdaptiveLayerNorm(d_model, norm3)
556
+ else:
557
+ self.norm1 = layer_norm_cls(
558
+ d_model, eps=layer_norm_eps, **factory_kwargs
559
+ )
560
+ self.norm2 = layer_norm_cls(
561
+ d_model, eps=layer_norm_eps, **factory_kwargs
562
+ )
563
+ if layer_norm_cls == IdentityNorm:
564
+ self.norm3 = BalancedBasicNorm(
565
+ d_model, eps=layer_norm_eps, **factory_kwargs
566
+ )
567
+ else:
568
+ self.norm3 = layer_norm_cls(
569
+ d_model, eps=layer_norm_eps, **factory_kwargs
570
+ )
571
+
572
+ def forward(
573
+ self,
574
+ tgt: Tensor,
575
+ memory: Tensor,
576
+ tgt_mask: Optional[Tensor] = None,
577
+ memory_mask: Optional[Tensor] = None,
578
+ tgt_key_padding_mask: Optional[Tensor] = None,
579
+ memory_key_padding_mask: Optional[Tensor] = None,
580
+ ) -> Tensor:
581
+ r"""Pass the inputs (and mask) through the decoder layer.
582
+
583
+ Args:
584
+ tgt: the sequence to the decoder layer (required).
585
+ memory: the sequence from the last layer of the encoder (required).
586
+ tgt_mask: the mask for the tgt sequence (optional).
587
+ memory_mask: the mask for the memory sequence (optional).
588
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
589
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
590
+
591
+ Shape:
592
+ see the docs in Transformer class.
593
+ """
594
+ tgt_is_tuple = False
595
+ if isinstance(tgt, tuple):
596
+ x, stage_embedding = tgt
597
+ tgt_is_tuple = True
598
+ else:
599
+ x, stage_embedding = tgt, None
600
+
601
+ if self.norm_first:
602
+ x = x + self._sa_block(
603
+ self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
604
+ )
605
+ x = x + self._mha_block(
606
+ self.norm2(x, stage_embedding),
607
+ memory,
608
+ memory_mask,
609
+ memory_key_padding_mask,
610
+ )
611
+ x = x + self._ff_block(self.norm3(x, stage_embedding))
612
+ else:
613
+ x = self.norm1(
614
+ x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
615
+ stage_embedding,
616
+ )
617
+ x = self.norm2(
618
+ x
619
+ + self._mha_block(
620
+ x, memory, memory_mask, memory_key_padding_mask
621
+ ),
622
+ stage_embedding,
623
+ )
624
+ x = self.norm3(x + self._ff_block(x), stage_embedding)
625
+
626
+ if tgt_is_tuple:
627
+ return (x, stage_embedding)
628
+ return x
629
+
630
+ # self-attention block
631
+ def _sa_block(
632
+ self,
633
+ x: Tensor,
634
+ attn_mask: Optional[Tensor],
635
+ key_padding_mask: Optional[Tensor],
636
+ ) -> Tensor:
637
+ x = self.self_attn(
638
+ x,
639
+ x,
640
+ x,
641
+ attn_mask=attn_mask,
642
+ key_padding_mask=key_padding_mask,
643
+ need_weights=False,
644
+ )[0]
645
+ return self.dropout1(x)
646
+
647
+ # multihead attention block
648
+ def _mha_block(
649
+ self,
650
+ x: Tensor,
651
+ mem: Tensor,
652
+ attn_mask: Optional[Tensor],
653
+ key_padding_mask: Optional[Tensor],
654
+ ) -> Tensor:
655
+ x = self.multihead_attn(
656
+ x,
657
+ mem,
658
+ mem,
659
+ attn_mask=attn_mask,
660
+ key_padding_mask=key_padding_mask,
661
+ need_weights=False,
662
+ )[0]
663
+ return self.dropout2(x)
664
+
665
+ # feed forward block
666
+ def _ff_block(self, x: Tensor) -> Tensor:
667
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
668
+ return self.dropout3(x)
669
+
670
+
671
+ def _get_clones(module, N):
672
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
673
+
674
+
675
+ def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
676
+ if activation == "relu":
677
+ return F.relu
678
+ elif activation == "gelu":
679
+ return F.gelu
680
+
681
+ raise RuntimeError(
682
+ "activation should be relu/gelu, not {}".format(activation)
683
+ )
apps/audio_cloning/vallex/presets/acou_1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:470ce66fc24a2d14e162343381f7d93ef0a3af51edf5fd37240c21f492b4e769
3
+ size 15650
apps/audio_cloning/vallex/presets/acou_2.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec1c5328751cadeed5356d4264759799ad96d33ea8dd4f8a3d0a80dd8ddb0e74
3
+ size 15426
apps/audio_cloning/vallex/presets/acou_3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03f241b094a32b3f542e74374183c6d15e8b70ae73ceeafb11bfd4ee6b8b4a3a
3
+ size 15410
apps/audio_cloning/vallex/presets/acou_4.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52b96f32863f13f84cf7ac4a27d2bc95cea70c350a037f4d1890b20b8da9501e
3
+ size 15506
apps/audio_cloning/vallex/presets/alan.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28838c3f0b2f9f315b34e9b940f30641306f0cadc5c527857cd1cc408547ed1c
3
+ size 50002
apps/audio_cloning/vallex/presets/amused.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df3e882f3a62805b9aaf300d81822cd4eddeafee480503b7b78e32be2085fb11
3
+ size 20882
apps/audio_cloning/vallex/presets/anger.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:959cec6dc0b30219db0d70cdd165fe00bbdc098165cf9d67ccdd1ecf7a5da5be
3
+ size 22090
apps/audio_cloning/vallex/presets/babara.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8106b2a98c3f70587f23ab46ed5bf73b1c9a770481c3620ab140bd3256010376
3
+ size 11526
apps/audio_cloning/vallex/presets/bronya.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02eaada2c3d58866c813887ed9f871587ef5a7e976abc23382ce46a17b208001
3
+ size 18106
apps/audio_cloning/vallex/presets/cafe.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d78d96f5829da8f69c327ff25958da5b451305fdc9c308f7e67f13cf8d640fea
3
+ size 22442
apps/audio_cloning/vallex/presets/dingzhen.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d19167c65eefef5e42dfaa1919ff5149ca0a93cb052396a47d1f42f9865f5f8
3
+ size 18154
apps/audio_cloning/vallex/presets/disgust.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4443f0a395072700f2ec6101dbf2ad9d28968aa3e5809e384ea131832f894d7f
3
+ size 39386