Spaces:
Sleeping
Sleeping
nukopy commited on
Commit ·
5becb6b
1
Parent(s): 725afe9
Add audio files with Git LFS support
Browse files- Add .wav, .ogg, .mp3, .flac to .gitattributes for LFS tracking
- Migrate existing audio files to Git LFS
- This resolves the binary file rejection from Hugging Face Spaces
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- .gitignore +219 -0
- Makefile +5 -0
- README.md +2 -0
- app.py +18 -0
- apps/audio_cloning/main.py +15 -0
- apps/audio_cloning/vallex/data/__init__.py +3 -0
- apps/audio_cloning/vallex/data/collation.py +107 -0
- apps/audio_cloning/vallex/data/datamodule.py +419 -0
- apps/audio_cloning/vallex/data/dataset.py +242 -0
- apps/audio_cloning/vallex/data/fbank.py +212 -0
- apps/audio_cloning/vallex/data/input_strategies.py +159 -0
- apps/audio_cloning/vallex/data/symbol_table.py +289 -0
- apps/audio_cloning/vallex/data/tokenizer.py +121 -0
- apps/audio_cloning/vallex/descriptions.py +34 -0
- apps/audio_cloning/vallex/examples.py +108 -0
- apps/audio_cloning/vallex/g2p/__init__.py +84 -0
- apps/audio_cloning/vallex/g2p/bpe_1024.json +2049 -0
- apps/audio_cloning/vallex/g2p/bpe_69.json +141 -0
- apps/audio_cloning/vallex/g2p/cleaners.py +76 -0
- apps/audio_cloning/vallex/g2p/english.py +197 -0
- apps/audio_cloning/vallex/g2p/japanese.py +173 -0
- apps/audio_cloning/vallex/g2p/mandarin.py +337 -0
- apps/audio_cloning/vallex/g2p/symbols.py +76 -0
- apps/audio_cloning/vallex/macros.py +34 -0
- apps/audio_cloning/vallex/main.py +461 -0
- apps/audio_cloning/vallex/models/__init__.py +127 -0
- apps/audio_cloning/vallex/models/macros.py +11 -0
- apps/audio_cloning/vallex/models/transformer.py +386 -0
- apps/audio_cloning/vallex/models/vallex.py +823 -0
- apps/audio_cloning/vallex/models/visualizer.py +102 -0
- apps/audio_cloning/vallex/modules/__init__.py +0 -0
- apps/audio_cloning/vallex/modules/activation.py +612 -0
- apps/audio_cloning/vallex/modules/embedding.py +97 -0
- apps/audio_cloning/vallex/modules/optim.py +1105 -0
- apps/audio_cloning/vallex/modules/scaling.py +1369 -0
- apps/audio_cloning/vallex/modules/scheduler.py +78 -0
- apps/audio_cloning/vallex/modules/transformer.py +683 -0
- apps/audio_cloning/vallex/presets/acou_1.npz +3 -0
- apps/audio_cloning/vallex/presets/acou_2.npz +3 -0
- apps/audio_cloning/vallex/presets/acou_3.npz +3 -0
- apps/audio_cloning/vallex/presets/acou_4.npz +3 -0
- apps/audio_cloning/vallex/presets/alan.npz +3 -0
- apps/audio_cloning/vallex/presets/amused.npz +3 -0
- apps/audio_cloning/vallex/presets/anger.npz +3 -0
- apps/audio_cloning/vallex/presets/babara.npz +3 -0
- apps/audio_cloning/vallex/presets/bronya.npz +3 -0
- apps/audio_cloning/vallex/presets/cafe.npz +3 -0
- apps/audio_cloning/vallex/presets/dingzhen.npz +3 -0
- apps/audio_cloning/vallex/presets/disgust.npz +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
# Audio files
|
| 37 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.ogg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.mp3 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.flac filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pretrained models
|
| 2 |
+
models/**/*.pt
|
| 3 |
+
|
| 4 |
+
# Byte-compiled / optimized / DLL files
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[codz]
|
| 7 |
+
*$py.class
|
| 8 |
+
|
| 9 |
+
# C extensions
|
| 10 |
+
*.so
|
| 11 |
+
|
| 12 |
+
# Distribution / packaging
|
| 13 |
+
.Python
|
| 14 |
+
build/
|
| 15 |
+
develop-eggs/
|
| 16 |
+
dist/
|
| 17 |
+
downloads/
|
| 18 |
+
eggs/
|
| 19 |
+
.eggs/
|
| 20 |
+
lib/
|
| 21 |
+
lib64/
|
| 22 |
+
parts/
|
| 23 |
+
sdist/
|
| 24 |
+
var/
|
| 25 |
+
wheels/
|
| 26 |
+
share/python-wheels/
|
| 27 |
+
*.egg-info/
|
| 28 |
+
.installed.cfg
|
| 29 |
+
*.egg
|
| 30 |
+
MANIFEST
|
| 31 |
+
|
| 32 |
+
# PyInstaller
|
| 33 |
+
# Usually these files are written by a python script from a template
|
| 34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 35 |
+
*.manifest
|
| 36 |
+
*.spec
|
| 37 |
+
|
| 38 |
+
# Installer logs
|
| 39 |
+
pip-log.txt
|
| 40 |
+
pip-delete-this-directory.txt
|
| 41 |
+
|
| 42 |
+
# Unit test / coverage reports
|
| 43 |
+
htmlcov/
|
| 44 |
+
.tox/
|
| 45 |
+
.nox/
|
| 46 |
+
.coverage
|
| 47 |
+
.coverage.*
|
| 48 |
+
.cache
|
| 49 |
+
nosetests.xml
|
| 50 |
+
coverage.xml
|
| 51 |
+
*.cover
|
| 52 |
+
*.py.cover
|
| 53 |
+
.hypothesis/
|
| 54 |
+
.pytest_cache/
|
| 55 |
+
cover/
|
| 56 |
+
|
| 57 |
+
# Translations
|
| 58 |
+
*.mo
|
| 59 |
+
*.pot
|
| 60 |
+
|
| 61 |
+
# Django stuff:
|
| 62 |
+
*.log
|
| 63 |
+
local_settings.py
|
| 64 |
+
db.sqlite3
|
| 65 |
+
db.sqlite3-journal
|
| 66 |
+
|
| 67 |
+
# Flask stuff:
|
| 68 |
+
instance/
|
| 69 |
+
.webassets-cache
|
| 70 |
+
|
| 71 |
+
# Scrapy stuff:
|
| 72 |
+
.scrapy
|
| 73 |
+
|
| 74 |
+
# Sphinx documentation
|
| 75 |
+
docs/_build/
|
| 76 |
+
|
| 77 |
+
# PyBuilder
|
| 78 |
+
.pybuilder/
|
| 79 |
+
target/
|
| 80 |
+
|
| 81 |
+
# Jupyter Notebook
|
| 82 |
+
.ipynb_checkpoints
|
| 83 |
+
|
| 84 |
+
# IPython
|
| 85 |
+
profile_default/
|
| 86 |
+
ipython_config.py
|
| 87 |
+
|
| 88 |
+
# pyenv
|
| 89 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 91 |
+
# .python-version
|
| 92 |
+
|
| 93 |
+
# pipenv
|
| 94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 97 |
+
# install all needed dependencies.
|
| 98 |
+
# Pipfile.lock
|
| 99 |
+
|
| 100 |
+
# UV
|
| 101 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 103 |
+
# commonly ignored for libraries.
|
| 104 |
+
# uv.lock
|
| 105 |
+
|
| 106 |
+
# poetry
|
| 107 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 108 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 109 |
+
# commonly ignored for libraries.
|
| 110 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 111 |
+
# poetry.lock
|
| 112 |
+
# poetry.toml
|
| 113 |
+
|
| 114 |
+
# pdm
|
| 115 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 116 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 117 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 118 |
+
# pdm.lock
|
| 119 |
+
# pdm.toml
|
| 120 |
+
.pdm-python
|
| 121 |
+
.pdm-build/
|
| 122 |
+
|
| 123 |
+
# pixi
|
| 124 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 125 |
+
# pixi.lock
|
| 126 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 127 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 128 |
+
.pixi
|
| 129 |
+
|
| 130 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 131 |
+
__pypackages__/
|
| 132 |
+
|
| 133 |
+
# Celery stuff
|
| 134 |
+
celerybeat-schedule
|
| 135 |
+
celerybeat.pid
|
| 136 |
+
|
| 137 |
+
# Redis
|
| 138 |
+
*.rdb
|
| 139 |
+
*.aof
|
| 140 |
+
*.pid
|
| 141 |
+
|
| 142 |
+
# RabbitMQ
|
| 143 |
+
mnesia/
|
| 144 |
+
rabbitmq/
|
| 145 |
+
rabbitmq-data/
|
| 146 |
+
|
| 147 |
+
# ActiveMQ
|
| 148 |
+
activemq-data/
|
| 149 |
+
|
| 150 |
+
# SageMath parsed files
|
| 151 |
+
*.sage.py
|
| 152 |
+
|
| 153 |
+
# Environments
|
| 154 |
+
.env
|
| 155 |
+
.envrc
|
| 156 |
+
.venv
|
| 157 |
+
env/
|
| 158 |
+
venv/
|
| 159 |
+
ENV/
|
| 160 |
+
env.bak/
|
| 161 |
+
venv.bak/
|
| 162 |
+
|
| 163 |
+
# Spyder project settings
|
| 164 |
+
.spyderproject
|
| 165 |
+
.spyproject
|
| 166 |
+
|
| 167 |
+
# Rope project settings
|
| 168 |
+
.ropeproject
|
| 169 |
+
|
| 170 |
+
# mkdocs documentation
|
| 171 |
+
/site
|
| 172 |
+
|
| 173 |
+
# mypy
|
| 174 |
+
.mypy_cache/
|
| 175 |
+
.dmypy.json
|
| 176 |
+
dmypy.json
|
| 177 |
+
|
| 178 |
+
# Pyre type checker
|
| 179 |
+
.pyre/
|
| 180 |
+
|
| 181 |
+
# pytype static type analyzer
|
| 182 |
+
.pytype/
|
| 183 |
+
|
| 184 |
+
# Cython debug symbols
|
| 185 |
+
cython_debug/
|
| 186 |
+
|
| 187 |
+
# PyCharm
|
| 188 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 189 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 190 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 191 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 192 |
+
# .idea/
|
| 193 |
+
|
| 194 |
+
# Abstra
|
| 195 |
+
# Abstra is an AI-powered process automation framework.
|
| 196 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 197 |
+
# Learn more at https://abstra.io/docs
|
| 198 |
+
.abstra/
|
| 199 |
+
|
| 200 |
+
# Visual Studio Code
|
| 201 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 202 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 203 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 204 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 205 |
+
# .vscode/
|
| 206 |
+
|
| 207 |
+
# Ruff stuff:
|
| 208 |
+
.ruff_cache/
|
| 209 |
+
|
| 210 |
+
# PyPI configuration file
|
| 211 |
+
.pypirc
|
| 212 |
+
|
| 213 |
+
# Marimo
|
| 214 |
+
marimo/_static/
|
| 215 |
+
marimo/_lsp/
|
| 216 |
+
__marimo__/
|
| 217 |
+
|
| 218 |
+
# Streamlit
|
| 219 |
+
.streamlit/secrets.toml
|
Makefile
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run:
|
| 2 |
+
PYTHONPATH=. uv run gradio app.py
|
| 3 |
+
|
| 4 |
+
export-requirements:
|
| 5 |
+
uv export --format requirements-txt > requirements.txt
|
README.md
CHANGED
|
@@ -4,7 +4,9 @@ emoji: 🐨
|
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: pink
|
| 6 |
sdk: gradio
|
|
|
|
| 7 |
sdk_version: 5.49.1
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
|
|
|
| 4 |
colorFrom: purple
|
| 5 |
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
+
python_version: 3.13
|
| 8 |
sdk_version: 5.49.1
|
| 9 |
+
suggested_hardware: g4
|
| 10 |
app_file: app.py
|
| 11 |
pinned: false
|
| 12 |
---
|
app.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
from apps.audio_cloning.main import main as audio_cloning
|
| 4 |
+
from apps.dev.main import main as dev
|
| 5 |
+
|
| 6 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 7 |
+
audio_cloning()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
with demo.route(name="Dev", path="/dev"):
|
| 11 |
+
dev()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
if __name__ == "__main__":
|
| 15 |
+
# demo.queue(max_size=2, concurrency_limit=2, concurrency_id="gpu_queue")
|
| 16 |
+
# auth = ("charaxim", "chrmx-demo-wordpass")
|
| 17 |
+
# demo.launch(share=False, auth=auth)
|
| 18 |
+
demo.launch(share=False)
|
apps/audio_cloning/main.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
|
| 3 |
+
from .vallex.main import main as vallex
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
gr.Markdown("# Charamix Audio Cloning Prototype")
|
| 8 |
+
|
| 9 |
+
# zero-shot audio cloning
|
| 10 |
+
with gr.Tab("Zero-shot Audio Cloning with VALL-E-X"):
|
| 11 |
+
vallex()
|
| 12 |
+
|
| 13 |
+
# fine-tuning audio cloning
|
| 14 |
+
# with gr.Tab("Fine-tuning Audio Cloning"):
|
| 15 |
+
# gr.Markdown("TODO")
|
apps/audio_cloning/vallex/data/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from .datamodule import *
|
| 2 |
+
# from .tokenizer import *
|
| 3 |
+
from .collation import *
|
apps/audio_cloning/vallex/data/collation.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TextTokenCollater:
|
| 8 |
+
"""Collate list of text tokens
|
| 9 |
+
|
| 10 |
+
Map sentences to integers. Sentences are padded to equal length.
|
| 11 |
+
Beginning and end-of-sequence symbols can be added.
|
| 12 |
+
|
| 13 |
+
Example:
|
| 14 |
+
>>> token_collater = TextTokenCollater(text_tokens)
|
| 15 |
+
>>> tokens_batch, tokens_lens = token_collater(text)
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
tokens_batch: IntTensor of shape (B, L)
|
| 19 |
+
B: batch dimension, number of input sentences
|
| 20 |
+
L: length of the longest sentence
|
| 21 |
+
tokens_lens: IntTensor of shape (B,)
|
| 22 |
+
Length of each sentence after adding <eos> and <bos>
|
| 23 |
+
but before padding.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
text_tokens: List[str],
|
| 29 |
+
add_eos: bool = True,
|
| 30 |
+
add_bos: bool = True,
|
| 31 |
+
pad_symbol: str = "<pad>",
|
| 32 |
+
bos_symbol: str = "<bos>",
|
| 33 |
+
eos_symbol: str = "<eos>",
|
| 34 |
+
):
|
| 35 |
+
self.pad_symbol = pad_symbol
|
| 36 |
+
|
| 37 |
+
self.add_eos = add_eos
|
| 38 |
+
self.add_bos = add_bos
|
| 39 |
+
|
| 40 |
+
self.bos_symbol = bos_symbol
|
| 41 |
+
self.eos_symbol = eos_symbol
|
| 42 |
+
|
| 43 |
+
unique_tokens = (
|
| 44 |
+
[pad_symbol]
|
| 45 |
+
+ ([bos_symbol] if add_bos else [])
|
| 46 |
+
+ ([eos_symbol] if add_eos else [])
|
| 47 |
+
+ sorted(text_tokens)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
|
| 51 |
+
self.idx2token = [token for token in unique_tokens]
|
| 52 |
+
|
| 53 |
+
def index(self, tokens_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 54 |
+
seqs, seq_lens = [], []
|
| 55 |
+
for tokens in tokens_list:
|
| 56 |
+
assert all([True if s in self.token2idx else False for s in tokens]) is True
|
| 57 |
+
seq = (
|
| 58 |
+
([self.bos_symbol] if self.add_bos else [])
|
| 59 |
+
+ list(tokens)
|
| 60 |
+
+ ([self.eos_symbol] if self.add_eos else [])
|
| 61 |
+
)
|
| 62 |
+
seqs.append(seq)
|
| 63 |
+
seq_lens.append(len(seq))
|
| 64 |
+
|
| 65 |
+
max_len = max(seq_lens)
|
| 66 |
+
for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
|
| 67 |
+
seq.extend([self.pad_symbol] * (max_len - seq_len))
|
| 68 |
+
|
| 69 |
+
tokens = torch.from_numpy(
|
| 70 |
+
np.array(
|
| 71 |
+
[[self.token2idx[token] for token in seq] for seq in seqs],
|
| 72 |
+
dtype=np.int64,
|
| 73 |
+
)
|
| 74 |
+
)
|
| 75 |
+
tokens_lens = torch.IntTensor(seq_lens)
|
| 76 |
+
|
| 77 |
+
return tokens, tokens_lens
|
| 78 |
+
|
| 79 |
+
def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 80 |
+
tokens_seqs = [[p for p in text] for text in texts]
|
| 81 |
+
max_len = len(max(tokens_seqs, key=len))
|
| 82 |
+
|
| 83 |
+
seqs = [
|
| 84 |
+
([self.bos_symbol] if self.add_bos else [])
|
| 85 |
+
+ list(seq)
|
| 86 |
+
+ ([self.eos_symbol] if self.add_eos else [])
|
| 87 |
+
+ [self.pad_symbol] * (max_len - len(seq))
|
| 88 |
+
for seq in tokens_seqs
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
tokens_batch = torch.from_numpy(
|
| 92 |
+
np.array(
|
| 93 |
+
[seq for seq in seqs],
|
| 94 |
+
dtype=np.int64,
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
tokens_lens = torch.IntTensor(
|
| 99 |
+
[len(seq) + int(self.add_eos) + int(self.add_bos) for seq in tokens_seqs]
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
return tokens_batch, tokens_lens
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_text_token_collater() -> TextTokenCollater:
|
| 106 |
+
collater = TextTokenCollater(["0"], add_bos=False, add_eos=False)
|
| 107 |
+
return collater
|
apps/audio_cloning/vallex/data/datamodule.py
ADDED
|
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
| 2 |
+
#
|
| 3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import inspect
|
| 20 |
+
import logging
|
| 21 |
+
from functools import lru_cache
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any, Dict, Optional
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
# from icefall.utils import str2bool
|
| 27 |
+
# from lhotse import CutSet, load_manifest_lazy
|
| 28 |
+
# from lhotse.dataset import (
|
| 29 |
+
# CutConcatenate,
|
| 30 |
+
# DynamicBucketingSampler,
|
| 31 |
+
# PrecomputedFeatures,
|
| 32 |
+
# SingleCutSampler,
|
| 33 |
+
# SpecAugment,
|
| 34 |
+
# )
|
| 35 |
+
# from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
| 36 |
+
# from lhotse.utils import fix_random_seed
|
| 37 |
+
from torch.utils.data import DataLoader
|
| 38 |
+
|
| 39 |
+
from data.collation import get_text_token_collater
|
| 40 |
+
# from data.dataset import SpeechSynthesisDataset
|
| 41 |
+
from data.fbank import get_fbank_extractor
|
| 42 |
+
from data.input_strategies import PromptedPrecomputedFeatures
|
| 43 |
+
|
| 44 |
+
# PrecomputedFeatures = PrecomputedFeatures
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class _SeedWorkers:
|
| 48 |
+
def __init__(self, seed: int):
|
| 49 |
+
self.seed = seed
|
| 50 |
+
|
| 51 |
+
def __call__(self, worker_id: int):
|
| 52 |
+
fix_random_seed(self.seed + worker_id)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _get_input_strategy(input_strategy, dataset, cuts):
|
| 56 |
+
if input_strategy == "PromptedPrecomputedFeatures":
|
| 57 |
+
return PromptedPrecomputedFeatures(dataset, cuts)
|
| 58 |
+
|
| 59 |
+
return eval(input_strategy)()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class TtsDataModule:
|
| 63 |
+
"""
|
| 64 |
+
DataModule for VALL-E TTS experiments.
|
| 65 |
+
It assumes there is always one train and valid dataloader.
|
| 66 |
+
|
| 67 |
+
It contains all the common data pipeline modules used in TTS
|
| 68 |
+
experiments, e.g.:
|
| 69 |
+
- dynamic batch size,
|
| 70 |
+
- bucketing samplers,
|
| 71 |
+
- cut concatenation[not used & tested yet],
|
| 72 |
+
- augmentation[not used & tested yet],
|
| 73 |
+
- on-the-fly feature extraction[not used & tested yet]
|
| 74 |
+
|
| 75 |
+
This class should be derived for specific corpora used in TTS tasks.
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, args: argparse.Namespace):
|
| 79 |
+
self.args = args
|
| 80 |
+
|
| 81 |
+
@classmethod
|
| 82 |
+
def add_arguments(cls, parser: argparse.ArgumentParser):
|
| 83 |
+
group = parser.add_argument_group(
|
| 84 |
+
title="TTS data related options",
|
| 85 |
+
description="These options are used for the preparation of "
|
| 86 |
+
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
| 87 |
+
"effective batch sizes, sampling strategies, applied data "
|
| 88 |
+
"augmentations, etc.",
|
| 89 |
+
)
|
| 90 |
+
group.add_argument(
|
| 91 |
+
"--manifest-dir",
|
| 92 |
+
type=Path,
|
| 93 |
+
default=Path("data/tokenized"),
|
| 94 |
+
help="Path to directory with train/valid/test cuts.",
|
| 95 |
+
)
|
| 96 |
+
group.add_argument(
|
| 97 |
+
"--max-duration",
|
| 98 |
+
type=int,
|
| 99 |
+
default=40.0,
|
| 100 |
+
help="Maximum pooled recordings duration (seconds) in a "
|
| 101 |
+
"single batch. You can reduce it if it causes CUDA OOM.",
|
| 102 |
+
)
|
| 103 |
+
group.add_argument(
|
| 104 |
+
"--bucketing-sampler",
|
| 105 |
+
type=str2bool,
|
| 106 |
+
default=True,
|
| 107 |
+
help="When enabled, the batches will come from buckets of "
|
| 108 |
+
"similar duration (saves padding frames).",
|
| 109 |
+
)
|
| 110 |
+
group.add_argument(
|
| 111 |
+
"--num-buckets",
|
| 112 |
+
type=int,
|
| 113 |
+
default=10,
|
| 114 |
+
help="The number of buckets for the DynamicBucketingSampler"
|
| 115 |
+
"(you might want to increase it for larger datasets).",
|
| 116 |
+
)
|
| 117 |
+
group.add_argument(
|
| 118 |
+
"--concatenate-cuts",
|
| 119 |
+
type=str2bool,
|
| 120 |
+
default=False,
|
| 121 |
+
help="When enabled, utterances (cuts) will be concatenated "
|
| 122 |
+
"to minimize the amount of padding.",
|
| 123 |
+
)
|
| 124 |
+
group.add_argument(
|
| 125 |
+
"--duration-factor",
|
| 126 |
+
type=float,
|
| 127 |
+
default=1.0,
|
| 128 |
+
help="Determines the maximum duration of a concatenated cut "
|
| 129 |
+
"relative to the duration of the longest cut in a batch.",
|
| 130 |
+
)
|
| 131 |
+
group.add_argument(
|
| 132 |
+
"--gap",
|
| 133 |
+
type=float,
|
| 134 |
+
default=0.1,
|
| 135 |
+
help="The amount of padding (in seconds) inserted between "
|
| 136 |
+
"concatenated cuts. This padding is filled with noise when "
|
| 137 |
+
"noise augmentation is used.",
|
| 138 |
+
)
|
| 139 |
+
group.add_argument(
|
| 140 |
+
"--on-the-fly-feats",
|
| 141 |
+
type=str2bool,
|
| 142 |
+
default=False,
|
| 143 |
+
help="When enabled, use on-the-fly cut mixing and feature "
|
| 144 |
+
"extraction. Will drop existing precomputed feature manifests "
|
| 145 |
+
"if available.",
|
| 146 |
+
)
|
| 147 |
+
group.add_argument(
|
| 148 |
+
"--shuffle",
|
| 149 |
+
type=str2bool,
|
| 150 |
+
default=True,
|
| 151 |
+
help="When enabled (=default), the examples will be "
|
| 152 |
+
"shuffled for each epoch.",
|
| 153 |
+
)
|
| 154 |
+
group.add_argument(
|
| 155 |
+
"--drop-last",
|
| 156 |
+
type=str2bool,
|
| 157 |
+
default=False,
|
| 158 |
+
help="Whether to drop last batch. Used by sampler.",
|
| 159 |
+
)
|
| 160 |
+
group.add_argument(
|
| 161 |
+
"--return-cuts",
|
| 162 |
+
type=str2bool,
|
| 163 |
+
default=True,
|
| 164 |
+
help="When enabled, each batch will have the "
|
| 165 |
+
"field: batch['supervisions']['cut'] with the cuts that "
|
| 166 |
+
"were used to construct it.",
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
group.add_argument(
|
| 170 |
+
"--num-workers",
|
| 171 |
+
type=int,
|
| 172 |
+
default=8,
|
| 173 |
+
help="The number of training dataloader workers that "
|
| 174 |
+
"collect the batches.",
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
group.add_argument(
|
| 178 |
+
"--enable-spec-aug",
|
| 179 |
+
type=str2bool,
|
| 180 |
+
default=False,
|
| 181 |
+
help="When enabled, use SpecAugment for training dataset.",
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
group.add_argument(
|
| 185 |
+
"--spec-aug-time-warp-factor",
|
| 186 |
+
type=int,
|
| 187 |
+
default=80,
|
| 188 |
+
help="Used only when --enable-spec-aug is True. "
|
| 189 |
+
"It specifies the factor for time warping in SpecAugment. "
|
| 190 |
+
"Larger values mean more warping. "
|
| 191 |
+
"A value less than 1 means to disable time warp.",
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
group.add_argument(
|
| 195 |
+
"--input-strategy",
|
| 196 |
+
type=str,
|
| 197 |
+
default="PrecomputedFeatures",
|
| 198 |
+
help="AudioSamples or PrecomputedFeatures or PromptedPrecomputedFeatures",
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
group.add_argument(
|
| 202 |
+
"--dataset",
|
| 203 |
+
type=str,
|
| 204 |
+
default="ljspeech",
|
| 205 |
+
help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.",
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
parser.add_argument(
|
| 209 |
+
"--text-tokens",
|
| 210 |
+
type=str,
|
| 211 |
+
default="data/tokenized/unique_text_tokens.k2symbols",
|
| 212 |
+
help="Path to the unique text tokens file",
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--sampling-rate",
|
| 217 |
+
type=int,
|
| 218 |
+
default=24000,
|
| 219 |
+
help="""Audio sampling rate.""",
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def train_dataloaders(
|
| 223 |
+
self,
|
| 224 |
+
cuts_train: CutSet,
|
| 225 |
+
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
| 226 |
+
) -> DataLoader:
|
| 227 |
+
"""
|
| 228 |
+
Args:
|
| 229 |
+
cuts_train:
|
| 230 |
+
CutSet for training.
|
| 231 |
+
sampler_state_dict:
|
| 232 |
+
The state dict for the training sampler.
|
| 233 |
+
"""
|
| 234 |
+
transforms = []
|
| 235 |
+
|
| 236 |
+
if self.args.concatenate_cuts:
|
| 237 |
+
logging.info(
|
| 238 |
+
f"Using cut concatenation with duration factor "
|
| 239 |
+
f"{self.args.duration_factor} and gap {self.args.gap}."
|
| 240 |
+
)
|
| 241 |
+
# Cut concatenation should be the first transform in the list,
|
| 242 |
+
# so that if we e.g. mix noise in, it will fill the gaps between
|
| 243 |
+
# different utterances.
|
| 244 |
+
transforms = [
|
| 245 |
+
CutConcatenate(
|
| 246 |
+
duration_factor=self.args.duration_factor, gap=self.args.gap
|
| 247 |
+
)
|
| 248 |
+
] + transforms
|
| 249 |
+
|
| 250 |
+
input_transforms = []
|
| 251 |
+
if self.args.enable_spec_aug:
|
| 252 |
+
logging.info("Enable SpecAugment")
|
| 253 |
+
logging.info(
|
| 254 |
+
f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
|
| 255 |
+
)
|
| 256 |
+
# Set the value of num_frame_masks according to Lhotse's version.
|
| 257 |
+
# In different Lhotse's versions, the default of num_frame_masks is
|
| 258 |
+
# different.
|
| 259 |
+
num_frame_masks = 10
|
| 260 |
+
num_frame_masks_parameter = inspect.signature(
|
| 261 |
+
SpecAugment.__init__
|
| 262 |
+
).parameters["num_frame_masks"]
|
| 263 |
+
if num_frame_masks_parameter.default == 1:
|
| 264 |
+
num_frame_masks = 2
|
| 265 |
+
logging.info(f"Num frame mask: {num_frame_masks}")
|
| 266 |
+
input_transforms.append(
|
| 267 |
+
SpecAugment(
|
| 268 |
+
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
| 269 |
+
num_frame_masks=num_frame_masks,
|
| 270 |
+
features_mask_size=27,
|
| 271 |
+
num_feature_masks=2,
|
| 272 |
+
frames_mask_size=100,
|
| 273 |
+
)
|
| 274 |
+
)
|
| 275 |
+
else:
|
| 276 |
+
logging.info("Disable SpecAugment")
|
| 277 |
+
|
| 278 |
+
logging.info("About to create train dataset")
|
| 279 |
+
if self.args.on_the_fly_feats:
|
| 280 |
+
# NOTE: the PerturbSpeed transform should be added only if we
|
| 281 |
+
# remove it from data prep stage.
|
| 282 |
+
# Add on-the-fly speed perturbation; since originally it would
|
| 283 |
+
# have increased epoch size by 3, we will apply prob 2/3 and use
|
| 284 |
+
# 3x more epochs.
|
| 285 |
+
# Speed perturbation probably should come first before
|
| 286 |
+
# concatenation, but in principle the transforms order doesn't have
|
| 287 |
+
# to be strict (e.g. could be randomized)
|
| 288 |
+
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
| 289 |
+
# Drop feats to be on the safe side.
|
| 290 |
+
train = SpeechSynthesisDataset(
|
| 291 |
+
get_text_token_collater(self.args.text_tokens),
|
| 292 |
+
cut_transforms=transforms,
|
| 293 |
+
feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
|
| 294 |
+
feature_transforms=input_transforms,
|
| 295 |
+
)
|
| 296 |
+
else:
|
| 297 |
+
train = SpeechSynthesisDataset(
|
| 298 |
+
get_text_token_collater(self.args.text_tokens),
|
| 299 |
+
feature_input_strategy=_get_input_strategy(
|
| 300 |
+
self.args.input_strategy, self.args.dataset, cuts_train
|
| 301 |
+
),
|
| 302 |
+
cut_transforms=transforms,
|
| 303 |
+
feature_transforms=input_transforms,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
if self.args.bucketing_sampler:
|
| 307 |
+
logging.info("Using DynamicBucketingSampler")
|
| 308 |
+
train_sampler = DynamicBucketingSampler(
|
| 309 |
+
cuts_train,
|
| 310 |
+
max_duration=self.args.max_duration,
|
| 311 |
+
shuffle=self.args.shuffle,
|
| 312 |
+
num_buckets=self.args.num_buckets,
|
| 313 |
+
drop_last=self.args.drop_last,
|
| 314 |
+
)
|
| 315 |
+
else:
|
| 316 |
+
logging.info(
|
| 317 |
+
"Using SingleCutSampler and sort by duraton(ascending=True)."
|
| 318 |
+
)
|
| 319 |
+
cuts_train = cuts_train.to_eager().sort_by_duration(ascending=True)
|
| 320 |
+
train_sampler = SingleCutSampler(
|
| 321 |
+
cuts_train,
|
| 322 |
+
max_duration=self.args.max_duration,
|
| 323 |
+
shuffle=self.args.shuffle,
|
| 324 |
+
)
|
| 325 |
+
logging.info("About to create train dataloader")
|
| 326 |
+
|
| 327 |
+
if sampler_state_dict is not None:
|
| 328 |
+
logging.info("Loading sampler state dict")
|
| 329 |
+
train_sampler.load_state_dict(sampler_state_dict)
|
| 330 |
+
|
| 331 |
+
# 'seed' is derived from the current random state, which will have
|
| 332 |
+
# previously been set in the main process.
|
| 333 |
+
seed = torch.randint(0, 100000, ()).item()
|
| 334 |
+
worker_init_fn = _SeedWorkers(seed)
|
| 335 |
+
|
| 336 |
+
train_dl = DataLoader(
|
| 337 |
+
train,
|
| 338 |
+
sampler=train_sampler,
|
| 339 |
+
batch_size=None,
|
| 340 |
+
num_workers=self.args.num_workers,
|
| 341 |
+
persistent_workers=False,
|
| 342 |
+
worker_init_fn=worker_init_fn,
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
return train_dl
|
| 346 |
+
|
| 347 |
+
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
| 348 |
+
logging.info("About to create dev dataset")
|
| 349 |
+
if self.args.on_the_fly_feats:
|
| 350 |
+
validate = SpeechSynthesisDataset(
|
| 351 |
+
get_text_token_collater(self.args.text_tokens),
|
| 352 |
+
feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
|
| 353 |
+
cut_transforms=[],
|
| 354 |
+
)
|
| 355 |
+
else:
|
| 356 |
+
validate = SpeechSynthesisDataset(
|
| 357 |
+
get_text_token_collater(self.args.text_tokens),
|
| 358 |
+
feature_input_strategy=_get_input_strategy(
|
| 359 |
+
self.args.input_strategy, self.args.dataset, cuts_valid
|
| 360 |
+
),
|
| 361 |
+
cut_transforms=[],
|
| 362 |
+
)
|
| 363 |
+
valid_sampler = DynamicBucketingSampler(
|
| 364 |
+
cuts_valid,
|
| 365 |
+
max_duration=self.args.max_duration,
|
| 366 |
+
shuffle=False,
|
| 367 |
+
)
|
| 368 |
+
logging.info("About to create dev dataloader")
|
| 369 |
+
valid_dl = DataLoader(
|
| 370 |
+
validate,
|
| 371 |
+
sampler=valid_sampler,
|
| 372 |
+
batch_size=None,
|
| 373 |
+
num_workers=4,
|
| 374 |
+
persistent_workers=False,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
return valid_dl
|
| 378 |
+
|
| 379 |
+
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
| 380 |
+
logging.debug("About to create test dataset")
|
| 381 |
+
test = SpeechSynthesisDataset(
|
| 382 |
+
get_text_token_collater(self.args.text_tokens),
|
| 383 |
+
feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor())
|
| 384 |
+
if self.args.on_the_fly_feats
|
| 385 |
+
else _get_input_strategy(
|
| 386 |
+
self.args.input_strategy, self.args.dataset, cuts
|
| 387 |
+
),
|
| 388 |
+
cut_transforms=[],
|
| 389 |
+
)
|
| 390 |
+
sampler = DynamicBucketingSampler(
|
| 391 |
+
cuts,
|
| 392 |
+
max_duration=self.args.max_duration,
|
| 393 |
+
shuffle=False,
|
| 394 |
+
)
|
| 395 |
+
logging.debug("About to create test dataloader")
|
| 396 |
+
test_dl = DataLoader(
|
| 397 |
+
test,
|
| 398 |
+
batch_size=None,
|
| 399 |
+
sampler=sampler,
|
| 400 |
+
num_workers=self.args.num_workers,
|
| 401 |
+
)
|
| 402 |
+
return test_dl
|
| 403 |
+
|
| 404 |
+
@lru_cache()
|
| 405 |
+
def train_cuts(self) -> CutSet:
|
| 406 |
+
logging.info("About to get train cuts")
|
| 407 |
+
return load_manifest_lazy(
|
| 408 |
+
self.args.manifest_dir / "cuts_train.jsonl.gz"
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
@lru_cache()
|
| 412 |
+
def dev_cuts(self) -> CutSet:
|
| 413 |
+
logging.info("About to get dev cuts")
|
| 414 |
+
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
|
| 415 |
+
|
| 416 |
+
@lru_cache()
|
| 417 |
+
def test_cuts(self) -> CutSet:
|
| 418 |
+
logging.info("About to get test cuts")
|
| 419 |
+
return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")
|
apps/audio_cloning/vallex/data/dataset.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
| 2 |
+
#
|
| 3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
modified from lhoste.dataset.speech_synthesis.py
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import math
|
| 23 |
+
import h5py
|
| 24 |
+
from tokenizers import Tokenizer
|
| 25 |
+
from typing import Union, List
|
| 26 |
+
import numpy as np
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
|
| 29 |
+
_pad = '_'
|
| 30 |
+
_punctuation = ',.!?-~…'
|
| 31 |
+
_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
|
| 32 |
+
symbols = [_pad] + list(_punctuation) + list(_letters)
|
| 33 |
+
|
| 34 |
+
language_dict = {
|
| 35 |
+
'en': 0,
|
| 36 |
+
'zh': 1,
|
| 37 |
+
'ja': 2,
|
| 38 |
+
}
|
| 39 |
+
def seq2phone(tokens: Union[List, np.ndarray]):
|
| 40 |
+
"""
|
| 41 |
+
Convert tokenized phoneme ID sequence back to phoneme string
|
| 42 |
+
:param tokens: phoneme tokens
|
| 43 |
+
:return: recovered phoneme sequence
|
| 44 |
+
"""
|
| 45 |
+
phones = "".join([symbols[i] for i in tokens])
|
| 46 |
+
return phones
|
| 47 |
+
|
| 48 |
+
class DynamicBatchSampler(torch.utils.data.Sampler):
|
| 49 |
+
def __init__(self, sampler, num_tokens_fn, num_buckets=100, min_size=0, max_size=1000,
|
| 50 |
+
max_tokens=None, max_sentences=None, drop_last=False):
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
:param sampler:
|
| 54 |
+
:param num_tokens_fn: 根据idx返回样本的长度的函数
|
| 55 |
+
:param num_buckets: 利用桶原理将相似长度的样本放在一个batchsize中,桶的数量
|
| 56 |
+
:param min_size: 最小长度的样本, 小于这个值的样本会被过滤掉。 依据这个值来创建样桶
|
| 57 |
+
:param max_size: 最大长度的样本
|
| 58 |
+
:param max_sentences: batch_size, 但是这里可以通过max_sentences 和 max_tokens 共同控制最终的大小
|
| 59 |
+
"""
|
| 60 |
+
super(DynamicBatchSampler, self).__init__(sampler)
|
| 61 |
+
self.sampler = sampler
|
| 62 |
+
self.num_tokens_fn = num_tokens_fn
|
| 63 |
+
self.num_buckets = num_buckets
|
| 64 |
+
|
| 65 |
+
self.min_size = min_size
|
| 66 |
+
self.max_size = max_size
|
| 67 |
+
|
| 68 |
+
assert max_size <= max_tokens, "max_size should be smaller than max tokens"
|
| 69 |
+
assert max_tokens is not None or max_sentences is not None, \
|
| 70 |
+
"max_tokens and max_sentences should not be null at the same time, please specify one parameter at least"
|
| 71 |
+
self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
|
| 72 |
+
self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
|
| 73 |
+
self.drop_last = drop_last
|
| 74 |
+
|
| 75 |
+
def set_epoch(self, epoch):
|
| 76 |
+
self.sampler.set_epoch(epoch)
|
| 77 |
+
def is_batch_full(self, num_tokens, batch):
|
| 78 |
+
if len(batch) == 0:
|
| 79 |
+
return False
|
| 80 |
+
if len(batch) == self.max_sentences:
|
| 81 |
+
return True
|
| 82 |
+
if num_tokens > self.max_tokens:
|
| 83 |
+
return True
|
| 84 |
+
return False
|
| 85 |
+
|
| 86 |
+
def __iter__(self):
|
| 87 |
+
buckets = [[] for _ in range(self.num_buckets)]
|
| 88 |
+
sample_len = [0] * self.num_buckets
|
| 89 |
+
|
| 90 |
+
for idx in self.sampler:
|
| 91 |
+
idx_length = self.num_tokens_fn(idx)
|
| 92 |
+
if not (self.min_size <= idx_length <= self.max_size):
|
| 93 |
+
print("sentence at index {} of size {} exceeds max_tokens, the sentence is ignored".format(idx, idx_length))
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
index_buckets = math.floor((idx_length - self.min_size) / (self.max_size - self.min_size + 1)
|
| 97 |
+
* self.num_buckets)
|
| 98 |
+
sample_len[index_buckets] = max(sample_len[index_buckets], idx_length)
|
| 99 |
+
|
| 100 |
+
num_tokens = (len(buckets[index_buckets]) + 1) * sample_len[index_buckets]
|
| 101 |
+
if self.is_batch_full(num_tokens, buckets[index_buckets]):
|
| 102 |
+
# yield this batch
|
| 103 |
+
yield buckets[index_buckets]
|
| 104 |
+
buckets[index_buckets] = []
|
| 105 |
+
sample_len[index_buckets] = 0
|
| 106 |
+
|
| 107 |
+
buckets[index_buckets].append(idx)
|
| 108 |
+
|
| 109 |
+
# process left-over
|
| 110 |
+
leftover_batch = []
|
| 111 |
+
leftover_sample_len = 0
|
| 112 |
+
leftover = [idx for bucket in buckets for idx in bucket]
|
| 113 |
+
for idx in leftover:
|
| 114 |
+
idx_length = self.num_tokens_fn(idx)
|
| 115 |
+
leftover_sample_len = max(leftover_sample_len, idx_length)
|
| 116 |
+
num_tokens = (len(leftover_batch) + 1) * leftover_sample_len
|
| 117 |
+
if self.is_batch_full(num_tokens, leftover_batch):
|
| 118 |
+
yield leftover_batch
|
| 119 |
+
leftover_batch = []
|
| 120 |
+
leftover_sample_len = 0
|
| 121 |
+
leftover_batch.append(idx)
|
| 122 |
+
|
| 123 |
+
if len(leftover_batch) > 0 and not self.drop_last:
|
| 124 |
+
yield leftover_batch
|
| 125 |
+
|
| 126 |
+
def __len__(self):
|
| 127 |
+
# we do not know the exactly batch size, so do not call len(dataloader)
|
| 128 |
+
pass
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class AudioDataset(torch.utils.data.Dataset):
|
| 132 |
+
def __init__(self, h5_path, ann_path, tokenizer_path):
|
| 133 |
+
self.h5_path = h5_path
|
| 134 |
+
with open(ann_path, 'r', encoding='utf-8') as f:
|
| 135 |
+
lines = f.readlines()
|
| 136 |
+
ls = [l.split("|") for l in lines]
|
| 137 |
+
ls_T = list(zip(*ls))
|
| 138 |
+
del ls_T[-1]
|
| 139 |
+
self.h5_paths, self.durations, self.langs, self.texts = \
|
| 140 |
+
list(ls_T[0]), list(ls_T[1]), list(ls_T[2]), list(ls_T[3])
|
| 141 |
+
self.durations = [float(dur) for dur in self.durations]
|
| 142 |
+
self.tokenizer = Tokenizer.from_file(tokenizer_path)
|
| 143 |
+
|
| 144 |
+
self._archive = None
|
| 145 |
+
|
| 146 |
+
def __len__(self):
|
| 147 |
+
return len(self.h5_paths)
|
| 148 |
+
|
| 149 |
+
def get_dur(self, idx):
|
| 150 |
+
return self.durations[idx]
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def archive(self):
|
| 154 |
+
if self._archive is None: # lazy loading here!
|
| 155 |
+
self._archive = h5py.File(self.h5_path, "r")
|
| 156 |
+
return self._archive
|
| 157 |
+
def __getitem__(self, idx):
|
| 158 |
+
archive = self.archive
|
| 159 |
+
h5_path = self.h5_paths[idx]
|
| 160 |
+
sub = archive[h5_path]
|
| 161 |
+
audio_tokens = sub['audio'][()]
|
| 162 |
+
phone_tokens = sub['text'][()]
|
| 163 |
+
dur = self.durations[idx]
|
| 164 |
+
lang = self.langs[idx]
|
| 165 |
+
text = self.texts[idx]
|
| 166 |
+
# tokenization should be done within dataloader
|
| 167 |
+
phones = seq2phone(phone_tokens)
|
| 168 |
+
phones = phones.replace(" ", "_")
|
| 169 |
+
if not len(phones):
|
| 170 |
+
cptpho_tokens = self.tokenizer.encode(text).ids
|
| 171 |
+
else:
|
| 172 |
+
cptpho_tokens = self.tokenizer.encode(phones).ids
|
| 173 |
+
assert len(cptpho_tokens)
|
| 174 |
+
return {
|
| 175 |
+
'utt_id': h5_path,
|
| 176 |
+
'text': text,
|
| 177 |
+
'audio': None,
|
| 178 |
+
'audio_lens': None,
|
| 179 |
+
'audio_features': audio_tokens,
|
| 180 |
+
'audio_features_lens': len(audio_tokens.T),
|
| 181 |
+
'text_tokens': np.array(cptpho_tokens),
|
| 182 |
+
'text_tokens_lens': len(cptpho_tokens),
|
| 183 |
+
'language': language_dict[lang],
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
def collate(batch):
|
| 187 |
+
utt_id_s = [b['utt_id'] for b in batch]
|
| 188 |
+
text_s = [b['text'] for b in batch]
|
| 189 |
+
|
| 190 |
+
audio_s = [b['audio'] for b in batch]
|
| 191 |
+
audio_lens_s = [b['audio_lens'] for b in batch]
|
| 192 |
+
|
| 193 |
+
audio_features_lens_s = [b['audio_features_lens'] for b in batch]
|
| 194 |
+
# create an empty tensor with maximum audio feature length
|
| 195 |
+
audio_features_s = torch.zeros([len(batch), max(audio_features_lens_s), 8], dtype=torch.int64) - 1 # audio pad with -1
|
| 196 |
+
|
| 197 |
+
text_tokens_lens_s = [b['text_tokens_lens'] for b in batch]
|
| 198 |
+
# create an empty tensor with maximum text tokens length
|
| 199 |
+
text_tokens_s = torch.zeros([len(batch), max(text_tokens_lens_s)], dtype=torch.int64) + 3 # [PAD] token id 3
|
| 200 |
+
|
| 201 |
+
language_s = [b['language'] for b in batch]
|
| 202 |
+
|
| 203 |
+
for i, b in enumerate(batch):
|
| 204 |
+
audio_features = b['audio_features']
|
| 205 |
+
audio_features_lens = b['audio_features_lens']
|
| 206 |
+
audio_features_s[i, :audio_features_lens, :] = torch.LongTensor(audio_features.T)
|
| 207 |
+
|
| 208 |
+
text_tokens = b['text_tokens']
|
| 209 |
+
text_tokens_lens = b['text_tokens_lens']
|
| 210 |
+
text_tokens_s[i, :text_tokens_lens] = torch.LongTensor(text_tokens)
|
| 211 |
+
|
| 212 |
+
batch = {
|
| 213 |
+
'utt_id': utt_id_s,
|
| 214 |
+
'text': text_s,
|
| 215 |
+
'audio': audio_s,
|
| 216 |
+
'audio_lens': audio_lens_s,
|
| 217 |
+
'audio_features': audio_features_s,
|
| 218 |
+
'audio_features_lens': torch.LongTensor(np.array(audio_features_lens_s)),
|
| 219 |
+
'text_tokens': text_tokens_s,
|
| 220 |
+
'text_tokens_lens': torch.LongTensor(np.array(text_tokens_lens_s)),
|
| 221 |
+
'languages': torch.LongTensor(np.array(language_s)),
|
| 222 |
+
}
|
| 223 |
+
return batch
|
| 224 |
+
|
| 225 |
+
def create_dataloader(data_dir="/root/valle/egs/mix", n_gpus=1, rank=0, num_workers=0, num_buckets=10, max_duration=120):
|
| 226 |
+
train_dataset = AudioDataset(h5_path=f"{data_dir}/audio_sum.hdf5",
|
| 227 |
+
ann_path=f"{data_dir}/audio_ann_sum.txt",
|
| 228 |
+
tokenizer_path=f"{data_dir}/bpe_69.json")
|
| 229 |
+
ran_sampler = torch.utils.data.distributed.DistributedSampler(
|
| 230 |
+
train_dataset,
|
| 231 |
+
num_replicas=n_gpus,
|
| 232 |
+
rank=rank,
|
| 233 |
+
shuffle=True,
|
| 234 |
+
)
|
| 235 |
+
dynamic_sampler = DynamicBatchSampler(ran_sampler, train_dataset.get_dur, num_buckets=num_buckets, max_size=20,
|
| 236 |
+
max_tokens=max_duration)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=num_workers, collate_fn=collate,
|
| 240 |
+
batch_sampler=dynamic_sampler)
|
| 241 |
+
|
| 242 |
+
return train_loader
|
apps/audio_cloning/vallex/data/fbank.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
| 2 |
+
#
|
| 3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
from dataclasses import asdict, dataclass
|
| 19 |
+
from typing import Any, Dict, Optional, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
# from lhotse.features.base import FeatureExtractor
|
| 24 |
+
# from lhotse.utils import EPSILON, Seconds, compute_num_frames
|
| 25 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class BigVGANFbankConfig:
|
| 30 |
+
# Spectogram-related part
|
| 31 |
+
# Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them
|
| 32 |
+
frame_length: Seconds = 1024 / 24000.0
|
| 33 |
+
frame_shift: Seconds = 256 / 24000.0
|
| 34 |
+
remove_dc_offset: bool = True
|
| 35 |
+
round_to_power_of_two: bool = True
|
| 36 |
+
|
| 37 |
+
# Fbank-related part
|
| 38 |
+
low_freq: float = 0.0
|
| 39 |
+
high_freq: float = 12000.0
|
| 40 |
+
num_mel_bins: int = 100
|
| 41 |
+
use_energy: bool = False
|
| 42 |
+
|
| 43 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 44 |
+
return asdict(self)
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig":
|
| 48 |
+
return BigVGANFbankConfig(**data)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 52 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def spectral_normalize_torch(magnitudes):
|
| 56 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 57 |
+
return output
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# https://github.com/NVIDIA/BigVGAN
|
| 61 |
+
# bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz
|
| 62 |
+
class BigVGANFbank(FeatureExtractor):
|
| 63 |
+
name = "fbank"
|
| 64 |
+
config_type = BigVGANFbankConfig
|
| 65 |
+
|
| 66 |
+
def __init__(self, config: Optional[Any] = None):
|
| 67 |
+
super(BigVGANFbank, self).__init__(config)
|
| 68 |
+
sampling_rate = 24000
|
| 69 |
+
self.mel_basis = torch.from_numpy(
|
| 70 |
+
librosa_mel_fn(
|
| 71 |
+
sampling_rate,
|
| 72 |
+
1024,
|
| 73 |
+
self.config.num_mel_bins,
|
| 74 |
+
self.config.low_freq,
|
| 75 |
+
self.config.high_freq,
|
| 76 |
+
).astype(np.float32)
|
| 77 |
+
)
|
| 78 |
+
self.hann_window = torch.hann_window(1024)
|
| 79 |
+
|
| 80 |
+
def _feature_fn(self, samples, **kwargs):
|
| 81 |
+
win_length, n_fft = 1024, 1024
|
| 82 |
+
hop_size = 256
|
| 83 |
+
if True:
|
| 84 |
+
sampling_rate = 24000
|
| 85 |
+
duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
|
| 86 |
+
expected_num_frames = compute_num_frames(
|
| 87 |
+
duration=duration,
|
| 88 |
+
frame_shift=self.frame_shift,
|
| 89 |
+
sampling_rate=sampling_rate,
|
| 90 |
+
)
|
| 91 |
+
pad_size = (
|
| 92 |
+
(expected_num_frames - 1) * hop_size
|
| 93 |
+
+ win_length
|
| 94 |
+
- samples.shape[-1]
|
| 95 |
+
)
|
| 96 |
+
assert pad_size >= 0
|
| 97 |
+
|
| 98 |
+
y = torch.nn.functional.pad(
|
| 99 |
+
samples,
|
| 100 |
+
(0, pad_size),
|
| 101 |
+
mode="constant",
|
| 102 |
+
)
|
| 103 |
+
else:
|
| 104 |
+
y = torch.nn.functional.pad(
|
| 105 |
+
samples,
|
| 106 |
+
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
| 107 |
+
mode="reflect",
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
y = y.squeeze(1)
|
| 111 |
+
|
| 112 |
+
# complex tensor as default, then use view_as_real for future pytorch compatibility
|
| 113 |
+
spec = torch.stft(
|
| 114 |
+
y,
|
| 115 |
+
n_fft,
|
| 116 |
+
hop_length=hop_size,
|
| 117 |
+
win_length=win_length,
|
| 118 |
+
window=self.hann_window,
|
| 119 |
+
center=False,
|
| 120 |
+
pad_mode="reflect",
|
| 121 |
+
normalized=False,
|
| 122 |
+
onesided=True,
|
| 123 |
+
return_complex=True,
|
| 124 |
+
)
|
| 125 |
+
spec = torch.view_as_real(spec)
|
| 126 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
| 127 |
+
|
| 128 |
+
spec = torch.matmul(self.mel_basis, spec)
|
| 129 |
+
spec = spectral_normalize_torch(spec)
|
| 130 |
+
|
| 131 |
+
return spec.transpose(2, 1).squeeze(0)
|
| 132 |
+
|
| 133 |
+
def extract(
|
| 134 |
+
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
|
| 135 |
+
) -> np.ndarray:
|
| 136 |
+
assert sampling_rate == 24000
|
| 137 |
+
params = asdict(self.config)
|
| 138 |
+
params.update({"sample_frequency": sampling_rate, "snip_edges": False})
|
| 139 |
+
params["frame_shift"] *= 1000.0
|
| 140 |
+
params["frame_length"] *= 1000.0
|
| 141 |
+
if not isinstance(samples, torch.Tensor):
|
| 142 |
+
samples = torch.from_numpy(samples)
|
| 143 |
+
# Torchaudio Kaldi feature extractors expect the channel dimension to be first.
|
| 144 |
+
if len(samples.shape) == 1:
|
| 145 |
+
samples = samples.unsqueeze(0)
|
| 146 |
+
features = self._feature_fn(samples, **params).to(torch.float32)
|
| 147 |
+
return features.numpy()
|
| 148 |
+
|
| 149 |
+
@property
|
| 150 |
+
def frame_shift(self) -> Seconds:
|
| 151 |
+
return self.config.frame_shift
|
| 152 |
+
|
| 153 |
+
def feature_dim(self, sampling_rate: int) -> int:
|
| 154 |
+
return self.config.num_mel_bins
|
| 155 |
+
|
| 156 |
+
@staticmethod
|
| 157 |
+
def mix(
|
| 158 |
+
features_a: np.ndarray,
|
| 159 |
+
features_b: np.ndarray,
|
| 160 |
+
energy_scaling_factor_b: float,
|
| 161 |
+
) -> np.ndarray:
|
| 162 |
+
return np.log(
|
| 163 |
+
np.maximum(
|
| 164 |
+
# protection against log(0); max with EPSILON is adequate since these are energies (always >= 0)
|
| 165 |
+
EPSILON,
|
| 166 |
+
np.exp(features_a)
|
| 167 |
+
+ energy_scaling_factor_b * np.exp(features_b),
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
@staticmethod
|
| 172 |
+
def compute_energy(features: np.ndarray) -> float:
|
| 173 |
+
return float(np.sum(np.exp(features)))
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def get_fbank_extractor() -> BigVGANFbank:
|
| 177 |
+
return BigVGANFbank(BigVGANFbankConfig())
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
if __name__ == "__main__":
|
| 181 |
+
extractor = BigVGANFbank(BigVGANFbankConfig())
|
| 182 |
+
|
| 183 |
+
samples = torch.from_numpy(np.random.random([1000]).astype(np.float32))
|
| 184 |
+
samples = torch.clip(samples, -1.0, 1.0)
|
| 185 |
+
fbank = extractor.extract(samples, 24000.0)
|
| 186 |
+
print(f"fbank {fbank.shape}")
|
| 187 |
+
|
| 188 |
+
from scipy.io.wavfile import read
|
| 189 |
+
|
| 190 |
+
MAX_WAV_VALUE = 32768.0
|
| 191 |
+
|
| 192 |
+
sampling_rate, samples = read(
|
| 193 |
+
"egs/libritts/prompts/5639_40744_000000_000002.wav"
|
| 194 |
+
)
|
| 195 |
+
print(f"samples: [{samples.min()}, {samples.max()}]")
|
| 196 |
+
fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000)
|
| 197 |
+
print(f"fbank {fbank.shape}")
|
| 198 |
+
|
| 199 |
+
import matplotlib.pyplot as plt
|
| 200 |
+
|
| 201 |
+
_ = plt.figure(figsize=(18, 10))
|
| 202 |
+
plt.imshow(
|
| 203 |
+
X=fbank.transpose(1, 0),
|
| 204 |
+
cmap=plt.get_cmap("jet"),
|
| 205 |
+
aspect="auto",
|
| 206 |
+
interpolation="nearest",
|
| 207 |
+
)
|
| 208 |
+
plt.gca().invert_yaxis()
|
| 209 |
+
plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png")
|
| 210 |
+
plt.close()
|
| 211 |
+
|
| 212 |
+
print("fbank test PASS!")
|
apps/audio_cloning/vallex/data/input_strategies.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 4 |
+
from typing import Tuple, Type
|
| 5 |
+
|
| 6 |
+
# from lhotse import CutSet
|
| 7 |
+
# from lhotse.dataset.collation import collate_features
|
| 8 |
+
# from lhotse.dataset.input_strategies import (
|
| 9 |
+
# ExecutorType,
|
| 10 |
+
# PrecomputedFeatures,
|
| 11 |
+
# _get_executor,
|
| 12 |
+
# )
|
| 13 |
+
# from lhotse.utils import fastcopy
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PromptedFeatures:
|
| 17 |
+
def __init__(self, prompts, features):
|
| 18 |
+
self.prompts = prompts
|
| 19 |
+
self.features = features
|
| 20 |
+
|
| 21 |
+
def to(self, device):
|
| 22 |
+
return PromptedFeatures(
|
| 23 |
+
self.prompts.to(device), self.features.to(device)
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
def sum(self):
|
| 27 |
+
return self.features.sum()
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def ndim(self):
|
| 31 |
+
return self.features.ndim
|
| 32 |
+
|
| 33 |
+
@property
|
| 34 |
+
def data(self):
|
| 35 |
+
return (self.prompts, self.features)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# class PromptedPrecomputedFeatures(PrecomputedFeatures):
|
| 39 |
+
# """
|
| 40 |
+
# :class:`InputStrategy` that reads pre-computed features, whose manifests
|
| 41 |
+
# are attached to cuts, from disk.
|
| 42 |
+
#
|
| 43 |
+
# It automatically pads the feature matrices with pre or post feature.
|
| 44 |
+
#
|
| 45 |
+
# .. automethod:: __call__
|
| 46 |
+
# """
|
| 47 |
+
#
|
| 48 |
+
# def __init__(
|
| 49 |
+
# self,
|
| 50 |
+
# dataset: str,
|
| 51 |
+
# cuts: CutSet,
|
| 52 |
+
# num_workers: int = 0,
|
| 53 |
+
# executor_type: Type[ExecutorType] = ThreadPoolExecutor,
|
| 54 |
+
# ) -> None:
|
| 55 |
+
# super(PromptedPrecomputedFeatures, self).__init__(
|
| 56 |
+
# num_workers, executor_type
|
| 57 |
+
# )
|
| 58 |
+
#
|
| 59 |
+
# self.utt2neighbors = defaultdict(lambda: [])
|
| 60 |
+
#
|
| 61 |
+
# if dataset.lower() == "libritts":
|
| 62 |
+
# # 909_131041_000013_000002
|
| 63 |
+
# # 909_131041_000013_000003
|
| 64 |
+
# speaker2utts = defaultdict(lambda: [])
|
| 65 |
+
#
|
| 66 |
+
# utt2cut = {}
|
| 67 |
+
# for cut in cuts:
|
| 68 |
+
# speaker = cut.supervisions[0].speaker
|
| 69 |
+
# speaker2utts[speaker].append(cut.id)
|
| 70 |
+
# utt2cut[cut.id] = cut
|
| 71 |
+
#
|
| 72 |
+
# for spk in speaker2utts:
|
| 73 |
+
# uttids = sorted(speaker2utts[spk])
|
| 74 |
+
# # Using the property of sorted keys to find previous utterance
|
| 75 |
+
# # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001
|
| 76 |
+
# if len(uttids) == 1:
|
| 77 |
+
# self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
|
| 78 |
+
# continue
|
| 79 |
+
#
|
| 80 |
+
# utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
|
| 81 |
+
# utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
|
| 82 |
+
#
|
| 83 |
+
# for utt in utt2prevutt:
|
| 84 |
+
# self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]])
|
| 85 |
+
#
|
| 86 |
+
# for utt in utt2postutt:
|
| 87 |
+
# self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]])
|
| 88 |
+
# elif dataset.lower() == "ljspeech":
|
| 89 |
+
# utt2cut = {}
|
| 90 |
+
# uttids = []
|
| 91 |
+
# for cut in cuts:
|
| 92 |
+
# uttids.append(cut.id)
|
| 93 |
+
# utt2cut[cut.id] = cut
|
| 94 |
+
#
|
| 95 |
+
# if len(uttids) == 1:
|
| 96 |
+
# self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
|
| 97 |
+
# else:
|
| 98 |
+
# # Using the property of sorted keys to find previous utterance
|
| 99 |
+
# # The keys has structure: LJ001-0010
|
| 100 |
+
# utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
|
| 101 |
+
# utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
|
| 102 |
+
#
|
| 103 |
+
# for utt in utt2postutt:
|
| 104 |
+
# postutt = utt2postutt[utt]
|
| 105 |
+
# if utt[:5] == postutt[:5]:
|
| 106 |
+
# self.utt2neighbors[utt].append(utt2cut[postutt])
|
| 107 |
+
#
|
| 108 |
+
# for utt in utt2prevutt:
|
| 109 |
+
# prevutt = utt2prevutt[utt]
|
| 110 |
+
# if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]:
|
| 111 |
+
# self.utt2neighbors[utt].append(utt2cut[prevutt])
|
| 112 |
+
# else:
|
| 113 |
+
# raise ValueError
|
| 114 |
+
#
|
| 115 |
+
# def __call__(
|
| 116 |
+
# self, cuts: CutSet
|
| 117 |
+
# ) -> Tuple[PromptedFeatures, PromptedFeatures]:
|
| 118 |
+
# """
|
| 119 |
+
# Reads the pre-computed features from disk/other storage.
|
| 120 |
+
# The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``.
|
| 121 |
+
#
|
| 122 |
+
# :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding.
|
| 123 |
+
# """
|
| 124 |
+
# features, features_lens = collate_features(
|
| 125 |
+
# cuts,
|
| 126 |
+
# executor=_get_executor(
|
| 127 |
+
# self.num_workers, executor_type=self._executor_type
|
| 128 |
+
# ),
|
| 129 |
+
# )
|
| 130 |
+
#
|
| 131 |
+
# prompts_cuts = []
|
| 132 |
+
# for k, cut in enumerate(cuts):
|
| 133 |
+
# prompts_cut = random.choice(self.utt2neighbors[cut.id])
|
| 134 |
+
# prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}"))
|
| 135 |
+
#
|
| 136 |
+
# mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0])
|
| 137 |
+
# # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate(
|
| 138 |
+
# # max_duration=mini_duration,
|
| 139 |
+
# # offset_type="random",
|
| 140 |
+
# # preserve_id=True,
|
| 141 |
+
# # )
|
| 142 |
+
# prompts_cuts = CutSet(
|
| 143 |
+
# cuts={k: cut for k, cut in enumerate(prompts_cuts)}
|
| 144 |
+
# ).truncate(
|
| 145 |
+
# max_duration=mini_duration,
|
| 146 |
+
# offset_type="random",
|
| 147 |
+
# preserve_id=False,
|
| 148 |
+
# )
|
| 149 |
+
#
|
| 150 |
+
# prompts, prompts_lens = collate_features(
|
| 151 |
+
# prompts_cuts,
|
| 152 |
+
# executor=_get_executor(
|
| 153 |
+
# self.num_workers, executor_type=self._executor_type
|
| 154 |
+
# ),
|
| 155 |
+
# )
|
| 156 |
+
#
|
| 157 |
+
# return PromptedFeatures(prompts, features), PromptedFeatures(
|
| 158 |
+
# prompts_lens, features_lens
|
| 159 |
+
# )
|
apps/audio_cloning/vallex/data/symbol_table.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
|
| 2 |
+
#
|
| 3 |
+
# See ../../../LICENSE for clarification regarding multiple authors
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from typing import Dict, Generic, List, Optional, TypeVar, Union
|
| 19 |
+
|
| 20 |
+
Symbol = TypeVar("Symbol")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Disable __repr__ otherwise it could freeze e.g. Jupyter.
|
| 24 |
+
@dataclass(repr=False)
|
| 25 |
+
class SymbolTable(Generic[Symbol]):
|
| 26 |
+
"""SymbolTable that maps symbol IDs, found on the FSA arcs to
|
| 27 |
+
actual objects. These objects can be arbitrary Python objects
|
| 28 |
+
that can serve as keys in a dictionary (i.e. they need to be
|
| 29 |
+
hashable and immutable).
|
| 30 |
+
|
| 31 |
+
The SymbolTable can only be read to/written from disk if the
|
| 32 |
+
symbols are strings.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
_id2sym: Dict[int, Symbol] = field(default_factory=dict)
|
| 36 |
+
"""Map an integer to a symbol.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
_sym2id: Dict[Symbol, int] = field(default_factory=dict)
|
| 40 |
+
"""Map a symbol to an integer.
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
_next_available_id: int = 1
|
| 44 |
+
"""A helper internal field that helps adding new symbols
|
| 45 |
+
to the table efficiently.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
eps: Symbol = "<eps>"
|
| 49 |
+
"""Null symbol, always mapped to index 0.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __post_init__(self):
|
| 53 |
+
for idx, sym in self._id2sym.items():
|
| 54 |
+
assert self._sym2id[sym] == idx
|
| 55 |
+
assert idx >= 0
|
| 56 |
+
|
| 57 |
+
for sym, idx in self._sym2id.items():
|
| 58 |
+
assert idx >= 0
|
| 59 |
+
assert self._id2sym[idx] == sym
|
| 60 |
+
|
| 61 |
+
if 0 not in self._id2sym:
|
| 62 |
+
self._id2sym[0] = self.eps
|
| 63 |
+
self._sym2id[self.eps] = 0
|
| 64 |
+
else:
|
| 65 |
+
assert self._id2sym[0] == self.eps
|
| 66 |
+
assert self._sym2id[self.eps] == 0
|
| 67 |
+
|
| 68 |
+
self._next_available_id = max(self._id2sym) + 1
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def from_str(s: str) -> "SymbolTable":
|
| 72 |
+
"""Build a symbol table from a string.
|
| 73 |
+
|
| 74 |
+
The string consists of lines. Every line has two fields separated
|
| 75 |
+
by space(s), tab(s) or both. The first field is the symbol and the
|
| 76 |
+
second the integer id of the symbol.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
s:
|
| 80 |
+
The input string with the format described above.
|
| 81 |
+
Returns:
|
| 82 |
+
An instance of :class:`SymbolTable`.
|
| 83 |
+
"""
|
| 84 |
+
id2sym: Dict[int, str] = dict()
|
| 85 |
+
sym2id: Dict[str, int] = dict()
|
| 86 |
+
|
| 87 |
+
for line in s.split("\n"):
|
| 88 |
+
fields = line.split()
|
| 89 |
+
if len(fields) == 0:
|
| 90 |
+
continue # skip empty lines
|
| 91 |
+
assert len(fields) == 2, (
|
| 92 |
+
f"Expect a line with 2 fields. Given: {len(fields)}"
|
| 93 |
+
)
|
| 94 |
+
sym, idx = fields[0], int(fields[1])
|
| 95 |
+
assert sym not in sym2id, f"Duplicated symbol {sym}"
|
| 96 |
+
assert idx not in id2sym, f"Duplicated id {idx}"
|
| 97 |
+
id2sym[idx] = sym
|
| 98 |
+
sym2id[sym] = idx
|
| 99 |
+
|
| 100 |
+
eps = id2sym.get(0, "<eps>")
|
| 101 |
+
|
| 102 |
+
return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps)
|
| 103 |
+
|
| 104 |
+
@staticmethod
|
| 105 |
+
def from_file(filename: str) -> "SymbolTable":
|
| 106 |
+
"""Build a symbol table from file.
|
| 107 |
+
|
| 108 |
+
Every line in the symbol table file has two fields separated by
|
| 109 |
+
space(s), tab(s) or both. The following is an example file:
|
| 110 |
+
|
| 111 |
+
.. code-block::
|
| 112 |
+
|
| 113 |
+
<eps> 0
|
| 114 |
+
a 1
|
| 115 |
+
b 2
|
| 116 |
+
c 3
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
filename:
|
| 120 |
+
Name of the symbol table file. Its format is documented above.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
An instance of :class:`SymbolTable`.
|
| 124 |
+
|
| 125 |
+
"""
|
| 126 |
+
with open(filename, "r", encoding="utf-8") as f:
|
| 127 |
+
return SymbolTable.from_str(f.read().strip())
|
| 128 |
+
|
| 129 |
+
def to_str(self) -> str:
|
| 130 |
+
"""
|
| 131 |
+
Returns:
|
| 132 |
+
Return a string representation of this object. You can pass
|
| 133 |
+
it to the method ``from_str`` to recreate an identical object.
|
| 134 |
+
"""
|
| 135 |
+
s = ""
|
| 136 |
+
for idx, symbol in sorted(self._id2sym.items()):
|
| 137 |
+
s += f"{symbol} {idx}\n"
|
| 138 |
+
return s
|
| 139 |
+
|
| 140 |
+
def to_file(self, filename: str):
|
| 141 |
+
"""Serialize the SymbolTable to a file.
|
| 142 |
+
|
| 143 |
+
Every line in the symbol table file has two fields separated by
|
| 144 |
+
space(s), tab(s) or both. The following is an example file:
|
| 145 |
+
|
| 146 |
+
.. code-block::
|
| 147 |
+
|
| 148 |
+
<eps> 0
|
| 149 |
+
a 1
|
| 150 |
+
b 2
|
| 151 |
+
c 3
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
filename:
|
| 155 |
+
Name of the symbol table file. Its format is documented above.
|
| 156 |
+
"""
|
| 157 |
+
with open(filename, "w", encoding="utf-8") as f:
|
| 158 |
+
for idx, symbol in sorted(self._id2sym.items()):
|
| 159 |
+
print(symbol, idx, file=f)
|
| 160 |
+
|
| 161 |
+
def add(self, symbol: Symbol, index: Optional[int] = None) -> int:
|
| 162 |
+
"""Add a new symbol to the SymbolTable.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
symbol:
|
| 166 |
+
The symbol to be added.
|
| 167 |
+
index:
|
| 168 |
+
Optional int id to which the symbol should be assigned.
|
| 169 |
+
If it is not available, a ValueError will be raised.
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
The int id to which the symbol has been assigned.
|
| 173 |
+
"""
|
| 174 |
+
# Already in the table? Return its ID.
|
| 175 |
+
if symbol in self._sym2id:
|
| 176 |
+
return self._sym2id[symbol]
|
| 177 |
+
# Specific ID not provided - use next available.
|
| 178 |
+
if index is None:
|
| 179 |
+
index = self._next_available_id
|
| 180 |
+
# Specific ID provided but not available.
|
| 181 |
+
if index in self._id2sym:
|
| 182 |
+
raise ValueError(
|
| 183 |
+
f"Cannot assign id '{index}' to '{symbol}' - "
|
| 184 |
+
f"already occupied by {self._id2sym[index]}"
|
| 185 |
+
)
|
| 186 |
+
self._sym2id[symbol] = index
|
| 187 |
+
self._id2sym[index] = symbol
|
| 188 |
+
|
| 189 |
+
# Update next available ID if needed
|
| 190 |
+
if self._next_available_id <= index:
|
| 191 |
+
self._next_available_id = index + 1
|
| 192 |
+
|
| 193 |
+
return index
|
| 194 |
+
|
| 195 |
+
def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]:
|
| 196 |
+
"""Get a symbol for an id or get an id for a symbol
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
k:
|
| 200 |
+
If it is an id, it tries to find the symbol corresponding
|
| 201 |
+
to the id; if it is a symbol, it tries to find the id
|
| 202 |
+
corresponding to the symbol.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
An id or a symbol depending on the given `k`.
|
| 206 |
+
"""
|
| 207 |
+
if isinstance(k, int):
|
| 208 |
+
return self._id2sym[k]
|
| 209 |
+
else:
|
| 210 |
+
return self._sym2id[k]
|
| 211 |
+
|
| 212 |
+
def merge(self, other: "SymbolTable") -> "SymbolTable":
|
| 213 |
+
"""Create a union of two SymbolTables.
|
| 214 |
+
Raises an AssertionError if the same IDs are occupied by
|
| 215 |
+
different symbols.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
other:
|
| 219 |
+
A symbol table to merge with ``self``.
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
A new symbol table.
|
| 223 |
+
"""
|
| 224 |
+
self._check_compatible(other)
|
| 225 |
+
|
| 226 |
+
id2sym = {**self._id2sym, **other._id2sym}
|
| 227 |
+
sym2id = {**self._sym2id, **other._sym2id}
|
| 228 |
+
|
| 229 |
+
return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps)
|
| 230 |
+
|
| 231 |
+
def _check_compatible(self, other: "SymbolTable") -> None:
|
| 232 |
+
# Epsilon compatibility
|
| 233 |
+
assert self.eps == other.eps, (
|
| 234 |
+
f"Mismatched epsilon symbol: {self.eps} != {other.eps}"
|
| 235 |
+
)
|
| 236 |
+
# IDs compatibility
|
| 237 |
+
common_ids = set(self._id2sym).intersection(other._id2sym)
|
| 238 |
+
for idx in common_ids:
|
| 239 |
+
assert self[idx] == other[idx], (
|
| 240 |
+
f"ID conflict for id: {idx}, "
|
| 241 |
+
f'self[idx] = "{self[idx]}", '
|
| 242 |
+
f'other[idx] = "{other[idx]}"'
|
| 243 |
+
)
|
| 244 |
+
# Symbols compatibility
|
| 245 |
+
common_symbols = set(self._sym2id).intersection(other._sym2id)
|
| 246 |
+
for sym in common_symbols:
|
| 247 |
+
assert self[sym] == other[sym], (
|
| 248 |
+
f"ID conflict for id: {sym}, "
|
| 249 |
+
f'self[sym] = "{self[sym]}", '
|
| 250 |
+
f'other[sym] = "{other[sym]}"'
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]:
|
| 254 |
+
return self.get(item)
|
| 255 |
+
|
| 256 |
+
def __contains__(self, item: Union[int, Symbol]) -> bool:
|
| 257 |
+
if isinstance(item, int):
|
| 258 |
+
return item in self._id2sym
|
| 259 |
+
else:
|
| 260 |
+
return item in self._sym2id
|
| 261 |
+
|
| 262 |
+
def __len__(self) -> int:
|
| 263 |
+
return len(self._id2sym)
|
| 264 |
+
|
| 265 |
+
def __eq__(self, other: "SymbolTable") -> bool:
|
| 266 |
+
if len(self) != len(other):
|
| 267 |
+
return False
|
| 268 |
+
|
| 269 |
+
for s in self.symbols:
|
| 270 |
+
if self[s] != other[s]:
|
| 271 |
+
return False
|
| 272 |
+
|
| 273 |
+
return True
|
| 274 |
+
|
| 275 |
+
@property
|
| 276 |
+
def ids(self) -> List[int]:
|
| 277 |
+
"""Returns a list of integer IDs corresponding to the symbols."""
|
| 278 |
+
ans = list(self._id2sym.keys())
|
| 279 |
+
ans.sort()
|
| 280 |
+
return ans
|
| 281 |
+
|
| 282 |
+
@property
|
| 283 |
+
def symbols(self) -> List[Symbol]:
|
| 284 |
+
"""Returns a list of symbols (e.g., strings) corresponding to
|
| 285 |
+
the integer IDs.
|
| 286 |
+
"""
|
| 287 |
+
ans = list(self._sym2id.keys())
|
| 288 |
+
ans.sort()
|
| 289 |
+
return ans
|
apps/audio_cloning/vallex/data/tokenizer.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torchaudio
|
| 21 |
+
from encodec import EncodecModel
|
| 22 |
+
from encodec.utils import convert_audio
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
pass
|
| 26 |
+
except Exception:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def remove_encodec_weight_norm(model):
|
| 31 |
+
from encodec.modules import SConv1d
|
| 32 |
+
from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
|
| 33 |
+
from torch.nn.utils import remove_weight_norm
|
| 34 |
+
|
| 35 |
+
encoder = model.encoder.model
|
| 36 |
+
for key in encoder._modules:
|
| 37 |
+
if isinstance(encoder._modules[key], SEANetResnetBlock):
|
| 38 |
+
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
|
| 39 |
+
block_modules = encoder._modules[key].block._modules
|
| 40 |
+
for skey in block_modules:
|
| 41 |
+
if isinstance(block_modules[skey], SConv1d):
|
| 42 |
+
remove_weight_norm(block_modules[skey].conv.conv)
|
| 43 |
+
elif isinstance(encoder._modules[key], SConv1d):
|
| 44 |
+
remove_weight_norm(encoder._modules[key].conv.conv)
|
| 45 |
+
|
| 46 |
+
decoder = model.decoder.model
|
| 47 |
+
for key in decoder._modules:
|
| 48 |
+
if isinstance(decoder._modules[key], SEANetResnetBlock):
|
| 49 |
+
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
|
| 50 |
+
block_modules = decoder._modules[key].block._modules
|
| 51 |
+
for skey in block_modules:
|
| 52 |
+
if isinstance(block_modules[skey], SConv1d):
|
| 53 |
+
remove_weight_norm(block_modules[skey].conv.conv)
|
| 54 |
+
elif isinstance(decoder._modules[key], SConvTranspose1d):
|
| 55 |
+
remove_weight_norm(decoder._modules[key].convtr.convtr)
|
| 56 |
+
elif isinstance(decoder._modules[key], SConv1d):
|
| 57 |
+
remove_weight_norm(decoder._modules[key].conv.conv)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class AudioTokenizer:
|
| 61 |
+
"""EnCodec audio."""
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
device: Any = None,
|
| 66 |
+
) -> None:
|
| 67 |
+
# Instantiate a pretrained EnCodec model
|
| 68 |
+
model = EncodecModel.encodec_model_24khz()
|
| 69 |
+
model.set_target_bandwidth(6.0)
|
| 70 |
+
remove_encodec_weight_norm(model)
|
| 71 |
+
|
| 72 |
+
if not device:
|
| 73 |
+
device = torch.device("cpu")
|
| 74 |
+
if torch.cuda.is_available():
|
| 75 |
+
device = torch.device("cuda:0")
|
| 76 |
+
if torch.backends.mps.is_available():
|
| 77 |
+
device = torch.device("mps")
|
| 78 |
+
|
| 79 |
+
self._device = device
|
| 80 |
+
|
| 81 |
+
self.codec = model.to(device)
|
| 82 |
+
self.sample_rate = model.sample_rate
|
| 83 |
+
self.channels = model.channels
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def device(self):
|
| 87 |
+
return self._device
|
| 88 |
+
|
| 89 |
+
def encode(self, wav: torch.Tensor) -> torch.Tensor:
|
| 90 |
+
return self.codec.encode(wav.to(self.device))
|
| 91 |
+
|
| 92 |
+
def decode(self, frames: torch.Tensor) -> torch.Tensor:
|
| 93 |
+
return self.codec.decode(frames)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def tokenize_audio(tokenizer: AudioTokenizer, audio):
|
| 97 |
+
# Load and pre-process the audio waveform
|
| 98 |
+
if isinstance(audio, str):
|
| 99 |
+
wav, sr = torchaudio.load(audio)
|
| 100 |
+
else:
|
| 101 |
+
wav, sr = audio
|
| 102 |
+
wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
|
| 103 |
+
wav = wav.unsqueeze(0)
|
| 104 |
+
|
| 105 |
+
# Extract discrete codes from EnCodec
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
encoded_frames = tokenizer.encode(wav)
|
| 108 |
+
return encoded_frames
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
model = EncodecModel.encodec_model_24khz()
|
| 113 |
+
model.set_target_bandwidth(6.0)
|
| 114 |
+
|
| 115 |
+
samples = torch.from_numpy(np.random.random([4, 1, 1600])).type(torch.float32)
|
| 116 |
+
codes_raw = model.encode(samples)
|
| 117 |
+
|
| 118 |
+
remove_encodec_weight_norm(model)
|
| 119 |
+
codes_norm = model.encode(samples)
|
| 120 |
+
|
| 121 |
+
assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
|
apps/audio_cloning/vallex/descriptions.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
top_md_org = """
|
| 2 |
+
# VALL-E X
|
| 3 |
+
VALL-E X can synthesize high-quality personalized speech with only a 3-second enrolled recording of
|
| 4 |
+
an unseen speaker as an acoustic prompt, even in another language for a monolingual speaker.<br>
|
| 5 |
+
This implementation supports zero-shot, mono-lingual/cross-lingual text-to-speech functionality of three languages (English, Chinese, Japanese)<br>
|
| 6 |
+
See this [demo](https://plachtaa.github.io/) page for more details.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
top_ja_md = """
|
| 10 |
+
# VALL-E X
|
| 11 |
+
|
| 12 |
+
VALL-E X は、未学習の話者でも 3 秒間の音声プロンプトだけで高品質なパーソナライズ音声を合成できます。<br>
|
| 13 |
+
単一言語話者であっても別の言語による音声合成が可能です。<br>
|
| 14 |
+
本実装は英語・中国語・日本語のゼロショット単言語/クロス言語テキスト読み上げをサポートしています。
|
| 15 |
+
|
| 16 |
+
## Reference
|
| 17 |
+
|
| 18 |
+
- [github.com/Plachtaa/VALL-E-X](https://github.com/Plachtaa/VALL-E-X/tree/master#readme)
|
| 19 |
+
- [github.com/gemelo-ai/vocos](https://github.com/gemelo-ai/vocos)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
infer_from_audio_md_org = """
|
| 23 |
+
Upload a speech of 3~10 seconds as the audio prompt and type in the text you'd like to synthesize.<br>
|
| 24 |
+
The model will synthesize speech of given text with the same voice of your audio prompt.<br>
|
| 25 |
+
The model also tends to preserve the emotion & acoustic environment of your given speech.<br>
|
| 26 |
+
For faster inference, please use **"Make prompt"** to get a `.npz` file as the encoded audio prompt, and use it by **"Infer from prompt"**
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
infer_from_audio_ja_md = """
|
| 30 |
+
3〜10 秒程度の音声をプロンプトとしてアップロードし、合成したいテキストを入力してください。<br>
|
| 31 |
+
モデルは、プロンプトと同じ声質でテキストを読み上げる音声を生成します。<br>
|
| 32 |
+
元の音声に含まれる感情や音響環境も比較的保持されます。<br>
|
| 33 |
+
推論を高速化したい場合は **"Make prompt"** で `.npz` ファイルを作成し、 **"Infer from prompt"** で利用してください。
|
| 34 |
+
"""
|
apps/audio_cloning/vallex/examples.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_prompts_dir = "apps/audio_cloning/vallex/prompts"
|
| 2 |
+
|
| 3 |
+
infer_from_audio_examples = [
|
| 4 |
+
[
|
| 5 |
+
"私のクローンに騙されないでください。",
|
| 6 |
+
"日本語",
|
| 7 |
+
"no-accent",
|
| 8 |
+
f"{_prompts_dir}/ja-okuwaki.wav",
|
| 9 |
+
None,
|
| 10 |
+
"こんにちは、私の名前はオクワキヨウスケです。",
|
| 11 |
+
],
|
| 12 |
+
[
|
| 13 |
+
"ぼくのクローンに騙されないでくれなのだ。",
|
| 14 |
+
"日本語",
|
| 15 |
+
"no-accent",
|
| 16 |
+
f"{_prompts_dir}/ja-zundamon.wav",
|
| 17 |
+
None,
|
| 18 |
+
"はじめまして、ずんだもんなのだ",
|
| 19 |
+
],
|
| 20 |
+
[
|
| 21 |
+
"私のクローンに騙されないでください。",
|
| 22 |
+
"日本語",
|
| 23 |
+
"no-accent",
|
| 24 |
+
f"{_prompts_dir}/ja-okuwaki-long.wav",
|
| 25 |
+
None,
|
| 26 |
+
"こんにちは、私の名前はオクワキヨウスケです。これは音声クローニング用のサンプルです。",
|
| 27 |
+
],
|
| 28 |
+
[
|
| 29 |
+
"私の声を真似するのはそんなに面白いですか?",
|
| 30 |
+
"日本語",
|
| 31 |
+
"no-accent",
|
| 32 |
+
f"{_prompts_dir}/ja-2.ogg",
|
| 33 |
+
None,
|
| 34 |
+
"初めまして、朝武よしのです。",
|
| 35 |
+
],
|
| 36 |
+
[
|
| 37 |
+
"This is how this machine has taken my voice.",
|
| 38 |
+
"English",
|
| 39 |
+
"no-accent",
|
| 40 |
+
f"{_prompts_dir}/en-2.wav",
|
| 41 |
+
None,
|
| 42 |
+
"Wow, look at that! That's no ordinary Teddy bear!",
|
| 43 |
+
],
|
| 44 |
+
[
|
| 45 |
+
"我喜欢抽电子烟,尤其是锐刻五代。",
|
| 46 |
+
"中文",
|
| 47 |
+
"no-accent",
|
| 48 |
+
f"{_prompts_dir}/zh-1.wav",
|
| 49 |
+
None,
|
| 50 |
+
"今天我很荣幸,",
|
| 51 |
+
],
|
| 52 |
+
[
|
| 53 |
+
"你可以听得出来我有多困。",
|
| 54 |
+
"中文",
|
| 55 |
+
"no-accent",
|
| 56 |
+
f"{_prompts_dir}/en-1.wav",
|
| 57 |
+
None,
|
| 58 |
+
"",
|
| 59 |
+
],
|
| 60 |
+
[
|
| 61 |
+
"この文は、クロスリンガル合成の例です。",
|
| 62 |
+
"日本語",
|
| 63 |
+
"no-accent",
|
| 64 |
+
f"{_prompts_dir}/zh-2.wav",
|
| 65 |
+
None,
|
| 66 |
+
"",
|
| 67 |
+
],
|
| 68 |
+
[
|
| 69 |
+
"Actually, I can't speak English, but this machine helped me do it.",
|
| 70 |
+
"English",
|
| 71 |
+
"no-accent",
|
| 72 |
+
f"{_prompts_dir}/ja-1.wav",
|
| 73 |
+
None,
|
| 74 |
+
"",
|
| 75 |
+
],
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
make_npz_prompt_examples = [
|
| 79 |
+
[
|
| 80 |
+
"Gem-trader",
|
| 81 |
+
f"{_prompts_dir}/en-2.wav",
|
| 82 |
+
None,
|
| 83 |
+
"Wow, look at that! That's no ordinary Teddy bear!",
|
| 84 |
+
],
|
| 85 |
+
["Ding Zhen", f"{_prompts_dir}/zh-1.wav", None, "今天我很荣幸,"],
|
| 86 |
+
["Yoshino", f"{_prompts_dir}/ja-2.ogg", None, "初めまして、朝武よしのです。"],
|
| 87 |
+
["Sleepy-woman", f"{_prompts_dir}/en-1.wav", None, ""],
|
| 88 |
+
["Yae", f"{_prompts_dir}/zh-2.wav", None, ""],
|
| 89 |
+
["Cafe", f"{_prompts_dir}/ja-1.wav", None, ""],
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
infer_from_prompt_examples = [
|
| 93 |
+
[
|
| 94 |
+
"A prompt contains voice, prosody and emotion information of a certain speaker.",
|
| 95 |
+
"English",
|
| 96 |
+
"no-accent",
|
| 97 |
+
f"{_prompts_dir}/vctk_1",
|
| 98 |
+
None,
|
| 99 |
+
],
|
| 100 |
+
[
|
| 101 |
+
"This prompt is made with an audio of three seconds.",
|
| 102 |
+
"English",
|
| 103 |
+
"no-accent",
|
| 104 |
+
f"{_prompts_dir}/librispeech_1",
|
| 105 |
+
None,
|
| 106 |
+
],
|
| 107 |
+
["This prompt is made with Chinese speech", "English", "no-accent", "seel", None],
|
| 108 |
+
]
|
apps/audio_cloning/vallex/g2p/__init__.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""from https://github.com/keithito/tacotron"""
|
| 2 |
+
|
| 3 |
+
# import utils.g2p.cleaners
|
| 4 |
+
from tokenizers import Tokenizer
|
| 5 |
+
|
| 6 |
+
import apps.audio_cloning.vallex.g2p.cleaners as cleaners
|
| 7 |
+
|
| 8 |
+
from .symbols import symbols
|
| 9 |
+
|
| 10 |
+
# Mappings from symbol to numeric ID and vice versa:
|
| 11 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
| 12 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
| 13 |
+
|
| 14 |
+
TOKENIZER_PATH = "./apps/audio_cloning/vallex/g2p/bpe_1024.json"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class PhonemeBpeTokenizer:
|
| 18 |
+
def __init__(self, tokenizer_path=TOKENIZER_PATH):
|
| 19 |
+
print(f"Initializing PhonemeBpeTokenizer with tokenizer path: {tokenizer_path}")
|
| 20 |
+
self.tokenizer = Tokenizer.from_file(tokenizer_path)
|
| 21 |
+
|
| 22 |
+
def tokenize(self, text):
|
| 23 |
+
# 1. convert text to phoneme
|
| 24 |
+
phonemes, langs = _clean_text(text, ["cje_cleaners"])
|
| 25 |
+
# 2. replace blank space " " with "_"
|
| 26 |
+
phonemes = phonemes.replace(" ", "_")
|
| 27 |
+
# 3. tokenize phonemes
|
| 28 |
+
phoneme_tokens = self.tokenizer.encode(phonemes).ids
|
| 29 |
+
assert len(phoneme_tokens) == len(langs)
|
| 30 |
+
if not len(phoneme_tokens):
|
| 31 |
+
raise ValueError("Empty text is given")
|
| 32 |
+
return phoneme_tokens, langs
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def text_to_sequence(text, cleaner_names):
|
| 36 |
+
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
| 37 |
+
Args:
|
| 38 |
+
text: string to convert to a sequence
|
| 39 |
+
cleaner_names: names of the cleaner functions to run the text through
|
| 40 |
+
Returns:
|
| 41 |
+
List of integers corresponding to the symbols in the text
|
| 42 |
+
"""
|
| 43 |
+
sequence = []
|
| 44 |
+
symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
| 45 |
+
clean_text = _clean_text(text, cleaner_names)
|
| 46 |
+
for symbol in clean_text:
|
| 47 |
+
if symbol not in symbol_to_id.keys():
|
| 48 |
+
continue
|
| 49 |
+
symbol_id = symbol_to_id[symbol]
|
| 50 |
+
sequence += [symbol_id]
|
| 51 |
+
return sequence
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def cleaned_text_to_sequence(cleaned_text):
|
| 55 |
+
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
| 56 |
+
Args:
|
| 57 |
+
text: string to convert to a sequence
|
| 58 |
+
Returns:
|
| 59 |
+
List of integers corresponding to the symbols in the text
|
| 60 |
+
"""
|
| 61 |
+
sequence = [
|
| 62 |
+
_symbol_to_id[symbol]
|
| 63 |
+
for symbol in cleaned_text
|
| 64 |
+
if symbol in _symbol_to_id.keys()
|
| 65 |
+
]
|
| 66 |
+
return sequence
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def sequence_to_text(sequence):
|
| 70 |
+
"""Converts a sequence of IDs back to a string"""
|
| 71 |
+
result = ""
|
| 72 |
+
for symbol_id in sequence:
|
| 73 |
+
s = _id_to_symbol[symbol_id]
|
| 74 |
+
result += s
|
| 75 |
+
return result
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _clean_text(text, cleaner_names):
|
| 79 |
+
for name in cleaner_names:
|
| 80 |
+
cleaner = getattr(cleaners, name)
|
| 81 |
+
if not cleaner:
|
| 82 |
+
raise Exception("Unknown cleaner: %s" % name)
|
| 83 |
+
text, langs = cleaner(text)
|
| 84 |
+
return text, langs
|
apps/audio_cloning/vallex/g2p/bpe_1024.json
ADDED
|
@@ -0,0 +1,2049 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "1.0",
|
| 3 |
+
"truncation": null,
|
| 4 |
+
"padding": null,
|
| 5 |
+
"added_tokens": [
|
| 6 |
+
{
|
| 7 |
+
"id": 0,
|
| 8 |
+
"content": "[UNK]",
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"lstrip": false,
|
| 11 |
+
"rstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"special": true
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"id": 1,
|
| 17 |
+
"content": "[CLS]",
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"normalized": false,
|
| 22 |
+
"special": true
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"id": 2,
|
| 26 |
+
"content": "[SEP]",
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"lstrip": false,
|
| 29 |
+
"rstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"special": true
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"id": 3,
|
| 35 |
+
"content": "[PAD]",
|
| 36 |
+
"single_word": false,
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"rstrip": false,
|
| 39 |
+
"normalized": false,
|
| 40 |
+
"special": true
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"id": 4,
|
| 44 |
+
"content": "[MASK]",
|
| 45 |
+
"single_word": false,
|
| 46 |
+
"lstrip": false,
|
| 47 |
+
"rstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"special": true
|
| 50 |
+
}
|
| 51 |
+
],
|
| 52 |
+
"normalizer": null,
|
| 53 |
+
"pre_tokenizer": {
|
| 54 |
+
"type": "Whitespace"
|
| 55 |
+
},
|
| 56 |
+
"post_processor": null,
|
| 57 |
+
"decoder": null,
|
| 58 |
+
"model": {
|
| 59 |
+
"type": "BPE",
|
| 60 |
+
"dropout": null,
|
| 61 |
+
"unk_token": "[UNK]",
|
| 62 |
+
"continuing_subword_prefix": null,
|
| 63 |
+
"end_of_word_suffix": null,
|
| 64 |
+
"fuse_unk": false,
|
| 65 |
+
"byte_fallback": false,
|
| 66 |
+
"vocab": {
|
| 67 |
+
"[UNK]": 0,
|
| 68 |
+
"[CLS]": 1,
|
| 69 |
+
"[SEP]": 2,
|
| 70 |
+
"[PAD]": 3,
|
| 71 |
+
"[MASK]": 4,
|
| 72 |
+
"!": 5,
|
| 73 |
+
"#": 6,
|
| 74 |
+
"*": 7,
|
| 75 |
+
",": 8,
|
| 76 |
+
"-": 9,
|
| 77 |
+
".": 10,
|
| 78 |
+
"=": 11,
|
| 79 |
+
"?": 12,
|
| 80 |
+
"N": 13,
|
| 81 |
+
"Q": 14,
|
| 82 |
+
"^": 15,
|
| 83 |
+
"_": 16,
|
| 84 |
+
"`": 17,
|
| 85 |
+
"a": 18,
|
| 86 |
+
"b": 19,
|
| 87 |
+
"d": 20,
|
| 88 |
+
"e": 21,
|
| 89 |
+
"f": 22,
|
| 90 |
+
"g": 23,
|
| 91 |
+
"h": 24,
|
| 92 |
+
"i": 25,
|
| 93 |
+
"j": 26,
|
| 94 |
+
"k": 27,
|
| 95 |
+
"l": 28,
|
| 96 |
+
"m": 29,
|
| 97 |
+
"n": 30,
|
| 98 |
+
"o": 31,
|
| 99 |
+
"p": 32,
|
| 100 |
+
"s": 33,
|
| 101 |
+
"t": 34,
|
| 102 |
+
"u": 35,
|
| 103 |
+
"v": 36,
|
| 104 |
+
"w": 37,
|
| 105 |
+
"x": 38,
|
| 106 |
+
"y": 39,
|
| 107 |
+
"z": 40,
|
| 108 |
+
"~": 41,
|
| 109 |
+
"æ": 42,
|
| 110 |
+
"ç": 43,
|
| 111 |
+
"ð": 44,
|
| 112 |
+
"ŋ": 45,
|
| 113 |
+
"ɑ": 46,
|
| 114 |
+
"ɔ": 47,
|
| 115 |
+
"ə": 48,
|
| 116 |
+
"ɛ": 49,
|
| 117 |
+
"ɥ": 50,
|
| 118 |
+
"ɪ": 51,
|
| 119 |
+
"ɫ": 52,
|
| 120 |
+
"ɯ": 53,
|
| 121 |
+
"ɸ": 54,
|
| 122 |
+
"ɹ": 55,
|
| 123 |
+
"ɾ": 56,
|
| 124 |
+
"ʃ": 57,
|
| 125 |
+
"ʊ": 58,
|
| 126 |
+
"ʑ": 59,
|
| 127 |
+
"ʒ": 60,
|
| 128 |
+
"ʰ": 61,
|
| 129 |
+
"ˈ": 62,
|
| 130 |
+
"ˌ": 63,
|
| 131 |
+
"θ": 64,
|
| 132 |
+
"…": 65,
|
| 133 |
+
"⁼": 66,
|
| 134 |
+
"↑": 67,
|
| 135 |
+
"→": 68,
|
| 136 |
+
"↓": 69,
|
| 137 |
+
"_t": 70,
|
| 138 |
+
"↓↑": 71,
|
| 139 |
+
"_ˈ": 72,
|
| 140 |
+
"ən": 73,
|
| 141 |
+
"_s": 74,
|
| 142 |
+
"aɪ": 75,
|
| 143 |
+
"əɹ": 76,
|
| 144 |
+
"eɪ": 77,
|
| 145 |
+
"oʊ": 78,
|
| 146 |
+
"_k": 79,
|
| 147 |
+
"ʃi": 80,
|
| 148 |
+
"_w": 81,
|
| 149 |
+
"_ð": 82,
|
| 150 |
+
"ts": 83,
|
| 151 |
+
"tʃ": 84,
|
| 152 |
+
"_ts": 85,
|
| 153 |
+
"_h": 86,
|
| 154 |
+
"_ə": 87,
|
| 155 |
+
"_m": 88,
|
| 156 |
+
"an": 89,
|
| 157 |
+
"_n": 90,
|
| 158 |
+
"_ðə": 91,
|
| 159 |
+
"ɛn": 92,
|
| 160 |
+
"ɑʊ": 93,
|
| 161 |
+
"ɑŋ": 94,
|
| 162 |
+
"`⁼": 95,
|
| 163 |
+
"_p": 96,
|
| 164 |
+
"_i": 97,
|
| 165 |
+
"_ɪ": 98,
|
| 166 |
+
"_tʃ": 99,
|
| 167 |
+
"_l": 100,
|
| 168 |
+
"jɛn": 101,
|
| 169 |
+
"_d": 102,
|
| 170 |
+
"_f": 103,
|
| 171 |
+
"_j": 104,
|
| 172 |
+
"wo": 105,
|
| 173 |
+
"_b": 106,
|
| 174 |
+
"ta": 107,
|
| 175 |
+
"`↓": 108,
|
| 176 |
+
"te": 109,
|
| 177 |
+
"ənd": 110,
|
| 178 |
+
"_ʃi": 111,
|
| 179 |
+
"wa": 112,
|
| 180 |
+
"ka": 113,
|
| 181 |
+
"ɪŋ": 114,
|
| 182 |
+
"in": 115,
|
| 183 |
+
"st": 116,
|
| 184 |
+
"li": 117,
|
| 185 |
+
"ʊŋ": 118,
|
| 186 |
+
"_tɪ": 119,
|
| 187 |
+
"to": 120,
|
| 188 |
+
"weɪ": 121,
|
| 189 |
+
"_ənd": 122,
|
| 190 |
+
"ʰi": 123,
|
| 191 |
+
"_əv": 124,
|
| 192 |
+
"əŋ": 125,
|
| 193 |
+
"no": 126,
|
| 194 |
+
"_x": 127,
|
| 195 |
+
"ɾɯ": 128,
|
| 196 |
+
"na": 129,
|
| 197 |
+
"_a": 130,
|
| 198 |
+
"_ɹ": 131,
|
| 199 |
+
"ɪn": 132,
|
| 200 |
+
"ga": 133,
|
| 201 |
+
"de": 134,
|
| 202 |
+
"joʊ": 135,
|
| 203 |
+
"æn": 136,
|
| 204 |
+
"kɯ": 137,
|
| 205 |
+
"ɾe": 138,
|
| 206 |
+
"ma": 139,
|
| 207 |
+
"_ðə_ˈ": 140,
|
| 208 |
+
"ɾa": 141,
|
| 209 |
+
"ɛɹ": 142,
|
| 210 |
+
"mo": 143,
|
| 211 |
+
"ɔɹ": 144,
|
| 212 |
+
"əɫ": 145,
|
| 213 |
+
"_g": 146,
|
| 214 |
+
"da": 147,
|
| 215 |
+
"*↑": 148,
|
| 216 |
+
"ɪˈ": 149,
|
| 217 |
+
"_o": 150,
|
| 218 |
+
"_ʃ": 151,
|
| 219 |
+
"iŋ": 152,
|
| 220 |
+
"ja": 153,
|
| 221 |
+
"əm": 154,
|
| 222 |
+
"_ˌ": 155,
|
| 223 |
+
"aʊ": 156,
|
| 224 |
+
"_əˈ": 157,
|
| 225 |
+
"`↑": 158,
|
| 226 |
+
"ət": 159,
|
| 227 |
+
"_aɪ": 160,
|
| 228 |
+
"oo": 161,
|
| 229 |
+
"sɯ": 162,
|
| 230 |
+
"↓.": 163,
|
| 231 |
+
"_ɪn": 164,
|
| 232 |
+
"_hi": 165,
|
| 233 |
+
"_wɪ": 166,
|
| 234 |
+
"ɪz": 167,
|
| 235 |
+
"_na": 168,
|
| 236 |
+
"wan": 169,
|
| 237 |
+
"_ko": 170,
|
| 238 |
+
"_wo": 171,
|
| 239 |
+
"ɪd": 172,
|
| 240 |
+
"ɾi": 173,
|
| 241 |
+
"_ju": 174,
|
| 242 |
+
"mə": 175,
|
| 243 |
+
"_lə": 176,
|
| 244 |
+
"_hæ": 177,
|
| 245 |
+
"_ðət": 178,
|
| 246 |
+
"ɑɹ": 179,
|
| 247 |
+
"tʰ": 180,
|
| 248 |
+
"ki": 181,
|
| 249 |
+
"……": 182,
|
| 250 |
+
"ɑz": 183,
|
| 251 |
+
"_ɔ": 184,
|
| 252 |
+
"_mi": 185,
|
| 253 |
+
"_wɑz": 186,
|
| 254 |
+
"_ˈs": 187,
|
| 255 |
+
"↓,": 188,
|
| 256 |
+
"_tʰ": 189,
|
| 257 |
+
"əˈ": 190,
|
| 258 |
+
"dʑ": 191,
|
| 259 |
+
"ɪt": 192,
|
| 260 |
+
"_kʰ": 193,
|
| 261 |
+
"iɛ": 194,
|
| 262 |
+
"_ma": 195,
|
| 263 |
+
"ɪs": 196,
|
| 264 |
+
"tsɯ": 197,
|
| 265 |
+
"_ni": 198,
|
| 266 |
+
"_ɪt": 199,
|
| 267 |
+
"ke": 200,
|
| 268 |
+
"iɑʊ": 201,
|
| 269 |
+
"_ka": 202,
|
| 270 |
+
"_əɹ": 203,
|
| 271 |
+
"nd": 204,
|
| 272 |
+
"_ˈp": 205,
|
| 273 |
+
"ko": 206,
|
| 274 |
+
"jo": 207,
|
| 275 |
+
"ɹi": 208,
|
| 276 |
+
"mən": 209,
|
| 277 |
+
"ʊd": 210,
|
| 278 |
+
"_ˈm": 211,
|
| 279 |
+
"_fəɹ": 212,
|
| 280 |
+
"tʃʰi": 213,
|
| 281 |
+
"sa": 214,
|
| 282 |
+
"ʰɥ": 215,
|
| 283 |
+
"kʰ": 216,
|
| 284 |
+
"ˈs": 217,
|
| 285 |
+
"ɑt": 218,
|
| 286 |
+
"ɛd": 219,
|
| 287 |
+
"se": 220,
|
| 288 |
+
"tʃi": 221,
|
| 289 |
+
"ɛɫ": 222,
|
| 290 |
+
"_ˈk": 223,
|
| 291 |
+
"_joʊ": 224,
|
| 292 |
+
"təɹ": 225,
|
| 293 |
+
"ɛz": 226,
|
| 294 |
+
"--": 227,
|
| 295 |
+
"vəɹ": 228,
|
| 296 |
+
"`→": 229,
|
| 297 |
+
"ʃən": 230,
|
| 298 |
+
"_ɪz": 231,
|
| 299 |
+
"_meɪ": 232,
|
| 300 |
+
"_æ": 233,
|
| 301 |
+
"dʒ": 234,
|
| 302 |
+
"_ki": 235,
|
| 303 |
+
"_hɪz": 236,
|
| 304 |
+
"_bi": 237,
|
| 305 |
+
"uɑŋ": 238,
|
| 306 |
+
"_ˈf": 239,
|
| 307 |
+
"↓↑.": 240,
|
| 308 |
+
"_wɪθ": 241,
|
| 309 |
+
"ju": 242,
|
| 310 |
+
"iɑŋ": 243,
|
| 311 |
+
"→.": 244,
|
| 312 |
+
"_so": 245,
|
| 313 |
+
"_həɹ": 246,
|
| 314 |
+
"↑.": 247,
|
| 315 |
+
"ni": 248,
|
| 316 |
+
"_mo": 249,
|
| 317 |
+
"_maɪ": 250,
|
| 318 |
+
"laɪ": 251,
|
| 319 |
+
"ɥɛ": 252,
|
| 320 |
+
"_ta": 253,
|
| 321 |
+
"ənt": 254,
|
| 322 |
+
"_tʃʰi": 255,
|
| 323 |
+
"_sɯ": 256,
|
| 324 |
+
"_θ": 257,
|
| 325 |
+
"_ɛz": 258,
|
| 326 |
+
"wən": 259,
|
| 327 |
+
"me": 260,
|
| 328 |
+
"mi": 261,
|
| 329 |
+
"_hæd": 262,
|
| 330 |
+
"_ha": 263,
|
| 331 |
+
"əs": 264,
|
| 332 |
+
"_ˈl": 265,
|
| 333 |
+
"_st": 266,
|
| 334 |
+
"ðəɹ": 267,
|
| 335 |
+
"oʊn": 268,
|
| 336 |
+
"_wa": 269,
|
| 337 |
+
"ʰəŋ": 270,
|
| 338 |
+
"_nɑt": 271,
|
| 339 |
+
"*.": 272,
|
| 340 |
+
"kt": 273,
|
| 341 |
+
"_ˈh": 274,
|
| 342 |
+
"do": 275,
|
| 343 |
+
"ɥæn": 276,
|
| 344 |
+
"ne": 277,
|
| 345 |
+
"_to": 278,
|
| 346 |
+
"_wən": 279,
|
| 347 |
+
"_no": 280,
|
| 348 |
+
"_laɪ": 281,
|
| 349 |
+
"_wəɹ": 282,
|
| 350 |
+
"↑,": 283,
|
| 351 |
+
"→,": 284,
|
| 352 |
+
"ɛs": 285,
|
| 353 |
+
"↓↑,": 286,
|
| 354 |
+
"_ɔn": 287,
|
| 355 |
+
"ʰu": 288,
|
| 356 |
+
"so": 289,
|
| 357 |
+
"_ˈb": 290,
|
| 358 |
+
"ɫd": 291,
|
| 359 |
+
"ɪk": 292,
|
| 360 |
+
"ɪst": 293,
|
| 361 |
+
"_fɹ": 294,
|
| 362 |
+
"_ðɛɹ": 295,
|
| 363 |
+
"_weɪ": 296,
|
| 364 |
+
"kaɾa": 297,
|
| 365 |
+
"_ˈd": 298,
|
| 366 |
+
"_hæv": 299,
|
| 367 |
+
"tsʰ": 300,
|
| 368 |
+
"waɪ": 301,
|
| 369 |
+
"ɾo": 302,
|
| 370 |
+
"ɛm": 303,
|
| 371 |
+
"_æt": 304,
|
| 372 |
+
"ʊɹ": 305,
|
| 373 |
+
"_ˈw": 306,
|
| 374 |
+
"ba": 307,
|
| 375 |
+
"_noʊ": 308,
|
| 376 |
+
"ʰjɛn": 309,
|
| 377 |
+
"ɹeɪ": 310,
|
| 378 |
+
"_jo": 311,
|
| 379 |
+
"ɸɯ": 312,
|
| 380 |
+
"_sa": 313,
|
| 381 |
+
"_ɹɪˈ": 314,
|
| 382 |
+
"_ˈn": 315,
|
| 383 |
+
"ai": 316,
|
| 384 |
+
"_bət": 317,
|
| 385 |
+
"ɪɹ": 318,
|
| 386 |
+
"tʃʰɥ": 319,
|
| 387 |
+
"_dʑ": 320,
|
| 388 |
+
"əˌ": 321,
|
| 389 |
+
"_ðɪs": 322,
|
| 390 |
+
"..": 323,
|
| 391 |
+
"xwa": 324,
|
| 392 |
+
"_ɪm": 325,
|
| 393 |
+
"_dɪˈ": 326,
|
| 394 |
+
"_kən": 327,
|
| 395 |
+
"dʑi": 328,
|
| 396 |
+
"*,": 329,
|
| 397 |
+
"ɑn": 330,
|
| 398 |
+
"_ʃiɑŋ": 331,
|
| 399 |
+
"_kɯ": 332,
|
| 400 |
+
"ʃin": 333,
|
| 401 |
+
"_soʊ": 334,
|
| 402 |
+
"bi": 335,
|
| 403 |
+
"tʰjɛn": 336,
|
| 404 |
+
"te_i": 337,
|
| 405 |
+
"_tsʰ": 338,
|
| 406 |
+
"_ɯ": 339,
|
| 407 |
+
"aɪt": 340,
|
| 408 |
+
"ʰiŋ": 341,
|
| 409 |
+
"ðə": 342,
|
| 410 |
+
"_ɔɫ": 343,
|
| 411 |
+
"_ˈɹ": 344,
|
| 412 |
+
"nai": 345,
|
| 413 |
+
"əɹd": 346,
|
| 414 |
+
"_ˈt": 347,
|
| 415 |
+
"_ən": 348,
|
| 416 |
+
"_tʃʰɥ": 349,
|
| 417 |
+
"_iɛ": 350,
|
| 418 |
+
"leɪ": 351,
|
| 419 |
+
"ɛɹi": 352,
|
| 420 |
+
"ˈt": 353,
|
| 421 |
+
"ha": 354,
|
| 422 |
+
"ʃiŋ": 355,
|
| 423 |
+
"ɛvəɹ": 356,
|
| 424 |
+
"zɯ": 357,
|
| 425 |
+
"_wi": 358,
|
| 426 |
+
"_ja": 359,
|
| 427 |
+
"ɛk": 360,
|
| 428 |
+
"ʰɑŋ": 361,
|
| 429 |
+
"_tsɯ": 362,
|
| 430 |
+
"_əv_ðə": 363,
|
| 431 |
+
"taʃi": 364,
|
| 432 |
+
"_sɛd": 365,
|
| 433 |
+
"_xə": 366,
|
| 434 |
+
"_li": 367,
|
| 435 |
+
"_si": 368,
|
| 436 |
+
"desɯ": 369,
|
| 437 |
+
"_ˌɪn": 370,
|
| 438 |
+
"ʃjɛn": 371,
|
| 439 |
+
"_baɪ": 372,
|
| 440 |
+
"on": 373,
|
| 441 |
+
"_xɑʊ": 374,
|
| 442 |
+
"_ðeɪ": 375,
|
| 443 |
+
"_xaɪ": 376,
|
| 444 |
+
"`↓↑": 377,
|
| 445 |
+
"xweɪ": 378,
|
| 446 |
+
"hi": 379,
|
| 447 |
+
"_se": 380,
|
| 448 |
+
"ə_s": 381,
|
| 449 |
+
"_fɹəm": 382,
|
| 450 |
+
"ʊt": 383,
|
| 451 |
+
"di": 384,
|
| 452 |
+
"aʊt": 385,
|
| 453 |
+
"əb": 386,
|
| 454 |
+
"sɹ": 387,
|
| 455 |
+
"əz": 388,
|
| 456 |
+
"_xweɪ": 389,
|
| 457 |
+
"_kʰə": 390,
|
| 458 |
+
"ɹu": 391,
|
| 459 |
+
"_u": 392,
|
| 460 |
+
"_de": 393,
|
| 461 |
+
"aɪd": 394,
|
| 462 |
+
"ɪv": 395,
|
| 463 |
+
"bɯ": 396,
|
| 464 |
+
"_ho": 397,
|
| 465 |
+
"əɹz": 398,
|
| 466 |
+
"joo": 399,
|
| 467 |
+
"_bɪˈ": 400,
|
| 468 |
+
"_tʰa": 401,
|
| 469 |
+
"ɛt": 402,
|
| 470 |
+
"en": 403,
|
| 471 |
+
"ɛni": 404,
|
| 472 |
+
"əst": 405,
|
| 473 |
+
"æk": 406,
|
| 474 |
+
"ə_ts": 407,
|
| 475 |
+
"_ˈɪn": 408,
|
| 476 |
+
"ti": 409,
|
| 477 |
+
"ɥn": 410,
|
| 478 |
+
"_dʒ": 411,
|
| 479 |
+
"xɑʊ": 412,
|
| 480 |
+
"_ˈv": 413,
|
| 481 |
+
"ʃiɑŋ": 414,
|
| 482 |
+
"pʰ": 415,
|
| 483 |
+
"_wɪtʃ": 416,
|
| 484 |
+
"eɪm": 417,
|
| 485 |
+
"oʊz": 418,
|
| 486 |
+
"əðəɹ": 419,
|
| 487 |
+
"fɑŋ": 420,
|
| 488 |
+
"_ˈg": 421,
|
| 489 |
+
"_do": 422,
|
| 490 |
+
"_ʃiɑʊ": 423,
|
| 491 |
+
"_ˈæ": 424,
|
| 492 |
+
"_jʊɹ": 425,
|
| 493 |
+
"_ðɛm": 426,
|
| 494 |
+
"ɪm": 427,
|
| 495 |
+
"ɛst": 428,
|
| 496 |
+
"ænd": 429,
|
| 497 |
+
"_du": 430,
|
| 498 |
+
"ɯɯ": 431,
|
| 499 |
+
"kan": 432,
|
| 500 |
+
"_da": 433,
|
| 501 |
+
"ino": 434,
|
| 502 |
+
"_e": 435,
|
| 503 |
+
"_wʊd": 436,
|
| 504 |
+
"ɛnd": 437,
|
| 505 |
+
"meɪ": 438,
|
| 506 |
+
"θɪŋ": 439,
|
| 507 |
+
"_ʃjɛn": 440,
|
| 508 |
+
"iz": 441,
|
| 509 |
+
"aɪm": 442,
|
| 510 |
+
"_hu": 443,
|
| 511 |
+
"_əˈb": 444,
|
| 512 |
+
"əns": 445,
|
| 513 |
+
"_wɪɫ": 446,
|
| 514 |
+
"tʰi": 447,
|
| 515 |
+
"go": 448,
|
| 516 |
+
"ɛnt": 449,
|
| 517 |
+
"fu": 450,
|
| 518 |
+
"æp": 451,
|
| 519 |
+
"xoʊ": 452,
|
| 520 |
+
"eɪk": 453,
|
| 521 |
+
"ʊk": 454,
|
| 522 |
+
"əɹˈ": 455,
|
| 523 |
+
"_θɪŋ": 456,
|
| 524 |
+
"əl": 457,
|
| 525 |
+
"pɹ": 458,
|
| 526 |
+
"ətʃ": 459,
|
| 527 |
+
"nt": 460,
|
| 528 |
+
"_ɸɯ": 461,
|
| 529 |
+
"lu": 462,
|
| 530 |
+
"_ˈɔ": 463,
|
| 531 |
+
"_iɑʊ": 464,
|
| 532 |
+
"lə": 465,
|
| 533 |
+
"tu": 466,
|
| 534 |
+
"_dʑi": 467,
|
| 535 |
+
"eɪt": 468,
|
| 536 |
+
"_ʃin": 469,
|
| 537 |
+
"nna": 470,
|
| 538 |
+
"_ˈpɹ": 471,
|
| 539 |
+
"fən": 472,
|
| 540 |
+
"_əp": 473,
|
| 541 |
+
"njɛn": 474,
|
| 542 |
+
"_aʊt": 475,
|
| 543 |
+
"fɔɹ": 476,
|
| 544 |
+
"_tu": 477,
|
| 545 |
+
"eɪʃən": 478,
|
| 546 |
+
"ɪɫ": 479,
|
| 547 |
+
"_wət": 480,
|
| 548 |
+
"_ɪf": 481,
|
| 549 |
+
"_ɥ": 482,
|
| 550 |
+
"_fa": 483,
|
| 551 |
+
"ˈw": 484,
|
| 552 |
+
"tʃʰjɛn": 485,
|
| 553 |
+
"_wɪn": 486,
|
| 554 |
+
"oʊɫd": 487,
|
| 555 |
+
"_əˈp": 488,
|
| 556 |
+
"aʊnd": 489,
|
| 557 |
+
"san": 490,
|
| 558 |
+
"he": 491,
|
| 559 |
+
"_bɪn": 492,
|
| 560 |
+
"fa": 493,
|
| 561 |
+
"ɪf": 494,
|
| 562 |
+
"ɔŋ": 495,
|
| 563 |
+
"ge": 496,
|
| 564 |
+
"_ɪn_ðə": 497,
|
| 565 |
+
"miŋ": 498,
|
| 566 |
+
"_pɹ": 499,
|
| 567 |
+
"ina": 500,
|
| 568 |
+
"ano": 501,
|
| 569 |
+
"əbəɫ": 502,
|
| 570 |
+
"kˈs": 503,
|
| 571 |
+
"_ˈɛni": 504,
|
| 572 |
+
"nəŋ": 505,
|
| 573 |
+
"əd": 506,
|
| 574 |
+
"_əv_ðə_ˈ": 507,
|
| 575 |
+
"_waɪ": 508,
|
| 576 |
+
"_taɪm": 509,
|
| 577 |
+
"ˈsɛɫ": 510,
|
| 578 |
+
"ʃiɛ": 511,
|
| 579 |
+
"_kəm": 512,
|
| 580 |
+
"æst": 513,
|
| 581 |
+
"_goʊ": 514,
|
| 582 |
+
"mɯ": 515,
|
| 583 |
+
"ˈp": 516,
|
| 584 |
+
"_ˈst": 517,
|
| 585 |
+
"ə_t": 518,
|
| 586 |
+
"pt": 519,
|
| 587 |
+
"_pʰ": 520,
|
| 588 |
+
"ʰɹ": 521,
|
| 589 |
+
"ʃja": 522,
|
| 590 |
+
"iwa": 523,
|
| 591 |
+
"ɪl": 524,
|
| 592 |
+
"bət": 525,
|
| 593 |
+
"_fɑŋ": 526,
|
| 594 |
+
"ho": 527,
|
| 595 |
+
"iv": 528,
|
| 596 |
+
"loʊ": 529,
|
| 597 |
+
"be": 530,
|
| 598 |
+
"_laɪk": 531,
|
| 599 |
+
"ɪʃ": 532,
|
| 600 |
+
"_fu": 533,
|
| 601 |
+
"ze": 534,
|
| 602 |
+
"ə_tʃ": 535,
|
| 603 |
+
"ɑɹt": 536,
|
| 604 |
+
"ɔɹd": 537,
|
| 605 |
+
"tʃʰiŋ": 538,
|
| 606 |
+
"mp": 539,
|
| 607 |
+
"_ðə_s": 540,
|
| 608 |
+
"_əˈbaʊt": 541,
|
| 609 |
+
"_ˈoʊ": 542,
|
| 610 |
+
"kʰə": 543,
|
| 611 |
+
"d_tɪ": 544,
|
| 612 |
+
"ŋga": 545,
|
| 613 |
+
"əli": 546,
|
| 614 |
+
"_kʰan": 547,
|
| 615 |
+
"çi": 548,
|
| 616 |
+
"_ˈju": 549,
|
| 617 |
+
"_kʊd": 550,
|
| 618 |
+
"ɔɫ": 551,
|
| 619 |
+
"ɔt": 552,
|
| 620 |
+
"_ɪts": 553,
|
| 621 |
+
"_san": 554,
|
| 622 |
+
"tʃa": 555,
|
| 623 |
+
"i_na": 556,
|
| 624 |
+
"xə": 557,
|
| 625 |
+
"ɛkt": 558,
|
| 626 |
+
"_mɔɹ": 559,
|
| 627 |
+
"te_kɯ": 560,
|
| 628 |
+
"ɪdʒ": 561,
|
| 629 |
+
"jʊŋ": 562,
|
| 630 |
+
"_wan": 563,
|
| 631 |
+
"æt": 564,
|
| 632 |
+
"kat": 565,
|
| 633 |
+
"ˈsɛɫf": 566,
|
| 634 |
+
"_ke": 567,
|
| 635 |
+
"aɪnd": 568,
|
| 636 |
+
"it": 569,
|
| 637 |
+
"_ɑɹ": 570,
|
| 638 |
+
"sp": 571,
|
| 639 |
+
"oʊnt": 572,
|
| 640 |
+
"_tʃi": 573,
|
| 641 |
+
"tsʰɹ": 574,
|
| 642 |
+
"_xən": 575,
|
| 643 |
+
"_əˈg": 576,
|
| 644 |
+
"ə_k": 577,
|
| 645 |
+
"to_i": 578,
|
| 646 |
+
"_tʰi": 579,
|
| 647 |
+
"_iŋ": 580,
|
| 648 |
+
"aʊn": 581,
|
| 649 |
+
"gɯ": 582,
|
| 650 |
+
"_ɪkˈs": 583,
|
| 651 |
+
"ɛv": 584,
|
| 652 |
+
"gi": 585,
|
| 653 |
+
"ks": 586,
|
| 654 |
+
"_səm": 587,
|
| 655 |
+
"ana": 588,
|
| 656 |
+
"ɪtəɫ": 589,
|
| 657 |
+
"nan": 590,
|
| 658 |
+
"_ˈɪntu": 591,
|
| 659 |
+
"_hiɹ": 592,
|
| 660 |
+
"_te": 593,
|
| 661 |
+
"_naʊ": 594,
|
| 662 |
+
"ʃiɑʊ": 595,
|
| 663 |
+
"ʃo": 596,
|
| 664 |
+
"ɹe": 597,
|
| 665 |
+
"xaɪ": 598,
|
| 666 |
+
"_tʃʰiŋ": 599,
|
| 667 |
+
"_sɹ": 600,
|
| 668 |
+
"_haʊ": 601,
|
| 669 |
+
"?.": 602,
|
| 670 |
+
"_feɪ": 603,
|
| 671 |
+
"liŋ": 604,
|
| 672 |
+
"_ʃja": 605,
|
| 673 |
+
"_ˈdʒ": 606,
|
| 674 |
+
"_seɪ": 607,
|
| 675 |
+
"ˈn": 608,
|
| 676 |
+
"soʊ": 609,
|
| 677 |
+
"tʰʊŋ": 610,
|
| 678 |
+
"_ljoʊ": 611,
|
| 679 |
+
"maɪ": 612,
|
| 680 |
+
"_bɹ": 613,
|
| 681 |
+
"ɹeɪt": 614,
|
| 682 |
+
"_nəŋ": 615,
|
| 683 |
+
"ʰə": 616,
|
| 684 |
+
"æns": 617,
|
| 685 |
+
"_ˈɔl": 618,
|
| 686 |
+
"tatʃi": 619,
|
| 687 |
+
"nto": 620,
|
| 688 |
+
"_ˌɪnˈ": 621,
|
| 689 |
+
"le": 622,
|
| 690 |
+
"nde": 623,
|
| 691 |
+
"_ˈvɛɹi": 624,
|
| 692 |
+
"mənt": 625,
|
| 693 |
+
"ɾima": 626,
|
| 694 |
+
"_ðɛn": 627,
|
| 695 |
+
"_həz": 628,
|
| 696 |
+
"_ɹi": 629,
|
| 697 |
+
"ftəɹ": 630,
|
| 698 |
+
"_sp": 631,
|
| 699 |
+
"ɾewa": 632,
|
| 700 |
+
"ga_a": 633,
|
| 701 |
+
"z_əv": 634,
|
| 702 |
+
"_miŋ": 635,
|
| 703 |
+
"_tɪ_ðə": 636,
|
| 704 |
+
"ɹaɪ": 637,
|
| 705 |
+
"ɛl": 638,
|
| 706 |
+
"ɹæ": 639,
|
| 707 |
+
"_hoʊ": 640,
|
| 708 |
+
"xu": 641,
|
| 709 |
+
"oʊnli": 642,
|
| 710 |
+
"ŋk": 643,
|
| 711 |
+
"i_i": 644,
|
| 712 |
+
"_dɪd": 645,
|
| 713 |
+
"_dʒɪst": 646,
|
| 714 |
+
"ing": 647,
|
| 715 |
+
"kai": 648,
|
| 716 |
+
"_mæn": 649,
|
| 717 |
+
"_in": 650,
|
| 718 |
+
"zo": 651,
|
| 719 |
+
"əf": 652,
|
| 720 |
+
"dake": 653,
|
| 721 |
+
"_ˈsəm": 654,
|
| 722 |
+
"ɾɯ_no": 655,
|
| 723 |
+
"_go": 656,
|
| 724 |
+
"tʃəɹ": 657,
|
| 725 |
+
"ite": 658,
|
| 726 |
+
"`↓.": 659,
|
| 727 |
+
"_kʰaɪ": 660,
|
| 728 |
+
"sk": 661,
|
| 729 |
+
"ɔɹs": 662,
|
| 730 |
+
"_tʰiŋ": 663,
|
| 731 |
+
"_nə": 664,
|
| 732 |
+
"pəɫ": 665,
|
| 733 |
+
"_tɪ_bi": 666,
|
| 734 |
+
"ˈfɔɹ": 667,
|
| 735 |
+
"mu": 668,
|
| 736 |
+
"su": 669,
|
| 737 |
+
"aa": 670,
|
| 738 |
+
"ɪstəɹ": 671,
|
| 739 |
+
"ʰan": 672,
|
| 740 |
+
"pəɹ": 673,
|
| 741 |
+
"ə_p": 674,
|
| 742 |
+
"liɑŋ": 675,
|
| 743 |
+
"_v": 676,
|
| 744 |
+
"oʊst": 677,
|
| 745 |
+
"_əˈgɛn": 678,
|
| 746 |
+
"ənz": 679,
|
| 747 |
+
"No": 680,
|
| 748 |
+
"ɔɹt": 681,
|
| 749 |
+
"_səˈ": 682,
|
| 750 |
+
"_mɯ": 683,
|
| 751 |
+
"tʃʰ": 684,
|
| 752 |
+
"_ˈlɪtəɫ": 685,
|
| 753 |
+
"_xwo": 686,
|
| 754 |
+
"_ˌbi": 687,
|
| 755 |
+
"_ˈoʊvəɹ": 688,
|
| 756 |
+
"_çi": 689,
|
| 757 |
+
"_deɪ": 690,
|
| 758 |
+
"aɪn": 691,
|
| 759 |
+
"_ʃiŋ": 692,
|
| 760 |
+
"i_ʃi": 693,
|
| 761 |
+
"_tsʰaɪ": 694,
|
| 762 |
+
"ʃoo": 695,
|
| 763 |
+
"ɾoo": 696,
|
| 764 |
+
"bəɹ": 697,
|
| 765 |
+
"ʰa": 698,
|
| 766 |
+
"ˈɛs": 699,
|
| 767 |
+
"_ɪn_ðə_ˈ": 700,
|
| 768 |
+
"Nwa": 701,
|
| 769 |
+
"_ðən": 702,
|
| 770 |
+
"saɪ": 703,
|
| 771 |
+
"_ˈjuˈɛs": 704,
|
| 772 |
+
"nda": 705,
|
| 773 |
+
"_pleɪ": 706,
|
| 774 |
+
"ɪŋ_tɪ": 707,
|
| 775 |
+
"ɪti": 708,
|
| 776 |
+
"_me": 709,
|
| 777 |
+
"_ʃʊd": 710,
|
| 778 |
+
"_nu": 711,
|
| 779 |
+
"_ðə_k": 712,
|
| 780 |
+
"za": 713,
|
| 781 |
+
"_ˈɛvəɹ": 714,
|
| 782 |
+
"əɹn": 715,
|
| 783 |
+
"æd": 716,
|
| 784 |
+
"ˈm": 717,
|
| 785 |
+
"_doʊnt": 718,
|
| 786 |
+
"_məst": 719,
|
| 787 |
+
"jɯɯ": 720,
|
| 788 |
+
"ɑɹd": 721,
|
| 789 |
+
"_jɛn": 722,
|
| 790 |
+
"ʃɥ": 723,
|
| 791 |
+
"_ˈoʊnli": 724,
|
| 792 |
+
"_ʃo": 725,
|
| 793 |
+
"_liŋ": 726,
|
| 794 |
+
"ss": 727,
|
| 795 |
+
"ɑl": 728,
|
| 796 |
+
"dea": 729,
|
| 797 |
+
"ɾeta": 730,
|
| 798 |
+
"mjɛn": 731,
|
| 799 |
+
"_gʊd": 732,
|
| 800 |
+
"_wɔ": 733,
|
| 801 |
+
"imo": 734,
|
| 802 |
+
"no_ko": 735,
|
| 803 |
+
"_ɥæn": 736,
|
| 804 |
+
"ndʒ": 737,
|
| 805 |
+
"ɪʃən": 738,
|
| 806 |
+
"o_ʃi": 739,
|
| 807 |
+
"_θɪŋk": 740,
|
| 808 |
+
"_nan": 741,
|
| 809 |
+
"to_o": 742,
|
| 810 |
+
"_tʰʊŋ": 743,
|
| 811 |
+
"ljoʊ": 744,
|
| 812 |
+
"tai": 745,
|
| 813 |
+
"mə_s": 746,
|
| 814 |
+
"_jɯ": 747,
|
| 815 |
+
"_uɑŋ": 748,
|
| 816 |
+
"_ˌbiˈfɔɹ": 749,
|
| 817 |
+
"æs": 750,
|
| 818 |
+
"_tʃʰjɛn": 751,
|
| 819 |
+
"ik": 752,
|
| 820 |
+
"_bæk": 753,
|
| 821 |
+
"_ˈiv": 754,
|
| 822 |
+
"eɪn": 755,
|
| 823 |
+
"un": 756,
|
| 824 |
+
"la": 757,
|
| 825 |
+
"ˈk": 758,
|
| 826 |
+
"_daʊn": 759,
|
| 827 |
+
"anai": 760,
|
| 828 |
+
"_lɛ": 761,
|
| 829 |
+
"əɹt": 762,
|
| 830 |
+
"ðɛɹ": 763,
|
| 831 |
+
"_ˈæftəɹ": 764,
|
| 832 |
+
"dat": 765,
|
| 833 |
+
"fan": 766,
|
| 834 |
+
"bəɫ": 767,
|
| 835 |
+
"temo": 768,
|
| 836 |
+
"tʰa": 769,
|
| 837 |
+
"ɾɯ_ko": 770,
|
| 838 |
+
"ˈv": 771,
|
| 839 |
+
"feɪ": 772,
|
| 840 |
+
"_mətʃ": 773,
|
| 841 |
+
"xwo": 774,
|
| 842 |
+
"ɹoʊ": 775,
|
| 843 |
+
"_ba": 776,
|
| 844 |
+
"_ˈnɛvəɹ": 777,
|
| 845 |
+
"_meɪd": 778,
|
| 846 |
+
"_jʊŋ": 779,
|
| 847 |
+
"_əˈpɑn": 780,
|
| 848 |
+
"!?": 781,
|
| 849 |
+
"_ˈʃ": 782,
|
| 850 |
+
"_ðə_ˈk": 783,
|
| 851 |
+
"ft": 784,
|
| 852 |
+
"_bo": 785,
|
| 853 |
+
"_ɪn_ə": 786,
|
| 854 |
+
"tʃʰɥæn": 787,
|
| 855 |
+
"ˈz": 788,
|
| 856 |
+
"`↓,": 789,
|
| 857 |
+
"_bɪˈk": 790,
|
| 858 |
+
"ɪg": 791,
|
| 859 |
+
"kin": 792,
|
| 860 |
+
"_kl": 793,
|
| 861 |
+
"ɾɯ_n": 794,
|
| 862 |
+
"_lɑʊ": 795,
|
| 863 |
+
"----": 796,
|
| 864 |
+
"ika": 797,
|
| 865 |
+
"_ɹaɪt": 798,
|
| 866 |
+
"zd": 799,
|
| 867 |
+
"z_ənd": 800,
|
| 868 |
+
"_kjo": 801,
|
| 869 |
+
"xwan": 802,
|
| 870 |
+
"too": 803,
|
| 871 |
+
"_gɪt": 804,
|
| 872 |
+
"_liɑŋ": 805,
|
| 873 |
+
"ta_n": 806,
|
| 874 |
+
"_keɪm": 807,
|
| 875 |
+
"_ˈəðəɹ": 808,
|
| 876 |
+
"_wɛɫ": 809,
|
| 877 |
+
"teki": 810,
|
| 878 |
+
"see": 811,
|
| 879 |
+
"jɯ": 812,
|
| 880 |
+
"i_o": 813,
|
| 881 |
+
"to_ʃi": 814,
|
| 882 |
+
"fəɫ": 815,
|
| 883 |
+
"bo": 816,
|
| 884 |
+
"ˌt": 817,
|
| 885 |
+
"ɪp": 818,
|
| 886 |
+
"ane": 819,
|
| 887 |
+
"_tʰjɛn": 820,
|
| 888 |
+
"_tʃo": 821,
|
| 889 |
+
"ɾjo": 822,
|
| 890 |
+
"ɪns": 823,
|
| 891 |
+
"_he": 824,
|
| 892 |
+
"ŋka": 825,
|
| 893 |
+
"ʃɥɛ": 826,
|
| 894 |
+
"dʑa": 827,
|
| 895 |
+
"vd": 828,
|
| 896 |
+
"ʰwan": 829,
|
| 897 |
+
"_gɹeɪt": 830,
|
| 898 |
+
"_əv_ə": 831,
|
| 899 |
+
"əndəɹ": 832,
|
| 900 |
+
"kedo": 833,
|
| 901 |
+
"_ðə_b": 834,
|
| 902 |
+
"ək": 835,
|
| 903 |
+
"_teɪk": 836,
|
| 904 |
+
"kʰan": 837,
|
| 905 |
+
"_ˈɔlˌ": 838,
|
| 906 |
+
"swo": 839,
|
| 907 |
+
"_ɪt_wɑz": 840,
|
| 908 |
+
"_ʃɥ": 841,
|
| 909 |
+
"_sim": 842,
|
| 910 |
+
"_ˈfɑ": 843,
|
| 911 |
+
"min": 844,
|
| 912 |
+
"i_a": 845,
|
| 913 |
+
"soo": 846,
|
| 914 |
+
"ɛns": 847,
|
| 915 |
+
"_sətʃ": 848,
|
| 916 |
+
"tʰaɪ": 849,
|
| 917 |
+
"_ga": 850,
|
| 918 |
+
"i_ka": 851,
|
| 919 |
+
"koo": 852,
|
| 920 |
+
"_fəɹst": 853,
|
| 921 |
+
"_ˈtʃ": 854,
|
| 922 |
+
"nno": 855,
|
| 923 |
+
"ə_ɹ": 856,
|
| 924 |
+
"taɾa": 857,
|
| 925 |
+
"tʃʰjoʊ": 858,
|
| 926 |
+
"_æm": 859,
|
| 927 |
+
"_mu": 860,
|
| 928 |
+
"_meɪk": 861,
|
| 929 |
+
"↓…": 862,
|
| 930 |
+
"ɪˈθ": 863,
|
| 931 |
+
"ɑb": 864,
|
| 932 |
+
"ɹa": 865,
|
| 933 |
+
"_wɛɹ": 866,
|
| 934 |
+
"_ðə_ˈs": 867,
|
| 935 |
+
"_əˈl": 868,
|
| 936 |
+
"_oʊɫd": 869,
|
| 937 |
+
"æl": 870,
|
| 938 |
+
"_ˈpi": 871,
|
| 939 |
+
"_lɔŋ": 872,
|
| 940 |
+
"dʑo": 873,
|
| 941 |
+
"_tʰaɪ": 874,
|
| 942 |
+
"ɔɹn": 875,
|
| 943 |
+
"əɫz": 876,
|
| 944 |
+
"_təˈ": 877,
|
| 945 |
+
"_əˈweɪ": 878,
|
| 946 |
+
"pa": 879,
|
| 947 |
+
"_ðiz": 880,
|
| 948 |
+
"_ˈsp": 881,
|
| 949 |
+
"nn": 882,
|
| 950 |
+
"mae": 883,
|
| 951 |
+
"towa": 884,
|
| 952 |
+
"ta_no": 885,
|
| 953 |
+
"_an": 886,
|
| 954 |
+
"kʰaɪ": 887,
|
| 955 |
+
"ɾaɾe": 888,
|
| 956 |
+
"eɪs": 889,
|
| 957 |
+
"ɑd": 890,
|
| 958 |
+
"_wɪˈθ": 891,
|
| 959 |
+
"_ˈivɪn": 892,
|
| 960 |
+
"_lu": 893,
|
| 961 |
+
"ɔɪ": 894,
|
| 962 |
+
"lɪŋ": 895,
|
| 963 |
+
"əti": 896,
|
| 964 |
+
"_ðə_f": 897,
|
| 965 |
+
"oʃi": 898,
|
| 966 |
+
"_la": 899,
|
| 967 |
+
"si": 900,
|
| 968 |
+
"tɪd": 901,
|
| 969 |
+
"haʊ": 902,
|
| 970 |
+
"pʰin": 903,
|
| 971 |
+
"ˈst": 904,
|
| 972 |
+
"_ˈpəɹ": 905,
|
| 973 |
+
"eɹ": 906,
|
| 974 |
+
"*!": 907,
|
| 975 |
+
"_ˈmɪstəɹ": 908,
|
| 976 |
+
"ʃa": 909,
|
| 977 |
+
"_ˌɪm": 910,
|
| 978 |
+
"ˌθɪŋ": 911,
|
| 979 |
+
"_neɪ": 912,
|
| 980 |
+
"_nɥ": 913,
|
| 981 |
+
"ɑk": 914,
|
| 982 |
+
"_ɹu": 915,
|
| 983 |
+
"_ʃɯ": 916,
|
| 984 |
+
"_ðə_ˈm": 917,
|
| 985 |
+
"demo": 918,
|
| 986 |
+
"_dɹ": 919,
|
| 987 |
+
"dʑoo": 920,
|
| 988 |
+
"_stɪɫ": 921,
|
| 989 |
+
"_pʰiŋ": 922,
|
| 990 |
+
"ə_i": 923,
|
| 991 |
+
"_ɪkˈsp": 924,
|
| 992 |
+
"_wɛnt": 925,
|
| 993 |
+
"ɪɹi": 926,
|
| 994 |
+
"əˈm": 927,
|
| 995 |
+
"o_ka": 928,
|
| 996 |
+
"_əˈk": 929,
|
| 997 |
+
"ɔk": 930,
|
| 998 |
+
"_ɥɛ": 931,
|
| 999 |
+
"_lʊk": 932,
|
| 1000 |
+
"ˈd": 933,
|
| 1001 |
+
"kaʃi": 934,
|
| 1002 |
+
"_wɪθ_ə": 935,
|
| 1003 |
+
"ljɛn": 936,
|
| 1004 |
+
"ɔn": 937,
|
| 1005 |
+
"_ljɛn": 938,
|
| 1006 |
+
"_hɛɫ": 939,
|
| 1007 |
+
"uɹ": 940,
|
| 1008 |
+
"_tʰoʊ": 941,
|
| 1009 |
+
"_tʃʰɥæn": 942,
|
| 1010 |
+
"_sk": 943,
|
| 1011 |
+
"tsʰaɪ": 944,
|
| 1012 |
+
"ɛtəɹ": 945,
|
| 1013 |
+
"_min": 946,
|
| 1014 |
+
"noʊ": 947,
|
| 1015 |
+
"ʃɯ": 948,
|
| 1016 |
+
"_θɹu": 949,
|
| 1017 |
+
"_θɔt": 950,
|
| 1018 |
+
"dajo": 951,
|
| 1019 |
+
"wi": 952,
|
| 1020 |
+
"i_ko": 953,
|
| 1021 |
+
"_tɹ": 954,
|
| 1022 |
+
"_fan": 955,
|
| 1023 |
+
"ɹɛ": 956,
|
| 1024 |
+
"saN": 957,
|
| 1025 |
+
"_hi_wɑz": 958,
|
| 1026 |
+
"_ɾe": 959,
|
| 1027 |
+
"_əm": 960,
|
| 1028 |
+
"te_ki": 961,
|
| 1029 |
+
"_xoʊ": 962,
|
| 1030 |
+
"ˈl": 963,
|
| 1031 |
+
"ˈg": 964,
|
| 1032 |
+
"ga_i": 965,
|
| 1033 |
+
"_ɔn_ðə": 966,
|
| 1034 |
+
"_xwa": 967,
|
| 1035 |
+
"vɪŋ": 968,
|
| 1036 |
+
"man": 969,
|
| 1037 |
+
"fəɹ": 970,
|
| 1038 |
+
"_oʊn": 971,
|
| 1039 |
+
"ˈɹ": 972,
|
| 1040 |
+
"_kɹ": 973,
|
| 1041 |
+
"te_o": 974,
|
| 1042 |
+
"ɪli": 975,
|
| 1043 |
+
"_ʃɥɛ": 976,
|
| 1044 |
+
"_fəŋ": 977,
|
| 1045 |
+
"æɫ": 978,
|
| 1046 |
+
"ɑp": 979,
|
| 1047 |
+
"_ˈɛv": 980,
|
| 1048 |
+
"eɪndʒ": 981,
|
| 1049 |
+
"iɫ": 982,
|
| 1050 |
+
"wət": 983,
|
| 1051 |
+
"ɛðəɹ": 984,
|
| 1052 |
+
"_fən": 985,
|
| 1053 |
+
"ɾee": 986,
|
| 1054 |
+
"_hi_hæd": 987,
|
| 1055 |
+
"_maɪt": 988,
|
| 1056 |
+
"_ge": 989,
|
| 1057 |
+
"ækt": 990,
|
| 1058 |
+
"ɪts": 991,
|
| 1059 |
+
"_hɪm": 992,
|
| 1060 |
+
"_ze": 993,
|
| 1061 |
+
"ii": 994,
|
| 1062 |
+
"_N": 995,
|
| 1063 |
+
"_əv_hɪz": 996,
|
| 1064 |
+
"_gɹ": 997,
|
| 1065 |
+
"ænt": 998,
|
| 1066 |
+
"ɪˌ": 999,
|
| 1067 |
+
"_hɪmˈsɛɫf": 1000,
|
| 1068 |
+
"wa_na": 1001,
|
| 1069 |
+
"aɪəɹ": 1002,
|
| 1070 |
+
"dʑanai": 1003,
|
| 1071 |
+
"kana": 1004,
|
| 1072 |
+
"aɪz": 1005,
|
| 1073 |
+
"_ɪt_ɪz": 1006,
|
| 1074 |
+
"mase": 1007,
|
| 1075 |
+
"wɪn": 1008,
|
| 1076 |
+
"əθɪŋ": 1009,
|
| 1077 |
+
"_pɹəˈ": 1010,
|
| 1078 |
+
"kɯn": 1011,
|
| 1079 |
+
"ˈju": 1012,
|
| 1080 |
+
"_fɔɹ": 1013,
|
| 1081 |
+
"pʰi": 1014,
|
| 1082 |
+
"pʰiŋ": 1015,
|
| 1083 |
+
"o_i": 1016,
|
| 1084 |
+
"vz": 1017,
|
| 1085 |
+
"ɔɪn": 1018,
|
| 1086 |
+
"tʰiŋ": 1019,
|
| 1087 |
+
"_ne": 1020,
|
| 1088 |
+
"gəɹ": 1021,
|
| 1089 |
+
"æts": 1022,
|
| 1090 |
+
"_ˈɹi": 1023
|
| 1091 |
+
},
|
| 1092 |
+
"merges": [
|
| 1093 |
+
"_ t",
|
| 1094 |
+
"↓ ↑",
|
| 1095 |
+
"_ ˈ",
|
| 1096 |
+
"ə n",
|
| 1097 |
+
"_ s",
|
| 1098 |
+
"a ɪ",
|
| 1099 |
+
"ə ɹ",
|
| 1100 |
+
"e ɪ",
|
| 1101 |
+
"o ʊ",
|
| 1102 |
+
"_ k",
|
| 1103 |
+
"ʃ i",
|
| 1104 |
+
"_ w",
|
| 1105 |
+
"_ ð",
|
| 1106 |
+
"t s",
|
| 1107 |
+
"t ʃ",
|
| 1108 |
+
"_t s",
|
| 1109 |
+
"_ h",
|
| 1110 |
+
"_ ə",
|
| 1111 |
+
"_ m",
|
| 1112 |
+
"a n",
|
| 1113 |
+
"_ n",
|
| 1114 |
+
"_ð ə",
|
| 1115 |
+
"ɛ n",
|
| 1116 |
+
"ɑ ʊ",
|
| 1117 |
+
"ɑ ŋ",
|
| 1118 |
+
"` ⁼",
|
| 1119 |
+
"_ p",
|
| 1120 |
+
"_ i",
|
| 1121 |
+
"_ ɪ",
|
| 1122 |
+
"_t ʃ",
|
| 1123 |
+
"_ l",
|
| 1124 |
+
"j ɛn",
|
| 1125 |
+
"_ d",
|
| 1126 |
+
"_ f",
|
| 1127 |
+
"_ j",
|
| 1128 |
+
"w o",
|
| 1129 |
+
"_ b",
|
| 1130 |
+
"t a",
|
| 1131 |
+
"` ↓",
|
| 1132 |
+
"t e",
|
| 1133 |
+
"ən d",
|
| 1134 |
+
"_ ʃi",
|
| 1135 |
+
"w a",
|
| 1136 |
+
"k a",
|
| 1137 |
+
"ɪ ŋ",
|
| 1138 |
+
"i n",
|
| 1139 |
+
"s t",
|
| 1140 |
+
"l i",
|
| 1141 |
+
"ʊ ŋ",
|
| 1142 |
+
"_t ɪ",
|
| 1143 |
+
"t o",
|
| 1144 |
+
"w eɪ",
|
| 1145 |
+
"_ ənd",
|
| 1146 |
+
"ʰ i",
|
| 1147 |
+
"_ə v",
|
| 1148 |
+
"ə ŋ",
|
| 1149 |
+
"n o",
|
| 1150 |
+
"_ x",
|
| 1151 |
+
"ɾ ɯ",
|
| 1152 |
+
"n a",
|
| 1153 |
+
"_ a",
|
| 1154 |
+
"_ ɹ",
|
| 1155 |
+
"ɪ n",
|
| 1156 |
+
"g a",
|
| 1157 |
+
"d e",
|
| 1158 |
+
"j oʊ",
|
| 1159 |
+
"æ n",
|
| 1160 |
+
"k ɯ",
|
| 1161 |
+
"ɾ e",
|
| 1162 |
+
"m a",
|
| 1163 |
+
"_ðə _ˈ",
|
| 1164 |
+
"ɾ a",
|
| 1165 |
+
"ɛ ɹ",
|
| 1166 |
+
"m o",
|
| 1167 |
+
"ɔ ɹ",
|
| 1168 |
+
"ə ɫ",
|
| 1169 |
+
"_ g",
|
| 1170 |
+
"d a",
|
| 1171 |
+
"* ↑",
|
| 1172 |
+
"ɪ ˈ",
|
| 1173 |
+
"_ o",
|
| 1174 |
+
"_ ʃ",
|
| 1175 |
+
"i ŋ",
|
| 1176 |
+
"j a",
|
| 1177 |
+
"ə m",
|
| 1178 |
+
"_ ˌ",
|
| 1179 |
+
"a ʊ",
|
| 1180 |
+
"_ə ˈ",
|
| 1181 |
+
"` ↑",
|
| 1182 |
+
"ə t",
|
| 1183 |
+
"_ aɪ",
|
| 1184 |
+
"o o",
|
| 1185 |
+
"s ɯ",
|
| 1186 |
+
"↓ .",
|
| 1187 |
+
"_ɪ n",
|
| 1188 |
+
"_h i",
|
| 1189 |
+
"_w ɪ",
|
| 1190 |
+
"ɪ z",
|
| 1191 |
+
"_n a",
|
| 1192 |
+
"w an",
|
| 1193 |
+
"_k o",
|
| 1194 |
+
"_w o",
|
| 1195 |
+
"ɪ d",
|
| 1196 |
+
"ɾ i",
|
| 1197 |
+
"_j u",
|
| 1198 |
+
"m ə",
|
| 1199 |
+
"_l ə",
|
| 1200 |
+
"_h æ",
|
| 1201 |
+
"_ðə t",
|
| 1202 |
+
"ɑ ɹ",
|
| 1203 |
+
"t ʰ",
|
| 1204 |
+
"k i",
|
| 1205 |
+
"… …",
|
| 1206 |
+
"ɑ z",
|
| 1207 |
+
"_ ɔ",
|
| 1208 |
+
"_m i",
|
| 1209 |
+
"_w ɑz",
|
| 1210 |
+
"_ˈ s",
|
| 1211 |
+
"↓ ,",
|
| 1212 |
+
"_t ʰ",
|
| 1213 |
+
"ə ˈ",
|
| 1214 |
+
"d ʑ",
|
| 1215 |
+
"ɪ t",
|
| 1216 |
+
"_k ʰ",
|
| 1217 |
+
"i ɛ",
|
| 1218 |
+
"_m a",
|
| 1219 |
+
"ɪ s",
|
| 1220 |
+
"ts ɯ",
|
| 1221 |
+
"_n i",
|
| 1222 |
+
"_ɪ t",
|
| 1223 |
+
"k e",
|
| 1224 |
+
"i ɑʊ",
|
| 1225 |
+
"_k a",
|
| 1226 |
+
"_ əɹ",
|
| 1227 |
+
"n d",
|
| 1228 |
+
"_ˈ p",
|
| 1229 |
+
"k o",
|
| 1230 |
+
"j o",
|
| 1231 |
+
"ɹ i",
|
| 1232 |
+
"m ən",
|
| 1233 |
+
"ʊ d",
|
| 1234 |
+
"_ˈ m",
|
| 1235 |
+
"_f əɹ",
|
| 1236 |
+
"tʃ ʰi",
|
| 1237 |
+
"s a",
|
| 1238 |
+
"ʰ ɥ",
|
| 1239 |
+
"k ʰ",
|
| 1240 |
+
"ˈ s",
|
| 1241 |
+
"ɑ t",
|
| 1242 |
+
"ɛ d",
|
| 1243 |
+
"s e",
|
| 1244 |
+
"t ʃi",
|
| 1245 |
+
"ɛ ɫ",
|
| 1246 |
+
"_ˈ k",
|
| 1247 |
+
"_j oʊ",
|
| 1248 |
+
"t əɹ",
|
| 1249 |
+
"ɛ z",
|
| 1250 |
+
"- -",
|
| 1251 |
+
"v əɹ",
|
| 1252 |
+
"` →",
|
| 1253 |
+
"ʃ ən",
|
| 1254 |
+
"_ɪ z",
|
| 1255 |
+
"_m eɪ",
|
| 1256 |
+
"_ æ",
|
| 1257 |
+
"d ʒ",
|
| 1258 |
+
"_k i",
|
| 1259 |
+
"_h ɪz",
|
| 1260 |
+
"_b i",
|
| 1261 |
+
"u ɑŋ",
|
| 1262 |
+
"_ˈ f",
|
| 1263 |
+
"↓↑ .",
|
| 1264 |
+
"_wɪ θ",
|
| 1265 |
+
"j u",
|
| 1266 |
+
"i ɑŋ",
|
| 1267 |
+
"→ .",
|
| 1268 |
+
"_s o",
|
| 1269 |
+
"_h əɹ",
|
| 1270 |
+
"↑ .",
|
| 1271 |
+
"n i",
|
| 1272 |
+
"_m o",
|
| 1273 |
+
"_m aɪ",
|
| 1274 |
+
"l aɪ",
|
| 1275 |
+
"ɥ ɛ",
|
| 1276 |
+
"_t a",
|
| 1277 |
+
"ən t",
|
| 1278 |
+
"_tʃ ʰi",
|
| 1279 |
+
"_s ɯ",
|
| 1280 |
+
"_ θ",
|
| 1281 |
+
"_ ɛz",
|
| 1282 |
+
"w ən",
|
| 1283 |
+
"m e",
|
| 1284 |
+
"m i",
|
| 1285 |
+
"_hæ d",
|
| 1286 |
+
"_h a",
|
| 1287 |
+
"ə s",
|
| 1288 |
+
"_ˈ l",
|
| 1289 |
+
"_s t",
|
| 1290 |
+
"ð əɹ",
|
| 1291 |
+
"oʊ n",
|
| 1292 |
+
"_w a",
|
| 1293 |
+
"ʰ əŋ",
|
| 1294 |
+
"_n ɑt",
|
| 1295 |
+
"* .",
|
| 1296 |
+
"k t",
|
| 1297 |
+
"_ˈ h",
|
| 1298 |
+
"d o",
|
| 1299 |
+
"ɥ æn",
|
| 1300 |
+
"n e",
|
| 1301 |
+
"_t o",
|
| 1302 |
+
"_w ən",
|
| 1303 |
+
"_n o",
|
| 1304 |
+
"_l aɪ",
|
| 1305 |
+
"_w əɹ",
|
| 1306 |
+
"↑ ,",
|
| 1307 |
+
"→ ,",
|
| 1308 |
+
"ɛ s",
|
| 1309 |
+
"↓↑ ,",
|
| 1310 |
+
"_ɔ n",
|
| 1311 |
+
"ʰ u",
|
| 1312 |
+
"s o",
|
| 1313 |
+
"_ˈ b",
|
| 1314 |
+
"ɫ d",
|
| 1315 |
+
"ɪ k",
|
| 1316 |
+
"ɪ st",
|
| 1317 |
+
"_f ɹ",
|
| 1318 |
+
"_ð ɛɹ",
|
| 1319 |
+
"_w eɪ",
|
| 1320 |
+
"ka ɾa",
|
| 1321 |
+
"_ˈ d",
|
| 1322 |
+
"_hæ v",
|
| 1323 |
+
"ts ʰ",
|
| 1324 |
+
"w aɪ",
|
| 1325 |
+
"ɾ o",
|
| 1326 |
+
"ɛ m",
|
| 1327 |
+
"_æ t",
|
| 1328 |
+
"ʊ ɹ",
|
| 1329 |
+
"_ˈ w",
|
| 1330 |
+
"b a",
|
| 1331 |
+
"_n oʊ",
|
| 1332 |
+
"ʰ jɛn",
|
| 1333 |
+
"ɹ eɪ",
|
| 1334 |
+
"_j o",
|
| 1335 |
+
"ɸ ɯ",
|
| 1336 |
+
"_s a",
|
| 1337 |
+
"_ɹ ɪˈ",
|
| 1338 |
+
"_ˈ n",
|
| 1339 |
+
"a i",
|
| 1340 |
+
"_b ət",
|
| 1341 |
+
"ɪ ɹ",
|
| 1342 |
+
"tʃ ʰɥ",
|
| 1343 |
+
"_d ʑ",
|
| 1344 |
+
"ə ˌ",
|
| 1345 |
+
"_ð ɪs",
|
| 1346 |
+
". .",
|
| 1347 |
+
"x wa",
|
| 1348 |
+
"_ɪ m",
|
| 1349 |
+
"_d ɪˈ",
|
| 1350 |
+
"_k ən",
|
| 1351 |
+
"dʑ i",
|
| 1352 |
+
"* ,",
|
| 1353 |
+
"ɑ n",
|
| 1354 |
+
"_ʃi ɑŋ",
|
| 1355 |
+
"_k ɯ",
|
| 1356 |
+
"ʃi n",
|
| 1357 |
+
"_s oʊ",
|
| 1358 |
+
"b i",
|
| 1359 |
+
"tʰ jɛn",
|
| 1360 |
+
"te _i",
|
| 1361 |
+
"_ts ʰ",
|
| 1362 |
+
"_ ɯ",
|
| 1363 |
+
"aɪ t",
|
| 1364 |
+
"ʰi ŋ",
|
| 1365 |
+
"ð ə",
|
| 1366 |
+
"_ɔ ɫ",
|
| 1367 |
+
"_ˈ ɹ",
|
| 1368 |
+
"na i",
|
| 1369 |
+
"əɹ d",
|
| 1370 |
+
"_ˈ t",
|
| 1371 |
+
"_ ən",
|
| 1372 |
+
"_tʃ ʰɥ",
|
| 1373 |
+
"_i ɛ",
|
| 1374 |
+
"l eɪ",
|
| 1375 |
+
"ɛɹ i",
|
| 1376 |
+
"ˈ t",
|
| 1377 |
+
"h a",
|
| 1378 |
+
"ʃi ŋ",
|
| 1379 |
+
"ɛ vəɹ",
|
| 1380 |
+
"z ɯ",
|
| 1381 |
+
"_w i",
|
| 1382 |
+
"_j a",
|
| 1383 |
+
"ɛ k",
|
| 1384 |
+
"ʰ ɑŋ",
|
| 1385 |
+
"_ts ɯ",
|
| 1386 |
+
"_əv _ðə",
|
| 1387 |
+
"ta ʃi",
|
| 1388 |
+
"_s ɛd",
|
| 1389 |
+
"_x ə",
|
| 1390 |
+
"_l i",
|
| 1391 |
+
"_s i",
|
| 1392 |
+
"de sɯ",
|
| 1393 |
+
"_ˌ ɪn",
|
| 1394 |
+
"ʃ jɛn",
|
| 1395 |
+
"_b aɪ",
|
| 1396 |
+
"o n",
|
| 1397 |
+
"_x ɑʊ",
|
| 1398 |
+
"_ð eɪ",
|
| 1399 |
+
"_x aɪ",
|
| 1400 |
+
"` ↓↑",
|
| 1401 |
+
"x weɪ",
|
| 1402 |
+
"h i",
|
| 1403 |
+
"_s e",
|
| 1404 |
+
"ə _s",
|
| 1405 |
+
"_fɹ əm",
|
| 1406 |
+
"ʊ t",
|
| 1407 |
+
"d i",
|
| 1408 |
+
"aʊ t",
|
| 1409 |
+
"ə b",
|
| 1410 |
+
"s ɹ",
|
| 1411 |
+
"ə z",
|
| 1412 |
+
"_x weɪ",
|
| 1413 |
+
"_kʰ ə",
|
| 1414 |
+
"ɹ u",
|
| 1415 |
+
"_ u",
|
| 1416 |
+
"_d e",
|
| 1417 |
+
"aɪ d",
|
| 1418 |
+
"ɪ v",
|
| 1419 |
+
"b ɯ",
|
| 1420 |
+
"_h o",
|
| 1421 |
+
"əɹ z",
|
| 1422 |
+
"j oo",
|
| 1423 |
+
"_b ɪˈ",
|
| 1424 |
+
"_tʰ a",
|
| 1425 |
+
"ɛ t",
|
| 1426 |
+
"e n",
|
| 1427 |
+
"ɛn i",
|
| 1428 |
+
"ə st",
|
| 1429 |
+
"æ k",
|
| 1430 |
+
"ə _ts",
|
| 1431 |
+
"_ˈ ɪn",
|
| 1432 |
+
"t i",
|
| 1433 |
+
"ɥ n",
|
| 1434 |
+
"_d ʒ",
|
| 1435 |
+
"x ɑʊ",
|
| 1436 |
+
"_ˈ v",
|
| 1437 |
+
"ʃi ɑŋ",
|
| 1438 |
+
"p ʰ",
|
| 1439 |
+
"_wɪ tʃ",
|
| 1440 |
+
"eɪ m",
|
| 1441 |
+
"oʊ z",
|
| 1442 |
+
"ə ðəɹ",
|
| 1443 |
+
"f ɑŋ",
|
| 1444 |
+
"_ˈ g",
|
| 1445 |
+
"_d o",
|
| 1446 |
+
"_ʃi ɑʊ",
|
| 1447 |
+
"_ˈ æ",
|
| 1448 |
+
"_j ʊɹ",
|
| 1449 |
+
"_ð ɛm",
|
| 1450 |
+
"ɪ m",
|
| 1451 |
+
"ɛ st",
|
| 1452 |
+
"æn d",
|
| 1453 |
+
"_d u",
|
| 1454 |
+
"ɯ ɯ",
|
| 1455 |
+
"k an",
|
| 1456 |
+
"_d a",
|
| 1457 |
+
"in o",
|
| 1458 |
+
"_ e",
|
| 1459 |
+
"_w ʊd",
|
| 1460 |
+
"ɛn d",
|
| 1461 |
+
"m eɪ",
|
| 1462 |
+
"θ ɪŋ",
|
| 1463 |
+
"_ʃ jɛn",
|
| 1464 |
+
"i z",
|
| 1465 |
+
"aɪ m",
|
| 1466 |
+
"_h u",
|
| 1467 |
+
"_əˈ b",
|
| 1468 |
+
"ən s",
|
| 1469 |
+
"_wɪ ɫ",
|
| 1470 |
+
"t ʰi",
|
| 1471 |
+
"g o",
|
| 1472 |
+
"ɛn t",
|
| 1473 |
+
"f u",
|
| 1474 |
+
"æ p",
|
| 1475 |
+
"x oʊ",
|
| 1476 |
+
"eɪ k",
|
| 1477 |
+
"ʊ k",
|
| 1478 |
+
"əɹ ˈ",
|
| 1479 |
+
"_θ ɪŋ",
|
| 1480 |
+
"ə l",
|
| 1481 |
+
"p ɹ",
|
| 1482 |
+
"ə tʃ",
|
| 1483 |
+
"n t",
|
| 1484 |
+
"_ ɸɯ",
|
| 1485 |
+
"l u",
|
| 1486 |
+
"_ˈ ɔ",
|
| 1487 |
+
"_i ɑʊ",
|
| 1488 |
+
"l ə",
|
| 1489 |
+
"t u",
|
| 1490 |
+
"_dʑ i",
|
| 1491 |
+
"eɪ t",
|
| 1492 |
+
"_ʃi n",
|
| 1493 |
+
"n na",
|
| 1494 |
+
"_ˈp ɹ",
|
| 1495 |
+
"f ən",
|
| 1496 |
+
"_ə p",
|
| 1497 |
+
"n jɛn",
|
| 1498 |
+
"_a ʊt",
|
| 1499 |
+
"f ɔɹ",
|
| 1500 |
+
"_t u",
|
| 1501 |
+
"eɪ ʃən",
|
| 1502 |
+
"ɪ ɫ",
|
| 1503 |
+
"_w ət",
|
| 1504 |
+
"_ɪ f",
|
| 1505 |
+
"_ ɥ",
|
| 1506 |
+
"_f a",
|
| 1507 |
+
"ˈ w",
|
| 1508 |
+
"tʃ ʰjɛn",
|
| 1509 |
+
"_w ɪn",
|
| 1510 |
+
"oʊ ɫd",
|
| 1511 |
+
"_əˈ p",
|
| 1512 |
+
"aʊ nd",
|
| 1513 |
+
"s an",
|
| 1514 |
+
"h e",
|
| 1515 |
+
"_b ɪn",
|
| 1516 |
+
"f a",
|
| 1517 |
+
"ɪ f",
|
| 1518 |
+
"ɔ ŋ",
|
| 1519 |
+
"g e",
|
| 1520 |
+
"_ɪn _ðə",
|
| 1521 |
+
"m iŋ",
|
| 1522 |
+
"_p ɹ",
|
| 1523 |
+
"in a",
|
| 1524 |
+
"an o",
|
| 1525 |
+
"əb əɫ",
|
| 1526 |
+
"k ˈs",
|
| 1527 |
+
"_ˈ ɛni",
|
| 1528 |
+
"n əŋ",
|
| 1529 |
+
"ə d",
|
| 1530 |
+
"_əv _ðə_ˈ",
|
| 1531 |
+
"_w aɪ",
|
| 1532 |
+
"_t aɪm",
|
| 1533 |
+
"ˈs ɛɫ",
|
| 1534 |
+
"ʃi ɛ",
|
| 1535 |
+
"_k əm",
|
| 1536 |
+
"æ st",
|
| 1537 |
+
"_g oʊ",
|
| 1538 |
+
"m ɯ",
|
| 1539 |
+
"ˈ p",
|
| 1540 |
+
"_ˈ st",
|
| 1541 |
+
"ə _t",
|
| 1542 |
+
"p t",
|
| 1543 |
+
"_p ʰ",
|
| 1544 |
+
"ʰ ɹ",
|
| 1545 |
+
"ʃ ja",
|
| 1546 |
+
"i wa",
|
| 1547 |
+
"ɪ l",
|
| 1548 |
+
"b ət",
|
| 1549 |
+
"_f ɑŋ",
|
| 1550 |
+
"h o",
|
| 1551 |
+
"i v",
|
| 1552 |
+
"l oʊ",
|
| 1553 |
+
"b e",
|
| 1554 |
+
"_laɪ k",
|
| 1555 |
+
"ɪ ʃ",
|
| 1556 |
+
"_f u",
|
| 1557 |
+
"z e",
|
| 1558 |
+
"ə _tʃ",
|
| 1559 |
+
"ɑɹ t",
|
| 1560 |
+
"ɔɹ d",
|
| 1561 |
+
"tʃʰi ŋ",
|
| 1562 |
+
"m p",
|
| 1563 |
+
"_ðə _s",
|
| 1564 |
+
"_əˈb aʊt",
|
| 1565 |
+
"_ˈ oʊ",
|
| 1566 |
+
"kʰ ə",
|
| 1567 |
+
"d _tɪ",
|
| 1568 |
+
"ŋ ga",
|
| 1569 |
+
"ə li",
|
| 1570 |
+
"_kʰ an",
|
| 1571 |
+
"ç i",
|
| 1572 |
+
"_ˈ ju",
|
| 1573 |
+
"_k ʊd",
|
| 1574 |
+
"ɔ ɫ",
|
| 1575 |
+
"ɔ t",
|
| 1576 |
+
"_ɪ ts",
|
| 1577 |
+
"_s an",
|
| 1578 |
+
"tʃ a",
|
| 1579 |
+
"i _na",
|
| 1580 |
+
"x ə",
|
| 1581 |
+
"ɛ kt",
|
| 1582 |
+
"_m ɔɹ",
|
| 1583 |
+
"te _kɯ",
|
| 1584 |
+
"ɪd ʒ",
|
| 1585 |
+
"j ʊŋ",
|
| 1586 |
+
"_w an",
|
| 1587 |
+
"æ t",
|
| 1588 |
+
"ka t",
|
| 1589 |
+
"ˈsɛɫ f",
|
| 1590 |
+
"_k e",
|
| 1591 |
+
"aɪ nd",
|
| 1592 |
+
"i t",
|
| 1593 |
+
"_ ɑɹ",
|
| 1594 |
+
"s p",
|
| 1595 |
+
"oʊn t",
|
| 1596 |
+
"_t ʃi",
|
| 1597 |
+
"tsʰ ɹ",
|
| 1598 |
+
"_x ən",
|
| 1599 |
+
"_əˈ g",
|
| 1600 |
+
"ə _k",
|
| 1601 |
+
"to _i",
|
| 1602 |
+
"_t ʰi",
|
| 1603 |
+
"_i ŋ",
|
| 1604 |
+
"aʊ n",
|
| 1605 |
+
"g ɯ",
|
| 1606 |
+
"_ɪ kˈs",
|
| 1607 |
+
"ɛ v",
|
| 1608 |
+
"g i",
|
| 1609 |
+
"k s",
|
| 1610 |
+
"_s əm",
|
| 1611 |
+
"an a",
|
| 1612 |
+
"ɪt əɫ",
|
| 1613 |
+
"n an",
|
| 1614 |
+
"_ˈɪn tu",
|
| 1615 |
+
"_hi ɹ",
|
| 1616 |
+
"_t e",
|
| 1617 |
+
"_n aʊ",
|
| 1618 |
+
"ʃi ɑʊ",
|
| 1619 |
+
"ʃ o",
|
| 1620 |
+
"ɹ e",
|
| 1621 |
+
"x aɪ",
|
| 1622 |
+
"_tʃʰi ŋ",
|
| 1623 |
+
"_s ɹ",
|
| 1624 |
+
"_h aʊ",
|
| 1625 |
+
"? .",
|
| 1626 |
+
"_f eɪ",
|
| 1627 |
+
"li ŋ",
|
| 1628 |
+
"_ʃ ja",
|
| 1629 |
+
"_ˈ dʒ",
|
| 1630 |
+
"_s eɪ",
|
| 1631 |
+
"ˈ n",
|
| 1632 |
+
"s oʊ",
|
| 1633 |
+
"tʰ ʊŋ",
|
| 1634 |
+
"_l joʊ",
|
| 1635 |
+
"m aɪ",
|
| 1636 |
+
"_b ɹ",
|
| 1637 |
+
"ɹeɪ t",
|
| 1638 |
+
"_n əŋ",
|
| 1639 |
+
"ʰ ə",
|
| 1640 |
+
"æn s",
|
| 1641 |
+
"_ˈɔ l",
|
| 1642 |
+
"ta tʃi",
|
| 1643 |
+
"n to",
|
| 1644 |
+
"_ˌɪn ˈ",
|
| 1645 |
+
"l e",
|
| 1646 |
+
"n de",
|
| 1647 |
+
"_ˈv ɛɹi",
|
| 1648 |
+
"mən t",
|
| 1649 |
+
"ɾi ma",
|
| 1650 |
+
"_ð ɛn",
|
| 1651 |
+
"_h əz",
|
| 1652 |
+
"_ɹ i",
|
| 1653 |
+
"f təɹ",
|
| 1654 |
+
"_s p",
|
| 1655 |
+
"ɾe wa",
|
| 1656 |
+
"ga _a",
|
| 1657 |
+
"z _əv",
|
| 1658 |
+
"_m iŋ",
|
| 1659 |
+
"_tɪ _ðə",
|
| 1660 |
+
"ɹ aɪ",
|
| 1661 |
+
"ɛ l",
|
| 1662 |
+
"ɹ æ",
|
| 1663 |
+
"_h oʊ",
|
| 1664 |
+
"x u",
|
| 1665 |
+
"oʊn li",
|
| 1666 |
+
"ŋ k",
|
| 1667 |
+
"i _i",
|
| 1668 |
+
"_d ɪd",
|
| 1669 |
+
"_dʒ ɪst",
|
| 1670 |
+
"in g",
|
| 1671 |
+
"ka i",
|
| 1672 |
+
"_m æn",
|
| 1673 |
+
"_i n",
|
| 1674 |
+
"z o",
|
| 1675 |
+
"ə f",
|
| 1676 |
+
"da ke",
|
| 1677 |
+
"_ˈs əm",
|
| 1678 |
+
"ɾɯ _no",
|
| 1679 |
+
"_g o",
|
| 1680 |
+
"tʃ əɹ",
|
| 1681 |
+
"i te",
|
| 1682 |
+
"`↓ .",
|
| 1683 |
+
"_kʰ aɪ",
|
| 1684 |
+
"s k",
|
| 1685 |
+
"ɔɹ s",
|
| 1686 |
+
"_t ʰiŋ",
|
| 1687 |
+
"_n ə",
|
| 1688 |
+
"p əɫ",
|
| 1689 |
+
"_tɪ _bi",
|
| 1690 |
+
"ˈ fɔɹ",
|
| 1691 |
+
"m u",
|
| 1692 |
+
"s u",
|
| 1693 |
+
"a a",
|
| 1694 |
+
"ɪst əɹ",
|
| 1695 |
+
"ʰ an",
|
| 1696 |
+
"p əɹ",
|
| 1697 |
+
"ə _p",
|
| 1698 |
+
"li ɑŋ",
|
| 1699 |
+
"_ v",
|
| 1700 |
+
"oʊ st",
|
| 1701 |
+
"_əˈg ɛn",
|
| 1702 |
+
"ən z",
|
| 1703 |
+
"N o",
|
| 1704 |
+
"ɔɹ t",
|
| 1705 |
+
"_s əˈ",
|
| 1706 |
+
"_m ɯ",
|
| 1707 |
+
"tʃ ʰ",
|
| 1708 |
+
"_ˈl ɪtəɫ",
|
| 1709 |
+
"_x wo",
|
| 1710 |
+
"_ˌ bi",
|
| 1711 |
+
"_ˈoʊ vəɹ",
|
| 1712 |
+
"_ çi",
|
| 1713 |
+
"_d eɪ",
|
| 1714 |
+
"aɪ n",
|
| 1715 |
+
"_ʃi ŋ",
|
| 1716 |
+
"i _ʃi",
|
| 1717 |
+
"_tsʰ aɪ",
|
| 1718 |
+
"ʃ oo",
|
| 1719 |
+
"ɾ oo",
|
| 1720 |
+
"b əɹ",
|
| 1721 |
+
"ʰ a",
|
| 1722 |
+
"ˈ ɛs",
|
| 1723 |
+
"_ɪn _ðə_ˈ",
|
| 1724 |
+
"N wa",
|
| 1725 |
+
"_ð ən",
|
| 1726 |
+
"s aɪ",
|
| 1727 |
+
"_ˈju ˈɛs",
|
| 1728 |
+
"n da",
|
| 1729 |
+
"_p leɪ",
|
| 1730 |
+
"ɪŋ _tɪ",
|
| 1731 |
+
"ɪt i",
|
| 1732 |
+
"_m e",
|
| 1733 |
+
"_ʃ ʊd",
|
| 1734 |
+
"_n u",
|
| 1735 |
+
"_ðə _k",
|
| 1736 |
+
"z a",
|
| 1737 |
+
"_ˈ ɛvəɹ",
|
| 1738 |
+
"əɹ n",
|
| 1739 |
+
"æ d",
|
| 1740 |
+
"ˈ m",
|
| 1741 |
+
"_d oʊnt",
|
| 1742 |
+
"_m əst",
|
| 1743 |
+
"j ɯɯ",
|
| 1744 |
+
"ɑɹ d",
|
| 1745 |
+
"_ jɛn",
|
| 1746 |
+
"ʃ ɥ",
|
| 1747 |
+
"_ˈ oʊnli",
|
| 1748 |
+
"_ʃ o",
|
| 1749 |
+
"_l iŋ",
|
| 1750 |
+
"s s",
|
| 1751 |
+
"ɑ l",
|
| 1752 |
+
"de a",
|
| 1753 |
+
"ɾe ta",
|
| 1754 |
+
"m jɛn",
|
| 1755 |
+
"_g ʊd",
|
| 1756 |
+
"_w ɔ",
|
| 1757 |
+
"i mo",
|
| 1758 |
+
"no _ko",
|
| 1759 |
+
"_ ɥæn",
|
| 1760 |
+
"nd ʒ",
|
| 1761 |
+
"ɪ ʃən",
|
| 1762 |
+
"o _ʃi",
|
| 1763 |
+
"_θɪŋ k",
|
| 1764 |
+
"_n an",
|
| 1765 |
+
"to _o",
|
| 1766 |
+
"_tʰ ʊŋ",
|
| 1767 |
+
"l joʊ",
|
| 1768 |
+
"ta i",
|
| 1769 |
+
"mə _s",
|
| 1770 |
+
"_j ɯ",
|
| 1771 |
+
"_ uɑŋ",
|
| 1772 |
+
"_ˌbi ˈfɔɹ",
|
| 1773 |
+
"æ s",
|
| 1774 |
+
"_tʃ ʰjɛn",
|
| 1775 |
+
"i k",
|
| 1776 |
+
"_b æk",
|
| 1777 |
+
"_ˈ iv",
|
| 1778 |
+
"eɪ n",
|
| 1779 |
+
"u n",
|
| 1780 |
+
"l a",
|
| 1781 |
+
"ˈ k",
|
| 1782 |
+
"_d aʊn",
|
| 1783 |
+
"an ai",
|
| 1784 |
+
"_l ɛ",
|
| 1785 |
+
"əɹ t",
|
| 1786 |
+
"ð ɛɹ",
|
| 1787 |
+
"_ˈæ ftəɹ",
|
| 1788 |
+
"da t",
|
| 1789 |
+
"f an",
|
| 1790 |
+
"b əɫ",
|
| 1791 |
+
"te mo",
|
| 1792 |
+
"tʰ a",
|
| 1793 |
+
"ɾɯ _ko",
|
| 1794 |
+
"ˈ v",
|
| 1795 |
+
"f eɪ",
|
| 1796 |
+
"_m ətʃ",
|
| 1797 |
+
"x wo",
|
| 1798 |
+
"ɹ oʊ",
|
| 1799 |
+
"_b a",
|
| 1800 |
+
"_ˈn ɛvəɹ",
|
| 1801 |
+
"_meɪ d",
|
| 1802 |
+
"_j ʊŋ",
|
| 1803 |
+
"_əˈp ɑn",
|
| 1804 |
+
"! ?",
|
| 1805 |
+
"_ˈ ʃ",
|
| 1806 |
+
"_ðə_ˈ k",
|
| 1807 |
+
"f t",
|
| 1808 |
+
"_b o",
|
| 1809 |
+
"_ɪn _ə",
|
| 1810 |
+
"tʃʰɥ æn",
|
| 1811 |
+
"ˈ z",
|
| 1812 |
+
"`↓ ,",
|
| 1813 |
+
"_bɪˈ k",
|
| 1814 |
+
"ɪ g",
|
| 1815 |
+
"k in",
|
| 1816 |
+
"_k l",
|
| 1817 |
+
"ɾɯ _n",
|
| 1818 |
+
"_l ɑʊ",
|
| 1819 |
+
"-- --",
|
| 1820 |
+
"i ka",
|
| 1821 |
+
"_ɹ aɪt",
|
| 1822 |
+
"z d",
|
| 1823 |
+
"z _ənd",
|
| 1824 |
+
"_k jo",
|
| 1825 |
+
"x wan",
|
| 1826 |
+
"to o",
|
| 1827 |
+
"_g ɪt",
|
| 1828 |
+
"_l iɑŋ",
|
| 1829 |
+
"ta _n",
|
| 1830 |
+
"_k eɪm",
|
| 1831 |
+
"_ˈ əðəɹ",
|
| 1832 |
+
"_w ɛɫ",
|
| 1833 |
+
"te ki",
|
| 1834 |
+
"se e",
|
| 1835 |
+
"j ɯ",
|
| 1836 |
+
"i _o",
|
| 1837 |
+
"to _ʃi",
|
| 1838 |
+
"f əɫ",
|
| 1839 |
+
"b o",
|
| 1840 |
+
"ˌ t",
|
| 1841 |
+
"ɪ p",
|
| 1842 |
+
"an e",
|
| 1843 |
+
"_tʰ jɛn",
|
| 1844 |
+
"_tʃ o",
|
| 1845 |
+
"ɾ jo",
|
| 1846 |
+
"ɪn s",
|
| 1847 |
+
"_h e",
|
| 1848 |
+
"ŋ ka",
|
| 1849 |
+
"ʃ ɥɛ",
|
| 1850 |
+
"dʑ a",
|
| 1851 |
+
"v d",
|
| 1852 |
+
"ʰ wan",
|
| 1853 |
+
"_g ɹeɪt",
|
| 1854 |
+
"_əv _ə",
|
| 1855 |
+
"ənd əɹ",
|
| 1856 |
+
"ke do",
|
| 1857 |
+
"_ðə _b",
|
| 1858 |
+
"ə k",
|
| 1859 |
+
"_t eɪk",
|
| 1860 |
+
"kʰ an",
|
| 1861 |
+
"_ˈɔl ˌ",
|
| 1862 |
+
"s wo",
|
| 1863 |
+
"_ɪt _wɑz",
|
| 1864 |
+
"_ʃ ɥ",
|
| 1865 |
+
"_si m",
|
| 1866 |
+
"_ˈf ɑ",
|
| 1867 |
+
"m in",
|
| 1868 |
+
"i _a",
|
| 1869 |
+
"s oo",
|
| 1870 |
+
"ɛn s",
|
| 1871 |
+
"_s ətʃ",
|
| 1872 |
+
"tʰ aɪ",
|
| 1873 |
+
"_ ga",
|
| 1874 |
+
"i _ka",
|
| 1875 |
+
"k oo",
|
| 1876 |
+
"_fəɹ st",
|
| 1877 |
+
"_ˈ tʃ",
|
| 1878 |
+
"n no",
|
| 1879 |
+
"ə _ɹ",
|
| 1880 |
+
"ta ɾa",
|
| 1881 |
+
"tʃʰ joʊ",
|
| 1882 |
+
"_æ m",
|
| 1883 |
+
"_m u",
|
| 1884 |
+
"_meɪ k",
|
| 1885 |
+
"↓ …",
|
| 1886 |
+
"ɪˈ θ",
|
| 1887 |
+
"ɑ b",
|
| 1888 |
+
"ɹ a",
|
| 1889 |
+
"_w ɛɹ",
|
| 1890 |
+
"_ðə_ˈ s",
|
| 1891 |
+
"_əˈ l",
|
| 1892 |
+
"_ oʊɫd",
|
| 1893 |
+
"æ l",
|
| 1894 |
+
"_ˈp i",
|
| 1895 |
+
"_l ɔŋ",
|
| 1896 |
+
"dʑ o",
|
| 1897 |
+
"_tʰ aɪ",
|
| 1898 |
+
"ɔɹ n",
|
| 1899 |
+
"əɫ z",
|
| 1900 |
+
"_t əˈ",
|
| 1901 |
+
"_əˈ weɪ",
|
| 1902 |
+
"p a",
|
| 1903 |
+
"_ð iz",
|
| 1904 |
+
"_ˈs p",
|
| 1905 |
+
"n n",
|
| 1906 |
+
"ma e",
|
| 1907 |
+
"to wa",
|
| 1908 |
+
"ta _no",
|
| 1909 |
+
"_ an",
|
| 1910 |
+
"kʰ aɪ",
|
| 1911 |
+
"ɾa ɾe",
|
| 1912 |
+
"eɪ s",
|
| 1913 |
+
"ɑ d",
|
| 1914 |
+
"_w ɪˈθ",
|
| 1915 |
+
"_ˈiv ɪn",
|
| 1916 |
+
"_l u",
|
| 1917 |
+
"ɔ ɪ",
|
| 1918 |
+
"l ɪŋ",
|
| 1919 |
+
"ət i",
|
| 1920 |
+
"_ðə _f",
|
| 1921 |
+
"o ʃi",
|
| 1922 |
+
"_l a",
|
| 1923 |
+
"s i",
|
| 1924 |
+
"t ɪd",
|
| 1925 |
+
"h aʊ",
|
| 1926 |
+
"pʰ in",
|
| 1927 |
+
"ˈ st",
|
| 1928 |
+
"_ˈp əɹ",
|
| 1929 |
+
"e ɹ",
|
| 1930 |
+
"* !",
|
| 1931 |
+
"_ˈm ɪstəɹ",
|
| 1932 |
+
"ʃ a",
|
| 1933 |
+
"_ˌ ɪm",
|
| 1934 |
+
"ˌ θɪŋ",
|
| 1935 |
+
"_n eɪ",
|
| 1936 |
+
"_n ɥ",
|
| 1937 |
+
"ɑ k",
|
| 1938 |
+
"_ɹ u",
|
| 1939 |
+
"_ʃ ɯ",
|
| 1940 |
+
"_ðə_ˈ m",
|
| 1941 |
+
"de mo",
|
| 1942 |
+
"_d ɹ",
|
| 1943 |
+
"dʑ oo",
|
| 1944 |
+
"_st ɪɫ",
|
| 1945 |
+
"_p ʰiŋ",
|
| 1946 |
+
"ə _i",
|
| 1947 |
+
"_ɪkˈs p",
|
| 1948 |
+
"_w ɛnt",
|
| 1949 |
+
"ɪ ɹi",
|
| 1950 |
+
"əˈ m",
|
| 1951 |
+
"o _ka",
|
| 1952 |
+
"_əˈ k",
|
| 1953 |
+
"ɔ k",
|
| 1954 |
+
"_ ɥɛ",
|
| 1955 |
+
"_l ʊk",
|
| 1956 |
+
"ˈ d",
|
| 1957 |
+
"ka ʃi",
|
| 1958 |
+
"_wɪθ _ə",
|
| 1959 |
+
"l jɛn",
|
| 1960 |
+
"ɔ n",
|
| 1961 |
+
"_l jɛn",
|
| 1962 |
+
"_h ɛɫ",
|
| 1963 |
+
"u ɹ",
|
| 1964 |
+
"_tʰ oʊ",
|
| 1965 |
+
"_tʃʰɥ æn",
|
| 1966 |
+
"_s k",
|
| 1967 |
+
"tsʰ aɪ",
|
| 1968 |
+
"ɛ təɹ",
|
| 1969 |
+
"_m in",
|
| 1970 |
+
"n oʊ",
|
| 1971 |
+
"ʃ ɯ",
|
| 1972 |
+
"_θ ɹu",
|
| 1973 |
+
"_θ ɔt",
|
| 1974 |
+
"da jo",
|
| 1975 |
+
"w i",
|
| 1976 |
+
"i _ko",
|
| 1977 |
+
"_t ɹ",
|
| 1978 |
+
"_f an",
|
| 1979 |
+
"ɹ ɛ",
|
| 1980 |
+
"sa N",
|
| 1981 |
+
"_hi _wɑz",
|
| 1982 |
+
"_ ɾe",
|
| 1983 |
+
"_ə m",
|
| 1984 |
+
"te _ki",
|
| 1985 |
+
"_x oʊ",
|
| 1986 |
+
"ˈ l",
|
| 1987 |
+
"ˈ g",
|
| 1988 |
+
"ga _i",
|
| 1989 |
+
"_ɔn _ðə",
|
| 1990 |
+
"_x wa",
|
| 1991 |
+
"v ɪŋ",
|
| 1992 |
+
"m an",
|
| 1993 |
+
"f əɹ",
|
| 1994 |
+
"_ oʊn",
|
| 1995 |
+
"ˈ ɹ",
|
| 1996 |
+
"_k ɹ",
|
| 1997 |
+
"te _o",
|
| 1998 |
+
"ɪ li",
|
| 1999 |
+
"_ʃ ɥɛ",
|
| 2000 |
+
"_f əŋ",
|
| 2001 |
+
"æ ɫ",
|
| 2002 |
+
"ɑ p",
|
| 2003 |
+
"_ˈ ɛv",
|
| 2004 |
+
"eɪ ndʒ",
|
| 2005 |
+
"i ɫ",
|
| 2006 |
+
"w ət",
|
| 2007 |
+
"ɛ ðəɹ",
|
| 2008 |
+
"_f ən",
|
| 2009 |
+
"ɾe e",
|
| 2010 |
+
"_hi _hæd",
|
| 2011 |
+
"_maɪ t",
|
| 2012 |
+
"_g e",
|
| 2013 |
+
"æ kt",
|
| 2014 |
+
"ɪ ts",
|
| 2015 |
+
"_h ɪm",
|
| 2016 |
+
"_ ze",
|
| 2017 |
+
"i i",
|
| 2018 |
+
"_ N",
|
| 2019 |
+
"_əv _hɪz",
|
| 2020 |
+
"_g ɹ",
|
| 2021 |
+
"æn t",
|
| 2022 |
+
"ɪ ˌ",
|
| 2023 |
+
"_hɪm ˈsɛɫf",
|
| 2024 |
+
"wa _na",
|
| 2025 |
+
"aɪ əɹ",
|
| 2026 |
+
"dʑ anai",
|
| 2027 |
+
"kan a",
|
| 2028 |
+
"aɪ z",
|
| 2029 |
+
"_ɪt _ɪz",
|
| 2030 |
+
"ma se",
|
| 2031 |
+
"w ɪn",
|
| 2032 |
+
"ə θɪŋ",
|
| 2033 |
+
"_pɹ əˈ",
|
| 2034 |
+
"kɯ n",
|
| 2035 |
+
"ˈ ju",
|
| 2036 |
+
"_f ɔɹ",
|
| 2037 |
+
"p ʰi",
|
| 2038 |
+
"p ʰiŋ",
|
| 2039 |
+
"o _i",
|
| 2040 |
+
"v z",
|
| 2041 |
+
"ɔ ɪn",
|
| 2042 |
+
"t ʰiŋ",
|
| 2043 |
+
"_n e",
|
| 2044 |
+
"g əɹ",
|
| 2045 |
+
"æ ts",
|
| 2046 |
+
"_ˈ ɹi"
|
| 2047 |
+
]
|
| 2048 |
+
}
|
| 2049 |
+
}
|
apps/audio_cloning/vallex/g2p/bpe_69.json
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"version": "1.0",
|
| 3 |
+
"truncation": null,
|
| 4 |
+
"padding": null,
|
| 5 |
+
"added_tokens": [
|
| 6 |
+
{
|
| 7 |
+
"id": 0,
|
| 8 |
+
"content": "[UNK]",
|
| 9 |
+
"single_word": false,
|
| 10 |
+
"lstrip": false,
|
| 11 |
+
"rstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"special": true
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"id": 1,
|
| 17 |
+
"content": "[CLS]",
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"normalized": false,
|
| 22 |
+
"special": true
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"id": 2,
|
| 26 |
+
"content": "[SEP]",
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"lstrip": false,
|
| 29 |
+
"rstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"special": true
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"id": 3,
|
| 35 |
+
"content": "[PAD]",
|
| 36 |
+
"single_word": false,
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"rstrip": false,
|
| 39 |
+
"normalized": false,
|
| 40 |
+
"special": true
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"id": 4,
|
| 44 |
+
"content": "[MASK]",
|
| 45 |
+
"single_word": false,
|
| 46 |
+
"lstrip": false,
|
| 47 |
+
"rstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"special": true
|
| 50 |
+
}
|
| 51 |
+
],
|
| 52 |
+
"normalizer": null,
|
| 53 |
+
"pre_tokenizer": {
|
| 54 |
+
"type": "Whitespace"
|
| 55 |
+
},
|
| 56 |
+
"post_processor": null,
|
| 57 |
+
"decoder": null,
|
| 58 |
+
"model": {
|
| 59 |
+
"type": "BPE",
|
| 60 |
+
"dropout": null,
|
| 61 |
+
"unk_token": "[UNK]",
|
| 62 |
+
"continuing_subword_prefix": null,
|
| 63 |
+
"end_of_word_suffix": null,
|
| 64 |
+
"fuse_unk": false,
|
| 65 |
+
"byte_fallback": false,
|
| 66 |
+
"vocab": {
|
| 67 |
+
"[UNK]": 0,
|
| 68 |
+
"[CLS]": 1,
|
| 69 |
+
"[SEP]": 2,
|
| 70 |
+
"[PAD]": 3,
|
| 71 |
+
"[MASK]": 4,
|
| 72 |
+
"!": 5,
|
| 73 |
+
"#": 6,
|
| 74 |
+
"*": 7,
|
| 75 |
+
",": 8,
|
| 76 |
+
"-": 9,
|
| 77 |
+
".": 10,
|
| 78 |
+
"=": 11,
|
| 79 |
+
"?": 12,
|
| 80 |
+
"N": 13,
|
| 81 |
+
"Q": 14,
|
| 82 |
+
"^": 15,
|
| 83 |
+
"_": 16,
|
| 84 |
+
"`": 17,
|
| 85 |
+
"a": 18,
|
| 86 |
+
"b": 19,
|
| 87 |
+
"d": 20,
|
| 88 |
+
"e": 21,
|
| 89 |
+
"f": 22,
|
| 90 |
+
"g": 23,
|
| 91 |
+
"h": 24,
|
| 92 |
+
"i": 25,
|
| 93 |
+
"j": 26,
|
| 94 |
+
"k": 27,
|
| 95 |
+
"l": 28,
|
| 96 |
+
"m": 29,
|
| 97 |
+
"n": 30,
|
| 98 |
+
"o": 31,
|
| 99 |
+
"p": 32,
|
| 100 |
+
"s": 33,
|
| 101 |
+
"t": 34,
|
| 102 |
+
"u": 35,
|
| 103 |
+
"v": 36,
|
| 104 |
+
"w": 37,
|
| 105 |
+
"x": 38,
|
| 106 |
+
"y": 39,
|
| 107 |
+
"z": 40,
|
| 108 |
+
"~": 41,
|
| 109 |
+
"æ": 42,
|
| 110 |
+
"ç": 43,
|
| 111 |
+
"ð": 44,
|
| 112 |
+
"ŋ": 45,
|
| 113 |
+
"ɑ": 46,
|
| 114 |
+
"ɔ": 47,
|
| 115 |
+
"ə": 48,
|
| 116 |
+
"ɛ": 49,
|
| 117 |
+
"ɥ": 50,
|
| 118 |
+
"ɪ": 51,
|
| 119 |
+
"ɫ": 52,
|
| 120 |
+
"ɯ": 53,
|
| 121 |
+
"ɸ": 54,
|
| 122 |
+
"ɹ": 55,
|
| 123 |
+
"ɾ": 56,
|
| 124 |
+
"ʃ": 57,
|
| 125 |
+
"ʊ": 58,
|
| 126 |
+
"ʑ": 59,
|
| 127 |
+
"ʒ": 60,
|
| 128 |
+
"ʰ": 61,
|
| 129 |
+
"ˈ": 62,
|
| 130 |
+
"ˌ": 63,
|
| 131 |
+
"θ": 64,
|
| 132 |
+
"…": 65,
|
| 133 |
+
"⁼": 66,
|
| 134 |
+
"↑": 67,
|
| 135 |
+
"→": 68,
|
| 136 |
+
"↓": 69
|
| 137 |
+
},
|
| 138 |
+
"merges": [
|
| 139 |
+
]
|
| 140 |
+
}
|
| 141 |
+
}
|
apps/audio_cloning/vallex/g2p/cleaners.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
from .english import english_to_ipa2
|
| 4 |
+
from .japanese import japanese_to_ipa2, japanese_to_romaji_with_accent
|
| 5 |
+
from .mandarin import (
|
| 6 |
+
chinese_to_bopomofo,
|
| 7 |
+
chinese_to_ipa,
|
| 8 |
+
latin_to_bopomofo,
|
| 9 |
+
number_to_chinese,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
patterns = [r"\[EN\](.*?)\[EN\]", r"\[ZH\](.*?)\[ZH\]", r"\[JA\](.*?)\[JA\]"]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def japanese_cleaners(text):
|
| 16 |
+
text = japanese_to_romaji_with_accent(text)
|
| 17 |
+
text = re.sub(r"([A-Za-z])$", r"\1.", text)
|
| 18 |
+
return text
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def japanese_cleaners2(text):
|
| 22 |
+
return japanese_cleaners(text).replace("ts", "ʦ").replace("...", "…")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def chinese_cleaners(text):
|
| 26 |
+
"""Pipeline for Chinese text"""
|
| 27 |
+
text = number_to_chinese(text)
|
| 28 |
+
text = chinese_to_bopomofo(text)
|
| 29 |
+
text = latin_to_bopomofo(text)
|
| 30 |
+
text = re.sub(r"([ˉˊˇˋ˙])$", r"\1。", text)
|
| 31 |
+
return text
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def cje_cleaners(text):
|
| 35 |
+
matches = []
|
| 36 |
+
for pattern in patterns:
|
| 37 |
+
matches.extend(re.finditer(pattern, text))
|
| 38 |
+
|
| 39 |
+
matches.sort(key=lambda x: x.start()) # Sort matches by their start positions
|
| 40 |
+
|
| 41 |
+
outputs = ""
|
| 42 |
+
output_langs = []
|
| 43 |
+
|
| 44 |
+
for match in matches:
|
| 45 |
+
text_segment = text[match.start() : match.end()]
|
| 46 |
+
phon = clean_one(text_segment)
|
| 47 |
+
if "[EN]" in text_segment:
|
| 48 |
+
lang = "en"
|
| 49 |
+
elif "[ZH]" in text_segment:
|
| 50 |
+
lang = "zh"
|
| 51 |
+
elif "[JA]" in text_segment:
|
| 52 |
+
lang = "ja"
|
| 53 |
+
else:
|
| 54 |
+
raise ValueError("If you see this error, please report this bug to issues.")
|
| 55 |
+
outputs += phon
|
| 56 |
+
output_langs += [lang] * len(phon)
|
| 57 |
+
assert len(outputs) == len(output_langs)
|
| 58 |
+
return outputs, output_langs
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def clean_one(text):
|
| 62 |
+
if text.find("[ZH]") != -1:
|
| 63 |
+
text = re.sub(
|
| 64 |
+
r"\[ZH\](.*?)\[ZH\]", lambda x: chinese_to_ipa(x.group(1)) + " ", text
|
| 65 |
+
)
|
| 66 |
+
if text.find("[JA]") != -1:
|
| 67 |
+
text = re.sub(
|
| 68 |
+
r"\[JA\](.*?)\[JA\]", lambda x: japanese_to_ipa2(x.group(1)) + " ", text
|
| 69 |
+
)
|
| 70 |
+
if text.find("[EN]") != -1:
|
| 71 |
+
text = re.sub(
|
| 72 |
+
r"\[EN\](.*?)\[EN\]", lambda x: english_to_ipa2(x.group(1)) + " ", text
|
| 73 |
+
)
|
| 74 |
+
text = re.sub(r"\s+$", "", text)
|
| 75 |
+
text = re.sub(r"([^\.,!\?\-…~])$", r"\1.", text)
|
| 76 |
+
return text
|
apps/audio_cloning/vallex/g2p/english.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
import inflect
|
| 4 |
+
from unidecode import unidecode
|
| 5 |
+
|
| 6 |
+
"""from https://github.com/keithito/tacotron"""
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Cleaners are transformations that run over the input text at both training and eval time.
|
| 10 |
+
|
| 11 |
+
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
| 12 |
+
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
|
| 13 |
+
1. "english_cleaners" for English text
|
| 14 |
+
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
|
| 15 |
+
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
| 16 |
+
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
| 17 |
+
the symbols in symbols.py to match your data).
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
_inflect = inflect.engine()
|
| 22 |
+
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
| 23 |
+
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
| 24 |
+
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
| 25 |
+
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
| 26 |
+
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
| 27 |
+
_number_re = re.compile(r"[0-9]+")
|
| 28 |
+
|
| 29 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
| 30 |
+
_abbreviations = [
|
| 31 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
| 32 |
+
for x in [
|
| 33 |
+
("mrs", "misess"),
|
| 34 |
+
("mr", "mister"),
|
| 35 |
+
("dr", "doctor"),
|
| 36 |
+
("st", "saint"),
|
| 37 |
+
("co", "company"),
|
| 38 |
+
("jr", "junior"),
|
| 39 |
+
("maj", "major"),
|
| 40 |
+
("gen", "general"),
|
| 41 |
+
("drs", "doctors"),
|
| 42 |
+
("rev", "reverend"),
|
| 43 |
+
("lt", "lieutenant"),
|
| 44 |
+
("hon", "honorable"),
|
| 45 |
+
("sgt", "sergeant"),
|
| 46 |
+
("capt", "captain"),
|
| 47 |
+
("esq", "esquire"),
|
| 48 |
+
("ltd", "limited"),
|
| 49 |
+
("col", "colonel"),
|
| 50 |
+
("ft", "fort"),
|
| 51 |
+
]
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# List of (ipa, lazy ipa) pairs:
|
| 56 |
+
_lazy_ipa = [
|
| 57 |
+
(re.compile("%s" % x[0]), x[1])
|
| 58 |
+
for x in [
|
| 59 |
+
("r", "ɹ"),
|
| 60 |
+
("æ", "e"),
|
| 61 |
+
("ɑ", "a"),
|
| 62 |
+
("ɔ", "o"),
|
| 63 |
+
("ð", "z"),
|
| 64 |
+
("θ", "s"),
|
| 65 |
+
("ɛ", "e"),
|
| 66 |
+
("ɪ", "i"),
|
| 67 |
+
("ʊ", "u"),
|
| 68 |
+
("ʒ", "ʥ"),
|
| 69 |
+
("ʤ", "ʥ"),
|
| 70 |
+
("ˈ", "↓"),
|
| 71 |
+
]
|
| 72 |
+
]
|
| 73 |
+
|
| 74 |
+
# List of (ipa, lazy ipa2) pairs:
|
| 75 |
+
_lazy_ipa2 = [
|
| 76 |
+
(re.compile("%s" % x[0]), x[1])
|
| 77 |
+
for x in [
|
| 78 |
+
("r", "ɹ"),
|
| 79 |
+
("ð", "z"),
|
| 80 |
+
("θ", "s"),
|
| 81 |
+
("ʒ", "ʑ"),
|
| 82 |
+
("ʤ", "dʑ"),
|
| 83 |
+
("ˈ", "↓"),
|
| 84 |
+
]
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
# List of (ipa, ipa2) pairs
|
| 88 |
+
_ipa_to_ipa2 = [
|
| 89 |
+
(re.compile("%s" % x[0]), x[1]) for x in [("r", "ɹ"), ("ʤ", "dʒ"), ("ʧ", "tʃ")]
|
| 90 |
+
]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def expand_abbreviations(text):
|
| 94 |
+
for regex, replacement in _abbreviations:
|
| 95 |
+
text = re.sub(regex, replacement, text)
|
| 96 |
+
return text
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def collapse_whitespace(text):
|
| 100 |
+
return re.sub(r"\s+", " ", text)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _remove_commas(m):
|
| 104 |
+
return m.group(1).replace(",", "")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _expand_decimal_point(m):
|
| 108 |
+
return m.group(1).replace(".", " point ")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _expand_dollars(m):
|
| 112 |
+
match = m.group(1)
|
| 113 |
+
parts = match.split(".")
|
| 114 |
+
if len(parts) > 2:
|
| 115 |
+
return match + " dollars" # Unexpected format
|
| 116 |
+
dollars = int(parts[0]) if parts[0] else 0
|
| 117 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
| 118 |
+
if dollars and cents:
|
| 119 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
| 120 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
| 121 |
+
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
| 122 |
+
elif dollars:
|
| 123 |
+
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
| 124 |
+
return "%s %s" % (dollars, dollar_unit)
|
| 125 |
+
elif cents:
|
| 126 |
+
cent_unit = "cent" if cents == 1 else "cents"
|
| 127 |
+
return "%s %s" % (cents, cent_unit)
|
| 128 |
+
else:
|
| 129 |
+
return "zero dollars"
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _expand_ordinal(m):
|
| 133 |
+
return _inflect.number_to_words(m.group(0))
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _expand_number(m):
|
| 137 |
+
num = int(m.group(0))
|
| 138 |
+
if num > 1000 and num < 3000:
|
| 139 |
+
if num == 2000:
|
| 140 |
+
return "two thousand"
|
| 141 |
+
elif num > 2000 and num < 2010:
|
| 142 |
+
return "two thousand " + _inflect.number_to_words(num % 100)
|
| 143 |
+
elif num % 100 == 0:
|
| 144 |
+
return _inflect.number_to_words(num // 100) + " hundred"
|
| 145 |
+
else:
|
| 146 |
+
return _inflect.number_to_words(
|
| 147 |
+
num, andword="", zero="oh", group=2
|
| 148 |
+
).replace(", ", " ")
|
| 149 |
+
else:
|
| 150 |
+
return _inflect.number_to_words(num, andword="")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def normalize_numbers(text):
|
| 154 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
| 155 |
+
text = re.sub(_pounds_re, r"\1 pounds", text)
|
| 156 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
| 157 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
| 158 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
| 159 |
+
text = re.sub(_number_re, _expand_number, text)
|
| 160 |
+
return text
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def mark_dark_l(text):
|
| 164 |
+
return re.sub(r"l([^aeiouæɑɔəɛɪʊ ]*(?: |$))", lambda x: "ɫ" + x.group(1), text)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def english_to_ipa(text):
|
| 168 |
+
import eng_to_ipa as ipa
|
| 169 |
+
|
| 170 |
+
text = unidecode(text).lower()
|
| 171 |
+
text = expand_abbreviations(text)
|
| 172 |
+
text = normalize_numbers(text)
|
| 173 |
+
phonemes = ipa.convert(text)
|
| 174 |
+
phonemes = collapse_whitespace(phonemes)
|
| 175 |
+
return phonemes
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def english_to_lazy_ipa(text):
|
| 179 |
+
text = english_to_ipa(text)
|
| 180 |
+
for regex, replacement in _lazy_ipa:
|
| 181 |
+
text = re.sub(regex, replacement, text)
|
| 182 |
+
return text
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def english_to_ipa2(text):
|
| 186 |
+
text = english_to_ipa(text)
|
| 187 |
+
text = mark_dark_l(text)
|
| 188 |
+
for regex, replacement in _ipa_to_ipa2:
|
| 189 |
+
text = re.sub(regex, replacement, text)
|
| 190 |
+
return text.replace("...", "…")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def english_to_lazy_ipa2(text):
|
| 194 |
+
text = english_to_ipa(text)
|
| 195 |
+
for regex, replacement in _lazy_ipa2:
|
| 196 |
+
text = re.sub(regex, replacement, text)
|
| 197 |
+
return text
|
apps/audio_cloning/vallex/g2p/japanese.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
from unidecode import unidecode
|
| 4 |
+
|
| 5 |
+
# Regular expression matching Japanese without punctuation marks:
|
| 6 |
+
_japanese_characters = re.compile(
|
| 7 |
+
r"[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
# Regular expression matching non-Japanese characters or punctuation marks:
|
| 11 |
+
_japanese_marks = re.compile(
|
| 12 |
+
r"[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
# List of (symbol, Japanese) pairs for marks:
|
| 16 |
+
_symbols_to_japanese = [(re.compile("%s" % x[0]), x[1]) for x in [("%", "パーセント")]]
|
| 17 |
+
|
| 18 |
+
# List of (romaji, ipa) pairs for marks:
|
| 19 |
+
_romaji_to_ipa = [
|
| 20 |
+
(re.compile("%s" % x[0]), x[1])
|
| 21 |
+
for x in [
|
| 22 |
+
("ts", "ʦ"),
|
| 23 |
+
("u", "ɯ"),
|
| 24 |
+
("j", "ʥ"),
|
| 25 |
+
("y", "j"),
|
| 26 |
+
("ni", "n^i"),
|
| 27 |
+
("nj", "n^"),
|
| 28 |
+
("hi", "çi"),
|
| 29 |
+
("hj", "ç"),
|
| 30 |
+
("f", "ɸ"),
|
| 31 |
+
("I", "i*"),
|
| 32 |
+
("U", "ɯ*"),
|
| 33 |
+
("r", "ɾ"),
|
| 34 |
+
]
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
# List of (romaji, ipa2) pairs for marks:
|
| 38 |
+
_romaji_to_ipa2 = [
|
| 39 |
+
(re.compile("%s" % x[0]), x[1])
|
| 40 |
+
for x in [
|
| 41 |
+
("u", "ɯ"),
|
| 42 |
+
("ʧ", "tʃ"),
|
| 43 |
+
("j", "dʑ"),
|
| 44 |
+
("y", "j"),
|
| 45 |
+
("ni", "n^i"),
|
| 46 |
+
("nj", "n^"),
|
| 47 |
+
("hi", "çi"),
|
| 48 |
+
("hj", "ç"),
|
| 49 |
+
("f", "ɸ"),
|
| 50 |
+
("I", "i*"),
|
| 51 |
+
("U", "ɯ*"),
|
| 52 |
+
("r", "ɾ"),
|
| 53 |
+
]
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
# List of (consonant, sokuon) pairs:
|
| 57 |
+
_real_sokuon = [
|
| 58 |
+
(re.compile("%s" % x[0]), x[1])
|
| 59 |
+
for x in [
|
| 60 |
+
(r"Q([↑↓]*[kg])", r"k#\1"),
|
| 61 |
+
(r"Q([↑↓]*[tdjʧ])", r"t#\1"),
|
| 62 |
+
(r"Q([↑↓]*[sʃ])", r"s\1"),
|
| 63 |
+
(r"Q([↑↓]*[pb])", r"p#\1"),
|
| 64 |
+
]
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
# List of (consonant, hatsuon) pairs:
|
| 68 |
+
_real_hatsuon = [
|
| 69 |
+
(re.compile("%s" % x[0]), x[1])
|
| 70 |
+
for x in [
|
| 71 |
+
(r"N([↑↓]*[pbm])", r"m\1"),
|
| 72 |
+
(r"N([↑↓]*[ʧʥj])", r"n^\1"),
|
| 73 |
+
(r"N([↑↓]*[tdn])", r"n\1"),
|
| 74 |
+
(r"N([↑↓]*[kg])", r"ŋ\1"),
|
| 75 |
+
]
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def symbols_to_japanese(text):
|
| 80 |
+
for regex, replacement in _symbols_to_japanese:
|
| 81 |
+
text = re.sub(regex, replacement, text)
|
| 82 |
+
return text
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def japanese_to_romaji_with_accent(text):
|
| 86 |
+
"""Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html"""
|
| 87 |
+
import pyopenjtalk
|
| 88 |
+
|
| 89 |
+
text = symbols_to_japanese(text)
|
| 90 |
+
sentences = re.split(_japanese_marks, text)
|
| 91 |
+
marks = re.findall(_japanese_marks, text)
|
| 92 |
+
text = ""
|
| 93 |
+
for i, sentence in enumerate(sentences):
|
| 94 |
+
if re.match(_japanese_characters, sentence):
|
| 95 |
+
if text != "":
|
| 96 |
+
text += " "
|
| 97 |
+
labels = pyopenjtalk.extract_fullcontext(sentence)
|
| 98 |
+
for n, label in enumerate(labels):
|
| 99 |
+
phoneme = re.search(r"\-([^\+]*)\+", label).group(1)
|
| 100 |
+
if phoneme not in ["sil", "pau"]:
|
| 101 |
+
text += (
|
| 102 |
+
phoneme.replace("ch", "ʧ").replace("sh", "ʃ").replace("cl", "Q")
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
continue
|
| 106 |
+
# n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
|
| 107 |
+
a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
|
| 108 |
+
a2 = int(re.search(r"\+(\d+)\+", label).group(1))
|
| 109 |
+
a3 = int(re.search(r"\+(\d+)/", label).group(1))
|
| 110 |
+
if re.search(r"\-([^\+]*)\+", labels[n + 1]).group(1) in ["sil", "pau"]:
|
| 111 |
+
a2_next = -1
|
| 112 |
+
else:
|
| 113 |
+
a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
|
| 114 |
+
# Accent phrase boundary
|
| 115 |
+
if a3 == 1 and a2_next == 1:
|
| 116 |
+
text += " "
|
| 117 |
+
# Falling
|
| 118 |
+
elif a1 == 0 and a2_next == a2 + 1:
|
| 119 |
+
text += "↓"
|
| 120 |
+
# Rising
|
| 121 |
+
elif a2 == 1 and a2_next == 2:
|
| 122 |
+
text += "↑"
|
| 123 |
+
if i < len(marks):
|
| 124 |
+
text += unidecode(marks[i]).replace(" ", "")
|
| 125 |
+
return text
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def get_real_sokuon(text):
|
| 129 |
+
for regex, replacement in _real_sokuon:
|
| 130 |
+
text = re.sub(regex, replacement, text)
|
| 131 |
+
return text
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_real_hatsuon(text):
|
| 135 |
+
for regex, replacement in _real_hatsuon:
|
| 136 |
+
text = re.sub(regex, replacement, text)
|
| 137 |
+
return text
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def japanese_to_ipa(text):
|
| 141 |
+
text = japanese_to_romaji_with_accent(text).replace("...", "…")
|
| 142 |
+
text = re.sub(
|
| 143 |
+
r"([aiueo])\1+", lambda x: x.group(0)[0] + "ː" * (len(x.group(0)) - 1), text
|
| 144 |
+
)
|
| 145 |
+
text = get_real_sokuon(text)
|
| 146 |
+
text = get_real_hatsuon(text)
|
| 147 |
+
for regex, replacement in _romaji_to_ipa:
|
| 148 |
+
text = re.sub(regex, replacement, text)
|
| 149 |
+
return text
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def japanese_to_ipa2(text):
|
| 153 |
+
text = japanese_to_romaji_with_accent(text).replace("...", "…")
|
| 154 |
+
text = get_real_sokuon(text)
|
| 155 |
+
text = get_real_hatsuon(text)
|
| 156 |
+
for regex, replacement in _romaji_to_ipa2:
|
| 157 |
+
text = re.sub(regex, replacement, text)
|
| 158 |
+
return text
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def japanese_to_ipa3(text):
|
| 162 |
+
text = (
|
| 163 |
+
japanese_to_ipa2(text)
|
| 164 |
+
.replace("n^", "ȵ")
|
| 165 |
+
.replace("ʃ", "ɕ")
|
| 166 |
+
.replace("*", "\u0325")
|
| 167 |
+
.replace("#", "\u031a")
|
| 168 |
+
)
|
| 169 |
+
text = re.sub(
|
| 170 |
+
r"([aiɯeo])\1+", lambda x: x.group(0)[0] + "ː" * (len(x.group(0)) - 1), text
|
| 171 |
+
)
|
| 172 |
+
text = re.sub(r"((?:^|\s)(?:ts|tɕ|[kpt]))", r"\1ʰ", text)
|
| 173 |
+
return text
|
apps/audio_cloning/vallex/g2p/mandarin.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
import cn2an
|
| 4 |
+
import jieba
|
| 5 |
+
|
| 6 |
+
# List of (Latin alphabet, bopomofo) pairs:
|
| 7 |
+
_latin_to_bopomofo = [
|
| 8 |
+
(re.compile("%s" % x[0], re.IGNORECASE), x[1])
|
| 9 |
+
for x in [
|
| 10 |
+
("a", "ㄟˉ"),
|
| 11 |
+
("b", "ㄅㄧˋ"),
|
| 12 |
+
("c", "ㄙㄧˉ"),
|
| 13 |
+
("d", "ㄉㄧˋ"),
|
| 14 |
+
("e", "ㄧˋ"),
|
| 15 |
+
("f", "ㄝˊㄈㄨˋ"),
|
| 16 |
+
("g", "ㄐㄧˋ"),
|
| 17 |
+
("h", "ㄝˇㄑㄩˋ"),
|
| 18 |
+
("i", "ㄞˋ"),
|
| 19 |
+
("j", "ㄐㄟˋ"),
|
| 20 |
+
("k", "ㄎㄟˋ"),
|
| 21 |
+
("l", "ㄝˊㄛˋ"),
|
| 22 |
+
("m", "ㄝˊㄇㄨˋ"),
|
| 23 |
+
("n", "ㄣˉ"),
|
| 24 |
+
("o", "ㄡˉ"),
|
| 25 |
+
("p", "ㄆㄧˉ"),
|
| 26 |
+
("q", "ㄎㄧㄡˉ"),
|
| 27 |
+
("r", "ㄚˋ"),
|
| 28 |
+
("s", "ㄝˊㄙˋ"),
|
| 29 |
+
("t", "ㄊㄧˋ"),
|
| 30 |
+
("u", "ㄧㄡˉ"),
|
| 31 |
+
("v", "ㄨㄧˉ"),
|
| 32 |
+
("w", "ㄉㄚˋㄅㄨˋㄌㄧㄡˋ"),
|
| 33 |
+
("x", "ㄝˉㄎㄨˋㄙˋ"),
|
| 34 |
+
("y", "ㄨㄞˋ"),
|
| 35 |
+
("z", "ㄗㄟˋ"),
|
| 36 |
+
]
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
# List of (bopomofo, romaji) pairs:
|
| 40 |
+
_bopomofo_to_romaji = [
|
| 41 |
+
(re.compile("%s" % x[0]), x[1])
|
| 42 |
+
for x in [
|
| 43 |
+
("ㄅㄛ", "p⁼wo"),
|
| 44 |
+
("ㄆㄛ", "pʰwo"),
|
| 45 |
+
("ㄇㄛ", "mwo"),
|
| 46 |
+
("ㄈㄛ", "fwo"),
|
| 47 |
+
("ㄅ", "p⁼"),
|
| 48 |
+
("ㄆ", "pʰ"),
|
| 49 |
+
("ㄇ", "m"),
|
| 50 |
+
("ㄈ", "f"),
|
| 51 |
+
("ㄉ", "t⁼"),
|
| 52 |
+
("ㄊ", "tʰ"),
|
| 53 |
+
("ㄋ", "n"),
|
| 54 |
+
("ㄌ", "l"),
|
| 55 |
+
("ㄍ", "k⁼"),
|
| 56 |
+
("ㄎ", "kʰ"),
|
| 57 |
+
("ㄏ", "h"),
|
| 58 |
+
("ㄐ", "ʧ⁼"),
|
| 59 |
+
("ㄑ", "ʧʰ"),
|
| 60 |
+
("ㄒ", "ʃ"),
|
| 61 |
+
("ㄓ", "ʦ`⁼"),
|
| 62 |
+
("ㄔ", "ʦ`ʰ"),
|
| 63 |
+
("ㄕ", "s`"),
|
| 64 |
+
("ㄖ", "ɹ`"),
|
| 65 |
+
("ㄗ", "ʦ⁼"),
|
| 66 |
+
("ㄘ", "ʦʰ"),
|
| 67 |
+
("ㄙ", "s"),
|
| 68 |
+
("ㄚ", "a"),
|
| 69 |
+
("ㄛ", "o"),
|
| 70 |
+
("ㄜ", "ə"),
|
| 71 |
+
("ㄝ", "e"),
|
| 72 |
+
("ㄞ", "ai"),
|
| 73 |
+
("ㄟ", "ei"),
|
| 74 |
+
("ㄠ", "au"),
|
| 75 |
+
("ㄡ", "ou"),
|
| 76 |
+
("ㄧㄢ", "yeNN"),
|
| 77 |
+
("ㄢ", "aNN"),
|
| 78 |
+
("ㄧㄣ", "iNN"),
|
| 79 |
+
("ㄣ", "əNN"),
|
| 80 |
+
("ㄤ", "aNg"),
|
| 81 |
+
("ㄧㄥ", "iNg"),
|
| 82 |
+
("ㄨㄥ", "uNg"),
|
| 83 |
+
("ㄩㄥ", "yuNg"),
|
| 84 |
+
("ㄥ", "əNg"),
|
| 85 |
+
("ㄦ", "əɻ"),
|
| 86 |
+
("ㄧ", "i"),
|
| 87 |
+
("ㄨ", "u"),
|
| 88 |
+
("ㄩ", "ɥ"),
|
| 89 |
+
("ˉ", "→"),
|
| 90 |
+
("ˊ", "↑"),
|
| 91 |
+
("ˇ", "↓↑"),
|
| 92 |
+
("ˋ", "↓"),
|
| 93 |
+
("˙", ""),
|
| 94 |
+
(",", ","),
|
| 95 |
+
("。", "."),
|
| 96 |
+
("!", "!"),
|
| 97 |
+
("?", "?"),
|
| 98 |
+
("—", "-"),
|
| 99 |
+
]
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
# List of (romaji, ipa) pairs:
|
| 103 |
+
_romaji_to_ipa = [
|
| 104 |
+
(re.compile("%s" % x[0], re.IGNORECASE), x[1])
|
| 105 |
+
for x in [
|
| 106 |
+
("ʃy", "ʃ"),
|
| 107 |
+
("ʧʰy", "ʧʰ"),
|
| 108 |
+
("ʧ⁼y", "ʧ⁼"),
|
| 109 |
+
("NN", "n"),
|
| 110 |
+
("Ng", "ŋ"),
|
| 111 |
+
("y", "j"),
|
| 112 |
+
("h", "x"),
|
| 113 |
+
]
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
# List of (bopomofo, ipa) pairs:
|
| 117 |
+
_bopomofo_to_ipa = [
|
| 118 |
+
(re.compile("%s" % x[0]), x[1])
|
| 119 |
+
for x in [
|
| 120 |
+
("ㄅㄛ", "p⁼wo"),
|
| 121 |
+
("ㄆㄛ", "pʰwo"),
|
| 122 |
+
("ㄇㄛ", "mwo"),
|
| 123 |
+
("ㄈㄛ", "fwo"),
|
| 124 |
+
("ㄅ", "p⁼"),
|
| 125 |
+
("ㄆ", "pʰ"),
|
| 126 |
+
("ㄇ", "m"),
|
| 127 |
+
("ㄈ", "f"),
|
| 128 |
+
("ㄉ", "t⁼"),
|
| 129 |
+
("ㄊ", "tʰ"),
|
| 130 |
+
("ㄋ", "n"),
|
| 131 |
+
("ㄌ", "l"),
|
| 132 |
+
("ㄍ", "k⁼"),
|
| 133 |
+
("ㄎ", "kʰ"),
|
| 134 |
+
("ㄏ", "x"),
|
| 135 |
+
("ㄐ", "tʃ⁼"),
|
| 136 |
+
("ㄑ", "tʃʰ"),
|
| 137 |
+
("ㄒ", "ʃ"),
|
| 138 |
+
("ㄓ", "ts`⁼"),
|
| 139 |
+
("ㄔ", "ts`ʰ"),
|
| 140 |
+
("ㄕ", "s`"),
|
| 141 |
+
("ㄖ", "ɹ`"),
|
| 142 |
+
("ㄗ", "ts⁼"),
|
| 143 |
+
("ㄘ", "tsʰ"),
|
| 144 |
+
("ㄙ", "s"),
|
| 145 |
+
("ㄚ", "a"),
|
| 146 |
+
("ㄛ", "o"),
|
| 147 |
+
("ㄜ", "ə"),
|
| 148 |
+
("ㄝ", "ɛ"),
|
| 149 |
+
("ㄞ", "aɪ"),
|
| 150 |
+
("ㄟ", "eɪ"),
|
| 151 |
+
("ㄠ", "ɑʊ"),
|
| 152 |
+
("ㄡ", "oʊ"),
|
| 153 |
+
("ㄧㄢ", "jɛn"),
|
| 154 |
+
("ㄩㄢ", "ɥæn"),
|
| 155 |
+
("ㄢ", "an"),
|
| 156 |
+
("ㄧㄣ", "in"),
|
| 157 |
+
("ㄩㄣ", "ɥn"),
|
| 158 |
+
("ㄣ", "ən"),
|
| 159 |
+
("ㄤ", "ɑŋ"),
|
| 160 |
+
("ㄧㄥ", "iŋ"),
|
| 161 |
+
("ㄨㄥ", "ʊŋ"),
|
| 162 |
+
("ㄩㄥ", "jʊŋ"),
|
| 163 |
+
("ㄥ", "əŋ"),
|
| 164 |
+
("ㄦ", "əɻ"),
|
| 165 |
+
("ㄧ", "i"),
|
| 166 |
+
("ㄨ", "u"),
|
| 167 |
+
("ㄩ", "ɥ"),
|
| 168 |
+
("ˉ", "→"),
|
| 169 |
+
("ˊ", "↑"),
|
| 170 |
+
("ˇ", "↓↑"),
|
| 171 |
+
("ˋ", "↓"),
|
| 172 |
+
("˙", ""),
|
| 173 |
+
(",", ","),
|
| 174 |
+
("。", "."),
|
| 175 |
+
("!", "!"),
|
| 176 |
+
("?", "?"),
|
| 177 |
+
("—", "-"),
|
| 178 |
+
]
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
# List of (bopomofo, ipa2) pairs:
|
| 182 |
+
_bopomofo_to_ipa2 = [
|
| 183 |
+
(re.compile("%s" % x[0]), x[1])
|
| 184 |
+
for x in [
|
| 185 |
+
("ㄅㄛ", "pwo"),
|
| 186 |
+
("ㄆㄛ", "pʰwo"),
|
| 187 |
+
("ㄇㄛ", "mwo"),
|
| 188 |
+
("ㄈㄛ", "fwo"),
|
| 189 |
+
("ㄅ", "p"),
|
| 190 |
+
("ㄆ", "pʰ"),
|
| 191 |
+
("ㄇ", "m"),
|
| 192 |
+
("ㄈ", "f"),
|
| 193 |
+
("ㄉ", "t"),
|
| 194 |
+
("ㄊ", "tʰ"),
|
| 195 |
+
("ㄋ", "n"),
|
| 196 |
+
("ㄌ", "l"),
|
| 197 |
+
("ㄍ", "k"),
|
| 198 |
+
("ㄎ", "kʰ"),
|
| 199 |
+
("ㄏ", "h"),
|
| 200 |
+
("ㄐ", "tɕ"),
|
| 201 |
+
("ㄑ", "tɕʰ"),
|
| 202 |
+
("ㄒ", "ɕ"),
|
| 203 |
+
("ㄓ", "tʂ"),
|
| 204 |
+
("ㄔ", "tʂʰ"),
|
| 205 |
+
("ㄕ", "ʂ"),
|
| 206 |
+
("ㄖ", "ɻ"),
|
| 207 |
+
("ㄗ", "ts"),
|
| 208 |
+
("ㄘ", "tsʰ"),
|
| 209 |
+
("���", "s"),
|
| 210 |
+
("ㄚ", "a"),
|
| 211 |
+
("ㄛ", "o"),
|
| 212 |
+
("ㄜ", "ɤ"),
|
| 213 |
+
("ㄝ", "ɛ"),
|
| 214 |
+
("ㄞ", "aɪ"),
|
| 215 |
+
("ㄟ", "eɪ"),
|
| 216 |
+
("ㄠ", "ɑʊ"),
|
| 217 |
+
("ㄡ", "oʊ"),
|
| 218 |
+
("ㄧㄢ", "jɛn"),
|
| 219 |
+
("ㄩㄢ", "yæn"),
|
| 220 |
+
("ㄢ", "an"),
|
| 221 |
+
("ㄧㄣ", "in"),
|
| 222 |
+
("ㄩㄣ", "yn"),
|
| 223 |
+
("ㄣ", "ən"),
|
| 224 |
+
("ㄤ", "ɑŋ"),
|
| 225 |
+
("ㄧㄥ", "iŋ"),
|
| 226 |
+
("ㄨㄥ", "ʊŋ"),
|
| 227 |
+
("ㄩㄥ", "jʊŋ"),
|
| 228 |
+
("ㄥ", "ɤŋ"),
|
| 229 |
+
("ㄦ", "əɻ"),
|
| 230 |
+
("ㄧ", "i"),
|
| 231 |
+
("ㄨ", "u"),
|
| 232 |
+
("ㄩ", "y"),
|
| 233 |
+
("ˉ", "˥"),
|
| 234 |
+
("ˊ", "˧˥"),
|
| 235 |
+
("ˇ", "˨˩˦"),
|
| 236 |
+
("ˋ", "˥˩"),
|
| 237 |
+
("˙", ""),
|
| 238 |
+
(",", ","),
|
| 239 |
+
("。", "."),
|
| 240 |
+
("!", "!"),
|
| 241 |
+
("?", "?"),
|
| 242 |
+
("—", "-"),
|
| 243 |
+
]
|
| 244 |
+
]
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def number_to_chinese(text):
|
| 248 |
+
numbers = re.findall(r"\d+(?:\.?\d+)?", text)
|
| 249 |
+
for number in numbers:
|
| 250 |
+
text = text.replace(number, cn2an.an2cn(number), 1)
|
| 251 |
+
return text
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def chinese_to_bopomofo(text):
|
| 255 |
+
from pypinyin import BOPOMOFO, lazy_pinyin
|
| 256 |
+
|
| 257 |
+
text = text.replace("、", ",").replace(";", ",").replace(":", ",")
|
| 258 |
+
words = jieba.lcut(text, cut_all=False)
|
| 259 |
+
text = ""
|
| 260 |
+
for word in words:
|
| 261 |
+
bopomofos = lazy_pinyin(word, BOPOMOFO)
|
| 262 |
+
if not re.search("[\u4e00-\u9fff]", word):
|
| 263 |
+
text += word
|
| 264 |
+
continue
|
| 265 |
+
for i in range(len(bopomofos)):
|
| 266 |
+
bopomofos[i] = re.sub(r"([\u3105-\u3129])$", r"\1ˉ", bopomofos[i])
|
| 267 |
+
if text != "":
|
| 268 |
+
text += " "
|
| 269 |
+
text += "".join(bopomofos)
|
| 270 |
+
return text
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def latin_to_bopomofo(text):
|
| 274 |
+
for regex, replacement in _latin_to_bopomofo:
|
| 275 |
+
text = re.sub(regex, replacement, text)
|
| 276 |
+
return text
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def bopomofo_to_romaji(text):
|
| 280 |
+
for regex, replacement in _bopomofo_to_romaji:
|
| 281 |
+
text = re.sub(regex, replacement, text)
|
| 282 |
+
return text
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def bopomofo_to_ipa(text):
|
| 286 |
+
for regex, replacement in _bopomofo_to_ipa:
|
| 287 |
+
text = re.sub(regex, replacement, text)
|
| 288 |
+
return text
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def bopomofo_to_ipa2(text):
|
| 292 |
+
for regex, replacement in _bopomofo_to_ipa2:
|
| 293 |
+
text = re.sub(regex, replacement, text)
|
| 294 |
+
return text
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def chinese_to_romaji(text):
|
| 298 |
+
text = number_to_chinese(text)
|
| 299 |
+
text = chinese_to_bopomofo(text)
|
| 300 |
+
text = latin_to_bopomofo(text)
|
| 301 |
+
text = bopomofo_to_romaji(text)
|
| 302 |
+
text = re.sub("i([aoe])", r"y\1", text)
|
| 303 |
+
text = re.sub("u([aoəe])", r"w\1", text)
|
| 304 |
+
text = re.sub("([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ`\2", text).replace("ɻ", "ɹ`")
|
| 305 |
+
text = re.sub("([ʦs][⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
|
| 306 |
+
return text
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def chinese_to_lazy_ipa(text):
|
| 310 |
+
text = chinese_to_romaji(text)
|
| 311 |
+
for regex, replacement in _romaji_to_ipa:
|
| 312 |
+
text = re.sub(regex, replacement, text)
|
| 313 |
+
return text
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def chinese_to_ipa(text):
|
| 317 |
+
text = number_to_chinese(text)
|
| 318 |
+
text = chinese_to_bopomofo(text)
|
| 319 |
+
text = latin_to_bopomofo(text)
|
| 320 |
+
text = bopomofo_to_ipa(text)
|
| 321 |
+
text = re.sub("i([aoe])", r"j\1", text)
|
| 322 |
+
text = re.sub("u([aoəe])", r"w\1", text)
|
| 323 |
+
text = re.sub("([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ`\2", text).replace("ɻ", "ɹ`")
|
| 324 |
+
text = re.sub("([s][⁼ʰ]?)([→↓↑ ]+|$)", r"\1ɹ\2", text)
|
| 325 |
+
return text
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def chinese_to_ipa2(text):
|
| 329 |
+
text = number_to_chinese(text)
|
| 330 |
+
text = chinese_to_bopomofo(text)
|
| 331 |
+
text = latin_to_bopomofo(text)
|
| 332 |
+
text = bopomofo_to_ipa2(text)
|
| 333 |
+
text = re.sub(r"i([aoe])", r"j\1", text)
|
| 334 |
+
text = re.sub(r"u([aoəe])", r"w\1", text)
|
| 335 |
+
text = re.sub(r"([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)", r"\1ʅ\2", text)
|
| 336 |
+
text = re.sub(r"(sʰ?)([˩˨˧˦˥ ]+|$)", r"\1ɿ\2", text)
|
| 337 |
+
return text
|
apps/audio_cloning/vallex/g2p/symbols.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Defines the set of symbols used in text input to the model.
|
| 3 |
+
'''
|
| 4 |
+
|
| 5 |
+
# japanese_cleaners
|
| 6 |
+
# _pad = '_'
|
| 7 |
+
# _punctuation = ',.!?-'
|
| 8 |
+
# _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
'''# japanese_cleaners2
|
| 12 |
+
_pad = '_'
|
| 13 |
+
_punctuation = ',.!?-~…'
|
| 14 |
+
_letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
|
| 15 |
+
'''
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
'''# korean_cleaners
|
| 19 |
+
_pad = '_'
|
| 20 |
+
_punctuation = ',.!?…~'
|
| 21 |
+
_letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
|
| 22 |
+
'''
|
| 23 |
+
|
| 24 |
+
'''# chinese_cleaners
|
| 25 |
+
_pad = '_'
|
| 26 |
+
_punctuation = ',。!?—…'
|
| 27 |
+
_letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
|
| 28 |
+
'''
|
| 29 |
+
|
| 30 |
+
# # zh_ja_mixture_cleaners
|
| 31 |
+
# _pad = '_'
|
| 32 |
+
# _punctuation = ',.!?-~…'
|
| 33 |
+
# _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
'''# sanskrit_cleaners
|
| 37 |
+
_pad = '_'
|
| 38 |
+
_punctuation = '।'
|
| 39 |
+
_letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
|
| 40 |
+
'''
|
| 41 |
+
|
| 42 |
+
'''# cjks_cleaners
|
| 43 |
+
_pad = '_'
|
| 44 |
+
_punctuation = ',.!?-~…'
|
| 45 |
+
_letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
|
| 46 |
+
'''
|
| 47 |
+
|
| 48 |
+
'''# thai_cleaners
|
| 49 |
+
_pad = '_'
|
| 50 |
+
_punctuation = '.!? '
|
| 51 |
+
_letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
|
| 52 |
+
'''
|
| 53 |
+
|
| 54 |
+
# # cjke_cleaners2
|
| 55 |
+
_pad = '_'
|
| 56 |
+
_punctuation = ',.!?-~…'
|
| 57 |
+
_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
'''# shanghainese_cleaners
|
| 61 |
+
_pad = '_'
|
| 62 |
+
_punctuation = ',.!?…'
|
| 63 |
+
_letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
|
| 64 |
+
'''
|
| 65 |
+
|
| 66 |
+
'''# chinese_dialect_cleaners
|
| 67 |
+
_pad = '_'
|
| 68 |
+
_punctuation = ',.!?~…─'
|
| 69 |
+
_letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ '
|
| 70 |
+
'''
|
| 71 |
+
|
| 72 |
+
# Export all symbols:
|
| 73 |
+
symbols = [_pad] + list(_punctuation) + list(_letters)
|
| 74 |
+
|
| 75 |
+
# Special symbol ids
|
| 76 |
+
SPACE_ID = symbols.index(" ")
|
apps/audio_cloning/vallex/macros.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
NUM_LAYERS = 12
|
| 2 |
+
NUM_HEAD = 16
|
| 3 |
+
N_DIM = 1024
|
| 4 |
+
PREFIX_MODE = 1
|
| 5 |
+
NUM_QUANTIZERS = 8
|
| 6 |
+
SAMPLE_RATE = 24000
|
| 7 |
+
|
| 8 |
+
lang2token = {
|
| 9 |
+
"zh": "[ZH]",
|
| 10 |
+
"ja": "[JA]",
|
| 11 |
+
"en": "[EN]",
|
| 12 |
+
"mix": "",
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
lang2code = {
|
| 16 |
+
"zh": 0,
|
| 17 |
+
"ja": 1,
|
| 18 |
+
"en": 2,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
token2lang = {"[ZH]": "zh", "[JA]": "ja", "[EN]": "en", "": "mix"}
|
| 22 |
+
|
| 23 |
+
code2lang = {
|
| 24 |
+
0: "zh",
|
| 25 |
+
1: "ja",
|
| 26 |
+
2: "en",
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
langdropdown2token = {
|
| 30 |
+
"English": "[EN]",
|
| 31 |
+
"中文": "[ZH]",
|
| 32 |
+
"日本語": "[JA]",
|
| 33 |
+
"Mix": "",
|
| 34 |
+
}
|
apps/audio_cloning/vallex/main.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import multiprocessing
|
| 3 |
+
import os
|
| 4 |
+
import pathlib
|
| 5 |
+
import platform
|
| 6 |
+
import sys
|
| 7 |
+
import tempfile
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import langid
|
| 12 |
+
import nltk
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torchaudio
|
| 16 |
+
import whisper
|
| 17 |
+
from vocos import Vocos
|
| 18 |
+
|
| 19 |
+
from .data.collation import get_text_token_collater
|
| 20 |
+
from .data.tokenizer import (
|
| 21 |
+
AudioTokenizer,
|
| 22 |
+
tokenize_audio,
|
| 23 |
+
)
|
| 24 |
+
from .descriptions import infer_from_audio_ja_md, top_ja_md
|
| 25 |
+
from .examples import infer_from_audio_examples
|
| 26 |
+
from .g2p import PhonemeBpeTokenizer
|
| 27 |
+
from .macros import (
|
| 28 |
+
N_DIM,
|
| 29 |
+
NUM_HEAD,
|
| 30 |
+
NUM_LAYERS,
|
| 31 |
+
NUM_QUANTIZERS,
|
| 32 |
+
PREFIX_MODE,
|
| 33 |
+
lang2code,
|
| 34 |
+
lang2token,
|
| 35 |
+
langdropdown2token,
|
| 36 |
+
token2lang,
|
| 37 |
+
)
|
| 38 |
+
from .models.vallex import VALLE
|
| 39 |
+
|
| 40 |
+
logger = logging.getLogger(__name__)
|
| 41 |
+
|
| 42 |
+
# set languages
|
| 43 |
+
langid.set_languages(["en", "zh", "ja"])
|
| 44 |
+
|
| 45 |
+
# set nltk data path
|
| 46 |
+
nltk.data.path = nltk.data.path + [os.path.join(os.getcwd(), "nltk_data")]
|
| 47 |
+
logger.info("nltk_data path: %s", nltk.data.path)
|
| 48 |
+
|
| 49 |
+
# get encoding
|
| 50 |
+
logger.info(
|
| 51 |
+
"default encoding is %s,file system encoding is %s",
|
| 52 |
+
sys.getdefaultencoding(),
|
| 53 |
+
sys.getfilesystemencoding(),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# check python version
|
| 57 |
+
logger.info("You are using Python version %s", platform.python_version())
|
| 58 |
+
if sys.version_info[0] < 3 or sys.version_info[1] < 7:
|
| 59 |
+
logger.warning("The Python version is too low and may cause problems")
|
| 60 |
+
if platform.system().lower() == "windows":
|
| 61 |
+
temp = pathlib.PosixPath
|
| 62 |
+
pathlib.PosixPath = pathlib.WindowsPath
|
| 63 |
+
else:
|
| 64 |
+
temp = pathlib.WindowsPath
|
| 65 |
+
pathlib.WindowsPath = pathlib.PosixPath
|
| 66 |
+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
| 67 |
+
|
| 68 |
+
# set torch threads (guarded for hot-reload)
|
| 69 |
+
thread_count = multiprocessing.cpu_count()
|
| 70 |
+
logger.info("Use %d cpu cores for computing", thread_count)
|
| 71 |
+
if not getattr(torch, "_vallex_threads_configured", False):
|
| 72 |
+
torch.set_num_threads(thread_count)
|
| 73 |
+
try:
|
| 74 |
+
torch.set_num_interop_threads(thread_count)
|
| 75 |
+
except RuntimeError as err:
|
| 76 |
+
logger.warning("Skipping set_num_interop_threads: %s", err)
|
| 77 |
+
torch._C._jit_set_profiling_executor(False)
|
| 78 |
+
torch._C._jit_set_profiling_mode(False)
|
| 79 |
+
torch._C._set_graph_executor_optimize(False)
|
| 80 |
+
|
| 81 |
+
# gradio のリロード時に torch.set_num_iterop_threads を実行するとエラーになるので、設定済みのフラグをセット
|
| 82 |
+
setattr(torch, "_vallex_threads_configured", True)
|
| 83 |
+
else:
|
| 84 |
+
logger.info("Torch threads already configured; skipping reconfiguration")
|
| 85 |
+
|
| 86 |
+
# set text tokenizer and collater
|
| 87 |
+
logger.info("Setting text tokenizer and collater...")
|
| 88 |
+
tokenizer_path = "./apps/audio_cloning/vallex/g2p/bpe_69.json"
|
| 89 |
+
text_tokenizer = PhonemeBpeTokenizer(tokenizer_path=tokenizer_path)
|
| 90 |
+
text_collater = get_text_token_collater()
|
| 91 |
+
|
| 92 |
+
# set device
|
| 93 |
+
logger.info("Setting device...")
|
| 94 |
+
device = torch.device("cpu")
|
| 95 |
+
if torch.cuda.is_available():
|
| 96 |
+
device = torch.device("cuda", 0)
|
| 97 |
+
# if torch.backends.mps.is_available():
|
| 98 |
+
# device = torch.device("mps")
|
| 99 |
+
logger.info("Device set to %s", device)
|
| 100 |
+
|
| 101 |
+
# Download VALL-E-X model weights if not exists
|
| 102 |
+
OUTPUT_DIR_CHECKPOINTS = "./models/checkpoints"
|
| 103 |
+
if platform.system().lower() == "linux":
|
| 104 |
+
# docker(linux)環境では /app/models/checkpoints にする
|
| 105 |
+
OUTPUT_DIR_CHECKPOINTS = "/app/models/checkpoints"
|
| 106 |
+
|
| 107 |
+
OUTPUT_FILENAME_CHECKPOINTS = "vallex-checkpoint.pt"
|
| 108 |
+
OUTPUT_PATH_CHECKPOINTS = os.path.join(
|
| 109 |
+
OUTPUT_DIR_CHECKPOINTS, OUTPUT_FILENAME_CHECKPOINTS
|
| 110 |
+
)
|
| 111 |
+
if not os.path.exists(OUTPUT_DIR_CHECKPOINTS):
|
| 112 |
+
os.makedirs(OUTPUT_DIR_CHECKPOINTS, exist_ok=True)
|
| 113 |
+
if not os.path.exists(OUTPUT_PATH_CHECKPOINTS):
|
| 114 |
+
import wget
|
| 115 |
+
|
| 116 |
+
logging.info(
|
| 117 |
+
"Downloading model from https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt ..."
|
| 118 |
+
)
|
| 119 |
+
wget.download(
|
| 120 |
+
"https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt",
|
| 121 |
+
out=OUTPUT_PATH_CHECKPOINTS,
|
| 122 |
+
bar=wget.bar_adaptive,
|
| 123 |
+
)
|
| 124 |
+
raise Exception(
|
| 125 |
+
"\n Model weights download failed, please go to 'https://huggingface.co/Plachta/VALL-E-X/resolve/main/vallex-checkpoint.pt'"
|
| 126 |
+
"\n manually download model weights and put it to {} .".format(
|
| 127 |
+
os.getcwd() + f"{OUTPUT_DIR_CHECKPOINTS}"
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# initialize VALL-E-X model
|
| 132 |
+
model = VALLE(
|
| 133 |
+
N_DIM,
|
| 134 |
+
NUM_HEAD,
|
| 135 |
+
NUM_LAYERS,
|
| 136 |
+
norm_first=True,
|
| 137 |
+
add_prenet=False,
|
| 138 |
+
prefix_mode=PREFIX_MODE,
|
| 139 |
+
share_embedding=True,
|
| 140 |
+
nar_scale_factor=1.0,
|
| 141 |
+
prepend_bos=True,
|
| 142 |
+
num_quantizers=NUM_QUANTIZERS,
|
| 143 |
+
)
|
| 144 |
+
checkpoint = torch.load(OUTPUT_PATH_CHECKPOINTS, map_location="cpu", weights_only=False)
|
| 145 |
+
missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model"], strict=True)
|
| 146 |
+
assert not missing_keys
|
| 147 |
+
model.eval()
|
| 148 |
+
|
| 149 |
+
# Encodec-based tokenizer: converts reference audio into discrete conditioning tokens for VALLE
|
| 150 |
+
logger.info("Initializing Encodec-based tokenizer...")
|
| 151 |
+
audio_tokenizer = AudioTokenizer(device)
|
| 152 |
+
|
| 153 |
+
# Vocos vocoder: decodes VALLE's discrete acoustic codes back into a 24 kHz waveform
|
| 154 |
+
vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(device)
|
| 155 |
+
|
| 156 |
+
# initialize ASR model
|
| 157 |
+
OUTPUT_DIR_WHISPER = "./models/whisper"
|
| 158 |
+
if platform.system().lower() == "linux":
|
| 159 |
+
OUTPUT_DIR_WHISPER = "/app/models/whisper"
|
| 160 |
+
|
| 161 |
+
if not os.path.exists(OUTPUT_DIR_WHISPER):
|
| 162 |
+
os.makedirs(OUTPUT_DIR_WHISPER, exist_ok=True)
|
| 163 |
+
try:
|
| 164 |
+
logger.info("Loading Whisper model...")
|
| 165 |
+
model_name = "tiny"
|
| 166 |
+
whisper_model = whisper.load_model(
|
| 167 |
+
model_name, download_root=OUTPUT_DIR_WHISPER
|
| 168 |
+
).cpu()
|
| 169 |
+
logger.info("Whisper model loaded successfully")
|
| 170 |
+
except Exception as e:
|
| 171 |
+
logging.info(e)
|
| 172 |
+
raise Exception(
|
| 173 |
+
"\n Whisper download failed or damaged, please go to "
|
| 174 |
+
"'https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt'"
|
| 175 |
+
"\n manually download model and put it to {} .".format(os.getcwd() + "/whisper")
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Initialize Voice Presets
|
| 179 |
+
logger.info("Initializing Voice Presets...")
|
| 180 |
+
PRESETS_DIR = "apps/audio_cloning/vallex/presets"
|
| 181 |
+
preset_list = os.walk(PRESETS_DIR).__next__()[2]
|
| 182 |
+
preset_list = [preset[:-4] for preset in preset_list if preset.endswith(".npz")]
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def clear_prompts():
|
| 186 |
+
try:
|
| 187 |
+
path = tempfile.gettempdir()
|
| 188 |
+
for eachfile in os.listdir(path):
|
| 189 |
+
filename = os.path.join(path, eachfile)
|
| 190 |
+
if os.path.isfile(filename) and filename.endswith(".npz"):
|
| 191 |
+
lastmodifytime = os.stat(filename).st_mtime
|
| 192 |
+
endfiletime = time.time() - 60
|
| 193 |
+
if endfiletime > lastmodifytime:
|
| 194 |
+
os.remove(filename)
|
| 195 |
+
except Exception as e:
|
| 196 |
+
logger.error("Error clearing prompts: %s", e)
|
| 197 |
+
return
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def transcribe_one(model, audio_path):
|
| 201 |
+
# load audio and pad/trim it to fit 30 seconds
|
| 202 |
+
audio = whisper.load_audio(audio_path)
|
| 203 |
+
audio = whisper.pad_or_trim(audio)
|
| 204 |
+
|
| 205 |
+
# make log-Mel spectrogram and move to the same device as the model
|
| 206 |
+
mel = whisper.log_mel_spectrogram(audio).to(model.device)
|
| 207 |
+
|
| 208 |
+
# detect the spoken language
|
| 209 |
+
_, probs = model.detect_language(mel)
|
| 210 |
+
print(f"Detected language: {max(probs, key=probs.get)}")
|
| 211 |
+
lang = max(probs, key=probs.get)
|
| 212 |
+
# decode the audio
|
| 213 |
+
options = whisper.DecodingOptions(
|
| 214 |
+
temperature=1.0,
|
| 215 |
+
best_of=5,
|
| 216 |
+
fp16=False if device == torch.device("cpu") else True,
|
| 217 |
+
sample_len=150,
|
| 218 |
+
)
|
| 219 |
+
result = whisper.decode(model, mel, options)
|
| 220 |
+
|
| 221 |
+
# print the recognized text
|
| 222 |
+
print(result.text)
|
| 223 |
+
|
| 224 |
+
text_pr = result.text
|
| 225 |
+
if text_pr.strip(" ")[-1] not in "?!.,。,?!。、":
|
| 226 |
+
text_pr += "."
|
| 227 |
+
return lang, text_pr
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
|
| 231 |
+
global model, text_collater, text_tokenizer, audio_tokenizer
|
| 232 |
+
clear_prompts()
|
| 233 |
+
audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio
|
| 234 |
+
sr, wav_pr = audio_prompt
|
| 235 |
+
if not isinstance(wav_pr, torch.FloatTensor):
|
| 236 |
+
wav_pr = torch.FloatTensor(wav_pr)
|
| 237 |
+
if wav_pr.abs().max() > 1:
|
| 238 |
+
wav_pr /= wav_pr.abs().max()
|
| 239 |
+
if wav_pr.size(-1) == 2:
|
| 240 |
+
wav_pr = wav_pr[:, 0]
|
| 241 |
+
if wav_pr.ndim == 1:
|
| 242 |
+
wav_pr = wav_pr.unsqueeze(0)
|
| 243 |
+
assert wav_pr.ndim and wav_pr.size(0) == 1
|
| 244 |
+
|
| 245 |
+
if transcript_content == "":
|
| 246 |
+
text_pr, lang_pr = make_prompt(name, wav_pr, sr, save=False)
|
| 247 |
+
else:
|
| 248 |
+
lang_pr = langid.classify(str(transcript_content))[0]
|
| 249 |
+
lang_token = lang2token[lang_pr]
|
| 250 |
+
text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
|
| 251 |
+
# tokenize audio
|
| 252 |
+
encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
|
| 253 |
+
audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy()
|
| 254 |
+
|
| 255 |
+
# tokenize text
|
| 256 |
+
phonemes, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
|
| 257 |
+
text_tokens, enroll_x_lens = text_collater([phonemes])
|
| 258 |
+
|
| 259 |
+
message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n"
|
| 260 |
+
|
| 261 |
+
# save as npz file
|
| 262 |
+
np.savez(
|
| 263 |
+
os.path.join(tempfile.gettempdir(), f"{name}.npz"),
|
| 264 |
+
audio_tokens=audio_tokens,
|
| 265 |
+
text_tokens=text_tokens,
|
| 266 |
+
lang_code=lang2code[lang_pr],
|
| 267 |
+
)
|
| 268 |
+
return message, os.path.join(tempfile.gettempdir(), f"{name}.npz")
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def make_prompt(name, wav, sr, save=True):
|
| 272 |
+
global whisper_model
|
| 273 |
+
whisper_model.to(device)
|
| 274 |
+
if not isinstance(wav, torch.FloatTensor):
|
| 275 |
+
wav = torch.tensor(wav)
|
| 276 |
+
if wav.abs().max() > 1:
|
| 277 |
+
wav /= wav.abs().max()
|
| 278 |
+
if wav.size(-1) == 2:
|
| 279 |
+
wav = wav.mean(-1, keepdim=False)
|
| 280 |
+
if wav.ndim == 1:
|
| 281 |
+
wav = wav.unsqueeze(0)
|
| 282 |
+
assert wav.ndim and wav.size(0) == 1
|
| 283 |
+
torchaudio.save(f"./prompts/{name}.wav", wav, sr)
|
| 284 |
+
lang, text = transcribe_one(whisper_model, f"./prompts/{name}.wav")
|
| 285 |
+
lang_token = lang2token[lang]
|
| 286 |
+
text = lang_token + text + lang_token
|
| 287 |
+
with open(f"./prompts/{name}.txt", "w", encoding="utf-8") as f:
|
| 288 |
+
f.write(text)
|
| 289 |
+
if not save:
|
| 290 |
+
os.remove(f"./prompts/{name}.wav")
|
| 291 |
+
os.remove(f"./prompts/{name}.txt")
|
| 292 |
+
|
| 293 |
+
whisper_model.cpu()
|
| 294 |
+
torch.cuda.empty_cache()
|
| 295 |
+
return text, lang
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
@torch.no_grad()
|
| 299 |
+
def infer_from_audio(
|
| 300 |
+
text, language, accent, audio_prompt, record_audio_prompt, transcript_content
|
| 301 |
+
):
|
| 302 |
+
global model, text_collater, text_tokenizer, audio_tokenizer
|
| 303 |
+
audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
|
| 304 |
+
sr, wav_pr = audio_prompt
|
| 305 |
+
if not isinstance(wav_pr, torch.FloatTensor):
|
| 306 |
+
wav_pr = torch.FloatTensor(wav_pr)
|
| 307 |
+
if wav_pr.abs().max() > 1:
|
| 308 |
+
wav_pr /= wav_pr.abs().max()
|
| 309 |
+
if wav_pr.size(-1) == 2:
|
| 310 |
+
wav_pr = wav_pr[:, 0]
|
| 311 |
+
if wav_pr.ndim == 1:
|
| 312 |
+
wav_pr = wav_pr.unsqueeze(0)
|
| 313 |
+
assert wav_pr.ndim and wav_pr.size(0) == 1
|
| 314 |
+
|
| 315 |
+
if transcript_content == "":
|
| 316 |
+
text_pr, lang_pr = make_prompt("dummy", wav_pr, sr, save=False)
|
| 317 |
+
else:
|
| 318 |
+
lang_pr = langid.classify(str(transcript_content))[0]
|
| 319 |
+
lang_token = lang2token[lang_pr]
|
| 320 |
+
text_pr = f"{lang_token}{str(transcript_content)}{lang_token}"
|
| 321 |
+
|
| 322 |
+
if language == "auto-detect":
|
| 323 |
+
lang_token = lang2token[langid.classify(text)[0]]
|
| 324 |
+
else:
|
| 325 |
+
lang_token = langdropdown2token[language]
|
| 326 |
+
lang = token2lang[lang_token]
|
| 327 |
+
text = lang_token + text + lang_token
|
| 328 |
+
|
| 329 |
+
# onload model
|
| 330 |
+
model.to(device)
|
| 331 |
+
|
| 332 |
+
# tokenize audio
|
| 333 |
+
encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
|
| 334 |
+
audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
|
| 335 |
+
|
| 336 |
+
# tokenize text
|
| 337 |
+
logging.info(f"synthesize text: {text}")
|
| 338 |
+
phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
|
| 339 |
+
text_tokens, text_tokens_lens = text_collater([phone_tokens])
|
| 340 |
+
|
| 341 |
+
enroll_x_lens = None
|
| 342 |
+
if text_pr:
|
| 343 |
+
text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
|
| 344 |
+
text_prompts, enroll_x_lens = text_collater([text_prompts])
|
| 345 |
+
text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
|
| 346 |
+
text_tokens_lens += enroll_x_lens
|
| 347 |
+
lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
|
| 348 |
+
encoded_frames = model.inference(
|
| 349 |
+
text_tokens.to(device),
|
| 350 |
+
text_tokens_lens.to(device),
|
| 351 |
+
audio_prompts,
|
| 352 |
+
enroll_x_lens=enroll_x_lens,
|
| 353 |
+
top_k=-100,
|
| 354 |
+
temperature=1,
|
| 355 |
+
prompt_language=lang_pr,
|
| 356 |
+
text_language=langs if accent == "no-accent" else lang,
|
| 357 |
+
best_of=5,
|
| 358 |
+
)
|
| 359 |
+
# Decode with Vocos
|
| 360 |
+
frames = encoded_frames.permute(2, 0, 1)
|
| 361 |
+
features = vocos.codes_to_features(frames)
|
| 362 |
+
samples = vocos.decode(features, bandwidth_id=torch.tensor([2], device=device))
|
| 363 |
+
|
| 364 |
+
# offload model
|
| 365 |
+
model.to("cpu")
|
| 366 |
+
torch.cuda.empty_cache()
|
| 367 |
+
|
| 368 |
+
message = f"text prompt: {text_pr}\nsythesized text: {text}"
|
| 369 |
+
return message, (24000, samples.squeeze(0).cpu().numpy())
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def main():
|
| 373 |
+
app = gr.Blocks(title="VALL-E X")
|
| 374 |
+
with app:
|
| 375 |
+
gr.Markdown(top_ja_md)
|
| 376 |
+
with gr.Tab("Infer from audio"):
|
| 377 |
+
gr.Markdown(infer_from_audio_ja_md)
|
| 378 |
+
with gr.Row():
|
| 379 |
+
with gr.Column():
|
| 380 |
+
textbox = gr.TextArea(
|
| 381 |
+
label="音声合成で喋らせたいテキスト",
|
| 382 |
+
# placeholder="Type your sentence here",
|
| 383 |
+
placeholder="ここに音声合成で喋らせたいテキストを入力してください。",
|
| 384 |
+
value="Welcome back, Master. What can I do for you today?",
|
| 385 |
+
elem_id="tts-input",
|
| 386 |
+
)
|
| 387 |
+
language_dropdown = gr.Dropdown(
|
| 388 |
+
choices=["auto-detect", "English", "中文", "日本語"],
|
| 389 |
+
value="auto-detect",
|
| 390 |
+
label="language",
|
| 391 |
+
)
|
| 392 |
+
accent_dropdown = gr.Dropdown(
|
| 393 |
+
choices=["no-accent", "English", "中文", "日本語"],
|
| 394 |
+
value="no-accent",
|
| 395 |
+
label="accent",
|
| 396 |
+
)
|
| 397 |
+
textbox_transcript = gr.TextArea(
|
| 398 |
+
label="Transcript",
|
| 399 |
+
# placeholder="Write transcript here. (leave empty to use whisper)",
|
| 400 |
+
placeholder="アップロードした音声、または録音した音声のテキストを入力してください。(whisper を使用する場合は空のままにしてください。)",
|
| 401 |
+
value="",
|
| 402 |
+
elem_id="prompt-name",
|
| 403 |
+
)
|
| 404 |
+
upload_audio_prompt = gr.Audio(
|
| 405 |
+
label="音声アップロード",
|
| 406 |
+
sources=["upload"],
|
| 407 |
+
interactive=True,
|
| 408 |
+
)
|
| 409 |
+
record_audio_prompt = gr.Audio(
|
| 410 |
+
label="音声を録音する",
|
| 411 |
+
sources=["microphone"],
|
| 412 |
+
interactive=True,
|
| 413 |
+
)
|
| 414 |
+
with gr.Column():
|
| 415 |
+
text_output = gr.Textbox(label="Message")
|
| 416 |
+
audio_output = gr.Audio(label="Output Audio", elem_id="tts-audio")
|
| 417 |
+
btn = gr.Button("音声合成を開始する")
|
| 418 |
+
btn.click(
|
| 419 |
+
infer_from_audio,
|
| 420 |
+
inputs=[
|
| 421 |
+
textbox,
|
| 422 |
+
language_dropdown,
|
| 423 |
+
accent_dropdown,
|
| 424 |
+
upload_audio_prompt,
|
| 425 |
+
record_audio_prompt,
|
| 426 |
+
textbox_transcript,
|
| 427 |
+
],
|
| 428 |
+
outputs=[text_output, audio_output],
|
| 429 |
+
)
|
| 430 |
+
textbox_mp = gr.TextArea(
|
| 431 |
+
label="Prompt name",
|
| 432 |
+
placeholder="Name your prompt here",
|
| 433 |
+
value="prompt_1",
|
| 434 |
+
elem_id="prompt-name",
|
| 435 |
+
)
|
| 436 |
+
btn_mp = gr.Button("Make prompt!")
|
| 437 |
+
prompt_output = gr.File(interactive=False)
|
| 438 |
+
btn_mp.click(
|
| 439 |
+
make_npz_prompt,
|
| 440 |
+
inputs=[
|
| 441 |
+
textbox_mp,
|
| 442 |
+
upload_audio_prompt,
|
| 443 |
+
record_audio_prompt,
|
| 444 |
+
textbox_transcript,
|
| 445 |
+
],
|
| 446 |
+
outputs=[text_output, prompt_output],
|
| 447 |
+
)
|
| 448 |
+
gr.Examples(
|
| 449 |
+
examples=infer_from_audio_examples,
|
| 450 |
+
inputs=[
|
| 451 |
+
textbox,
|
| 452 |
+
language_dropdown,
|
| 453 |
+
accent_dropdown,
|
| 454 |
+
upload_audio_prompt,
|
| 455 |
+
record_audio_prompt,
|
| 456 |
+
textbox_transcript,
|
| 457 |
+
],
|
| 458 |
+
outputs=[text_output, audio_output],
|
| 459 |
+
fn=infer_from_audio,
|
| 460 |
+
cache_examples=False,
|
| 461 |
+
)
|
apps/audio_cloning/vallex/models/__init__.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from .transformer import Transformer
|
| 6 |
+
from .vallex import VALLE, VALLF
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def add_model_arguments(parser: argparse.ArgumentParser):
|
| 10 |
+
parser.add_argument(
|
| 11 |
+
"--model-name",
|
| 12 |
+
type=str,
|
| 13 |
+
default="VALL-E",
|
| 14 |
+
help="VALL-E, VALL-F, Transformer.",
|
| 15 |
+
)
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
"--decoder-dim",
|
| 18 |
+
type=int,
|
| 19 |
+
default=1024,
|
| 20 |
+
help="Embedding dimension in the decoder model.",
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--nhead",
|
| 24 |
+
type=int,
|
| 25 |
+
default=16,
|
| 26 |
+
help="Number of attention heads in the Decoder layers.",
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--num-decoder-layers",
|
| 30 |
+
type=int,
|
| 31 |
+
default=12,
|
| 32 |
+
help="Number of Decoder layers.",
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--scale-factor",
|
| 36 |
+
type=float,
|
| 37 |
+
default=1.0,
|
| 38 |
+
help="Model scale factor which will be assigned different meanings in different models.",
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--norm-first",
|
| 42 |
+
type=bool,
|
| 43 |
+
default=True,
|
| 44 |
+
help="Pre or Post Normalization.",
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--add-prenet",
|
| 48 |
+
type=bool,
|
| 49 |
+
default=False,
|
| 50 |
+
help="Whether add PreNet after Inputs.",
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# VALL-E & F
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--prefix-mode",
|
| 56 |
+
type=int,
|
| 57 |
+
default=1,
|
| 58 |
+
help="The mode for how to prefix VALL-E NAR Decoder, "
|
| 59 |
+
"0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--share-embedding",
|
| 63 |
+
type=bool,
|
| 64 |
+
default=True,
|
| 65 |
+
help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.",
|
| 66 |
+
)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--prepend-bos",
|
| 69 |
+
type=bool,
|
| 70 |
+
default=False,
|
| 71 |
+
help="Whether prepend <BOS> to the acoustic tokens -> AR Decoder inputs.",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--num-quantizers",
|
| 75 |
+
type=int,
|
| 76 |
+
default=8,
|
| 77 |
+
help="Number of Audio/Semantic quantization layers.",
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Transformer
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--scaling-xformers",
|
| 83 |
+
type=bool,
|
| 84 |
+
default=False,
|
| 85 |
+
help="Apply Reworked Conformer scaling on Transformers.",
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_model(params) -> nn.Module:
|
| 90 |
+
if params.model_name.lower() in ["vall-f", "vallf"]:
|
| 91 |
+
model = VALLF(
|
| 92 |
+
params.decoder_dim,
|
| 93 |
+
params.nhead,
|
| 94 |
+
params.num_decoder_layers,
|
| 95 |
+
norm_first=params.norm_first,
|
| 96 |
+
add_prenet=params.add_prenet,
|
| 97 |
+
prefix_mode=params.prefix_mode,
|
| 98 |
+
share_embedding=params.share_embedding,
|
| 99 |
+
nar_scale_factor=params.scale_factor,
|
| 100 |
+
prepend_bos=params.prepend_bos,
|
| 101 |
+
num_quantizers=params.num_quantizers,
|
| 102 |
+
)
|
| 103 |
+
elif params.model_name.lower() in ["vall-e", "valle"]:
|
| 104 |
+
model = VALLE(
|
| 105 |
+
params.decoder_dim,
|
| 106 |
+
params.nhead,
|
| 107 |
+
params.num_decoder_layers,
|
| 108 |
+
norm_first=params.norm_first,
|
| 109 |
+
add_prenet=params.add_prenet,
|
| 110 |
+
prefix_mode=params.prefix_mode,
|
| 111 |
+
share_embedding=params.share_embedding,
|
| 112 |
+
nar_scale_factor=params.scale_factor,
|
| 113 |
+
prepend_bos=params.prepend_bos,
|
| 114 |
+
num_quantizers=params.num_quantizers,
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
assert params.model_name in ["Transformer"]
|
| 118 |
+
model = Transformer(
|
| 119 |
+
params.decoder_dim,
|
| 120 |
+
params.nhead,
|
| 121 |
+
params.num_decoder_layers,
|
| 122 |
+
norm_first=params.norm_first,
|
| 123 |
+
add_prenet=params.add_prenet,
|
| 124 |
+
scaling_xformers=params.scaling_xformers,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return model
|
apps/audio_cloning/vallex/models/macros.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Text
|
| 2 |
+
NUM_TEXT_TOKENS = 2048
|
| 3 |
+
|
| 4 |
+
# Audio
|
| 5 |
+
NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins
|
| 6 |
+
NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Speaker
|
| 10 |
+
NUM_SPEAKER_CLASSES = 4096
|
| 11 |
+
SPEAKER_EMBEDDING_DIM = 64
|
apps/audio_cloning/vallex/models/transformer.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from functools import partial
|
| 16 |
+
from typing import Any, Dict, List, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from ..modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
| 23 |
+
from ..modules.scaling import BalancedDoubleSwish, ScaledLinear
|
| 24 |
+
from ..modules.transformer import (
|
| 25 |
+
BalancedBasicNorm,
|
| 26 |
+
IdentityNorm,
|
| 27 |
+
TransformerDecoderLayer,
|
| 28 |
+
TransformerEncoder,
|
| 29 |
+
TransformerEncoderLayer,
|
| 30 |
+
)
|
| 31 |
+
from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS
|
| 32 |
+
|
| 33 |
+
# from icefall.utils import make_pad_mask
|
| 34 |
+
# from torchmetrics.classification import BinaryAccuracy
|
| 35 |
+
from .vallex import Transpose
|
| 36 |
+
from .visualizer import visualize
|
| 37 |
+
|
| 38 |
+
IdentityNorm = IdentityNorm
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Transformer(nn.Module):
|
| 42 |
+
"""It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding)
|
| 43 |
+
Neural Speech Synthesis with Transformer Network
|
| 44 |
+
https://arxiv.org/abs/1809.08895
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
d_model: int,
|
| 50 |
+
nhead: int,
|
| 51 |
+
num_layers: int,
|
| 52 |
+
norm_first: bool = True,
|
| 53 |
+
add_prenet: bool = False,
|
| 54 |
+
scaling_xformers: bool = False,
|
| 55 |
+
):
|
| 56 |
+
"""
|
| 57 |
+
Args:
|
| 58 |
+
d_model:
|
| 59 |
+
The number of expected features in the input (required).
|
| 60 |
+
nhead:
|
| 61 |
+
The number of heads in the multiheadattention models (required).
|
| 62 |
+
num_layers:
|
| 63 |
+
The number of sub-decoder-layers in the decoder (required).
|
| 64 |
+
"""
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
|
| 67 |
+
|
| 68 |
+
if add_prenet:
|
| 69 |
+
self.encoder_prenet = nn.Sequential(
|
| 70 |
+
Transpose(),
|
| 71 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
| 72 |
+
nn.BatchNorm1d(d_model),
|
| 73 |
+
nn.ReLU(),
|
| 74 |
+
nn.Dropout(0.5),
|
| 75 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
| 76 |
+
nn.BatchNorm1d(d_model),
|
| 77 |
+
nn.ReLU(),
|
| 78 |
+
nn.Dropout(0.5),
|
| 79 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
| 80 |
+
nn.BatchNorm1d(d_model),
|
| 81 |
+
nn.ReLU(),
|
| 82 |
+
nn.Dropout(0.5),
|
| 83 |
+
Transpose(),
|
| 84 |
+
nn.Linear(d_model, d_model),
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
self.decoder_prenet = nn.Sequential(
|
| 88 |
+
nn.Linear(NUM_MEL_BINS, 256),
|
| 89 |
+
nn.ReLU(),
|
| 90 |
+
nn.Dropout(0.5),
|
| 91 |
+
nn.Linear(256, 256),
|
| 92 |
+
nn.ReLU(),
|
| 93 |
+
nn.Dropout(0.5),
|
| 94 |
+
nn.Linear(256, d_model),
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
assert scaling_xformers is False # TODO: update this block
|
| 98 |
+
else:
|
| 99 |
+
self.encoder_prenet = nn.Identity()
|
| 100 |
+
if scaling_xformers:
|
| 101 |
+
self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model)
|
| 102 |
+
else:
|
| 103 |
+
self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model)
|
| 104 |
+
|
| 105 |
+
self.encoder_position = SinePositionalEmbedding(
|
| 106 |
+
d_model,
|
| 107 |
+
dropout=0.1,
|
| 108 |
+
scale=False,
|
| 109 |
+
)
|
| 110 |
+
self.decoder_position = SinePositionalEmbedding(
|
| 111 |
+
d_model, dropout=0.1, scale=False
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
if scaling_xformers:
|
| 115 |
+
self.encoder = TransformerEncoder(
|
| 116 |
+
TransformerEncoderLayer(
|
| 117 |
+
d_model,
|
| 118 |
+
nhead,
|
| 119 |
+
dim_feedforward=d_model * 4,
|
| 120 |
+
dropout=0.1,
|
| 121 |
+
batch_first=True,
|
| 122 |
+
norm_first=norm_first,
|
| 123 |
+
linear1_self_attention_cls=ScaledLinear,
|
| 124 |
+
linear2_self_attention_cls=partial(
|
| 125 |
+
ScaledLinear, initial_scale=0.01
|
| 126 |
+
),
|
| 127 |
+
linear1_feedforward_cls=ScaledLinear,
|
| 128 |
+
linear2_feedforward_cls=partial(ScaledLinear, initial_scale=0.01),
|
| 129 |
+
activation=partial(
|
| 130 |
+
BalancedDoubleSwish,
|
| 131 |
+
channel_dim=-1,
|
| 132 |
+
max_abs=10.0,
|
| 133 |
+
min_prob=0.25,
|
| 134 |
+
),
|
| 135 |
+
layer_norm_cls=IdentityNorm,
|
| 136 |
+
),
|
| 137 |
+
num_layers=num_layers,
|
| 138 |
+
norm=BalancedBasicNorm(d_model) if norm_first else None,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
self.decoder = nn.TransformerDecoder(
|
| 142 |
+
TransformerDecoderLayer(
|
| 143 |
+
d_model,
|
| 144 |
+
nhead,
|
| 145 |
+
dim_feedforward=d_model * 4,
|
| 146 |
+
dropout=0.1,
|
| 147 |
+
batch_first=True,
|
| 148 |
+
norm_first=norm_first,
|
| 149 |
+
linear1_self_attention_cls=ScaledLinear,
|
| 150 |
+
linear2_self_attention_cls=partial(
|
| 151 |
+
ScaledLinear, initial_scale=0.01
|
| 152 |
+
),
|
| 153 |
+
linear1_feedforward_cls=ScaledLinear,
|
| 154 |
+
linear2_feedforward_cls=partial(ScaledLinear, initial_scale=0.01),
|
| 155 |
+
activation=partial(
|
| 156 |
+
BalancedDoubleSwish,
|
| 157 |
+
channel_dim=-1,
|
| 158 |
+
max_abs=10.0,
|
| 159 |
+
min_prob=0.25,
|
| 160 |
+
),
|
| 161 |
+
layer_norm_cls=IdentityNorm,
|
| 162 |
+
),
|
| 163 |
+
num_layers=num_layers,
|
| 164 |
+
norm=BalancedBasicNorm(d_model) if norm_first else None,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS)
|
| 168 |
+
self.stop_layer = nn.Linear(d_model, 1)
|
| 169 |
+
else:
|
| 170 |
+
self.encoder = nn.TransformerEncoder(
|
| 171 |
+
nn.TransformerEncoderLayer(
|
| 172 |
+
d_model,
|
| 173 |
+
nhead,
|
| 174 |
+
dim_feedforward=d_model * 4,
|
| 175 |
+
activation=F.relu,
|
| 176 |
+
dropout=0.1,
|
| 177 |
+
batch_first=True,
|
| 178 |
+
norm_first=norm_first,
|
| 179 |
+
),
|
| 180 |
+
num_layers=num_layers,
|
| 181 |
+
norm=nn.LayerNorm(d_model) if norm_first else None,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
self.decoder = nn.TransformerDecoder(
|
| 185 |
+
nn.TransformerDecoderLayer(
|
| 186 |
+
d_model,
|
| 187 |
+
nhead,
|
| 188 |
+
dim_feedforward=d_model * 4,
|
| 189 |
+
activation=F.relu,
|
| 190 |
+
dropout=0.1,
|
| 191 |
+
batch_first=True,
|
| 192 |
+
norm_first=norm_first,
|
| 193 |
+
),
|
| 194 |
+
num_layers=num_layers,
|
| 195 |
+
norm=nn.LayerNorm(d_model) if norm_first else None,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS)
|
| 199 |
+
self.stop_layer = nn.Linear(d_model, 1)
|
| 200 |
+
|
| 201 |
+
self.stop_accuracy_metric = BinaryAccuracy(
|
| 202 |
+
threshold=0.5, multidim_average="global"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# self.apply(self._init_weights)
|
| 206 |
+
|
| 207 |
+
# def _init_weights(self, module):
|
| 208 |
+
# if isinstance(module, (nn.Linear)):
|
| 209 |
+
# module.weight.data.normal_(mean=0.0, std=0.02)
|
| 210 |
+
# if isinstance(module, nn.Linear) and module.bias is not None:
|
| 211 |
+
# module.bias.data.zero_()
|
| 212 |
+
# elif isinstance(module, nn.LayerNorm):
|
| 213 |
+
# module.bias.data.zero_()
|
| 214 |
+
# module.weight.data.fill_(1.0)
|
| 215 |
+
# elif isinstance(module, nn.Embedding):
|
| 216 |
+
# module.weight.data.normal_(mean=0.0, std=0.02)
|
| 217 |
+
|
| 218 |
+
def forward(
|
| 219 |
+
self,
|
| 220 |
+
x: torch.Tensor,
|
| 221 |
+
x_lens: torch.Tensor,
|
| 222 |
+
y: torch.Tensor,
|
| 223 |
+
y_lens: torch.Tensor,
|
| 224 |
+
reduction: str = "sum",
|
| 225 |
+
train_stage: int = 0,
|
| 226 |
+
**kwargs,
|
| 227 |
+
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
| 228 |
+
"""
|
| 229 |
+
Args:
|
| 230 |
+
x:
|
| 231 |
+
A 2-D tensor of shape (N, S).
|
| 232 |
+
x_lens:
|
| 233 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
| 234 |
+
before padding.
|
| 235 |
+
y:
|
| 236 |
+
A 3-D tensor of shape (N, T, 8).
|
| 237 |
+
y_lens:
|
| 238 |
+
A 1-D tensor of shape (N,). It contains the number of tokens in `x`
|
| 239 |
+
before padding.
|
| 240 |
+
train_stage:
|
| 241 |
+
Not used in this model.
|
| 242 |
+
Returns:
|
| 243 |
+
Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
|
| 244 |
+
"""
|
| 245 |
+
del train_stage
|
| 246 |
+
|
| 247 |
+
assert x.ndim == 2, x.shape
|
| 248 |
+
assert x_lens.ndim == 1, x_lens.shape
|
| 249 |
+
assert y.ndim == 3, y.shape
|
| 250 |
+
assert y_lens.ndim == 1, y_lens.shape
|
| 251 |
+
|
| 252 |
+
assert torch.all(x_lens > 0)
|
| 253 |
+
|
| 254 |
+
# NOTE: x has been padded in TextTokenCollater
|
| 255 |
+
x_mask = make_pad_mask(x_lens).to(x.device)
|
| 256 |
+
|
| 257 |
+
x = self.text_embedding(x)
|
| 258 |
+
x = self.encoder_prenet(x)
|
| 259 |
+
x = self.encoder_position(x)
|
| 260 |
+
x = self.encoder(x, src_key_padding_mask=x_mask)
|
| 261 |
+
|
| 262 |
+
total_loss, metrics = 0.0, {}
|
| 263 |
+
|
| 264 |
+
y_mask = make_pad_mask(y_lens).to(y.device)
|
| 265 |
+
y_mask_float = y_mask.type(torch.float32)
|
| 266 |
+
data_mask = 1.0 - y_mask_float.unsqueeze(-1)
|
| 267 |
+
|
| 268 |
+
# Training
|
| 269 |
+
# AR Decoder
|
| 270 |
+
def pad_y(y):
|
| 271 |
+
y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach()
|
| 272 |
+
# inputs, targets
|
| 273 |
+
return y[:, :-1], y[:, 1:]
|
| 274 |
+
|
| 275 |
+
y, targets = pad_y(y * data_mask) # mask padding as zeros
|
| 276 |
+
|
| 277 |
+
y_emb = self.decoder_prenet(y)
|
| 278 |
+
y_pos = self.decoder_position(y_emb)
|
| 279 |
+
|
| 280 |
+
y_len = y_lens.max()
|
| 281 |
+
tgt_mask = torch.triu(
|
| 282 |
+
torch.ones(y_len, y_len, device=y.device, dtype=torch.bool),
|
| 283 |
+
diagonal=1,
|
| 284 |
+
)
|
| 285 |
+
y_dec = self.decoder(
|
| 286 |
+
y_pos,
|
| 287 |
+
x,
|
| 288 |
+
tgt_mask=tgt_mask,
|
| 289 |
+
memory_key_padding_mask=x_mask,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
predict = self.predict_layer(y_dec)
|
| 293 |
+
# loss
|
| 294 |
+
total_loss = F.mse_loss(predict, targets, reduction=reduction)
|
| 295 |
+
|
| 296 |
+
logits = self.stop_layer(y_dec).squeeze(-1)
|
| 297 |
+
stop_loss = F.binary_cross_entropy_with_logits(
|
| 298 |
+
logits,
|
| 299 |
+
y_mask_float.detach(),
|
| 300 |
+
weight=1.0 + y_mask_float.detach() * 4.0,
|
| 301 |
+
reduction=reduction,
|
| 302 |
+
)
|
| 303 |
+
metrics["stop_loss"] = stop_loss.detach()
|
| 304 |
+
|
| 305 |
+
stop_accuracy = self.stop_accuracy_metric(
|
| 306 |
+
(torch.sigmoid(logits) >= 0.5).type(torch.int64),
|
| 307 |
+
y_mask.type(torch.int64),
|
| 308 |
+
)
|
| 309 |
+
# icefall MetricsTracker.norm_items()
|
| 310 |
+
metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type(
|
| 311 |
+
torch.float32
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
return ((x, predict), total_loss + 100.0 * stop_loss, metrics)
|
| 315 |
+
|
| 316 |
+
def inference(
|
| 317 |
+
self,
|
| 318 |
+
x: torch.Tensor,
|
| 319 |
+
x_lens: torch.Tensor,
|
| 320 |
+
y: Any = None,
|
| 321 |
+
**kwargs,
|
| 322 |
+
) -> torch.Tensor:
|
| 323 |
+
"""
|
| 324 |
+
Args:
|
| 325 |
+
x:
|
| 326 |
+
A 2-D tensor of shape (1, S).
|
| 327 |
+
x_lens:
|
| 328 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
| 329 |
+
before padding.
|
| 330 |
+
Returns:
|
| 331 |
+
Return the predicted audio code matrix and cross-entropy loss.
|
| 332 |
+
"""
|
| 333 |
+
assert x.ndim == 2, x.shape
|
| 334 |
+
assert x_lens.ndim == 1, x_lens.shape
|
| 335 |
+
|
| 336 |
+
assert torch.all(x_lens > 0)
|
| 337 |
+
|
| 338 |
+
x_mask = make_pad_mask(x_lens).to(x.device)
|
| 339 |
+
|
| 340 |
+
x = self.text_embedding(x)
|
| 341 |
+
x = self.encoder_prenet(x)
|
| 342 |
+
x = self.encoder_position(x)
|
| 343 |
+
x = self.encoder(x, src_key_padding_mask=x_mask)
|
| 344 |
+
|
| 345 |
+
x_mask = make_pad_mask(x_lens).to(x.device)
|
| 346 |
+
|
| 347 |
+
# AR Decoder
|
| 348 |
+
# TODO: Managing decoder steps avoid repetitive computation
|
| 349 |
+
y = torch.zeros(
|
| 350 |
+
[x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device
|
| 351 |
+
)
|
| 352 |
+
while True:
|
| 353 |
+
y_emb = self.decoder_prenet(y)
|
| 354 |
+
y_pos = self.decoder_position(y_emb)
|
| 355 |
+
|
| 356 |
+
tgt_mask = torch.triu(
|
| 357 |
+
torch.ones(y.shape[1], y.shape[1], device=y.device, dtype=torch.bool),
|
| 358 |
+
diagonal=1,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
y_dec = self.decoder(
|
| 362 |
+
y_pos,
|
| 363 |
+
x,
|
| 364 |
+
tgt_mask=tgt_mask,
|
| 365 |
+
memory_mask=None,
|
| 366 |
+
memory_key_padding_mask=x_mask,
|
| 367 |
+
)
|
| 368 |
+
predict = self.predict_layer(y_dec[:, -1:])
|
| 369 |
+
|
| 370 |
+
logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5
|
| 371 |
+
if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()):
|
| 372 |
+
print(f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]")
|
| 373 |
+
break
|
| 374 |
+
|
| 375 |
+
y = torch.concat([y, predict], dim=1)
|
| 376 |
+
|
| 377 |
+
return y[:, 1:]
|
| 378 |
+
|
| 379 |
+
def visualize(
|
| 380 |
+
self,
|
| 381 |
+
predicts: Tuple[torch.Tensor],
|
| 382 |
+
batch: Dict[str, Union[List, torch.Tensor]],
|
| 383 |
+
output_dir: str,
|
| 384 |
+
limit: int = 4,
|
| 385 |
+
) -> None:
|
| 386 |
+
visualize(predicts, batch, output_dir, limit=limit)
|
apps/audio_cloning/vallex/models/vallex.py
ADDED
|
@@ -0,0 +1,823 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import random
|
| 16 |
+
from typing import Dict, Iterator, List, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
# from icefall.utils import make_pad_mask
|
| 24 |
+
# from torchmetrics.classification import MulticlassAccuracy
|
| 25 |
+
from ..data.input_strategies import PromptedFeatures
|
| 26 |
+
from ..modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
| 27 |
+
from ..modules.transformer import (
|
| 28 |
+
AdaptiveLayerNorm,
|
| 29 |
+
LayerNorm,
|
| 30 |
+
TransformerDecoderLayer,
|
| 31 |
+
TransformerEncoder,
|
| 32 |
+
TransformerEncoderLayer,
|
| 33 |
+
)
|
| 34 |
+
from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Transpose(nn.Identity):
|
| 38 |
+
"""(N, T, D) -> (N, D, T)"""
|
| 39 |
+
|
| 40 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 41 |
+
return input.transpose(1, 2)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# NOTE: There are two ways to implement the model
|
| 45 |
+
# 1) [VALL-F] standard TransformerDecoder, use x as memory
|
| 46 |
+
# 2) [VALL-E] modified TransformerDecoder like GPT-x(e.g. causal TransformerEncoder),
|
| 47 |
+
# use x as the prefix of decoder inputs
|
| 48 |
+
class VALLF(nn.Module):
|
| 49 |
+
"""It implements https://arxiv.org/abs/2301.02111
|
| 50 |
+
"Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
d_model: int,
|
| 56 |
+
nhead: int,
|
| 57 |
+
num_layers: int,
|
| 58 |
+
norm_first: bool = True,
|
| 59 |
+
add_prenet: bool = False,
|
| 60 |
+
decoder_cls: Union[
|
| 61 |
+
nn.TransformerDecoder, nn.TransformerEncoder
|
| 62 |
+
] = nn.TransformerDecoder,
|
| 63 |
+
decoder_layer_cls: Union[
|
| 64 |
+
TransformerDecoderLayer, TransformerEncoderLayer
|
| 65 |
+
] = TransformerDecoderLayer,
|
| 66 |
+
prefix_mode: int = 0,
|
| 67 |
+
share_embedding: bool = True,
|
| 68 |
+
nar_scale_factor: float = 1.0,
|
| 69 |
+
prepend_bos: bool = True,
|
| 70 |
+
num_quantizers: int = 8,
|
| 71 |
+
):
|
| 72 |
+
"""
|
| 73 |
+
Args:
|
| 74 |
+
d_model:
|
| 75 |
+
The number of expected features in the input (required).
|
| 76 |
+
nhead:
|
| 77 |
+
The number of heads in the multiheadattention models (required).
|
| 78 |
+
num_layers:
|
| 79 |
+
The number of sub-decoder-layers in the decoder (required).
|
| 80 |
+
"""
|
| 81 |
+
super().__init__()
|
| 82 |
+
nar_d_model = int(d_model * nar_scale_factor)
|
| 83 |
+
|
| 84 |
+
self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
|
| 85 |
+
self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS)
|
| 86 |
+
|
| 87 |
+
# ID NUM_AUDIO_TOKENS -> PAD
|
| 88 |
+
# ID NUM_AUDIO_TOKENS + 1 -> BOS
|
| 89 |
+
self.ar_audio_prepend_bos = prepend_bos
|
| 90 |
+
self.ar_audio_embedding = TokenEmbedding(
|
| 91 |
+
d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# PreNet
|
| 95 |
+
if add_prenet:
|
| 96 |
+
self.ar_text_prenet = nn.Sequential(
|
| 97 |
+
Transpose(),
|
| 98 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
| 99 |
+
nn.BatchNorm1d(d_model),
|
| 100 |
+
nn.ReLU(),
|
| 101 |
+
nn.Dropout(0.5),
|
| 102 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
| 103 |
+
nn.BatchNorm1d(d_model),
|
| 104 |
+
nn.ReLU(),
|
| 105 |
+
nn.Dropout(0.5),
|
| 106 |
+
nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
|
| 107 |
+
nn.BatchNorm1d(d_model),
|
| 108 |
+
nn.ReLU(),
|
| 109 |
+
nn.Dropout(0.5),
|
| 110 |
+
Transpose(),
|
| 111 |
+
nn.Linear(d_model, d_model),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.ar_audio_prenet = nn.Sequential(
|
| 115 |
+
nn.Linear(d_model, 256),
|
| 116 |
+
nn.ReLU(),
|
| 117 |
+
nn.Dropout(0.25),
|
| 118 |
+
nn.Linear(256, 256),
|
| 119 |
+
nn.ReLU(),
|
| 120 |
+
nn.Dropout(0.25),
|
| 121 |
+
nn.Linear(256, d_model),
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
self.ar_text_prenet = nn.Identity()
|
| 125 |
+
self.ar_audio_prenet = nn.Identity()
|
| 126 |
+
|
| 127 |
+
self.ar_text_position = SinePositionalEmbedding(
|
| 128 |
+
d_model,
|
| 129 |
+
dropout=0.1,
|
| 130 |
+
scale=False,
|
| 131 |
+
alpha=True,
|
| 132 |
+
)
|
| 133 |
+
self.ar_audio_position = SinePositionalEmbedding(
|
| 134 |
+
d_model,
|
| 135 |
+
dropout=0.1,
|
| 136 |
+
scale=False,
|
| 137 |
+
alpha=True,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
self.ar_decoder = decoder_cls(
|
| 141 |
+
decoder_layer_cls(
|
| 142 |
+
d_model,
|
| 143 |
+
nhead,
|
| 144 |
+
dim_feedforward=d_model * 4,
|
| 145 |
+
dropout=0.1,
|
| 146 |
+
batch_first=True,
|
| 147 |
+
norm_first=norm_first,
|
| 148 |
+
),
|
| 149 |
+
num_layers=num_layers,
|
| 150 |
+
norm=LayerNorm(d_model) if norm_first else None,
|
| 151 |
+
)
|
| 152 |
+
self.ar_predict_layer = nn.Linear(d_model, NUM_AUDIO_TOKENS + 1, bias=False)
|
| 153 |
+
|
| 154 |
+
self.rng = random.Random(0)
|
| 155 |
+
self.num_heads = nhead
|
| 156 |
+
self.prefix_mode = prefix_mode
|
| 157 |
+
self.num_quantizers = num_quantizers
|
| 158 |
+
|
| 159 |
+
assert num_quantizers >= 1
|
| 160 |
+
if num_quantizers > 1:
|
| 161 |
+
self.nar_audio_embeddings = nn.ModuleList(
|
| 162 |
+
[TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)]
|
| 163 |
+
+ [
|
| 164 |
+
TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS)
|
| 165 |
+
for i in range(num_quantizers - 1)
|
| 166 |
+
]
|
| 167 |
+
) # W_a
|
| 168 |
+
|
| 169 |
+
# PreNet
|
| 170 |
+
if add_prenet:
|
| 171 |
+
self.nar_text_prenet = nn.Sequential(
|
| 172 |
+
Transpose(),
|
| 173 |
+
nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
|
| 174 |
+
nn.BatchNorm1d(nar_d_model),
|
| 175 |
+
nn.ReLU(),
|
| 176 |
+
nn.Dropout(0.5),
|
| 177 |
+
nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
|
| 178 |
+
nn.BatchNorm1d(nar_d_model),
|
| 179 |
+
nn.ReLU(),
|
| 180 |
+
nn.Dropout(0.5),
|
| 181 |
+
nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"),
|
| 182 |
+
nn.BatchNorm1d(nar_d_model),
|
| 183 |
+
nn.ReLU(),
|
| 184 |
+
nn.Dropout(0.5),
|
| 185 |
+
Transpose(),
|
| 186 |
+
nn.Linear(nar_d_model, nar_d_model),
|
| 187 |
+
)
|
| 188 |
+
self.nar_audio_prenet = nn.Sequential(
|
| 189 |
+
nn.Linear(nar_d_model, 256),
|
| 190 |
+
nn.ReLU(),
|
| 191 |
+
nn.Dropout(0.25),
|
| 192 |
+
nn.Linear(256, 256),
|
| 193 |
+
nn.ReLU(),
|
| 194 |
+
nn.Dropout(0.25),
|
| 195 |
+
nn.Linear(256, nar_d_model),
|
| 196 |
+
)
|
| 197 |
+
else:
|
| 198 |
+
self.nar_text_prenet = nn.Identity()
|
| 199 |
+
self.nar_audio_prenet = nn.Identity()
|
| 200 |
+
|
| 201 |
+
self.nar_text_position = SinePositionalEmbedding(
|
| 202 |
+
nar_d_model,
|
| 203 |
+
dropout=0.0,
|
| 204 |
+
scale=False,
|
| 205 |
+
alpha=False,
|
| 206 |
+
)
|
| 207 |
+
self.nar_audio_position = SinePositionalEmbedding(
|
| 208 |
+
nar_d_model,
|
| 209 |
+
dropout=0.1,
|
| 210 |
+
scale=False,
|
| 211 |
+
alpha=False,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
self.nar_decoder = decoder_cls(
|
| 215 |
+
decoder_layer_cls(
|
| 216 |
+
nar_d_model,
|
| 217 |
+
int(nhead * nar_scale_factor),
|
| 218 |
+
dim_feedforward=nar_d_model * 4,
|
| 219 |
+
dropout=0.1,
|
| 220 |
+
batch_first=True,
|
| 221 |
+
norm_first=norm_first,
|
| 222 |
+
adaptive_layer_norm=True,
|
| 223 |
+
),
|
| 224 |
+
num_layers=int(num_layers * nar_scale_factor),
|
| 225 |
+
norm=AdaptiveLayerNorm(nar_d_model, norm=nn.LayerNorm(nar_d_model))
|
| 226 |
+
if norm_first
|
| 227 |
+
else None,
|
| 228 |
+
)
|
| 229 |
+
self.nar_predict_layers = nn.ModuleList(
|
| 230 |
+
[
|
| 231 |
+
nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False)
|
| 232 |
+
for i in range(num_quantizers - 1)
|
| 233 |
+
]
|
| 234 |
+
)
|
| 235 |
+
self.nar_stage_embeddings = nn.ModuleList(
|
| 236 |
+
[TokenEmbedding(nar_d_model, 1) for i in range(num_quantizers - 1)]
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
if share_embedding:
|
| 240 |
+
# We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa
|
| 241 |
+
# NOTE(Feiteng): In the experiment, this undermines accuracy
|
| 242 |
+
# self.ar_predict_layer.weight = self.ar_audio_embedding.weight
|
| 243 |
+
|
| 244 |
+
# We also share the parameters of the acoustic embedding layer and the output prediction layer,
|
| 245 |
+
# which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
|
| 246 |
+
for j in range(0, num_quantizers - 2):
|
| 247 |
+
self.nar_predict_layers[j].weight = self.nar_audio_embeddings[
|
| 248 |
+
j + 2
|
| 249 |
+
].weight
|
| 250 |
+
|
| 251 |
+
def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
|
| 252 |
+
assert stage > 0
|
| 253 |
+
if stage == 1:
|
| 254 |
+
for name, param in self.named_parameters():
|
| 255 |
+
if name.startswith("ar_"):
|
| 256 |
+
print(f" AR parameter: {name}")
|
| 257 |
+
yield param
|
| 258 |
+
|
| 259 |
+
if stage == 2:
|
| 260 |
+
for name, param in self.named_parameters():
|
| 261 |
+
if name.startswith("nar_"):
|
| 262 |
+
print(f"NAR parameter: {name}")
|
| 263 |
+
yield param
|
| 264 |
+
|
| 265 |
+
def stage_named_parameters(
|
| 266 |
+
self, stage: int = 1
|
| 267 |
+
) -> Iterator[Tuple[str, nn.Parameter]]:
|
| 268 |
+
assert stage > 0
|
| 269 |
+
if stage == 1:
|
| 270 |
+
for pair in self.named_parameters():
|
| 271 |
+
if pair[0].startswith("ar_"):
|
| 272 |
+
yield pair
|
| 273 |
+
|
| 274 |
+
if stage == 2:
|
| 275 |
+
for pair in self.named_parameters():
|
| 276 |
+
if pair[0].startswith("nar_"):
|
| 277 |
+
yield pair
|
| 278 |
+
|
| 279 |
+
def pad_y_eos(self, y, y_mask_int, eos_id):
|
| 280 |
+
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
|
| 281 |
+
y_mask_int, (0, 1), value=1
|
| 282 |
+
)
|
| 283 |
+
# inputs, targets
|
| 284 |
+
if self.ar_audio_prepend_bos:
|
| 285 |
+
return (
|
| 286 |
+
F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1),
|
| 287 |
+
targets,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
return targets[:, :-1], targets[:, 1:]
|
| 291 |
+
|
| 292 |
+
def _prepare_prompts(
|
| 293 |
+
self, y, y_lens, codes, nar_stage, y_prompts_codes, prefix_mode
|
| 294 |
+
):
|
| 295 |
+
# 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
|
| 296 |
+
# from the same utterance.
|
| 297 |
+
# We implement this differently.
|
| 298 |
+
if prefix_mode == 0:
|
| 299 |
+
# no prefix
|
| 300 |
+
prefix_len = 0
|
| 301 |
+
y_emb = self.nar_audio_embeddings[0](y)
|
| 302 |
+
for j in range(1, nar_stage):
|
| 303 |
+
# Formula (4) (5)
|
| 304 |
+
y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
|
| 305 |
+
elif prefix_mode == 1:
|
| 306 |
+
# prefix at begining
|
| 307 |
+
int_low = (0.25 * y_lens.min()).type(torch.int64).item()
|
| 308 |
+
prefix_len = torch.randint(0, int_low * 2, size=()).item()
|
| 309 |
+
prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames
|
| 310 |
+
|
| 311 |
+
y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
|
| 312 |
+
y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
|
| 313 |
+
for j in range(1, self.num_quantizers):
|
| 314 |
+
y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j])
|
| 315 |
+
if j < nar_stage:
|
| 316 |
+
y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j])
|
| 317 |
+
y_emb = torch.concat([y_prompts, y_emb], axis=1)
|
| 318 |
+
elif prefix_mode in [2, 4]:
|
| 319 |
+
if prefix_mode == 2:
|
| 320 |
+
# random prefix
|
| 321 |
+
prefix_len = min(225, int(0.25 * y_lens.min().item()))
|
| 322 |
+
|
| 323 |
+
y_prompts_codes = []
|
| 324 |
+
for b in range(codes.shape[0]):
|
| 325 |
+
start = self.rng.randint(0, y_lens[b].item() - prefix_len)
|
| 326 |
+
y_prompts_codes.append(
|
| 327 |
+
torch.clone(codes[b, start : start + prefix_len])
|
| 328 |
+
)
|
| 329 |
+
codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS
|
| 330 |
+
y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
|
| 331 |
+
else:
|
| 332 |
+
prefix_len = y_prompts_codes.shape[1]
|
| 333 |
+
|
| 334 |
+
y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
|
| 335 |
+
y_emb = self.nar_audio_embeddings[0](y)
|
| 336 |
+
for j in range(1, self.num_quantizers):
|
| 337 |
+
y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j])
|
| 338 |
+
if j < nar_stage:
|
| 339 |
+
y_emb += self.nar_audio_embeddings[j](codes[..., j])
|
| 340 |
+
y_emb = torch.concat([y_prompts, y_emb], axis=1)
|
| 341 |
+
else:
|
| 342 |
+
raise ValueError
|
| 343 |
+
|
| 344 |
+
return y_emb, prefix_len
|
| 345 |
+
|
| 346 |
+
def forward(
|
| 347 |
+
self,
|
| 348 |
+
x: torch.Tensor,
|
| 349 |
+
x_lens: torch.Tensor,
|
| 350 |
+
y: Union[torch.Tensor, PromptedFeatures],
|
| 351 |
+
y_lens: Union[torch.Tensor, PromptedFeatures],
|
| 352 |
+
reduction: str = "sum",
|
| 353 |
+
train_stage: int = 0,
|
| 354 |
+
**kwargs,
|
| 355 |
+
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
| 356 |
+
raise NotImplementedError
|
| 357 |
+
|
| 358 |
+
def inference(
|
| 359 |
+
self,
|
| 360 |
+
x: torch.Tensor,
|
| 361 |
+
x_lens: torch.Tensor,
|
| 362 |
+
y: torch.Tensor,
|
| 363 |
+
enroll_x_lens: Union[torch.Tensor, None] = None,
|
| 364 |
+
top_k: int = -100,
|
| 365 |
+
temperature: float = 1.0,
|
| 366 |
+
) -> torch.Tensor:
|
| 367 |
+
raise NotImplementedError
|
| 368 |
+
|
| 369 |
+
def visualize(
|
| 370 |
+
self,
|
| 371 |
+
predicts: Tuple[torch.Tensor],
|
| 372 |
+
batch: Dict[str, Union[List, torch.Tensor]],
|
| 373 |
+
output_dir: str,
|
| 374 |
+
limit: int = 4,
|
| 375 |
+
) -> None:
|
| 376 |
+
raise NotImplementedError
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class VALLE(VALLF):
|
| 380 |
+
"""It implements https://arxiv.org/abs/2301.02111
|
| 381 |
+
"Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
|
| 382 |
+
"""
|
| 383 |
+
|
| 384 |
+
def __init__(
|
| 385 |
+
self,
|
| 386 |
+
d_model: int,
|
| 387 |
+
nhead: int,
|
| 388 |
+
num_layers: int,
|
| 389 |
+
norm_first: bool = True,
|
| 390 |
+
add_prenet: bool = False,
|
| 391 |
+
prefix_mode: int = 0,
|
| 392 |
+
share_embedding: bool = True,
|
| 393 |
+
nar_scale_factor: float = 1.0,
|
| 394 |
+
**kwargs,
|
| 395 |
+
):
|
| 396 |
+
"""
|
| 397 |
+
Args:
|
| 398 |
+
d_model:
|
| 399 |
+
The number of expected features in the input (required).
|
| 400 |
+
nhead:
|
| 401 |
+
The number of heads in the multiheadattention models (required).
|
| 402 |
+
num_layers:
|
| 403 |
+
The number of sub-decoder-layers in the decoder (required).
|
| 404 |
+
"""
|
| 405 |
+
super(VALLE, self).__init__(
|
| 406 |
+
d_model,
|
| 407 |
+
nhead,
|
| 408 |
+
num_layers,
|
| 409 |
+
norm_first=norm_first,
|
| 410 |
+
add_prenet=add_prenet,
|
| 411 |
+
decoder_cls=TransformerEncoder,
|
| 412 |
+
decoder_layer_cls=TransformerEncoderLayer,
|
| 413 |
+
prefix_mode=prefix_mode,
|
| 414 |
+
share_embedding=share_embedding,
|
| 415 |
+
nar_scale_factor=nar_scale_factor,
|
| 416 |
+
**kwargs,
|
| 417 |
+
)
|
| 418 |
+
self.language_ID = {
|
| 419 |
+
"en": 0,
|
| 420 |
+
"zh": 1,
|
| 421 |
+
"ja": 2,
|
| 422 |
+
}
|
| 423 |
+
self.ar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
|
| 424 |
+
self.nar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
|
| 425 |
+
|
| 426 |
+
def forward(
|
| 427 |
+
self,
|
| 428 |
+
x: torch.Tensor,
|
| 429 |
+
x_lens: torch.Tensor,
|
| 430 |
+
y: Union[torch.Tensor, PromptedFeatures],
|
| 431 |
+
y_lens: Union[torch.Tensor, PromptedFeatures],
|
| 432 |
+
reduction: str = "sum",
|
| 433 |
+
train_stage: int = 0,
|
| 434 |
+
**kwargs,
|
| 435 |
+
):
|
| 436 |
+
raise NotImplementedError
|
| 437 |
+
|
| 438 |
+
def inference(
|
| 439 |
+
self,
|
| 440 |
+
x: torch.Tensor,
|
| 441 |
+
x_lens: torch.Tensor,
|
| 442 |
+
y: torch.Tensor,
|
| 443 |
+
enroll_x_lens: torch.Tensor,
|
| 444 |
+
top_k: int = -100,
|
| 445 |
+
temperature: float = 1.0,
|
| 446 |
+
prompt_language: str = None,
|
| 447 |
+
text_language: str = None,
|
| 448 |
+
best_of: int = 1,
|
| 449 |
+
length_penalty: float = 1.0,
|
| 450 |
+
return_worst: bool = False,
|
| 451 |
+
) -> torch.Tensor:
|
| 452 |
+
"""
|
| 453 |
+
Args:
|
| 454 |
+
x:
|
| 455 |
+
A 2-D tensor of shape (1, S).
|
| 456 |
+
x_lens:
|
| 457 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
| 458 |
+
before padding.
|
| 459 |
+
y:
|
| 460 |
+
A 3-D tensor of shape (1, T, 8).
|
| 461 |
+
top_k: (`optional`) int
|
| 462 |
+
The number of highest probability tokens to keep for top-k-filtering. Default to -100.
|
| 463 |
+
temperature: (`optional`) float
|
| 464 |
+
The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
| 465 |
+
Returns:
|
| 466 |
+
Return the predicted audio code matrix.
|
| 467 |
+
"""
|
| 468 |
+
assert x.ndim == 2, x.shape
|
| 469 |
+
assert x_lens.ndim == 1, x_lens.shape
|
| 470 |
+
assert y.ndim == 3, y.shape
|
| 471 |
+
assert y.shape[0] == 1, y.shape
|
| 472 |
+
|
| 473 |
+
assert torch.all(x_lens > 0)
|
| 474 |
+
|
| 475 |
+
# NOTE: x has been padded in TextTokenCollater
|
| 476 |
+
text = x
|
| 477 |
+
x = self.ar_text_embedding(text)
|
| 478 |
+
# Add language embedding
|
| 479 |
+
prompt_language_id = torch.LongTensor(
|
| 480 |
+
np.array([self.language_ID[prompt_language]])
|
| 481 |
+
).to(x.device)
|
| 482 |
+
if isinstance(text_language, str):
|
| 483 |
+
text_language_id = torch.LongTensor(
|
| 484 |
+
np.array([self.language_ID[text_language]])
|
| 485 |
+
).to(x.device)
|
| 486 |
+
elif isinstance(text_language, List):
|
| 487 |
+
text_language_id = torch.LongTensor(
|
| 488 |
+
np.array([self.language_ID[tl] for tl in text_language])
|
| 489 |
+
).to(x.device)
|
| 490 |
+
x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
|
| 491 |
+
x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
|
| 492 |
+
x = self.ar_text_prenet(x)
|
| 493 |
+
x = self.ar_text_position(x)
|
| 494 |
+
|
| 495 |
+
text_len = x_lens.max()
|
| 496 |
+
prompts = y
|
| 497 |
+
prefix_len = y.shape[1]
|
| 498 |
+
|
| 499 |
+
# AR Decoder
|
| 500 |
+
# TODO: Managing decoder steps avoid repetitive computation
|
| 501 |
+
y = prompts[..., 0]
|
| 502 |
+
if self.ar_audio_prepend_bos:
|
| 503 |
+
y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1)
|
| 504 |
+
|
| 505 |
+
x_len = x_lens.max()
|
| 506 |
+
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
| 507 |
+
|
| 508 |
+
kv_cache = None
|
| 509 |
+
use_kv_caching = True
|
| 510 |
+
|
| 511 |
+
sum_logprobs = torch.zeros(
|
| 512 |
+
best_of, device=y.device
|
| 513 |
+
) # implement batch decoding here
|
| 514 |
+
x = x.repeat(best_of, 1, 1)
|
| 515 |
+
y = y.repeat(best_of, 1)
|
| 516 |
+
while True:
|
| 517 |
+
y_emb = self.ar_audio_embedding(y)
|
| 518 |
+
y_emb = self.ar_audio_prenet(y_emb)
|
| 519 |
+
y_pos = self.ar_audio_position(y_emb)
|
| 520 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
| 521 |
+
|
| 522 |
+
y_len = y.shape[1]
|
| 523 |
+
x_attn_mask_pad = F.pad(
|
| 524 |
+
x_attn_mask,
|
| 525 |
+
(0, y_len),
|
| 526 |
+
value=True,
|
| 527 |
+
)
|
| 528 |
+
y_attn_mask = F.pad(
|
| 529 |
+
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
| 530 |
+
(x_len, 0),
|
| 531 |
+
value=False,
|
| 532 |
+
)
|
| 533 |
+
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
| 534 |
+
y.device
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
if use_kv_caching and kv_cache is not None:
|
| 538 |
+
xy_pos = xy_pos[:, [-1]]
|
| 539 |
+
else:
|
| 540 |
+
pass
|
| 541 |
+
|
| 542 |
+
xy_dec, kv_cache = self.ar_decoder.infer(
|
| 543 |
+
xy_pos,
|
| 544 |
+
mask=xy_attn_mask,
|
| 545 |
+
past_kv=kv_cache,
|
| 546 |
+
use_cache=use_kv_caching,
|
| 547 |
+
)
|
| 548 |
+
# xy_dec, _ = self.ar_decoder(
|
| 549 |
+
# (xy_pos, None),
|
| 550 |
+
# mask=xy_attn_mask,
|
| 551 |
+
# )
|
| 552 |
+
|
| 553 |
+
logits = self.ar_predict_layer(xy_dec[:, -1])
|
| 554 |
+
samples, current_logprobs = topk_sampling(
|
| 555 |
+
logits, top_k=top_k, top_p=1, temperature=temperature
|
| 556 |
+
)
|
| 557 |
+
sum_logprobs += current_logprobs * (y[:, -1] != NUM_AUDIO_TOKENS)
|
| 558 |
+
samples[y[:, -1] == NUM_AUDIO_TOKENS] = NUM_AUDIO_TOKENS
|
| 559 |
+
completed = (samples[:, -1] == NUM_AUDIO_TOKENS).all()
|
| 560 |
+
if completed or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16:
|
| 561 |
+
if prompts.shape[1] == y.shape[1]:
|
| 562 |
+
raise SyntaxError("well trained model shouldn't reach here.")
|
| 563 |
+
lengths = torch.sum(y != NUM_AUDIO_TOKENS, dim=1)
|
| 564 |
+
avg_logprobs = sum_logprobs / lengths**length_penalty
|
| 565 |
+
# choose the best beam according to sum_logprobs
|
| 566 |
+
best_beam = y[torch.argmax(avg_logprobs), :]
|
| 567 |
+
worst_beam = y[torch.argmin(avg_logprobs), :]
|
| 568 |
+
# strip all eos tokens
|
| 569 |
+
best_beam = best_beam[best_beam != NUM_AUDIO_TOKENS]
|
| 570 |
+
worst_beam = worst_beam[worst_beam != NUM_AUDIO_TOKENS]
|
| 571 |
+
if return_worst:
|
| 572 |
+
y = worst_beam.unsqueeze(0)
|
| 573 |
+
else:
|
| 574 |
+
y = best_beam.unsqueeze(0)
|
| 575 |
+
print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
|
| 576 |
+
break
|
| 577 |
+
|
| 578 |
+
y = torch.concat([y, samples], dim=1)
|
| 579 |
+
|
| 580 |
+
codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
|
| 581 |
+
if self.num_quantizers == 1:
|
| 582 |
+
return torch.stack(codes, dim=-1)
|
| 583 |
+
|
| 584 |
+
# Non-AR Decoders
|
| 585 |
+
y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :])
|
| 586 |
+
|
| 587 |
+
if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
|
| 588 |
+
enrolled_len = enroll_x_lens.max().item()
|
| 589 |
+
# SOS + Synthesis Text + EOS
|
| 590 |
+
text = torch.concat(
|
| 591 |
+
[
|
| 592 |
+
text[:, :1],
|
| 593 |
+
text[:, enrolled_len - 1 :],
|
| 594 |
+
],
|
| 595 |
+
dim=1,
|
| 596 |
+
)
|
| 597 |
+
text_len = text_len - (enrolled_len - 2)
|
| 598 |
+
assert text.shape[0] == 1
|
| 599 |
+
|
| 600 |
+
x = self.nar_text_embedding(text)
|
| 601 |
+
# Add language embedding
|
| 602 |
+
prompt_language_id = torch.LongTensor(
|
| 603 |
+
np.array([self.language_ID[prompt_language]])
|
| 604 |
+
).to(x.device)
|
| 605 |
+
if isinstance(text_language, str):
|
| 606 |
+
text_language_id = torch.LongTensor(
|
| 607 |
+
np.array([self.language_ID[text_language]])
|
| 608 |
+
).to(x.device)
|
| 609 |
+
elif isinstance(text_language, List):
|
| 610 |
+
text_language_id = torch.LongTensor(
|
| 611 |
+
np.array([self.language_ID[tl] for tl in text_language])
|
| 612 |
+
).to(x.device)
|
| 613 |
+
x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
|
| 614 |
+
x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
|
| 615 |
+
x = self.nar_text_prenet(x)
|
| 616 |
+
x = self.nar_text_position(x)
|
| 617 |
+
|
| 618 |
+
if self.prefix_mode == 0:
|
| 619 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
| 620 |
+
zip(
|
| 621 |
+
self.nar_predict_layers,
|
| 622 |
+
self.nar_audio_embeddings[1:],
|
| 623 |
+
)
|
| 624 |
+
):
|
| 625 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
| 626 |
+
y_pos = self.nar_audio_position(y_pos)
|
| 627 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
| 628 |
+
|
| 629 |
+
xy_dec, _ = self.nar_decoder(
|
| 630 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
| 631 |
+
)
|
| 632 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
| 633 |
+
|
| 634 |
+
samples = torch.argmax(logits, dim=-1)
|
| 635 |
+
codes.append(samples)
|
| 636 |
+
|
| 637 |
+
if i < self.num_quantizers - 2:
|
| 638 |
+
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
|
| 639 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
| 640 |
+
else:
|
| 641 |
+
for j in range(1, self.num_quantizers):
|
| 642 |
+
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
|
| 643 |
+
|
| 644 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
| 645 |
+
zip(
|
| 646 |
+
self.nar_predict_layers,
|
| 647 |
+
self.nar_audio_embeddings[1:],
|
| 648 |
+
)
|
| 649 |
+
):
|
| 650 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
| 651 |
+
y_pos = self.nar_audio_position(y_pos)
|
| 652 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
| 653 |
+
|
| 654 |
+
xy_dec, _ = self.nar_decoder(
|
| 655 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
| 656 |
+
)
|
| 657 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
| 658 |
+
|
| 659 |
+
samples = torch.argmax(logits, dim=-1)
|
| 660 |
+
codes.append(samples)
|
| 661 |
+
|
| 662 |
+
if i < self.num_quantizers - 2:
|
| 663 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
| 664 |
+
|
| 665 |
+
assert len(codes) == self.num_quantizers
|
| 666 |
+
return torch.stack(codes, dim=-1)
|
| 667 |
+
|
| 668 |
+
def continual(
|
| 669 |
+
self,
|
| 670 |
+
x: torch.Tensor,
|
| 671 |
+
x_lens: torch.Tensor,
|
| 672 |
+
y: torch.Tensor,
|
| 673 |
+
) -> torch.Tensor:
|
| 674 |
+
"""
|
| 675 |
+
Args:
|
| 676 |
+
x:
|
| 677 |
+
A 2-D tensor of shape (1, S).
|
| 678 |
+
x_lens:
|
| 679 |
+
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
|
| 680 |
+
before padding.
|
| 681 |
+
y:
|
| 682 |
+
A 3-D tensor of shape (1, T, 8).
|
| 683 |
+
Returns:
|
| 684 |
+
Return the predicted audio code matrix.
|
| 685 |
+
"""
|
| 686 |
+
assert x.ndim == 2, x.shape
|
| 687 |
+
assert x_lens.ndim == 1, x_lens.shape
|
| 688 |
+
assert y.ndim == 3, y.shape
|
| 689 |
+
assert y.shape[0] == 1, y.shape
|
| 690 |
+
|
| 691 |
+
assert torch.all(x_lens > 0)
|
| 692 |
+
assert self.num_quantizers == 8
|
| 693 |
+
|
| 694 |
+
# NOTE: x has been padded in TextTokenCollater
|
| 695 |
+
text = x
|
| 696 |
+
x = self.ar_text_embedding(text)
|
| 697 |
+
x = self.ar_text_prenet(x)
|
| 698 |
+
x = self.ar_text_position(x)
|
| 699 |
+
|
| 700 |
+
text_len = x_lens.max()
|
| 701 |
+
|
| 702 |
+
prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
|
| 703 |
+
|
| 704 |
+
# AR Decoder
|
| 705 |
+
prompts = y[:, :prefix_len]
|
| 706 |
+
|
| 707 |
+
codes = [y[:, prefix_len:, 0]]
|
| 708 |
+
# Non-AR Decoders
|
| 709 |
+
x = self.nar_text_embedding(text)
|
| 710 |
+
x = self.nar_text_prenet(x)
|
| 711 |
+
x = self.nar_text_position(x)
|
| 712 |
+
|
| 713 |
+
y_emb = self.nar_audio_embeddings[0](y[..., 0])
|
| 714 |
+
|
| 715 |
+
if self.prefix_mode == 0:
|
| 716 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
| 717 |
+
zip(
|
| 718 |
+
self.nar_predict_layers,
|
| 719 |
+
self.nar_audio_embeddings[1:],
|
| 720 |
+
)
|
| 721 |
+
):
|
| 722 |
+
y_pos = self.nar_audio_position(y_emb)
|
| 723 |
+
y_pos = self.nar_audio_prenet(y_pos)
|
| 724 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
| 725 |
+
|
| 726 |
+
xy_dec, _ = self.nar_decoder(
|
| 727 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
| 728 |
+
)
|
| 729 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
| 730 |
+
|
| 731 |
+
samples = torch.argmax(logits, dim=-1)
|
| 732 |
+
codes.append(samples)
|
| 733 |
+
|
| 734 |
+
if i < 6:
|
| 735 |
+
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
|
| 736 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
| 737 |
+
else:
|
| 738 |
+
for j in range(1, 8):
|
| 739 |
+
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
|
| 740 |
+
|
| 741 |
+
for i, (predict_layer, embedding_layer) in enumerate(
|
| 742 |
+
zip(
|
| 743 |
+
self.nar_predict_layers,
|
| 744 |
+
self.nar_audio_embeddings[1:],
|
| 745 |
+
)
|
| 746 |
+
):
|
| 747 |
+
y_pos = self.nar_audio_prenet(y_emb)
|
| 748 |
+
y_pos = self.nar_audio_position(y_pos)
|
| 749 |
+
xy_pos = torch.concat([x, y_pos], dim=1)
|
| 750 |
+
|
| 751 |
+
xy_dec, _ = self.nar_decoder(
|
| 752 |
+
(xy_pos, self.nar_stage_embeddings[i].weight)
|
| 753 |
+
)
|
| 754 |
+
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
|
| 755 |
+
|
| 756 |
+
samples = torch.argmax(logits, dim=-1)
|
| 757 |
+
codes.append(samples)
|
| 758 |
+
|
| 759 |
+
if i < 6:
|
| 760 |
+
y_emb[:, prefix_len:] += embedding_layer(samples)
|
| 761 |
+
|
| 762 |
+
assert len(codes) == 8
|
| 763 |
+
return torch.stack(codes, dim=-1)
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
| 767 |
+
def top_k_top_p_filtering(
|
| 768 |
+
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
| 769 |
+
):
|
| 770 |
+
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
| 771 |
+
Args:
|
| 772 |
+
logits: logits distribution shape (batch size, vocabulary size)
|
| 773 |
+
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
| 774 |
+
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
| 775 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
| 776 |
+
Make sure we keep at least min_tokens_to_keep per batch example in the output
|
| 777 |
+
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
| 778 |
+
"""
|
| 779 |
+
if top_k > 0:
|
| 780 |
+
top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
|
| 781 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
| 782 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 783 |
+
logits[indices_to_remove] = filter_value
|
| 784 |
+
|
| 785 |
+
if top_p < 1.0:
|
| 786 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 787 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 788 |
+
|
| 789 |
+
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
| 790 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 791 |
+
if min_tokens_to_keep > 1:
|
| 792 |
+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
| 793 |
+
sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
|
| 794 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
| 795 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 796 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 797 |
+
|
| 798 |
+
# scatter sorted tensors to original indexing
|
| 799 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 800 |
+
1, sorted_indices, sorted_indices_to_remove
|
| 801 |
+
)
|
| 802 |
+
logits[indices_to_remove] = filter_value
|
| 803 |
+
return logits
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
| 807 |
+
# temperature: (`optional`) float
|
| 808 |
+
# The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
|
| 809 |
+
# top_k: (`optional`) int
|
| 810 |
+
# The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
|
| 811 |
+
# top_p: (`optional`) float
|
| 812 |
+
# The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
|
| 813 |
+
|
| 814 |
+
# Temperature (higher temperature => more likely to sample low probability tokens)
|
| 815 |
+
if temperature != 1.0:
|
| 816 |
+
logits = logits / temperature
|
| 817 |
+
# Top-p/top-k filtering
|
| 818 |
+
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
| 819 |
+
# Sample
|
| 820 |
+
token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
| 821 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
| 822 |
+
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)]
|
| 823 |
+
return token, current_logprobs
|
apps/audio_cloning/vallex/models/visualizer.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
from typing import Dict, List, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def visualize(
|
| 27 |
+
predicts: Tuple[torch.Tensor],
|
| 28 |
+
batch: Dict[str, Union[List, torch.Tensor]],
|
| 29 |
+
output_dir: str,
|
| 30 |
+
limit: int = 4,
|
| 31 |
+
) -> None:
|
| 32 |
+
text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
|
| 33 |
+
text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
|
| 34 |
+
audio_features = batch["audio_features"].to("cpu").detach().numpy()
|
| 35 |
+
audio_features_lens = batch["audio_features_lens"].to("cpu").detach().numpy()
|
| 36 |
+
assert text_tokens.ndim == 2
|
| 37 |
+
|
| 38 |
+
utt_ids, texts = batch["utt_id"], batch["text"]
|
| 39 |
+
|
| 40 |
+
encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
|
| 41 |
+
decoder_outputs = predicts[1]
|
| 42 |
+
if isinstance(decoder_outputs, list):
|
| 43 |
+
decoder_outputs = decoder_outputs[-1]
|
| 44 |
+
decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
|
| 45 |
+
|
| 46 |
+
vmin, vmax = 0, 1024 # Encodec
|
| 47 |
+
if decoder_outputs.dtype == np.float32:
|
| 48 |
+
vmin, vmax = -6, 0 # Fbank
|
| 49 |
+
|
| 50 |
+
num_figures = 3
|
| 51 |
+
for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
|
| 52 |
+
_ = plt.figure(figsize=(14, 8 * num_figures))
|
| 53 |
+
|
| 54 |
+
S = text_tokens_lens[b]
|
| 55 |
+
T = audio_features_lens[b]
|
| 56 |
+
|
| 57 |
+
# encoder
|
| 58 |
+
plt.subplot(num_figures, 1, 1)
|
| 59 |
+
plt.title(f"Text: {text}")
|
| 60 |
+
plt.imshow(
|
| 61 |
+
X=np.transpose(encoder_outputs[b]),
|
| 62 |
+
cmap=plt.get_cmap("jet"),
|
| 63 |
+
aspect="auto",
|
| 64 |
+
interpolation="nearest",
|
| 65 |
+
)
|
| 66 |
+
plt.gca().invert_yaxis()
|
| 67 |
+
plt.axvline(x=S - 0.4, linewidth=2, color="r")
|
| 68 |
+
plt.xlabel("Encoder Output")
|
| 69 |
+
plt.colorbar()
|
| 70 |
+
|
| 71 |
+
# decoder
|
| 72 |
+
plt.subplot(num_figures, 1, 2)
|
| 73 |
+
plt.imshow(
|
| 74 |
+
X=np.transpose(decoder_outputs[b]),
|
| 75 |
+
cmap=plt.get_cmap("jet"),
|
| 76 |
+
aspect="auto",
|
| 77 |
+
interpolation="nearest",
|
| 78 |
+
vmin=vmin,
|
| 79 |
+
vmax=vmax,
|
| 80 |
+
)
|
| 81 |
+
plt.gca().invert_yaxis()
|
| 82 |
+
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
| 83 |
+
plt.xlabel("Decoder Output")
|
| 84 |
+
plt.colorbar()
|
| 85 |
+
|
| 86 |
+
# target
|
| 87 |
+
plt.subplot(num_figures, 1, 3)
|
| 88 |
+
plt.imshow(
|
| 89 |
+
X=np.transpose(audio_features[b]),
|
| 90 |
+
cmap=plt.get_cmap("jet"),
|
| 91 |
+
aspect="auto",
|
| 92 |
+
interpolation="nearest",
|
| 93 |
+
vmin=vmin,
|
| 94 |
+
vmax=vmax,
|
| 95 |
+
)
|
| 96 |
+
plt.gca().invert_yaxis()
|
| 97 |
+
plt.axvline(x=T - 0.4, linewidth=2, color="r")
|
| 98 |
+
plt.xlabel("Decoder Target")
|
| 99 |
+
plt.colorbar()
|
| 100 |
+
|
| 101 |
+
plt.savefig(f"{output_dir}/{utt_id}.png")
|
| 102 |
+
plt.close()
|
apps/audio_cloning/vallex/modules/__init__.py
ADDED
|
File without changes
|
apps/audio_cloning/vallex/modules/activation.py
ADDED
|
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, List
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
from torch.nn import Linear, Module
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
| 9 |
+
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
| 10 |
+
from torch.nn.parameter import Parameter
|
| 11 |
+
|
| 12 |
+
def _in_projection_packed(
|
| 13 |
+
q: Tensor,
|
| 14 |
+
k: Tensor,
|
| 15 |
+
v: Tensor,
|
| 16 |
+
w: Tensor,
|
| 17 |
+
b: Optional[Tensor] = None,
|
| 18 |
+
) -> List[Tensor]:
|
| 19 |
+
r"""
|
| 20 |
+
Performs the in-projection step of the attention operation, using packed weights.
|
| 21 |
+
Output is a triple containing projection tensors for query, key and value.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
q, k, v: query, key and value tensors to be projected. For self-attention,
|
| 25 |
+
these are typically the same tensor; for encoder-decoder attention,
|
| 26 |
+
k and v are typically the same tensor. (We take advantage of these
|
| 27 |
+
identities for performance if they are present.) Regardless, q, k and v
|
| 28 |
+
must share a common embedding dimension; otherwise their shapes may vary.
|
| 29 |
+
w: projection weights for q, k and v, packed into a single tensor. Weights
|
| 30 |
+
are packed along dimension 0, in q, k, v order.
|
| 31 |
+
b: optional projection biases for q, k and v, packed into a single tensor
|
| 32 |
+
in q, k, v order.
|
| 33 |
+
|
| 34 |
+
Shape:
|
| 35 |
+
Inputs:
|
| 36 |
+
- q: :math:`(..., E)` where E is the embedding dimension
|
| 37 |
+
- k: :math:`(..., E)` where E is the embedding dimension
|
| 38 |
+
- v: :math:`(..., E)` where E is the embedding dimension
|
| 39 |
+
- w: :math:`(E * 3, E)` where E is the embedding dimension
|
| 40 |
+
- b: :math:`E * 3` where E is the embedding dimension
|
| 41 |
+
|
| 42 |
+
Output:
|
| 43 |
+
- in output list :math:`[q', k', v']`, each output tensor will have the
|
| 44 |
+
same shape as the corresponding input tensor.
|
| 45 |
+
"""
|
| 46 |
+
E = q.size(-1)
|
| 47 |
+
if k is v:
|
| 48 |
+
if q is k:
|
| 49 |
+
# self-attention
|
| 50 |
+
return F.linear(q, w, b).chunk(3, dim=-1)
|
| 51 |
+
else:
|
| 52 |
+
# encoder-decoder attention
|
| 53 |
+
w_q, w_kv = w.split([E, E * 2])
|
| 54 |
+
if b is None:
|
| 55 |
+
b_q = b_kv = None
|
| 56 |
+
else:
|
| 57 |
+
b_q, b_kv = b.split([E, E * 2])
|
| 58 |
+
return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
|
| 59 |
+
else:
|
| 60 |
+
w_q, w_k, w_v = w.chunk(3)
|
| 61 |
+
if b is None:
|
| 62 |
+
b_q = b_k = b_v = None
|
| 63 |
+
else:
|
| 64 |
+
b_q, b_k, b_v = b.chunk(3)
|
| 65 |
+
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
|
| 66 |
+
|
| 67 |
+
def _scaled_dot_product_attention(
|
| 68 |
+
q: Tensor,
|
| 69 |
+
k: Tensor,
|
| 70 |
+
v: Tensor,
|
| 71 |
+
attn_mask: Optional[Tensor] = None,
|
| 72 |
+
dropout_p: float = 0.0,
|
| 73 |
+
) -> Tuple[Tensor, Tensor]:
|
| 74 |
+
r"""
|
| 75 |
+
Computes scaled dot product attention on query, key and value tensors, using
|
| 76 |
+
an optional attention mask if passed, and applying dropout if a probability
|
| 77 |
+
greater than 0.0 is specified.
|
| 78 |
+
Returns a tensor pair containing attended values and attention weights.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
q, k, v: query, key and value tensors. See Shape section for shape details.
|
| 82 |
+
attn_mask: optional tensor containing mask values to be added to calculated
|
| 83 |
+
attention. May be 2D or 3D; see Shape section for details.
|
| 84 |
+
dropout_p: dropout probability. If greater than 0.0, dropout is applied.
|
| 85 |
+
|
| 86 |
+
Shape:
|
| 87 |
+
- q: :math:`(B, Nt, E)` where B is batch size, Nt is the target sequence length,
|
| 88 |
+
and E is embedding dimension.
|
| 89 |
+
- key: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
|
| 90 |
+
and E is embedding dimension.
|
| 91 |
+
- value: :math:`(B, Ns, E)` where B is batch size, Ns is the source sequence length,
|
| 92 |
+
and E is embedding dimension.
|
| 93 |
+
- attn_mask: either a 3D tensor of shape :math:`(B, Nt, Ns)` or a 2D tensor of
|
| 94 |
+
shape :math:`(Nt, Ns)`.
|
| 95 |
+
|
| 96 |
+
- Output: attention values have shape :math:`(B, Nt, E)`; attention weights
|
| 97 |
+
have shape :math:`(B, Nt, Ns)`
|
| 98 |
+
"""
|
| 99 |
+
B, Nt, E = q.shape
|
| 100 |
+
q = q / math.sqrt(E)
|
| 101 |
+
# (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
|
| 102 |
+
if attn_mask is not None:
|
| 103 |
+
attn = torch.baddbmm(attn_mask, q, k.transpose(-2, -1))
|
| 104 |
+
else:
|
| 105 |
+
attn = torch.bmm(q, k.transpose(-2, -1))
|
| 106 |
+
|
| 107 |
+
attn = F.softmax(attn, dim=-1)
|
| 108 |
+
if dropout_p > 0.0:
|
| 109 |
+
attn = F.dropout(attn, p=dropout_p)
|
| 110 |
+
# (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
|
| 111 |
+
output = torch.bmm(attn, v)
|
| 112 |
+
return output, attn
|
| 113 |
+
|
| 114 |
+
def multi_head_attention_forward(
|
| 115 |
+
x,
|
| 116 |
+
ipw,
|
| 117 |
+
ipb,
|
| 118 |
+
opw,
|
| 119 |
+
opb,
|
| 120 |
+
n_head,
|
| 121 |
+
attn_mask,
|
| 122 |
+
past_kv=None,
|
| 123 |
+
use_cache=False,
|
| 124 |
+
):
|
| 125 |
+
# x = x.transpose(1, 0)
|
| 126 |
+
# tgt_len, bsz, embed_dim = x.shape
|
| 127 |
+
# head_dim = embed_dim // n_head
|
| 128 |
+
# q, k, v = _in_projection_packed(x, x, x, ipw, ipb)
|
| 129 |
+
# q = q.contiguous().view(tgt_len, bsz * n_head, head_dim).transpose(0, 1)
|
| 130 |
+
# k = k.contiguous().view(k.shape[0], bsz * n_head, head_dim).transpose(0, 1)
|
| 131 |
+
# v = v.contiguous().view(v.shape[0], bsz * n_head, head_dim).transpose(0, 1)
|
| 132 |
+
|
| 133 |
+
# new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
| 134 |
+
# new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
| 135 |
+
# attn_mask = new_attn_mask
|
| 136 |
+
#
|
| 137 |
+
# attn_output, attn_output_weights = _scaled_dot_product_attention(q, k, v, attn_mask, 0.0)
|
| 138 |
+
# attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
| 139 |
+
# attn_output = torch._C._nn.linear(attn_output, opw, opb)
|
| 140 |
+
# attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
| 141 |
+
|
| 142 |
+
B, T, C = x.size()
|
| 143 |
+
|
| 144 |
+
q, k, v = torch._C._nn.linear(x, ipw, ipb).chunk(3, dim=-1)
|
| 145 |
+
k = k.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 146 |
+
q = q.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 147 |
+
v = v.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 148 |
+
if past_kv is not None:
|
| 149 |
+
past_key = past_kv[0]
|
| 150 |
+
past_value = past_kv[1]
|
| 151 |
+
k = torch.cat((past_key, k), dim=-2)
|
| 152 |
+
v = torch.cat((past_value, v), dim=-2)
|
| 153 |
+
|
| 154 |
+
FULL_T = k.shape[-2]
|
| 155 |
+
|
| 156 |
+
if use_cache is True:
|
| 157 |
+
present = (k, v)
|
| 158 |
+
else:
|
| 159 |
+
present = None
|
| 160 |
+
|
| 161 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
| 162 |
+
att = att.masked_fill(attn_mask[FULL_T - T:FULL_T, :FULL_T], float('-inf'))
|
| 163 |
+
att = F.softmax(att, dim=-1)
|
| 164 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
| 165 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
| 166 |
+
y = torch._C._nn.linear(y, opw, opb)
|
| 167 |
+
return (y, present)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class MultiheadAttention(Module):
|
| 171 |
+
r"""Allows the model to jointly attend to information
|
| 172 |
+
from different representation subspaces as described in the paper:
|
| 173 |
+
`Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_.
|
| 174 |
+
|
| 175 |
+
Multi-Head Attention is defined as:
|
| 176 |
+
|
| 177 |
+
.. math::
|
| 178 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
| 179 |
+
|
| 180 |
+
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
| 181 |
+
|
| 182 |
+
``forward()`` will use a special optimized implementation if all of the following
|
| 183 |
+
conditions are met:
|
| 184 |
+
|
| 185 |
+
- self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
|
| 186 |
+
restriction will be loosened in the future.)
|
| 187 |
+
- Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
|
| 188 |
+
- training is disabled (using ``.eval()``)
|
| 189 |
+
- dropout is 0
|
| 190 |
+
- ``add_bias_kv`` is ``False``
|
| 191 |
+
- ``add_zero_attn`` is ``False``
|
| 192 |
+
- ``batch_first`` is ``True`` and the input is batched
|
| 193 |
+
- ``kdim`` and ``vdim`` are equal to ``embed_dim``
|
| 194 |
+
- at most one of ``key_padding_mask`` or ``attn_mask`` is passed
|
| 195 |
+
- if a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ is passed, neither ``key_padding_mask``
|
| 196 |
+
nor ``attn_mask`` is passed
|
| 197 |
+
|
| 198 |
+
If the optimized implementation is in use, a
|
| 199 |
+
`NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be passed for
|
| 200 |
+
``query``/``key``/``value`` to represent padding more efficiently than using a
|
| 201 |
+
padding mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_
|
| 202 |
+
will be returned, and an additional speedup proportional to the fraction of the input
|
| 203 |
+
that is padding can be expected.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
embed_dim: Total dimension of the model.
|
| 207 |
+
num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
|
| 208 |
+
across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
|
| 209 |
+
dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
|
| 210 |
+
bias: If specified, adds bias to input / output projection layers. Default: ``True``.
|
| 211 |
+
add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
|
| 212 |
+
add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
|
| 213 |
+
Default: ``False``.
|
| 214 |
+
kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
|
| 215 |
+
vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
|
| 216 |
+
batch_first: If ``True``, then the input and output tensors are provided
|
| 217 |
+
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
| 218 |
+
|
| 219 |
+
Examples::
|
| 220 |
+
|
| 221 |
+
>>> # xdoctest: +SKIP
|
| 222 |
+
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
| 223 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
| 224 |
+
|
| 225 |
+
"""
|
| 226 |
+
__constants__ = ["batch_first"]
|
| 227 |
+
bias_k: Optional[torch.Tensor]
|
| 228 |
+
bias_v: Optional[torch.Tensor]
|
| 229 |
+
|
| 230 |
+
def __init__(
|
| 231 |
+
self,
|
| 232 |
+
embed_dim,
|
| 233 |
+
num_heads,
|
| 234 |
+
dropout=0.0,
|
| 235 |
+
bias=True,
|
| 236 |
+
add_bias_kv=False,
|
| 237 |
+
add_zero_attn=False,
|
| 238 |
+
kdim=None,
|
| 239 |
+
vdim=None,
|
| 240 |
+
batch_first=False,
|
| 241 |
+
linear1_cls=Linear,
|
| 242 |
+
linear2_cls=Linear,
|
| 243 |
+
device=None,
|
| 244 |
+
dtype=None,
|
| 245 |
+
) -> None:
|
| 246 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 247 |
+
super(MultiheadAttention, self).__init__()
|
| 248 |
+
self.embed_dim = embed_dim
|
| 249 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 250 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 251 |
+
self._qkv_same_embed_dim = (
|
| 252 |
+
self.kdim == embed_dim and self.vdim == embed_dim
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
self.num_heads = num_heads
|
| 256 |
+
self.dropout = dropout
|
| 257 |
+
self.batch_first = batch_first
|
| 258 |
+
self.head_dim = embed_dim // num_heads
|
| 259 |
+
assert (
|
| 260 |
+
self.head_dim * num_heads == self.embed_dim
|
| 261 |
+
), "embed_dim must be divisible by num_heads"
|
| 262 |
+
|
| 263 |
+
if add_bias_kv:
|
| 264 |
+
self.bias_k = Parameter(
|
| 265 |
+
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
| 266 |
+
)
|
| 267 |
+
self.bias_v = Parameter(
|
| 268 |
+
torch.empty((1, 1, embed_dim), **factory_kwargs)
|
| 269 |
+
)
|
| 270 |
+
else:
|
| 271 |
+
self.bias_k = self.bias_v = None
|
| 272 |
+
|
| 273 |
+
if linear1_cls == Linear:
|
| 274 |
+
if not self._qkv_same_embed_dim:
|
| 275 |
+
self.q_proj_weight = Parameter(
|
| 276 |
+
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
| 277 |
+
)
|
| 278 |
+
self.k_proj_weight = Parameter(
|
| 279 |
+
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
| 280 |
+
)
|
| 281 |
+
self.v_proj_weight = Parameter(
|
| 282 |
+
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
| 283 |
+
)
|
| 284 |
+
self.register_parameter("in_proj_weight", None)
|
| 285 |
+
else:
|
| 286 |
+
self.in_proj_weight = Parameter(
|
| 287 |
+
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
| 288 |
+
)
|
| 289 |
+
self.register_parameter("q_proj_weight", None)
|
| 290 |
+
self.register_parameter("k_proj_weight", None)
|
| 291 |
+
self.register_parameter("v_proj_weight", None)
|
| 292 |
+
|
| 293 |
+
if bias:
|
| 294 |
+
self.in_proj_bias = Parameter(
|
| 295 |
+
torch.empty(3 * embed_dim, **factory_kwargs)
|
| 296 |
+
)
|
| 297 |
+
else:
|
| 298 |
+
self.register_parameter("in_proj_bias", None)
|
| 299 |
+
self.out_proj = NonDynamicallyQuantizableLinear(
|
| 300 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
self._reset_parameters()
|
| 304 |
+
else:
|
| 305 |
+
if not self._qkv_same_embed_dim:
|
| 306 |
+
raise NotImplementedError
|
| 307 |
+
else:
|
| 308 |
+
self.in_proj_linear = linear1_cls(
|
| 309 |
+
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
| 310 |
+
)
|
| 311 |
+
self.in_proj_weight = self.in_proj_linear.weight
|
| 312 |
+
|
| 313 |
+
self.register_parameter("q_proj_weight", None)
|
| 314 |
+
self.register_parameter("k_proj_weight", None)
|
| 315 |
+
self.register_parameter("v_proj_weight", None)
|
| 316 |
+
|
| 317 |
+
if bias:
|
| 318 |
+
self.in_proj_bias = self.in_proj_linear.bias
|
| 319 |
+
else:
|
| 320 |
+
self.register_parameter("in_proj_bias", None)
|
| 321 |
+
|
| 322 |
+
self.out_proj = linear2_cls(
|
| 323 |
+
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
if self.bias_k is not None:
|
| 327 |
+
xavier_normal_(self.bias_k)
|
| 328 |
+
if self.bias_v is not None:
|
| 329 |
+
xavier_normal_(self.bias_v)
|
| 330 |
+
|
| 331 |
+
self.add_zero_attn = add_zero_attn
|
| 332 |
+
|
| 333 |
+
def _reset_parameters(self):
|
| 334 |
+
if self._qkv_same_embed_dim:
|
| 335 |
+
xavier_uniform_(self.in_proj_weight)
|
| 336 |
+
else:
|
| 337 |
+
xavier_uniform_(self.q_proj_weight)
|
| 338 |
+
xavier_uniform_(self.k_proj_weight)
|
| 339 |
+
xavier_uniform_(self.v_proj_weight)
|
| 340 |
+
|
| 341 |
+
if self.in_proj_bias is not None:
|
| 342 |
+
constant_(self.in_proj_bias, 0.0)
|
| 343 |
+
constant_(self.out_proj.bias, 0.0)
|
| 344 |
+
|
| 345 |
+
if self.bias_k is not None:
|
| 346 |
+
xavier_normal_(self.bias_k)
|
| 347 |
+
if self.bias_v is not None:
|
| 348 |
+
xavier_normal_(self.bias_v)
|
| 349 |
+
|
| 350 |
+
def __setstate__(self, state):
|
| 351 |
+
# Support loading old MultiheadAttention checkpoints generated by v1.1.0
|
| 352 |
+
if "_qkv_same_embed_dim" not in state:
|
| 353 |
+
state["_qkv_same_embed_dim"] = True
|
| 354 |
+
|
| 355 |
+
super(MultiheadAttention, self).__setstate__(state)
|
| 356 |
+
|
| 357 |
+
def forward(
|
| 358 |
+
self,
|
| 359 |
+
query: Tensor,
|
| 360 |
+
key: Tensor,
|
| 361 |
+
value: Tensor,
|
| 362 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 363 |
+
need_weights: bool = True,
|
| 364 |
+
attn_mask: Optional[Tensor] = None,
|
| 365 |
+
average_attn_weights: bool = True,
|
| 366 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
| 367 |
+
r"""
|
| 368 |
+
Args:
|
| 369 |
+
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
|
| 370 |
+
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
|
| 371 |
+
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
|
| 372 |
+
Queries are compared against key-value pairs to produce the output.
|
| 373 |
+
See "Attention Is All You Need" for more details.
|
| 374 |
+
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
|
| 375 |
+
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
|
| 376 |
+
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
|
| 377 |
+
See "Attention Is All You Need" for more details.
|
| 378 |
+
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
|
| 379 |
+
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
|
| 380 |
+
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
|
| 381 |
+
See "Attention Is All You Need" for more details.
|
| 382 |
+
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
|
| 383 |
+
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
|
| 384 |
+
Binary and byte masks are supported.
|
| 385 |
+
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
|
| 386 |
+
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
|
| 387 |
+
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
|
| 388 |
+
Default: ``True``.
|
| 389 |
+
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
|
| 390 |
+
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
|
| 391 |
+
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
|
| 392 |
+
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
|
| 393 |
+
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
|
| 394 |
+
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
|
| 395 |
+
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
|
| 396 |
+
the attention weight.
|
| 397 |
+
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
| 398 |
+
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
| 399 |
+
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
|
| 400 |
+
|
| 401 |
+
Outputs:
|
| 402 |
+
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
|
| 403 |
+
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
|
| 404 |
+
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
|
| 405 |
+
embedding dimension ``embed_dim``.
|
| 406 |
+
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
|
| 407 |
+
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
|
| 408 |
+
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
|
| 409 |
+
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
| 410 |
+
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
|
| 411 |
+
|
| 412 |
+
.. note::
|
| 413 |
+
`batch_first` argument is ignored for unbatched inputs.
|
| 414 |
+
"""
|
| 415 |
+
is_batched = query.dim() == 3
|
| 416 |
+
if key_padding_mask is not None:
|
| 417 |
+
_kpm_dtype = key_padding_mask.dtype
|
| 418 |
+
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
| 419 |
+
key_padding_mask
|
| 420 |
+
):
|
| 421 |
+
raise AssertionError(
|
| 422 |
+
"only bool and floating types of key_padding_mask are supported"
|
| 423 |
+
)
|
| 424 |
+
why_not_fast_path = ""
|
| 425 |
+
if not is_batched:
|
| 426 |
+
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
| 427 |
+
elif query is not key or key is not value:
|
| 428 |
+
# When lifting this restriction, don't forget to either
|
| 429 |
+
# enforce that the dtypes all match or test cases where
|
| 430 |
+
# they don't!
|
| 431 |
+
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
| 432 |
+
elif (
|
| 433 |
+
self.in_proj_bias is not None
|
| 434 |
+
and query.dtype != self.in_proj_bias.dtype
|
| 435 |
+
):
|
| 436 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
| 437 |
+
elif (
|
| 438 |
+
self.in_proj_weight is not None
|
| 439 |
+
and query.dtype != self.in_proj_weight.dtype
|
| 440 |
+
):
|
| 441 |
+
# this case will fail anyway, but at least they'll get a useful error message.
|
| 442 |
+
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
| 443 |
+
elif self.training:
|
| 444 |
+
why_not_fast_path = "training is enabled"
|
| 445 |
+
elif not self.batch_first:
|
| 446 |
+
why_not_fast_path = "batch_first was not True"
|
| 447 |
+
elif self.bias_k is not None:
|
| 448 |
+
why_not_fast_path = "self.bias_k was not None"
|
| 449 |
+
elif self.bias_v is not None:
|
| 450 |
+
why_not_fast_path = "self.bias_v was not None"
|
| 451 |
+
elif self.dropout:
|
| 452 |
+
why_not_fast_path = f"dropout was {self.dropout}, required zero"
|
| 453 |
+
elif self.add_zero_attn:
|
| 454 |
+
why_not_fast_path = "add_zero_attn was enabled"
|
| 455 |
+
elif not self._qkv_same_embed_dim:
|
| 456 |
+
why_not_fast_path = "_qkv_same_embed_dim was not True"
|
| 457 |
+
elif attn_mask is not None:
|
| 458 |
+
why_not_fast_path = "attn_mask was not None"
|
| 459 |
+
elif query.is_nested and key_padding_mask is not None:
|
| 460 |
+
why_not_fast_path = (
|
| 461 |
+
"key_padding_mask is not supported with NestedTensor input"
|
| 462 |
+
)
|
| 463 |
+
elif self.num_heads % 2 == 1:
|
| 464 |
+
why_not_fast_path = "num_heads is odd"
|
| 465 |
+
elif torch.is_autocast_enabled():
|
| 466 |
+
why_not_fast_path = "autocast is enabled"
|
| 467 |
+
|
| 468 |
+
if not why_not_fast_path:
|
| 469 |
+
tensor_args = (
|
| 470 |
+
query,
|
| 471 |
+
key,
|
| 472 |
+
value,
|
| 473 |
+
self.in_proj_weight,
|
| 474 |
+
self.in_proj_bias,
|
| 475 |
+
self.out_proj.weight,
|
| 476 |
+
self.out_proj.bias,
|
| 477 |
+
)
|
| 478 |
+
# We have to use list comprehensions below because TorchScript does not support
|
| 479 |
+
# generator expressions.
|
| 480 |
+
if torch.overrides.has_torch_function(tensor_args):
|
| 481 |
+
why_not_fast_path = "some Tensor argument has_torch_function"
|
| 482 |
+
elif not all(
|
| 483 |
+
[
|
| 484 |
+
(x is None or x.is_cuda or "cpu" in str(x.device))
|
| 485 |
+
for x in tensor_args
|
| 486 |
+
]
|
| 487 |
+
):
|
| 488 |
+
why_not_fast_path = (
|
| 489 |
+
"some Tensor argument is neither CUDA nor CPU"
|
| 490 |
+
)
|
| 491 |
+
elif torch.is_grad_enabled() and any(
|
| 492 |
+
[x is not None and x.requires_grad for x in tensor_args]
|
| 493 |
+
):
|
| 494 |
+
why_not_fast_path = (
|
| 495 |
+
"grad is enabled and at least one of query or the "
|
| 496 |
+
"input/output projection weights or biases requires_grad"
|
| 497 |
+
)
|
| 498 |
+
if not why_not_fast_path:
|
| 499 |
+
return torch._native_multi_head_attention(
|
| 500 |
+
query,
|
| 501 |
+
key,
|
| 502 |
+
value,
|
| 503 |
+
self.embed_dim,
|
| 504 |
+
self.num_heads,
|
| 505 |
+
self.in_proj_weight,
|
| 506 |
+
self.in_proj_bias,
|
| 507 |
+
self.out_proj.weight,
|
| 508 |
+
self.out_proj.bias,
|
| 509 |
+
key_padding_mask
|
| 510 |
+
if key_padding_mask is not None
|
| 511 |
+
else attn_mask,
|
| 512 |
+
need_weights,
|
| 513 |
+
average_attn_weights,
|
| 514 |
+
1
|
| 515 |
+
if key_padding_mask is not None
|
| 516 |
+
else 0
|
| 517 |
+
if attn_mask is not None
|
| 518 |
+
else None,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
any_nested = query.is_nested or key.is_nested or value.is_nested
|
| 522 |
+
assert not any_nested, (
|
| 523 |
+
"MultiheadAttention does not support NestedTensor outside of its fast path. "
|
| 524 |
+
+ f"The fast path was not hit because {why_not_fast_path}"
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
if self.batch_first and is_batched:
|
| 528 |
+
# make sure that the transpose op does not affect the "is" property
|
| 529 |
+
if key is value:
|
| 530 |
+
if query is key:
|
| 531 |
+
query = key = value = query.transpose(1, 0)
|
| 532 |
+
else:
|
| 533 |
+
query, key = [x.transpose(1, 0) for x in (query, key)]
|
| 534 |
+
value = key
|
| 535 |
+
else:
|
| 536 |
+
query, key, value = [
|
| 537 |
+
x.transpose(1, 0) for x in (query, key, value)
|
| 538 |
+
]
|
| 539 |
+
|
| 540 |
+
if not self._qkv_same_embed_dim:
|
| 541 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
| 542 |
+
query,
|
| 543 |
+
key,
|
| 544 |
+
value,
|
| 545 |
+
self.embed_dim,
|
| 546 |
+
self.num_heads,
|
| 547 |
+
self.in_proj_weight,
|
| 548 |
+
self.in_proj_bias,
|
| 549 |
+
self.bias_k,
|
| 550 |
+
self.bias_v,
|
| 551 |
+
self.add_zero_attn,
|
| 552 |
+
self.dropout,
|
| 553 |
+
self.out_proj.weight,
|
| 554 |
+
self.out_proj.bias,
|
| 555 |
+
training=self.training,
|
| 556 |
+
key_padding_mask=key_padding_mask,
|
| 557 |
+
need_weights=need_weights,
|
| 558 |
+
attn_mask=attn_mask,
|
| 559 |
+
use_separate_proj_weight=True,
|
| 560 |
+
q_proj_weight=self.q_proj_weight,
|
| 561 |
+
k_proj_weight=self.k_proj_weight,
|
| 562 |
+
v_proj_weight=self.v_proj_weight,
|
| 563 |
+
average_attn_weights=average_attn_weights,
|
| 564 |
+
)
|
| 565 |
+
else:
|
| 566 |
+
attn_output, attn_output_weights = F.multi_head_attention_forward(
|
| 567 |
+
query,
|
| 568 |
+
key,
|
| 569 |
+
value,
|
| 570 |
+
self.embed_dim,
|
| 571 |
+
self.num_heads,
|
| 572 |
+
self.in_proj_weight,
|
| 573 |
+
self.in_proj_bias,
|
| 574 |
+
self.bias_k,
|
| 575 |
+
self.bias_v,
|
| 576 |
+
self.add_zero_attn,
|
| 577 |
+
self.dropout,
|
| 578 |
+
self.out_proj.weight,
|
| 579 |
+
self.out_proj.bias,
|
| 580 |
+
training=self.training,
|
| 581 |
+
key_padding_mask=key_padding_mask,
|
| 582 |
+
need_weights=need_weights,
|
| 583 |
+
attn_mask=attn_mask,
|
| 584 |
+
average_attn_weights=average_attn_weights,
|
| 585 |
+
)
|
| 586 |
+
if self.batch_first and is_batched:
|
| 587 |
+
return attn_output.transpose(1, 0), attn_output_weights
|
| 588 |
+
else:
|
| 589 |
+
return attn_output, attn_output_weights
|
| 590 |
+
|
| 591 |
+
def infer(self,
|
| 592 |
+
x: Tensor,
|
| 593 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 594 |
+
need_weights: bool = True,
|
| 595 |
+
attn_mask: Optional[Tensor] = None,
|
| 596 |
+
average_attn_weights: bool = True,
|
| 597 |
+
past_kv = None,
|
| 598 |
+
use_cache = False
|
| 599 |
+
):
|
| 600 |
+
# x = x.transpose(1, 0)
|
| 601 |
+
y, kv = multi_head_attention_forward(
|
| 602 |
+
x=x,
|
| 603 |
+
ipw=self.in_proj_weight,
|
| 604 |
+
ipb=self.in_proj_bias,
|
| 605 |
+
opw=self.out_proj.weight,
|
| 606 |
+
opb=self.out_proj.bias,
|
| 607 |
+
n_head=self.num_heads,
|
| 608 |
+
attn_mask=attn_mask,
|
| 609 |
+
past_kv=past_kv,
|
| 610 |
+
use_cache=use_cache,
|
| 611 |
+
)
|
| 612 |
+
return (y, kv)
|
apps/audio_cloning/vallex/modules/embedding.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 (authors: Feiteng Li)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TokenEmbedding(nn.Module):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
dim_model: int,
|
| 25 |
+
vocab_size: int,
|
| 26 |
+
dropout: float = 0.0,
|
| 27 |
+
):
|
| 28 |
+
super().__init__()
|
| 29 |
+
|
| 30 |
+
self.vocab_size = vocab_size
|
| 31 |
+
self.dim_model = dim_model
|
| 32 |
+
|
| 33 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
| 34 |
+
self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
|
| 35 |
+
|
| 36 |
+
@property
|
| 37 |
+
def weight(self) -> torch.Tensor:
|
| 38 |
+
return self.word_embeddings.weight
|
| 39 |
+
|
| 40 |
+
def embedding(self, index: int) -> torch.Tensor:
|
| 41 |
+
return self.word_embeddings.weight[index : index + 1]
|
| 42 |
+
|
| 43 |
+
def forward(self, x: torch.Tensor):
|
| 44 |
+
X = self.word_embeddings(x)
|
| 45 |
+
X = self.dropout(X)
|
| 46 |
+
|
| 47 |
+
return X
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class SinePositionalEmbedding(nn.Module):
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
dim_model: int,
|
| 54 |
+
dropout: float = 0.0,
|
| 55 |
+
scale: bool = False,
|
| 56 |
+
alpha: bool = False,
|
| 57 |
+
):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.dim_model = dim_model
|
| 60 |
+
self.x_scale = math.sqrt(dim_model) if scale else 1.0
|
| 61 |
+
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
| 62 |
+
self.dropout = torch.nn.Dropout(p=dropout)
|
| 63 |
+
|
| 64 |
+
self.reverse = False
|
| 65 |
+
self.pe = None
|
| 66 |
+
self.extend_pe(torch.tensor(0.0).expand(1, 4000))
|
| 67 |
+
|
| 68 |
+
def extend_pe(self, x):
|
| 69 |
+
"""Reset the positional encodings."""
|
| 70 |
+
if self.pe is not None:
|
| 71 |
+
if self.pe.size(1) >= x.size(1):
|
| 72 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
| 73 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
| 74 |
+
return
|
| 75 |
+
pe = torch.zeros(x.size(1), self.dim_model)
|
| 76 |
+
if self.reverse:
|
| 77 |
+
position = torch.arange(
|
| 78 |
+
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
| 79 |
+
).unsqueeze(1)
|
| 80 |
+
else:
|
| 81 |
+
position = torch.arange(
|
| 82 |
+
0, x.size(1), dtype=torch.float32
|
| 83 |
+
).unsqueeze(1)
|
| 84 |
+
div_term = torch.exp(
|
| 85 |
+
torch.arange(0, self.dim_model, 2, dtype=torch.float32)
|
| 86 |
+
* -(math.log(10000.0) / self.dim_model)
|
| 87 |
+
)
|
| 88 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 89 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 90 |
+
pe = pe.unsqueeze(0)
|
| 91 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
|
| 92 |
+
|
| 93 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
self.extend_pe(x)
|
| 95 |
+
output = x.unsqueeze(-1) if x.ndim == 2 else x
|
| 96 |
+
output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
|
| 97 |
+
return self.dropout(output)
|
apps/audio_cloning/vallex/modules/optim.py
ADDED
|
@@ -0,0 +1,1105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
| 2 |
+
#
|
| 3 |
+
# See ../LICENSE for clarification regarding multiple authors
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import contextlib
|
| 18 |
+
import logging
|
| 19 |
+
import random
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
from typing import List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from lhotse.utils import fix_random_seed
|
| 25 |
+
from torch import Tensor
|
| 26 |
+
from torch.optim import Optimizer
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class BatchedOptimizer(Optimizer):
|
| 30 |
+
"""
|
| 31 |
+
This class adds to class Optimizer the capability to optimize parameters in batches:
|
| 32 |
+
it will stack the parameters and their grads for you so the optimizer can work
|
| 33 |
+
on tensors with an extra leading dimension. This is intended for speed with GPUs,
|
| 34 |
+
as it reduces the number of kernels launched in the optimizer.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
params:
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, params, defaults):
|
| 41 |
+
super(BatchedOptimizer, self).__init__(params, defaults)
|
| 42 |
+
|
| 43 |
+
@contextlib.contextmanager
|
| 44 |
+
def batched_params(self, param_group, group_params_names):
|
| 45 |
+
"""
|
| 46 |
+
This function returns (technically, yields) a list of
|
| 47 |
+
of tuples (p, state), where
|
| 48 |
+
p is a `fake` parameter that is stacked (over axis 0) from real parameters
|
| 49 |
+
that share the same shape, and its gradient is also stacked;
|
| 50 |
+
`state` is the state corresponding to this batch of parameters
|
| 51 |
+
(it will be physically located in the "state" for one of the real
|
| 52 |
+
parameters, the last one that has any particular shape and dtype).
|
| 53 |
+
|
| 54 |
+
This function is decorated as a context manager so that it can
|
| 55 |
+
write parameters back to their "real" locations.
|
| 56 |
+
|
| 57 |
+
The idea is, instead of doing:
|
| 58 |
+
<code>
|
| 59 |
+
for p in group["params"]:
|
| 60 |
+
state = self.state[p]
|
| 61 |
+
...
|
| 62 |
+
</code>
|
| 63 |
+
you can do:
|
| 64 |
+
<code>
|
| 65 |
+
with self.batched_params(group["params"]) as batches:
|
| 66 |
+
for p, state, p_names in batches:
|
| 67 |
+
...
|
| 68 |
+
</code>
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
group: a parameter group, which is a list of parameters; should be
|
| 72 |
+
one of self.param_groups.
|
| 73 |
+
group_params_names: name for each parameter in group,
|
| 74 |
+
which is List[str].
|
| 75 |
+
"""
|
| 76 |
+
batches = defaultdict(
|
| 77 |
+
list
|
| 78 |
+
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
| 79 |
+
batches_names = defaultdict(
|
| 80 |
+
list
|
| 81 |
+
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
| 82 |
+
|
| 83 |
+
assert len(param_group) == len(group_params_names)
|
| 84 |
+
for p, named_p in zip(param_group, group_params_names):
|
| 85 |
+
key = (str(p.dtype), *p.shape)
|
| 86 |
+
batches[key].append(p)
|
| 87 |
+
batches_names[key].append(named_p)
|
| 88 |
+
|
| 89 |
+
batches_names_keys = list(batches_names.keys())
|
| 90 |
+
sorted_idx = sorted(
|
| 91 |
+
range(len(batches_names)), key=lambda i: batches_names_keys[i]
|
| 92 |
+
)
|
| 93 |
+
batches_names = [
|
| 94 |
+
batches_names[batches_names_keys[idx]] for idx in sorted_idx
|
| 95 |
+
]
|
| 96 |
+
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
| 97 |
+
|
| 98 |
+
stacked_params_dict = dict()
|
| 99 |
+
|
| 100 |
+
# turn batches into a list, in deterministic order.
|
| 101 |
+
# tuples will contain tuples of (stacked_param, state, stacked_params_names),
|
| 102 |
+
# one for each batch in `batches`.
|
| 103 |
+
tuples = []
|
| 104 |
+
|
| 105 |
+
for batch, batch_names in zip(batches, batches_names):
|
| 106 |
+
p = batch[0]
|
| 107 |
+
# we arbitrarily store the state in the
|
| 108 |
+
# state corresponding to the 1st parameter in the
|
| 109 |
+
# group. class Optimizer will take care of saving/loading state.
|
| 110 |
+
state = self.state[p]
|
| 111 |
+
p_stacked = torch.stack(batch)
|
| 112 |
+
grad = torch.stack(
|
| 113 |
+
[
|
| 114 |
+
torch.zeros_like(p) if p.grad is None else p.grad
|
| 115 |
+
for p in batch
|
| 116 |
+
]
|
| 117 |
+
)
|
| 118 |
+
p_stacked.grad = grad
|
| 119 |
+
stacked_params_dict[key] = p_stacked
|
| 120 |
+
tuples.append((p_stacked, state, batch_names))
|
| 121 |
+
|
| 122 |
+
yield tuples # <-- calling code will do the actual optimization here!
|
| 123 |
+
|
| 124 |
+
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
| 125 |
+
for i, p in enumerate(batch): # batch is list of Parameter
|
| 126 |
+
p.copy_(stacked_params[i])
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class ScaledAdam(BatchedOptimizer):
|
| 130 |
+
"""
|
| 131 |
+
Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
|
| 132 |
+
proportional to the norm of that parameter; and also learn the scale of the parameter,
|
| 133 |
+
in log space, subject to upper and lower limits (as if we had factored each parameter as
|
| 134 |
+
param = underlying_param * log_scale.exp())
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
params: The parameters or param_groups to optimize (like other Optimizer subclasses)
|
| 139 |
+
lr: The learning rate. We will typically use a learning rate schedule that starts
|
| 140 |
+
at 0.03 and decreases over time, i.e. much higher than other common
|
| 141 |
+
optimizers.
|
| 142 |
+
clipping_scale: (e.g. 2.0)
|
| 143 |
+
A scale for gradient-clipping: if specified, the normalized gradients
|
| 144 |
+
over the whole model will be clipped to have 2-norm equal to
|
| 145 |
+
`clipping_scale` times the median 2-norm over the most recent period
|
| 146 |
+
of `clipping_update_period` minibatches. By "normalized gradients",
|
| 147 |
+
we mean after multiplying by the rms parameter value for this tensor
|
| 148 |
+
[for non-scalars]; this is appropriate because our update is scaled
|
| 149 |
+
by this quantity.
|
| 150 |
+
betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
|
| 151 |
+
Must satisfy 0 < beta <= beta2 < 1.
|
| 152 |
+
scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
|
| 153 |
+
scale of each parameter tensor and scalar parameters of the mode..
|
| 154 |
+
If each parameter were decomposed
|
| 155 |
+
as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
|
| 156 |
+
would be a the scaling factor on the learning rate of p_scale.
|
| 157 |
+
eps: A general-purpose epsilon to prevent division by zero
|
| 158 |
+
param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
|
| 159 |
+
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
| 160 |
+
parameter tensor to be >= this value)
|
| 161 |
+
param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
|
| 162 |
+
learning the scale on the parameters (we'll constrain the rms of each non-scalar
|
| 163 |
+
parameter tensor to be <= this value)
|
| 164 |
+
scalar_max: Maximum absolute value for scalar parameters (applicable if your
|
| 165 |
+
model has any parameters with numel() == 1).
|
| 166 |
+
size_update_period: The periodicity, in steps, with which we update the size (scale)
|
| 167 |
+
of the parameter tensor. This is provided to save a little time
|
| 168 |
+
in the update.
|
| 169 |
+
clipping_update_period: if clipping_scale is specified, this is the period
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def __init__(
|
| 173 |
+
self,
|
| 174 |
+
params,
|
| 175 |
+
lr=3e-02,
|
| 176 |
+
clipping_scale=None,
|
| 177 |
+
betas=(0.9, 0.98),
|
| 178 |
+
scalar_lr_scale=0.1,
|
| 179 |
+
eps=1.0e-08,
|
| 180 |
+
param_min_rms=1.0e-05,
|
| 181 |
+
param_max_rms=3.0,
|
| 182 |
+
scalar_max=10.0,
|
| 183 |
+
size_update_period=4,
|
| 184 |
+
clipping_update_period=100,
|
| 185 |
+
parameters_names=None,
|
| 186 |
+
show_dominant_parameters=True,
|
| 187 |
+
):
|
| 188 |
+
|
| 189 |
+
assert parameters_names is not None, (
|
| 190 |
+
"Please prepare parameters_names,"
|
| 191 |
+
"which is a List[List[str]]. Each List[str] is for a group"
|
| 192 |
+
"and each str is for a parameter"
|
| 193 |
+
)
|
| 194 |
+
defaults = dict(
|
| 195 |
+
lr=lr,
|
| 196 |
+
clipping_scale=clipping_scale,
|
| 197 |
+
betas=betas,
|
| 198 |
+
scalar_lr_scale=scalar_lr_scale,
|
| 199 |
+
eps=eps,
|
| 200 |
+
param_min_rms=param_min_rms,
|
| 201 |
+
param_max_rms=param_max_rms,
|
| 202 |
+
scalar_max=scalar_max,
|
| 203 |
+
size_update_period=size_update_period,
|
| 204 |
+
clipping_update_period=clipping_update_period,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
super(ScaledAdam, self).__init__(params, defaults)
|
| 208 |
+
assert len(self.param_groups) == len(parameters_names)
|
| 209 |
+
self.parameters_names = parameters_names
|
| 210 |
+
self.show_dominant_parameters = show_dominant_parameters
|
| 211 |
+
|
| 212 |
+
def __setstate__(self, state):
|
| 213 |
+
super(ScaledAdam, self).__setstate__(state)
|
| 214 |
+
|
| 215 |
+
@torch.no_grad()
|
| 216 |
+
def step(self, closure=None):
|
| 217 |
+
"""Performs a single optimization step.
|
| 218 |
+
|
| 219 |
+
Arguments:
|
| 220 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 221 |
+
and returns the loss.
|
| 222 |
+
"""
|
| 223 |
+
loss = None
|
| 224 |
+
if closure is not None:
|
| 225 |
+
with torch.enable_grad():
|
| 226 |
+
loss = closure()
|
| 227 |
+
|
| 228 |
+
batch = True
|
| 229 |
+
|
| 230 |
+
for group, group_params_names in zip(
|
| 231 |
+
self.param_groups, self.parameters_names
|
| 232 |
+
):
|
| 233 |
+
|
| 234 |
+
with self.batched_params(
|
| 235 |
+
group["params"], group_params_names
|
| 236 |
+
) as batches:
|
| 237 |
+
|
| 238 |
+
# batches is list of pairs (stacked_param, state). stacked_param is like
|
| 239 |
+
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
| 240 |
+
# a stacking dim, it is not a real dim.
|
| 241 |
+
|
| 242 |
+
if (
|
| 243 |
+
len(batches[0][1]) == 0
|
| 244 |
+
): # if len(first state) == 0: not yet initialized
|
| 245 |
+
clipping_scale = 1
|
| 246 |
+
else:
|
| 247 |
+
clipping_scale = self._get_clipping_scale(group, batches)
|
| 248 |
+
|
| 249 |
+
for p, state, _ in batches:
|
| 250 |
+
# Perform optimization step.
|
| 251 |
+
# grad is not going to be None, we handled that when creating the batches.
|
| 252 |
+
grad = p.grad
|
| 253 |
+
if grad.is_sparse:
|
| 254 |
+
raise RuntimeError(
|
| 255 |
+
"ScaledAdam optimizer does not support sparse gradients"
|
| 256 |
+
)
|
| 257 |
+
# State initialization
|
| 258 |
+
if len(state) == 0:
|
| 259 |
+
self._init_state(group, p, state)
|
| 260 |
+
|
| 261 |
+
self._step_one_batch(group, p, state, clipping_scale)
|
| 262 |
+
|
| 263 |
+
return loss
|
| 264 |
+
|
| 265 |
+
def _init_state(self, group: dict, p: Tensor, state: dict):
|
| 266 |
+
"""
|
| 267 |
+
Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
|
| 268 |
+
is actually the batch dimension, corresponding to batched-together
|
| 269 |
+
parameters of a given shape.
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
group: Dict to look up configuration values.
|
| 274 |
+
p: The parameter that we are initializing the state for
|
| 275 |
+
state: Dict from string to whatever state we are initializing
|
| 276 |
+
"""
|
| 277 |
+
size_update_period = group["size_update_period"]
|
| 278 |
+
|
| 279 |
+
state["step"] = 0
|
| 280 |
+
|
| 281 |
+
kwargs = {"device": p.device, "dtype": p.dtype}
|
| 282 |
+
|
| 283 |
+
# 'delta' implements conventional momentum. There are
|
| 284 |
+
# several different kinds of update going on, so rather than
|
| 285 |
+
# compute "exp_avg" like in Adam, we store and decay a
|
| 286 |
+
# parameter-change "delta", which combines all forms of
|
| 287 |
+
# update. this is equivalent to how it's done in Adam,
|
| 288 |
+
# except for the first few steps.
|
| 289 |
+
state["delta"] = torch.zeros_like(
|
| 290 |
+
p, memory_format=torch.preserve_format
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
batch_size = p.shape[0]
|
| 294 |
+
numel = p.numel() // batch_size
|
| 295 |
+
numel = p.numel()
|
| 296 |
+
|
| 297 |
+
if numel > 1:
|
| 298 |
+
# "param_rms" just periodically records the scalar root-mean-square value of
|
| 299 |
+
# the parameter tensor.
|
| 300 |
+
# it has a shape like (batch_size, 1, 1, 1, 1)
|
| 301 |
+
param_rms = (
|
| 302 |
+
(p ** 2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
| 303 |
+
)
|
| 304 |
+
state["param_rms"] = param_rms
|
| 305 |
+
|
| 306 |
+
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
| 307 |
+
state["scale_grads"] = torch.zeros(
|
| 308 |
+
size_update_period, *param_rms.shape, **kwargs
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
| 312 |
+
state["exp_avg_sq"] = torch.zeros_like(
|
| 313 |
+
p, memory_format=torch.preserve_format
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
def _get_clipping_scale(
|
| 317 |
+
self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
|
| 318 |
+
) -> float:
|
| 319 |
+
"""
|
| 320 |
+
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
| 321 |
+
by this amount before applying the rest of the update.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
group: the parameter group, an item in self.param_groups
|
| 325 |
+
tuples: a list of tuples of (param, state, param_names)
|
| 326 |
+
where param is a batched set of parameters,
|
| 327 |
+
with a .grad (1st dim is batch dim)
|
| 328 |
+
and state is the state-dict where optimization parameters are kept.
|
| 329 |
+
param_names is a List[str] while each str is name for a parameter
|
| 330 |
+
in batched set of parameters "param".
|
| 331 |
+
"""
|
| 332 |
+
assert len(tuples) >= 1
|
| 333 |
+
clipping_scale = group["clipping_scale"]
|
| 334 |
+
(first_p, first_state, _) = tuples[0]
|
| 335 |
+
step = first_state["step"]
|
| 336 |
+
if clipping_scale is None or step == 0:
|
| 337 |
+
# no clipping. return early on step == 0 because the other
|
| 338 |
+
# parameters' state won't have been initialized yet.
|
| 339 |
+
return 1.0
|
| 340 |
+
clipping_update_period = group["clipping_update_period"]
|
| 341 |
+
|
| 342 |
+
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
| 343 |
+
for (p, state, param_names) in tuples:
|
| 344 |
+
grad = p.grad
|
| 345 |
+
if grad.is_sparse:
|
| 346 |
+
raise RuntimeError(
|
| 347 |
+
"ScaledAdam optimizer does not support sparse gradients"
|
| 348 |
+
)
|
| 349 |
+
if p.numel() == p.shape[0]: # a batch of scalars
|
| 350 |
+
tot_sumsq += (
|
| 351 |
+
grad ** 2
|
| 352 |
+
).sum() # sum() to change shape [1] to []
|
| 353 |
+
else:
|
| 354 |
+
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
| 355 |
+
|
| 356 |
+
tot_norm = tot_sumsq.sqrt()
|
| 357 |
+
if "model_norms" not in first_state:
|
| 358 |
+
first_state["model_norms"] = torch.zeros(
|
| 359 |
+
clipping_update_period, device=p.device
|
| 360 |
+
)
|
| 361 |
+
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
| 362 |
+
|
| 363 |
+
if step % clipping_update_period == 0:
|
| 364 |
+
# Print some stats.
|
| 365 |
+
# We don't reach here if step == 0 because we would have returned
|
| 366 |
+
# above.
|
| 367 |
+
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
|
| 368 |
+
quartiles = []
|
| 369 |
+
for n in range(0, 5):
|
| 370 |
+
index = min(
|
| 371 |
+
clipping_update_period - 1,
|
| 372 |
+
(clipping_update_period // 4) * n,
|
| 373 |
+
)
|
| 374 |
+
quartiles.append(sorted_norms[index].item())
|
| 375 |
+
|
| 376 |
+
median = quartiles[2]
|
| 377 |
+
threshold = clipping_scale * median
|
| 378 |
+
first_state["model_norm_threshold"] = threshold
|
| 379 |
+
percent_clipped = (
|
| 380 |
+
first_state["num_clipped"] * 100.0 / clipping_update_period
|
| 381 |
+
if "num_clipped" in first_state
|
| 382 |
+
else 0.0
|
| 383 |
+
)
|
| 384 |
+
first_state["num_clipped"] = 0
|
| 385 |
+
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
| 386 |
+
logging.info(
|
| 387 |
+
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
| 388 |
+
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
if step < clipping_update_period:
|
| 392 |
+
return 1.0 # We have not yet estimated a norm to clip to.
|
| 393 |
+
else:
|
| 394 |
+
try:
|
| 395 |
+
model_norm_threshold = first_state["model_norm_threshold"]
|
| 396 |
+
except KeyError:
|
| 397 |
+
logging.info(
|
| 398 |
+
"Warning: model_norm_threshold not in state: possibly "
|
| 399 |
+
"you changed config when restarting, adding clipping_scale option?"
|
| 400 |
+
)
|
| 401 |
+
return 1.0
|
| 402 |
+
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
| 403 |
+
if ans < 1.0:
|
| 404 |
+
first_state["num_clipped"] += 1
|
| 405 |
+
if ans < 0.1:
|
| 406 |
+
logging.warn(
|
| 407 |
+
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
| 408 |
+
)
|
| 409 |
+
if self.show_dominant_parameters:
|
| 410 |
+
assert p.shape[0] == len(param_names)
|
| 411 |
+
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
| 412 |
+
return ans
|
| 413 |
+
|
| 414 |
+
def _show_gradient_dominating_parameter(
|
| 415 |
+
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
|
| 416 |
+
):
|
| 417 |
+
"""
|
| 418 |
+
Show information of parameter wihch dominanting tot_sumsq.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
tuples: a list of tuples of (param, state, param_names)
|
| 422 |
+
where param is a batched set of parameters,
|
| 423 |
+
with a .grad (1st dim is batch dim)
|
| 424 |
+
and state is the state-dict where optimization parameters are kept.
|
| 425 |
+
param_names is a List[str] while each str is name for a parameter
|
| 426 |
+
in batched set of parameters "param".
|
| 427 |
+
tot_sumsq: sumsq of all parameters. Though it's could be calculated
|
| 428 |
+
from tuples, we still pass it to save some time.
|
| 429 |
+
"""
|
| 430 |
+
all_sumsq_orig = {}
|
| 431 |
+
for (p, state, batch_param_names) in tuples:
|
| 432 |
+
# p is a stacked batch parameters.
|
| 433 |
+
batch_grad = p.grad
|
| 434 |
+
if p.numel() == p.shape[0]: # a batch of scalars
|
| 435 |
+
batch_sumsq_orig = batch_grad ** 2
|
| 436 |
+
# Dummpy values used by following `zip` statement.
|
| 437 |
+
batch_rms_orig = torch.ones(p.shape[0])
|
| 438 |
+
else:
|
| 439 |
+
batch_rms_orig = state["param_rms"]
|
| 440 |
+
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
|
| 441 |
+
dim=list(range(1, batch_grad.ndim))
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
for name, sumsq_orig, rms, grad in zip(
|
| 445 |
+
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
|
| 446 |
+
):
|
| 447 |
+
|
| 448 |
+
proportion_orig = sumsq_orig / tot_sumsq
|
| 449 |
+
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
| 450 |
+
|
| 451 |
+
assert torch.isclose(
|
| 452 |
+
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
| 453 |
+
torch.tensor(1.0),
|
| 454 |
+
)
|
| 455 |
+
sorted_by_proportion = {
|
| 456 |
+
k: v
|
| 457 |
+
for k, v in sorted(
|
| 458 |
+
all_sumsq_orig.items(),
|
| 459 |
+
key=lambda item: item[1][0],
|
| 460 |
+
reverse=True,
|
| 461 |
+
)
|
| 462 |
+
}
|
| 463 |
+
dominant_param_name = next(iter(sorted_by_proportion))
|
| 464 |
+
(
|
| 465 |
+
dominant_proportion,
|
| 466 |
+
dominant_sumsq,
|
| 467 |
+
dominant_rms,
|
| 468 |
+
dominant_grad,
|
| 469 |
+
) = sorted_by_proportion[dominant_param_name]
|
| 470 |
+
logging.info(
|
| 471 |
+
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
| 472 |
+
f" with proportion {dominant_proportion:.2f},"
|
| 473 |
+
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
| 474 |
+
f"={dominant_sumsq:.3e},"
|
| 475 |
+
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
| 476 |
+
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
def _step_one_batch(
|
| 480 |
+
self, group: dict, p: Tensor, state: dict, clipping_scale: float
|
| 481 |
+
):
|
| 482 |
+
"""
|
| 483 |
+
Do the step for one parameter, which is actually going to be a batch of
|
| 484 |
+
`real` parameters, with dim 0 as the batch dim.
|
| 485 |
+
Args:
|
| 486 |
+
group: dict to look up configuration values
|
| 487 |
+
p: parameter to update (actually multiple parameters stacked together
|
| 488 |
+
as a batch)
|
| 489 |
+
state: state-dict for p, to look up the optimizer state
|
| 490 |
+
"""
|
| 491 |
+
lr = group["lr"]
|
| 492 |
+
size_update_period = group["size_update_period"]
|
| 493 |
+
beta1 = group["betas"][0]
|
| 494 |
+
|
| 495 |
+
grad = p.grad
|
| 496 |
+
if clipping_scale != 1.0:
|
| 497 |
+
grad = grad * clipping_scale
|
| 498 |
+
step = state["step"]
|
| 499 |
+
delta = state["delta"]
|
| 500 |
+
|
| 501 |
+
delta.mul_(beta1)
|
| 502 |
+
batch_size = p.shape[0]
|
| 503 |
+
numel = p.numel() // batch_size
|
| 504 |
+
if numel > 1:
|
| 505 |
+
# Update the size/scale of p, and set param_rms
|
| 506 |
+
scale_grads = state["scale_grads"]
|
| 507 |
+
scale_grads[step % size_update_period] = (p * grad).sum(
|
| 508 |
+
dim=list(range(1, p.ndim)), keepdim=True
|
| 509 |
+
)
|
| 510 |
+
if step % size_update_period == size_update_period - 1:
|
| 511 |
+
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
| 512 |
+
param_rms.copy_(
|
| 513 |
+
(p ** 2)
|
| 514 |
+
.mean(dim=list(range(1, p.ndim)), keepdim=True)
|
| 515 |
+
.sqrt()
|
| 516 |
+
)
|
| 517 |
+
if step > 0:
|
| 518 |
+
# self._size_update() learns the overall scale on the
|
| 519 |
+
# parameter, by shrinking or expanding it.
|
| 520 |
+
self._size_update(group, scale_grads, p, state)
|
| 521 |
+
|
| 522 |
+
if numel == 1:
|
| 523 |
+
# For parameters with 1 element we just use regular Adam.
|
| 524 |
+
# Updates delta.
|
| 525 |
+
self._step_scalar(group, p, state)
|
| 526 |
+
else:
|
| 527 |
+
self._step(group, p, state)
|
| 528 |
+
|
| 529 |
+
state["step"] = step + 1
|
| 530 |
+
|
| 531 |
+
def _size_update(
|
| 532 |
+
self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
|
| 533 |
+
) -> None:
|
| 534 |
+
"""
|
| 535 |
+
Called only where p.numel() > 1, this updates the scale of the parameter.
|
| 536 |
+
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
| 537 |
+
gradient descent on underlying param and on scale, this function does the update
|
| 538 |
+
on `scale`.
|
| 539 |
+
|
| 540 |
+
Args:
|
| 541 |
+
group: dict to look up configuration values
|
| 542 |
+
scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
|
| 543 |
+
grads w.r.t. the scales.
|
| 544 |
+
p: The parameter to update
|
| 545 |
+
state: The state-dict of p
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
param_rms = state["param_rms"]
|
| 549 |
+
beta1, beta2 = group["betas"]
|
| 550 |
+
size_lr = group["lr"] * group["scalar_lr_scale"]
|
| 551 |
+
param_min_rms = group["param_min_rms"]
|
| 552 |
+
param_max_rms = group["param_max_rms"]
|
| 553 |
+
eps = group["eps"]
|
| 554 |
+
step = state["step"]
|
| 555 |
+
batch_size = p.shape[0]
|
| 556 |
+
|
| 557 |
+
size_update_period = scale_grads.shape[0]
|
| 558 |
+
# correct beta2 for the size update period: we will have
|
| 559 |
+
# faster decay at this level.
|
| 560 |
+
beta2_corr = beta2 ** size_update_period
|
| 561 |
+
|
| 562 |
+
scale_exp_avg_sq = state[
|
| 563 |
+
"scale_exp_avg_sq"
|
| 564 |
+
] # shape: (batch_size, 1, 1, ..)
|
| 565 |
+
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
| 566 |
+
(scale_grads ** 2).mean(
|
| 567 |
+
dim=0
|
| 568 |
+
), # mean over dim `size_update_period`
|
| 569 |
+
alpha=1 - beta2_corr,
|
| 570 |
+
) # shape is (batch_size, 1, 1, ...)
|
| 571 |
+
|
| 572 |
+
# The 1st time we reach here is when size_step == 1.
|
| 573 |
+
size_step = (step + 1) // size_update_period
|
| 574 |
+
bias_correction2 = 1 - beta2_corr ** size_step
|
| 575 |
+
# we don't bother with bias_correction1; this will help prevent divergence
|
| 576 |
+
# at the start of training.
|
| 577 |
+
|
| 578 |
+
denom = scale_exp_avg_sq.sqrt() + eps
|
| 579 |
+
|
| 580 |
+
scale_step = (
|
| 581 |
+
-size_lr
|
| 582 |
+
* (bias_correction2 ** 0.5)
|
| 583 |
+
* scale_grads.sum(dim=0)
|
| 584 |
+
/ denom
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
is_too_small = param_rms < param_min_rms
|
| 588 |
+
is_too_large = param_rms > param_max_rms
|
| 589 |
+
|
| 590 |
+
# when the param gets too small, just don't shrink it any further.
|
| 591 |
+
scale_step.masked_fill_(is_too_small, 0.0)
|
| 592 |
+
# when it gets too large, stop it from getting any larger.
|
| 593 |
+
scale_step.masked_fill_(is_too_large, -size_lr * size_update_period)
|
| 594 |
+
delta = state["delta"]
|
| 595 |
+
# the factor of (1-beta1) relates to momentum.
|
| 596 |
+
delta.add_(p * scale_step, alpha=(1 - beta1))
|
| 597 |
+
|
| 598 |
+
def _step(self, group: dict, p: Tensor, state: dict):
|
| 599 |
+
"""
|
| 600 |
+
This function does the core update of self.step(), in the case where the members of
|
| 601 |
+
the batch have more than 1 element.
|
| 602 |
+
|
| 603 |
+
Args:
|
| 604 |
+
group: A dict which will be used to look up configuration values
|
| 605 |
+
p: The parameter to be updated
|
| 606 |
+
grad: The grad of p
|
| 607 |
+
state: The state-dict corresponding to parameter p
|
| 608 |
+
|
| 609 |
+
This function modifies p.
|
| 610 |
+
"""
|
| 611 |
+
grad = p.grad
|
| 612 |
+
lr = group["lr"]
|
| 613 |
+
beta1, beta2 = group["betas"]
|
| 614 |
+
eps = group["eps"]
|
| 615 |
+
param_min_rms = group["param_min_rms"]
|
| 616 |
+
step = state["step"]
|
| 617 |
+
|
| 618 |
+
exp_avg_sq = state["exp_avg_sq"]
|
| 619 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
| 620 |
+
|
| 621 |
+
this_step = state["step"] - (
|
| 622 |
+
state["zero_step"] if "zero_step" in state else 0
|
| 623 |
+
)
|
| 624 |
+
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
| 625 |
+
if bias_correction2 < 0.99:
|
| 626 |
+
# note: not in-place.
|
| 627 |
+
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
| 628 |
+
|
| 629 |
+
denom = exp_avg_sq.sqrt()
|
| 630 |
+
denom += eps
|
| 631 |
+
grad = grad / denom
|
| 632 |
+
|
| 633 |
+
alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
|
| 634 |
+
|
| 635 |
+
delta = state["delta"]
|
| 636 |
+
delta.add_(grad * alpha)
|
| 637 |
+
p.add_(delta)
|
| 638 |
+
|
| 639 |
+
def _step_scalar(self, group: dict, p: Tensor, state: dict):
|
| 640 |
+
"""
|
| 641 |
+
A simplified form of the core update for scalar tensors, where we cannot get a good
|
| 642 |
+
estimate of the parameter rms.
|
| 643 |
+
"""
|
| 644 |
+
beta1, beta2 = group["betas"]
|
| 645 |
+
scalar_max = group["scalar_max"]
|
| 646 |
+
eps = group["eps"]
|
| 647 |
+
lr = group["lr"] * group["scalar_lr_scale"]
|
| 648 |
+
grad = p.grad
|
| 649 |
+
|
| 650 |
+
exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
|
| 651 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
| 652 |
+
|
| 653 |
+
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
| 654 |
+
# slower update at the start will help stability anyway.
|
| 655 |
+
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
| 656 |
+
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
| 657 |
+
|
| 658 |
+
delta = state["delta"]
|
| 659 |
+
delta.add_(grad / denom, alpha=-lr * (1 - beta1))
|
| 660 |
+
p.clamp_(min=-scalar_max, max=scalar_max)
|
| 661 |
+
p.add_(delta)
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
class LRScheduler(object):
|
| 665 |
+
"""
|
| 666 |
+
Base-class for learning rate schedulers where the learning-rate depends on both the
|
| 667 |
+
batch and the epoch.
|
| 668 |
+
"""
|
| 669 |
+
|
| 670 |
+
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
| 671 |
+
# Attach optimizer
|
| 672 |
+
if not isinstance(optimizer, Optimizer):
|
| 673 |
+
raise TypeError(
|
| 674 |
+
"{} is not an Optimizer".format(type(optimizer).__name__)
|
| 675 |
+
)
|
| 676 |
+
self.optimizer = optimizer
|
| 677 |
+
self.verbose = verbose
|
| 678 |
+
|
| 679 |
+
for group in optimizer.param_groups:
|
| 680 |
+
group.setdefault("base_lr", group["lr"])
|
| 681 |
+
|
| 682 |
+
self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
|
| 683 |
+
|
| 684 |
+
self.epoch = 0
|
| 685 |
+
self.batch = 0
|
| 686 |
+
|
| 687 |
+
def state_dict(self):
|
| 688 |
+
"""Returns the state of the scheduler as a :class:`dict`.
|
| 689 |
+
|
| 690 |
+
It contains an entry for every variable in self.__dict__ which
|
| 691 |
+
is not the optimizer.
|
| 692 |
+
"""
|
| 693 |
+
return {
|
| 694 |
+
"base_lrs": self.base_lrs,
|
| 695 |
+
"epoch": self.epoch,
|
| 696 |
+
"batch": self.batch,
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
def load_state_dict(self, state_dict):
|
| 700 |
+
"""Loads the schedulers state.
|
| 701 |
+
|
| 702 |
+
Args:
|
| 703 |
+
state_dict (dict): scheduler state. Should be an object returned
|
| 704 |
+
from a call to :meth:`state_dict`.
|
| 705 |
+
"""
|
| 706 |
+
self.__dict__.update(state_dict)
|
| 707 |
+
|
| 708 |
+
def get_last_lr(self) -> List[float]:
|
| 709 |
+
"""Return last computed learning rate by current scheduler. Will be a list of float."""
|
| 710 |
+
return self._last_lr
|
| 711 |
+
|
| 712 |
+
def get_lr(self):
|
| 713 |
+
# Compute list of learning rates from self.epoch and self.batch and
|
| 714 |
+
# self.base_lrs; this must be overloaded by the user.
|
| 715 |
+
# e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
|
| 716 |
+
raise NotImplementedError
|
| 717 |
+
|
| 718 |
+
def step_batch(self, batch: Optional[int] = None) -> None:
|
| 719 |
+
# Step the batch index, or just set it. If `batch` is specified, it
|
| 720 |
+
# must be the batch index from the start of training, i.e. summed over
|
| 721 |
+
# all epochs.
|
| 722 |
+
# You can call this in any order; if you don't provide 'batch', it should
|
| 723 |
+
# of course be called once per batch.
|
| 724 |
+
if batch is not None:
|
| 725 |
+
self.batch = batch
|
| 726 |
+
else:
|
| 727 |
+
self.batch = self.batch + 1
|
| 728 |
+
self._set_lrs()
|
| 729 |
+
|
| 730 |
+
def step_epoch(self, epoch: Optional[int] = None):
|
| 731 |
+
# Step the epoch index, or just set it. If you provide the 'epoch' arg,
|
| 732 |
+
# you should call this at the start of the epoch; if you don't provide the 'epoch'
|
| 733 |
+
# arg, you should call it at the end of the epoch.
|
| 734 |
+
if epoch is not None:
|
| 735 |
+
self.epoch = epoch
|
| 736 |
+
else:
|
| 737 |
+
self.epoch = self.epoch + 1
|
| 738 |
+
self._set_lrs()
|
| 739 |
+
|
| 740 |
+
def _set_lrs(self):
|
| 741 |
+
values = self.get_lr()
|
| 742 |
+
assert len(values) == len(self.optimizer.param_groups)
|
| 743 |
+
|
| 744 |
+
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
| 745 |
+
param_group, lr = data
|
| 746 |
+
param_group["lr"] = lr
|
| 747 |
+
self.print_lr(self.verbose, i, lr)
|
| 748 |
+
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
| 749 |
+
|
| 750 |
+
def print_lr(self, is_verbose, group, lr):
|
| 751 |
+
"""Display the current learning rate."""
|
| 752 |
+
if is_verbose:
|
| 753 |
+
logging.info(
|
| 754 |
+
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
|
| 755 |
+
f" of group {group} to {lr:.4e}."
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
class Eden(LRScheduler):
|
| 760 |
+
"""
|
| 761 |
+
Eden scheduler.
|
| 762 |
+
The basic formula (before warmup) is:
|
| 763 |
+
lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
|
| 764 |
+
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
|
| 765 |
+
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
|
| 766 |
+
and then stays constant at 1.
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
|
| 770 |
+
|
| 771 |
+
Args:
|
| 772 |
+
optimizer: the optimizer to change the learning rates on
|
| 773 |
+
lr_batches: the number of batches after which we start significantly
|
| 774 |
+
decreasing the learning rate, suggest 5000.
|
| 775 |
+
lr_epochs: the number of epochs after which we start significantly
|
| 776 |
+
decreasing the learning rate, suggest 6 if you plan to do e.g.
|
| 777 |
+
20 to 40 epochs, but may need smaller number if dataset is huge
|
| 778 |
+
and you will do few epochs.
|
| 779 |
+
"""
|
| 780 |
+
|
| 781 |
+
def __init__(
|
| 782 |
+
self,
|
| 783 |
+
optimizer: Optimizer,
|
| 784 |
+
lr_batches: Union[int, float],
|
| 785 |
+
lr_epochs: Union[int, float],
|
| 786 |
+
warmup_batches: Union[int, float] = 500.0,
|
| 787 |
+
verbose: bool = False,
|
| 788 |
+
):
|
| 789 |
+
super(Eden, self).__init__(optimizer, verbose)
|
| 790 |
+
self.lr_batches = lr_batches
|
| 791 |
+
self.lr_epochs = lr_epochs
|
| 792 |
+
self.warmup_batches = warmup_batches
|
| 793 |
+
|
| 794 |
+
def get_lr(self):
|
| 795 |
+
factor = (
|
| 796 |
+
(self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2
|
| 797 |
+
) ** -0.25 * (
|
| 798 |
+
((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2)
|
| 799 |
+
** -0.25
|
| 800 |
+
)
|
| 801 |
+
warmup_factor = (
|
| 802 |
+
1.0
|
| 803 |
+
if self.batch >= self.warmup_batches
|
| 804 |
+
else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
return [x * factor * warmup_factor for x in self.base_lrs]
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
def _test_eden():
|
| 811 |
+
m = torch.nn.Linear(100, 100)
|
| 812 |
+
optim = ScaledAdam(m.parameters(), lr=0.03)
|
| 813 |
+
|
| 814 |
+
scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
|
| 815 |
+
|
| 816 |
+
for epoch in range(10):
|
| 817 |
+
scheduler.step_epoch(epoch) # sets epoch to `epoch`
|
| 818 |
+
|
| 819 |
+
for step in range(20):
|
| 820 |
+
x = torch.randn(200, 100).detach()
|
| 821 |
+
x.requires_grad = True
|
| 822 |
+
y = m(x)
|
| 823 |
+
dy = torch.randn(200, 100).detach()
|
| 824 |
+
f = (y * dy).sum()
|
| 825 |
+
f.backward()
|
| 826 |
+
|
| 827 |
+
optim.step()
|
| 828 |
+
scheduler.step_batch()
|
| 829 |
+
optim.zero_grad()
|
| 830 |
+
|
| 831 |
+
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
| 832 |
+
logging.info(f"state dict = {scheduler.state_dict()}")
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
# This is included mostly as a baseline for ScaledAdam.
|
| 836 |
+
class Eve(Optimizer):
|
| 837 |
+
"""
|
| 838 |
+
Implements Eve algorithm. This is a modified version of AdamW with a special
|
| 839 |
+
way of setting the weight-decay / shrinkage-factor, which is designed to make the
|
| 840 |
+
rms of the parameters approach a particular target_rms (default: 0.1). This is
|
| 841 |
+
for use with networks with 'scaled' versions of modules (see scaling.py), which
|
| 842 |
+
will be close to invariant to the absolute scale on the parameter matrix.
|
| 843 |
+
|
| 844 |
+
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
|
| 845 |
+
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
|
| 846 |
+
Eve is unpublished so far.
|
| 847 |
+
|
| 848 |
+
Arguments:
|
| 849 |
+
params (iterable): iterable of parameters to optimize or dicts defining
|
| 850 |
+
parameter groups
|
| 851 |
+
lr (float, optional): learning rate (default: 1e-3)
|
| 852 |
+
betas (Tuple[float, float], optional): coefficients used for computing
|
| 853 |
+
running averages of gradient and its square (default: (0.9, 0.999))
|
| 854 |
+
eps (float, optional): term added to the denominator to improve
|
| 855 |
+
numerical stability (default: 1e-8)
|
| 856 |
+
weight_decay (float, optional): weight decay coefficient (default: 3e-4;
|
| 857 |
+
this value means that the weight would decay significantly after
|
| 858 |
+
about 3k minibatches. Is not multiplied by learning rate, but
|
| 859 |
+
is conditional on RMS-value of parameter being > target_rms.
|
| 860 |
+
target_rms (float, optional): target root-mean-square value of
|
| 861 |
+
parameters, if they fall below this we will stop applying weight decay.
|
| 862 |
+
|
| 863 |
+
|
| 864 |
+
.. _Adam: A Method for Stochastic Optimization:
|
| 865 |
+
https://arxiv.org/abs/1412.6980
|
| 866 |
+
.. _Decoupled Weight Decay Regularization:
|
| 867 |
+
https://arxiv.org/abs/1711.05101
|
| 868 |
+
.. _On the Convergence of Adam and Beyond:
|
| 869 |
+
https://openreview.net/forum?id=ryQu7f-RZ
|
| 870 |
+
"""
|
| 871 |
+
|
| 872 |
+
def __init__(
|
| 873 |
+
self,
|
| 874 |
+
params,
|
| 875 |
+
lr=1e-3,
|
| 876 |
+
betas=(0.9, 0.98),
|
| 877 |
+
eps=1e-8,
|
| 878 |
+
weight_decay=1e-3,
|
| 879 |
+
target_rms=0.1,
|
| 880 |
+
):
|
| 881 |
+
if not 0.0 <= lr:
|
| 882 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
| 883 |
+
if not 0.0 <= eps:
|
| 884 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
| 885 |
+
if not 0.0 <= betas[0] < 1.0:
|
| 886 |
+
raise ValueError(
|
| 887 |
+
"Invalid beta parameter at index 0: {}".format(betas[0])
|
| 888 |
+
)
|
| 889 |
+
if not 0.0 <= betas[1] < 1.0:
|
| 890 |
+
raise ValueError(
|
| 891 |
+
"Invalid beta parameter at index 1: {}".format(betas[1])
|
| 892 |
+
)
|
| 893 |
+
if not 0 <= weight_decay <= 0.1:
|
| 894 |
+
raise ValueError(
|
| 895 |
+
"Invalid weight_decay value: {}".format(weight_decay)
|
| 896 |
+
)
|
| 897 |
+
if not 0 < target_rms <= 10.0:
|
| 898 |
+
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
| 899 |
+
defaults = dict(
|
| 900 |
+
lr=lr,
|
| 901 |
+
betas=betas,
|
| 902 |
+
eps=eps,
|
| 903 |
+
weight_decay=weight_decay,
|
| 904 |
+
target_rms=target_rms,
|
| 905 |
+
)
|
| 906 |
+
super(Eve, self).__init__(params, defaults)
|
| 907 |
+
|
| 908 |
+
def __setstate__(self, state):
|
| 909 |
+
super(Eve, self).__setstate__(state)
|
| 910 |
+
|
| 911 |
+
@torch.no_grad()
|
| 912 |
+
def step(self, closure=None):
|
| 913 |
+
"""Performs a single optimization step.
|
| 914 |
+
|
| 915 |
+
Arguments:
|
| 916 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 917 |
+
and returns the loss.
|
| 918 |
+
"""
|
| 919 |
+
loss = None
|
| 920 |
+
if closure is not None:
|
| 921 |
+
with torch.enable_grad():
|
| 922 |
+
loss = closure()
|
| 923 |
+
|
| 924 |
+
for group in self.param_groups:
|
| 925 |
+
for p in group["params"]:
|
| 926 |
+
if p.grad is None:
|
| 927 |
+
continue
|
| 928 |
+
|
| 929 |
+
# Perform optimization step
|
| 930 |
+
grad = p.grad
|
| 931 |
+
if grad.is_sparse:
|
| 932 |
+
raise RuntimeError(
|
| 933 |
+
"AdamW does not support sparse gradients"
|
| 934 |
+
)
|
| 935 |
+
|
| 936 |
+
state = self.state[p]
|
| 937 |
+
|
| 938 |
+
# State initialization
|
| 939 |
+
if len(state) == 0:
|
| 940 |
+
state["step"] = 0
|
| 941 |
+
# Exponential moving average of gradient values
|
| 942 |
+
state["exp_avg"] = torch.zeros_like(
|
| 943 |
+
p, memory_format=torch.preserve_format
|
| 944 |
+
)
|
| 945 |
+
# Exponential moving average of squared gradient values
|
| 946 |
+
state["exp_avg_sq"] = torch.zeros_like(
|
| 947 |
+
p, memory_format=torch.preserve_format
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
| 951 |
+
|
| 952 |
+
beta1, beta2 = group["betas"]
|
| 953 |
+
|
| 954 |
+
state["step"] += 1
|
| 955 |
+
bias_correction1 = 1 - beta1 ** state["step"]
|
| 956 |
+
bias_correction2 = 1 - beta2 ** state["step"]
|
| 957 |
+
|
| 958 |
+
# Decay the first and second moment running average coefficient
|
| 959 |
+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
| 960 |
+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
| 961 |
+
denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_(
|
| 962 |
+
group["eps"]
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
step_size = group["lr"] / bias_correction1
|
| 966 |
+
target_rms = group["target_rms"]
|
| 967 |
+
weight_decay = group["weight_decay"]
|
| 968 |
+
|
| 969 |
+
if p.numel() > 1:
|
| 970 |
+
# avoid applying this weight-decay on "scaling factors"
|
| 971 |
+
# (which are scalar).
|
| 972 |
+
is_above_target_rms = p.norm() > (
|
| 973 |
+
target_rms * (p.numel() ** 0.5)
|
| 974 |
+
)
|
| 975 |
+
p.mul_(1 - (weight_decay * is_above_target_rms))
|
| 976 |
+
|
| 977 |
+
p.addcdiv_(exp_avg, denom, value=-step_size)
|
| 978 |
+
|
| 979 |
+
# if random.random() < 0.0005:
|
| 980 |
+
# step = (exp_avg / denom) * step_size
|
| 981 |
+
# logging.info(
|
| 982 |
+
# f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
|
| 983 |
+
# )
|
| 984 |
+
|
| 985 |
+
return loss
|
| 986 |
+
|
| 987 |
+
|
| 988 |
+
def _test_scaled_adam(hidden_dim: int):
|
| 989 |
+
import timeit
|
| 990 |
+
|
| 991 |
+
from scaling import ScaledLinear
|
| 992 |
+
|
| 993 |
+
E = 100
|
| 994 |
+
B = 4
|
| 995 |
+
T = 2
|
| 996 |
+
logging.info("in test_eve_cain")
|
| 997 |
+
# device = torch.device('cuda')
|
| 998 |
+
device = torch.device("cpu")
|
| 999 |
+
dtype = torch.float32
|
| 1000 |
+
|
| 1001 |
+
fix_random_seed(42)
|
| 1002 |
+
# these input_magnitudes and output_magnitudes are to test that
|
| 1003 |
+
# Abel is working as we expect and is able to adjust scales of
|
| 1004 |
+
# different dims differently.
|
| 1005 |
+
input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
| 1006 |
+
output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
|
| 1007 |
+
|
| 1008 |
+
for iter in [1, 0]:
|
| 1009 |
+
fix_random_seed(42)
|
| 1010 |
+
Linear = torch.nn.Linear if iter == 0 else ScaledLinear
|
| 1011 |
+
|
| 1012 |
+
m = torch.nn.Sequential(
|
| 1013 |
+
Linear(E, hidden_dim),
|
| 1014 |
+
torch.nn.PReLU(),
|
| 1015 |
+
Linear(hidden_dim, hidden_dim),
|
| 1016 |
+
torch.nn.PReLU(),
|
| 1017 |
+
Linear(hidden_dim, E),
|
| 1018 |
+
).to(device)
|
| 1019 |
+
|
| 1020 |
+
train_pairs = [
|
| 1021 |
+
(
|
| 1022 |
+
100.0
|
| 1023 |
+
* torch.randn(B, T, E, device=device, dtype=dtype)
|
| 1024 |
+
* input_magnitudes,
|
| 1025 |
+
torch.randn(B, T, E, device=device, dtype=dtype)
|
| 1026 |
+
* output_magnitudes,
|
| 1027 |
+
)
|
| 1028 |
+
for _ in range(20)
|
| 1029 |
+
]
|
| 1030 |
+
|
| 1031 |
+
if iter == 0:
|
| 1032 |
+
optim = Eve(m.parameters(), lr=0.003)
|
| 1033 |
+
elif iter == 1:
|
| 1034 |
+
optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
|
| 1035 |
+
scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
|
| 1036 |
+
|
| 1037 |
+
start = timeit.default_timer()
|
| 1038 |
+
avg_loss = 0.0
|
| 1039 |
+
for epoch in range(180):
|
| 1040 |
+
scheduler.step_epoch()
|
| 1041 |
+
# if epoch == 100 and iter in [2,3]:
|
| 1042 |
+
# optim.reset_speedup() # check it doesn't crash.
|
| 1043 |
+
|
| 1044 |
+
# if epoch == 130:
|
| 1045 |
+
# opts = diagnostics.TensorDiagnosticOptions(
|
| 1046 |
+
# 2 ** 22
|
| 1047 |
+
# ) # allow 4 megabytes per sub-module
|
| 1048 |
+
# diagnostic = diagnostics.attach_diagnostics(m, opts)
|
| 1049 |
+
|
| 1050 |
+
for n, (x, y) in enumerate(train_pairs):
|
| 1051 |
+
y_out = m(x)
|
| 1052 |
+
loss = ((y_out - y) ** 2).mean() * 100.0
|
| 1053 |
+
if epoch == 0 and n == 0:
|
| 1054 |
+
avg_loss = loss.item()
|
| 1055 |
+
else:
|
| 1056 |
+
avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
|
| 1057 |
+
if n == 0 and epoch % 5 == 0:
|
| 1058 |
+
# norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
|
| 1059 |
+
# norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
|
| 1060 |
+
# norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
|
| 1061 |
+
# norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
|
| 1062 |
+
# scale1 = '%.2e' % (m[0].weight_scale.exp().item())
|
| 1063 |
+
# scale1b = '%.2e' % (m[0].bias_scale.exp().item())
|
| 1064 |
+
# scale2 = '%.2e' % (m[2].weight_scale.exp().item())
|
| 1065 |
+
# scale2b = '%.2e' % (m[2].bias_scale.exp().item())
|
| 1066 |
+
lr = scheduler.get_last_lr()[0]
|
| 1067 |
+
logging.info(
|
| 1068 |
+
f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
|
| 1069 |
+
) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
|
| 1070 |
+
loss.log().backward()
|
| 1071 |
+
optim.step()
|
| 1072 |
+
optim.zero_grad()
|
| 1073 |
+
scheduler.step_batch()
|
| 1074 |
+
|
| 1075 |
+
# diagnostic.print_diagnostics()
|
| 1076 |
+
|
| 1077 |
+
stop = timeit.default_timer()
|
| 1078 |
+
logging.info(f"Iter={iter}, Time taken: {stop - start}")
|
| 1079 |
+
|
| 1080 |
+
logging.info(f"last lr = {scheduler.get_last_lr()}")
|
| 1081 |
+
# logging.info("state dict = ", scheduler.state_dict())
|
| 1082 |
+
# logging.info("optim state_dict = ", optim.state_dict())
|
| 1083 |
+
logging.info(f"input_magnitudes = {input_magnitudes}")
|
| 1084 |
+
logging.info(f"output_magnitudes = {output_magnitudes}")
|
| 1085 |
+
|
| 1086 |
+
|
| 1087 |
+
if __name__ == "__main__":
|
| 1088 |
+
torch.set_num_threads(1)
|
| 1089 |
+
torch.set_num_interop_threads(1)
|
| 1090 |
+
logging.getLogger().setLevel(logging.INFO)
|
| 1091 |
+
import subprocess
|
| 1092 |
+
|
| 1093 |
+
s = subprocess.check_output(
|
| 1094 |
+
"git status -uno .; git log -1; git diff HEAD .", shell=True
|
| 1095 |
+
)
|
| 1096 |
+
logging.info(s)
|
| 1097 |
+
import sys
|
| 1098 |
+
|
| 1099 |
+
if len(sys.argv) > 1:
|
| 1100 |
+
hidden_dim = int(sys.argv[1])
|
| 1101 |
+
else:
|
| 1102 |
+
hidden_dim = 200
|
| 1103 |
+
|
| 1104 |
+
_test_scaled_adam(hidden_dim)
|
| 1105 |
+
_test_eden()
|
apps/audio_cloning/vallex/modules/scaling.py
ADDED
|
@@ -0,0 +1,1369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
|
| 2 |
+
#
|
| 3 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import logging
|
| 19 |
+
import math
|
| 20 |
+
import random
|
| 21 |
+
from typing import Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
from torch import Tensor
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Transpose(nn.Identity):
|
| 29 |
+
"""(N, T, D) -> (N, D, T)"""
|
| 30 |
+
|
| 31 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
return input.transpose(1, 2)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ActivationBalancerFunction(torch.autograd.Function):
|
| 36 |
+
@staticmethod
|
| 37 |
+
def forward(
|
| 38 |
+
ctx,
|
| 39 |
+
x: Tensor,
|
| 40 |
+
scale_factor: Tensor,
|
| 41 |
+
sign_factor: Optional[Tensor],
|
| 42 |
+
channel_dim: int,
|
| 43 |
+
) -> Tensor:
|
| 44 |
+
if channel_dim < 0:
|
| 45 |
+
channel_dim += x.ndim
|
| 46 |
+
ctx.channel_dim = channel_dim
|
| 47 |
+
xgt0 = x > 0
|
| 48 |
+
if sign_factor is None:
|
| 49 |
+
ctx.save_for_backward(xgt0, scale_factor)
|
| 50 |
+
else:
|
| 51 |
+
ctx.save_for_backward(xgt0, scale_factor, sign_factor)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
| 56 |
+
if len(ctx.saved_tensors) == 3:
|
| 57 |
+
xgt0, scale_factor, sign_factor = ctx.saved_tensors
|
| 58 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
| 59 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
| 60 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
| 61 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
| 62 |
+
else:
|
| 63 |
+
xgt0, scale_factor = ctx.saved_tensors
|
| 64 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
| 65 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
| 66 |
+
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
| 67 |
+
neg_delta_grad = x_grad.abs() * factor
|
| 68 |
+
return (
|
| 69 |
+
x_grad - neg_delta_grad,
|
| 70 |
+
None,
|
| 71 |
+
None,
|
| 72 |
+
None,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _compute_scale_factor(
|
| 77 |
+
x: Tensor,
|
| 78 |
+
channel_dim: int,
|
| 79 |
+
min_abs: float,
|
| 80 |
+
max_abs: float,
|
| 81 |
+
gain_factor: float,
|
| 82 |
+
max_factor: float,
|
| 83 |
+
) -> Tensor:
|
| 84 |
+
if channel_dim < 0:
|
| 85 |
+
channel_dim += x.ndim
|
| 86 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
| 87 |
+
x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32)
|
| 88 |
+
|
| 89 |
+
if min_abs == 0.0:
|
| 90 |
+
below_threshold = 0.0
|
| 91 |
+
else:
|
| 92 |
+
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
| 93 |
+
# x_abs)_mean , min_abs.
|
| 94 |
+
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
|
| 95 |
+
min=0, max=max_factor
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
|
| 99 |
+
min=0, max=max_factor
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
return below_threshold - above_threshold
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _compute_sign_factor(
|
| 106 |
+
x: Tensor,
|
| 107 |
+
channel_dim: int,
|
| 108 |
+
min_positive: float,
|
| 109 |
+
max_positive: float,
|
| 110 |
+
gain_factor: float,
|
| 111 |
+
max_factor: float,
|
| 112 |
+
) -> Tensor:
|
| 113 |
+
if channel_dim < 0:
|
| 114 |
+
channel_dim += x.ndim
|
| 115 |
+
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
|
| 116 |
+
proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims)
|
| 117 |
+
if min_positive == 0.0:
|
| 118 |
+
factor1 = 0.0
|
| 119 |
+
else:
|
| 120 |
+
# 0 if proportion_positive >= min_positive, else can be
|
| 121 |
+
# as large as max_factor.
|
| 122 |
+
factor1 = (
|
| 123 |
+
(min_positive - proportion_positive) * (gain_factor / min_positive)
|
| 124 |
+
).clamp_(min=0, max=max_factor)
|
| 125 |
+
|
| 126 |
+
if max_positive == 1.0:
|
| 127 |
+
factor2 = 0.0
|
| 128 |
+
else:
|
| 129 |
+
# 0 if self.proportion_positive <= max_positive, else can be
|
| 130 |
+
# as large as -max_factor.
|
| 131 |
+
factor2 = (
|
| 132 |
+
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
|
| 133 |
+
).clamp_(min=0, max=max_factor)
|
| 134 |
+
sign_factor = factor1 - factor2
|
| 135 |
+
# require min_positive != 0 or max_positive != 1:
|
| 136 |
+
assert not isinstance(sign_factor, float)
|
| 137 |
+
return sign_factor
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class ActivationScaleBalancerFunction(torch.autograd.Function):
|
| 141 |
+
"""
|
| 142 |
+
This object is used in class ActivationBalancer when the user specified
|
| 143 |
+
min_positive=0, max_positive=1, so there are no constraints on the signs
|
| 144 |
+
of the activations and only the absolute value has a constraint.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
@staticmethod
|
| 148 |
+
def forward(
|
| 149 |
+
ctx,
|
| 150 |
+
x: Tensor,
|
| 151 |
+
sign_factor: Tensor,
|
| 152 |
+
scale_factor: Tensor,
|
| 153 |
+
channel_dim: int,
|
| 154 |
+
) -> Tensor:
|
| 155 |
+
if channel_dim < 0:
|
| 156 |
+
channel_dim += x.ndim
|
| 157 |
+
ctx.channel_dim = channel_dim
|
| 158 |
+
xgt0 = x > 0
|
| 159 |
+
ctx.save_for_backward(xgt0, sign_factor, scale_factor)
|
| 160 |
+
return x
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]:
|
| 164 |
+
xgt0, sign_factor, scale_factor = ctx.saved_tensors
|
| 165 |
+
for _ in range(ctx.channel_dim, x_grad.ndim - 1):
|
| 166 |
+
sign_factor = sign_factor.unsqueeze(-1)
|
| 167 |
+
scale_factor = scale_factor.unsqueeze(-1)
|
| 168 |
+
|
| 169 |
+
factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
|
| 170 |
+
neg_delta_grad = x_grad.abs() * factor
|
| 171 |
+
return (
|
| 172 |
+
x_grad - neg_delta_grad,
|
| 173 |
+
None,
|
| 174 |
+
None,
|
| 175 |
+
None,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class RandomClampFunction(torch.autograd.Function):
|
| 180 |
+
@staticmethod
|
| 181 |
+
def forward(
|
| 182 |
+
ctx,
|
| 183 |
+
x: Tensor,
|
| 184 |
+
min: Optional[float],
|
| 185 |
+
max: Optional[float],
|
| 186 |
+
prob: float,
|
| 187 |
+
reflect: float,
|
| 188 |
+
) -> Tensor:
|
| 189 |
+
x_clamped = torch.clamp(x, min=min, max=max)
|
| 190 |
+
mask = torch.rand_like(x) < prob
|
| 191 |
+
ans = torch.where(mask, x_clamped, x)
|
| 192 |
+
if x.requires_grad:
|
| 193 |
+
ctx.save_for_backward(ans == x)
|
| 194 |
+
ctx.reflect = reflect
|
| 195 |
+
if reflect != 0.0:
|
| 196 |
+
ans = ans * (1.0 + reflect) - (x * reflect)
|
| 197 |
+
return ans
|
| 198 |
+
|
| 199 |
+
@staticmethod
|
| 200 |
+
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]:
|
| 201 |
+
(is_same,) = ctx.saved_tensors
|
| 202 |
+
x_grad = ans_grad * is_same.to(ans_grad.dtype)
|
| 203 |
+
reflect = ctx.reflect
|
| 204 |
+
if reflect != 0.0:
|
| 205 |
+
x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect)
|
| 206 |
+
return x_grad, None, None, None, None
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def random_clamp(
|
| 210 |
+
x: Tensor,
|
| 211 |
+
min: Optional[float] = None,
|
| 212 |
+
max: Optional[float] = None,
|
| 213 |
+
prob: float = 0.5,
|
| 214 |
+
reflect: float = 0.0,
|
| 215 |
+
):
|
| 216 |
+
return RandomClampFunction.apply(x, min, max, prob, reflect)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor:
|
| 220 |
+
"""
|
| 221 |
+
A randomized way of casting a floating point value to half precision.
|
| 222 |
+
"""
|
| 223 |
+
if x.dtype == torch.float16:
|
| 224 |
+
return x
|
| 225 |
+
x_abs = x.abs()
|
| 226 |
+
is_too_small = x_abs < min_abs
|
| 227 |
+
# for elements where is_too_small is true, random_val will contain +-min_abs with
|
| 228 |
+
# probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations,
|
| 229 |
+
# for those elements].
|
| 230 |
+
random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs)
|
| 231 |
+
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class RandomGradFunction(torch.autograd.Function):
|
| 235 |
+
"""
|
| 236 |
+
Does nothing in forward pass; in backward pass, gets rid of very small grads using
|
| 237 |
+
randomized approach that preserves expectations (intended to reduce roundoff).
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
@staticmethod
|
| 241 |
+
def forward(ctx, x: Tensor, min_abs: float) -> Tensor:
|
| 242 |
+
ctx.min_abs = min_abs
|
| 243 |
+
return x
|
| 244 |
+
|
| 245 |
+
@staticmethod
|
| 246 |
+
def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]:
|
| 247 |
+
if ans_grad.dtype == torch.float16:
|
| 248 |
+
return (
|
| 249 |
+
random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs),
|
| 250 |
+
None,
|
| 251 |
+
)
|
| 252 |
+
else:
|
| 253 |
+
return ans_grad, None
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
class RandomGrad(torch.nn.Module):
|
| 257 |
+
"""
|
| 258 |
+
Gets rid of very small gradients using an expectation-preserving method, intended to increase
|
| 259 |
+
accuracy of training when using amp (automatic mixed precision)
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
def __init__(self, min_abs: float = 5.0e-06):
|
| 263 |
+
super(RandomGrad, self).__init__()
|
| 264 |
+
self.min_abs = min_abs
|
| 265 |
+
|
| 266 |
+
def forward(self, x: Tensor):
|
| 267 |
+
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
|
| 268 |
+
return x
|
| 269 |
+
else:
|
| 270 |
+
return RandomGradFunction.apply(x, self.min_abs)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class SoftmaxFunction(torch.autograd.Function):
|
| 274 |
+
"""
|
| 275 |
+
Tries to handle half-precision derivatives in a randomized way that should
|
| 276 |
+
be more accurate for training than the default behavior.
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
@staticmethod
|
| 280 |
+
def forward(ctx, x: Tensor, dim: int):
|
| 281 |
+
ans = x.softmax(dim=dim)
|
| 282 |
+
# if x dtype is float16, x.softmax() returns a float32 because
|
| 283 |
+
# (presumably) that op does not support float16, and autocast
|
| 284 |
+
# is enabled.
|
| 285 |
+
if torch.is_autocast_enabled():
|
| 286 |
+
ans = ans.to(torch.float16)
|
| 287 |
+
ctx.save_for_backward(ans)
|
| 288 |
+
ctx.x_dtype = x.dtype
|
| 289 |
+
ctx.dim = dim
|
| 290 |
+
return ans
|
| 291 |
+
|
| 292 |
+
@staticmethod
|
| 293 |
+
def backward(ctx, ans_grad: Tensor):
|
| 294 |
+
(ans,) = ctx.saved_tensors
|
| 295 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 296 |
+
ans_grad = ans_grad.to(torch.float32)
|
| 297 |
+
ans = ans.to(torch.float32)
|
| 298 |
+
x_grad = ans_grad * ans
|
| 299 |
+
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
| 300 |
+
return x_grad, None
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def softmax(x: Tensor, dim: int):
|
| 304 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 305 |
+
return x.softmax(dim)
|
| 306 |
+
|
| 307 |
+
return SoftmaxFunction.apply(x, dim)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class MaxEigLimiterFunction(torch.autograd.Function):
|
| 311 |
+
@staticmethod
|
| 312 |
+
def forward(
|
| 313 |
+
ctx,
|
| 314 |
+
x: Tensor,
|
| 315 |
+
coeffs: Tensor,
|
| 316 |
+
direction: Tensor,
|
| 317 |
+
channel_dim: int,
|
| 318 |
+
grad_scale: float,
|
| 319 |
+
) -> Tensor:
|
| 320 |
+
ctx.channel_dim = channel_dim
|
| 321 |
+
ctx.grad_scale = grad_scale
|
| 322 |
+
ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach())
|
| 323 |
+
return x
|
| 324 |
+
|
| 325 |
+
@staticmethod
|
| 326 |
+
def backward(ctx, x_grad, *args):
|
| 327 |
+
with torch.enable_grad():
|
| 328 |
+
(x_orig, coeffs, new_direction) = ctx.saved_tensors
|
| 329 |
+
x_orig.requires_grad = True
|
| 330 |
+
num_channels = x_orig.shape[ctx.channel_dim]
|
| 331 |
+
x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels)
|
| 332 |
+
new_direction.requires_grad = False
|
| 333 |
+
x = x - x.mean(dim=0)
|
| 334 |
+
x_var = (x**2).mean()
|
| 335 |
+
x_residual = x - coeffs * new_direction
|
| 336 |
+
x_residual_var = (x_residual**2).mean()
|
| 337 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
| 338 |
+
# by the top eigen-direction. This is to be minimized.
|
| 339 |
+
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
| 340 |
+
variance_proportion.backward()
|
| 341 |
+
x_orig_grad = x_orig.grad
|
| 342 |
+
x_extra_grad = (
|
| 343 |
+
x_orig.grad
|
| 344 |
+
* ctx.grad_scale
|
| 345 |
+
* x_grad.norm()
|
| 346 |
+
/ (x_orig_grad.norm() + 1.0e-20)
|
| 347 |
+
)
|
| 348 |
+
return x_grad + x_extra_grad.detach(), None, None, None, None
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
class BasicNorm(torch.nn.Module):
|
| 352 |
+
"""
|
| 353 |
+
This is intended to be a simpler, and hopefully cheaper, replacement for
|
| 354 |
+
LayerNorm. The observation this is based on, is that Transformer-type
|
| 355 |
+
networks, especially with pre-norm, sometimes seem to set one of the
|
| 356 |
+
feature dimensions to a large constant value (e.g. 50), which "defeats"
|
| 357 |
+
the LayerNorm because the output magnitude is then not strongly dependent
|
| 358 |
+
on the other (useful) features. Presumably the weight and bias of the
|
| 359 |
+
LayerNorm are required to allow it to do this.
|
| 360 |
+
|
| 361 |
+
So the idea is to introduce this large constant value as an explicit
|
| 362 |
+
parameter, that takes the role of the "eps" in LayerNorm, so the network
|
| 363 |
+
doesn't have to do this trick. We make the "eps" learnable.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
num_channels: the number of channels, e.g. 512.
|
| 367 |
+
channel_dim: the axis/dimension corresponding to the channel,
|
| 368 |
+
interprted as an offset from the input's ndim if negative.
|
| 369 |
+
shis is NOT the num_channels; it should typically be one of
|
| 370 |
+
{-2, -1, 0, 1, 2, 3}.
|
| 371 |
+
eps: the initial "epsilon" that we add as ballast in:
|
| 372 |
+
scale = ((input_vec**2).mean() + epsilon)**-0.5
|
| 373 |
+
Note: our epsilon is actually large, but we keep the name
|
| 374 |
+
to indicate the connection with conventional LayerNorm.
|
| 375 |
+
learn_eps: if true, we learn epsilon; if false, we keep it
|
| 376 |
+
at the initial value.
|
| 377 |
+
eps_min: float
|
| 378 |
+
eps_max: float
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
def __init__(
|
| 382 |
+
self,
|
| 383 |
+
num_channels: int,
|
| 384 |
+
channel_dim: int = -1, # CAUTION: see documentation.
|
| 385 |
+
eps: float = 0.25,
|
| 386 |
+
learn_eps: bool = True,
|
| 387 |
+
eps_min: float = -3.0,
|
| 388 |
+
eps_max: float = 3.0,
|
| 389 |
+
) -> None:
|
| 390 |
+
super(BasicNorm, self).__init__()
|
| 391 |
+
self.num_channels = num_channels
|
| 392 |
+
self.channel_dim = channel_dim
|
| 393 |
+
if learn_eps:
|
| 394 |
+
self.eps = nn.Parameter(torch.tensor(eps).log().detach())
|
| 395 |
+
else:
|
| 396 |
+
self.register_buffer("eps", torch.tensor(eps).log().detach())
|
| 397 |
+
self.eps_min = eps_min
|
| 398 |
+
self.eps_max = eps_max
|
| 399 |
+
|
| 400 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 401 |
+
assert x.shape[self.channel_dim] == self.num_channels
|
| 402 |
+
eps = self.eps
|
| 403 |
+
if self.training and random.random() < 0.25:
|
| 404 |
+
# with probability 0.25, in training mode, clamp eps between the min
|
| 405 |
+
# and max; this will encourage it to learn parameters within the
|
| 406 |
+
# allowed range by making parameters that are outside the allowed
|
| 407 |
+
# range noisy.
|
| 408 |
+
|
| 409 |
+
# gradients to allow the parameter to get back into the allowed
|
| 410 |
+
# region if it happens to exit it.
|
| 411 |
+
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
| 412 |
+
scales = (
|
| 413 |
+
torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp()
|
| 414 |
+
) ** -0.5
|
| 415 |
+
return x * scales
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
|
| 419 |
+
"""
|
| 420 |
+
Behaves like a constructor of a modified version of nn.Linear
|
| 421 |
+
that gives an easy way to set the default initial parameter scale.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
| 425 |
+
e.g. in_features, out_features, bias=False.
|
| 426 |
+
|
| 427 |
+
initial_scale: you can override this if you want to increase
|
| 428 |
+
or decrease the initial magnitude of the module's output
|
| 429 |
+
(affects the initialization of weight_scale and bias_scale).
|
| 430 |
+
Another option, if you want to do something like this, is
|
| 431 |
+
to re-initialize the parameters.
|
| 432 |
+
"""
|
| 433 |
+
ans = nn.Linear(*args, **kwargs)
|
| 434 |
+
with torch.no_grad():
|
| 435 |
+
ans.weight[:] *= initial_scale
|
| 436 |
+
if ans.bias is not None:
|
| 437 |
+
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
| 438 |
+
return ans
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def ScaledConv1d(
|
| 442 |
+
*args,
|
| 443 |
+
initial_scale: float = 1.0,
|
| 444 |
+
kernel_size: int = 3,
|
| 445 |
+
padding: str = "same",
|
| 446 |
+
**kwargs,
|
| 447 |
+
) -> nn.Conv1d:
|
| 448 |
+
"""
|
| 449 |
+
Behaves like a constructor of a modified version of nn.Conv1d
|
| 450 |
+
that gives an easy way to set the default initial parameter scale.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
Accepts the standard args and kwargs that nn.Linear accepts
|
| 454 |
+
e.g. in_features, out_features, bias=False.
|
| 455 |
+
|
| 456 |
+
initial_scale: you can override this if you want to increase
|
| 457 |
+
or decrease the initial magnitude of the module's output
|
| 458 |
+
(affects the initialization of weight_scale and bias_scale).
|
| 459 |
+
Another option, if you want to do something like this, is
|
| 460 |
+
to re-initialize the parameters.
|
| 461 |
+
"""
|
| 462 |
+
ans = nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs)
|
| 463 |
+
with torch.no_grad():
|
| 464 |
+
ans.weight[:] *= initial_scale
|
| 465 |
+
if ans.bias is not None:
|
| 466 |
+
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
| 467 |
+
return ans
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def TransposeScaledConv1d(
|
| 471 |
+
*args,
|
| 472 |
+
initial_scale: float = 1.0,
|
| 473 |
+
kernel_size: int = 3,
|
| 474 |
+
padding: str = "same",
|
| 475 |
+
**kwargs,
|
| 476 |
+
) -> nn.Sequential:
|
| 477 |
+
"""
|
| 478 |
+
Transpose -> ScaledConv1d
|
| 479 |
+
"""
|
| 480 |
+
return nn.Sequential(
|
| 481 |
+
Transpose(),
|
| 482 |
+
ScaledConv1d(
|
| 483 |
+
*args,
|
| 484 |
+
initial_scale=initial_scale,
|
| 485 |
+
kernel_size=kernel_size,
|
| 486 |
+
padding=padding,
|
| 487 |
+
**kwargs,
|
| 488 |
+
),
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
def ScaledConv1dTranspose(
|
| 493 |
+
*args,
|
| 494 |
+
initial_scale: float = 1.0,
|
| 495 |
+
kernel_size: int = 3,
|
| 496 |
+
padding: str = "same",
|
| 497 |
+
**kwargs,
|
| 498 |
+
) -> nn.Sequential:
|
| 499 |
+
"""
|
| 500 |
+
Transpose -> ScaledConv1d
|
| 501 |
+
"""
|
| 502 |
+
return nn.Sequential(
|
| 503 |
+
ScaledConv1d(
|
| 504 |
+
*args,
|
| 505 |
+
initial_scale=initial_scale,
|
| 506 |
+
kernel_size=kernel_size,
|
| 507 |
+
padding=padding,
|
| 508 |
+
**kwargs,
|
| 509 |
+
),
|
| 510 |
+
Transpose(),
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def TransposeConv1d(
|
| 515 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
| 516 |
+
) -> nn.Sequential:
|
| 517 |
+
"""
|
| 518 |
+
Transpose -> Conv1d
|
| 519 |
+
"""
|
| 520 |
+
return nn.Sequential(
|
| 521 |
+
Transpose(),
|
| 522 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
def Conv1dTranspose(
|
| 527 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
| 528 |
+
) -> nn.Sequential:
|
| 529 |
+
"""
|
| 530 |
+
ScaledConv1d -> Transpose
|
| 531 |
+
"""
|
| 532 |
+
return nn.Sequential(
|
| 533 |
+
nn.Conv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
| 534 |
+
Transpose(),
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class SRLinear(nn.Linear):
|
| 539 |
+
"""https://arxiv.org/abs/2303.06296
|
| 540 |
+
Stabilizing Transformer Training by Preventing Attention Entropy Collapse
|
| 541 |
+
"""
|
| 542 |
+
|
| 543 |
+
def __init__(self, in_features, out_features, bias=True, **kwargs):
|
| 544 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
| 545 |
+
self.register_buffer(
|
| 546 |
+
"u", nn.functional.normalize(torch.randn(in_features), dim=0)
|
| 547 |
+
)
|
| 548 |
+
with torch.no_grad():
|
| 549 |
+
sigma = self.get_sigma()
|
| 550 |
+
self.register_buffer("spectral_norm", sigma)
|
| 551 |
+
self.sigma = nn.Parameter(torch.ones(1))
|
| 552 |
+
|
| 553 |
+
def get_sigma(self):
|
| 554 |
+
with torch.no_grad():
|
| 555 |
+
u = self.u
|
| 556 |
+
v = self.weight.mv(u)
|
| 557 |
+
v = nn.functional.normalize(v, dim=0)
|
| 558 |
+
u = self.weight.T.mv(v)
|
| 559 |
+
u = nn.functional.normalize(u, dim=0)
|
| 560 |
+
self.u.data.copy_(u)
|
| 561 |
+
return torch.einsum("c,cd,d->", v, self.weight, u)
|
| 562 |
+
|
| 563 |
+
def get_weight(self):
|
| 564 |
+
sigma = self.get_sigma()
|
| 565 |
+
if self.training:
|
| 566 |
+
self.spectral_norm.data.copy_(sigma)
|
| 567 |
+
weight = (self.sigma / sigma) * self.weight
|
| 568 |
+
return weight
|
| 569 |
+
|
| 570 |
+
def forward(self, x):
|
| 571 |
+
return nn.functional.linear(x, self.get_weight(), self.bias)
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
class SRConv1d(SRLinear):
|
| 575 |
+
def __init__(
|
| 576 |
+
self,
|
| 577 |
+
in_features,
|
| 578 |
+
out_features,
|
| 579 |
+
kernel_size,
|
| 580 |
+
stride: int = 1,
|
| 581 |
+
padding: str = "same",
|
| 582 |
+
bias: bool = True,
|
| 583 |
+
**kwargs,
|
| 584 |
+
):
|
| 585 |
+
in_features = in_features * kernel_size
|
| 586 |
+
super().__init__(in_features, out_features, bias=bias, **kwargs)
|
| 587 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 588 |
+
self.kernel_size = kernel_size
|
| 589 |
+
self.stride = stride
|
| 590 |
+
self.padding = padding
|
| 591 |
+
|
| 592 |
+
def forward(self, x):
|
| 593 |
+
in_features = self.in_features // self.kernel_size
|
| 594 |
+
weight = self.get_weight().view(
|
| 595 |
+
self.out_features, in_features, self.kernel_size
|
| 596 |
+
)
|
| 597 |
+
return nn.functional.conv1d(
|
| 598 |
+
x, weight, bias=self.bias, stride=self.stride, padding=self.padding
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def TransposeSRConv1d(
|
| 603 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
| 604 |
+
) -> nn.Sequential:
|
| 605 |
+
"""
|
| 606 |
+
Transpose -> SRConv1d
|
| 607 |
+
"""
|
| 608 |
+
return nn.Sequential(
|
| 609 |
+
Transpose(),
|
| 610 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
def SRConv1dTranspose(
|
| 615 |
+
*args, kernel_size: int = 3, padding: str = "same", **kwargs
|
| 616 |
+
) -> nn.Sequential:
|
| 617 |
+
"""
|
| 618 |
+
SRConv1d -> Transpose
|
| 619 |
+
"""
|
| 620 |
+
return nn.Sequential(
|
| 621 |
+
SRConv1d(*args, kernel_size=kernel_size, padding=padding, **kwargs),
|
| 622 |
+
Transpose(),
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
class ActivationBalancer(torch.nn.Module):
|
| 627 |
+
"""
|
| 628 |
+
Modifies the backpropped derivatives of a function to try to encourage, for
|
| 629 |
+
each channel, that it is positive at least a proportion `threshold` of the
|
| 630 |
+
time. It does this by multiplying negative derivative values by up to
|
| 631 |
+
(1+max_factor), and positive derivative values by up to (1-max_factor),
|
| 632 |
+
interpolated from 1 at the threshold to those extremal values when none
|
| 633 |
+
of the inputs are positive.
|
| 634 |
+
|
| 635 |
+
Args:
|
| 636 |
+
num_channels: the number of channels
|
| 637 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
| 638 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
| 639 |
+
min_positive: the minimum, per channel, of the proportion of the time
|
| 640 |
+
that (x > 0), below which we start to modify the derivatives.
|
| 641 |
+
max_positive: the maximum, per channel, of the proportion of the time
|
| 642 |
+
that (x > 0), above which we start to modify the derivatives.
|
| 643 |
+
max_factor: the maximum factor by which we modify the derivatives for
|
| 644 |
+
either the sign constraint or the magnitude constraint;
|
| 645 |
+
e.g. with max_factor=0.02, the the derivatives would be multiplied by
|
| 646 |
+
values in the range [0.98..1.02].
|
| 647 |
+
sign_gain_factor: determines the 'gain' with which we increase the
|
| 648 |
+
change in gradient once the constraints on min_positive and max_positive
|
| 649 |
+
are violated.
|
| 650 |
+
scale_gain_factor: determines the 'gain' with which we increase the
|
| 651 |
+
change in gradient once the constraints on min_abs and max_abs
|
| 652 |
+
are violated.
|
| 653 |
+
min_abs: the minimum average-absolute-value difference from the mean
|
| 654 |
+
value per channel, which we allow, before we start to modify
|
| 655 |
+
the derivatives to prevent this.
|
| 656 |
+
max_abs: the maximum average-absolute-value difference from the mean
|
| 657 |
+
value per channel, which we allow, before we start to modify
|
| 658 |
+
the derivatives to prevent this.
|
| 659 |
+
min_prob: determines the minimum probability with which we modify the
|
| 660 |
+
gradients for the {min,max}_positive and {min,max}_abs constraints,
|
| 661 |
+
on each forward(). This is done randomly to prevent all layers
|
| 662 |
+
from doing it at the same time. Early in training we may use
|
| 663 |
+
higher probabilities than this; it will decay to this value.
|
| 664 |
+
"""
|
| 665 |
+
|
| 666 |
+
def __init__(
|
| 667 |
+
self,
|
| 668 |
+
num_channels: int,
|
| 669 |
+
channel_dim: int,
|
| 670 |
+
min_positive: float = 0.05,
|
| 671 |
+
max_positive: float = 0.95,
|
| 672 |
+
max_factor: float = 0.04,
|
| 673 |
+
sign_gain_factor: float = 0.01,
|
| 674 |
+
scale_gain_factor: float = 0.02,
|
| 675 |
+
min_abs: float = 0.2,
|
| 676 |
+
max_abs: float = 100.0,
|
| 677 |
+
min_prob: float = 0.1,
|
| 678 |
+
):
|
| 679 |
+
super(ActivationBalancer, self).__init__()
|
| 680 |
+
self.num_channels = num_channels
|
| 681 |
+
self.channel_dim = channel_dim
|
| 682 |
+
self.min_positive = min_positive
|
| 683 |
+
self.max_positive = max_positive
|
| 684 |
+
self.max_factor = max_factor
|
| 685 |
+
self.min_abs = min_abs
|
| 686 |
+
self.max_abs = max_abs
|
| 687 |
+
self.min_prob = min_prob
|
| 688 |
+
self.sign_gain_factor = sign_gain_factor
|
| 689 |
+
self.scale_gain_factor = scale_gain_factor
|
| 690 |
+
|
| 691 |
+
# count measures how many times the forward() function has been called.
|
| 692 |
+
# We occasionally sync this to a tensor called `count`, that exists to
|
| 693 |
+
# make sure it is synced to disk when we load and save the model.
|
| 694 |
+
self.cpu_count = 0
|
| 695 |
+
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
|
| 696 |
+
|
| 697 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 698 |
+
if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
|
| 699 |
+
return _no_op(x)
|
| 700 |
+
|
| 701 |
+
count = self.cpu_count
|
| 702 |
+
self.cpu_count += 1
|
| 703 |
+
|
| 704 |
+
if random.random() < 0.01:
|
| 705 |
+
# Occasionally sync self.cpu_count with self.count.
|
| 706 |
+
# count affects the decay of 'prob'. don't do this on every iter,
|
| 707 |
+
# because syncing with the GPU is slow.
|
| 708 |
+
self.cpu_count = max(self.cpu_count, self.count.item())
|
| 709 |
+
self.count.fill_(self.cpu_count)
|
| 710 |
+
|
| 711 |
+
# the prob of doing some work exponentially decreases from 0.5 till it hits
|
| 712 |
+
# a floor at min_prob (==0.1, by default)
|
| 713 |
+
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
|
| 714 |
+
|
| 715 |
+
if random.random() < prob:
|
| 716 |
+
sign_gain_factor = 0.5
|
| 717 |
+
if self.min_positive != 0.0 or self.max_positive != 1.0:
|
| 718 |
+
sign_factor = _compute_sign_factor(
|
| 719 |
+
x,
|
| 720 |
+
self.channel_dim,
|
| 721 |
+
self.min_positive,
|
| 722 |
+
self.max_positive,
|
| 723 |
+
gain_factor=self.sign_gain_factor / prob,
|
| 724 |
+
max_factor=self.max_factor,
|
| 725 |
+
)
|
| 726 |
+
else:
|
| 727 |
+
sign_factor = None
|
| 728 |
+
|
| 729 |
+
scale_factor = _compute_scale_factor(
|
| 730 |
+
x.detach(),
|
| 731 |
+
self.channel_dim,
|
| 732 |
+
min_abs=self.min_abs,
|
| 733 |
+
max_abs=self.max_abs,
|
| 734 |
+
gain_factor=self.scale_gain_factor / prob,
|
| 735 |
+
max_factor=self.max_factor,
|
| 736 |
+
)
|
| 737 |
+
return ActivationBalancerFunction.apply(
|
| 738 |
+
x,
|
| 739 |
+
scale_factor,
|
| 740 |
+
sign_factor,
|
| 741 |
+
self.channel_dim,
|
| 742 |
+
)
|
| 743 |
+
else:
|
| 744 |
+
return _no_op(x)
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor:
|
| 748 |
+
"""
|
| 749 |
+
Returns x unmodified, but in backprop will put a penalty for the excess of
|
| 750 |
+
the absolute values of elements of x over the limit "limit". E.g. if
|
| 751 |
+
limit == 10.0, then if x has any values over 10 it will get a penalty.
|
| 752 |
+
|
| 753 |
+
Caution: the value of this penalty will be affected by grad scaling used
|
| 754 |
+
in automatic mixed precision training. For this reasons we use this,
|
| 755 |
+
it shouldn't really matter, or may even be helpful; we just use this
|
| 756 |
+
to disallow really implausible values of scores to be given to softmax.
|
| 757 |
+
"""
|
| 758 |
+
x_sign = x.sign()
|
| 759 |
+
over_limit = (x.abs() - limit) > 0
|
| 760 |
+
# The following is a memory efficient way to penalize the absolute values of
|
| 761 |
+
# x that's over the limit. (The memory efficiency comes when you think
|
| 762 |
+
# about which items torch needs to cache for the autograd, and which ones it
|
| 763 |
+
# can throw away). The numerical value of aux_loss as computed here will
|
| 764 |
+
# actually be larger than it should be, by limit * over_limit.sum(), but it
|
| 765 |
+
# has the same derivative as the real aux_loss which is penalty * (x.abs() -
|
| 766 |
+
# limit).relu().
|
| 767 |
+
aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x)
|
| 768 |
+
# note: we don't do sum() here on aux)_loss, but it's as if we had done
|
| 769 |
+
# sum() due to how with_loss() works.
|
| 770 |
+
x = with_loss(x, aux_loss)
|
| 771 |
+
# you must use x for something, or this will be ineffective.
|
| 772 |
+
return x
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims.
|
| 776 |
+
if x.ndim == 2:
|
| 777 |
+
return x.diag()
|
| 778 |
+
else:
|
| 779 |
+
(batch, dim, dim) = x.shape
|
| 780 |
+
x = x.reshape(batch, dim * dim)
|
| 781 |
+
x = x[:, :: dim + 1]
|
| 782 |
+
assert x.shape == (batch, dim)
|
| 783 |
+
return x
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
def _whitening_metric(x: Tensor, num_groups: int):
|
| 787 |
+
"""
|
| 788 |
+
Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of
|
| 789 |
+
of the centered feature covariance are the same within each group's covariance matrix
|
| 790 |
+
and also between groups.
|
| 791 |
+
Args:
|
| 792 |
+
x: a Tensor of shape (*, num_channels)
|
| 793 |
+
num_groups: the number of groups of channels, a number >=1 that divides num_channels
|
| 794 |
+
Returns:
|
| 795 |
+
Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and
|
| 796 |
+
greater than 1.0 otherwise.
|
| 797 |
+
"""
|
| 798 |
+
assert x.dtype != torch.float16
|
| 799 |
+
x = x.reshape(-1, x.shape[-1])
|
| 800 |
+
(num_frames, num_channels) = x.shape
|
| 801 |
+
assert num_channels % num_groups == 0
|
| 802 |
+
channels_per_group = num_channels // num_groups
|
| 803 |
+
x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1)
|
| 804 |
+
# x now has shape (num_groups, num_frames, channels_per_group)
|
| 805 |
+
# subtract the mean so we use the centered, not uncentered, covariance.
|
| 806 |
+
# My experience has been that when we "mess with the gradients" like this,
|
| 807 |
+
# it's better not do anything that tries to move the mean around, because
|
| 808 |
+
# that can easily cause instability.
|
| 809 |
+
x = x - x.mean(dim=1, keepdim=True)
|
| 810 |
+
# x_covar: (num_groups, channels_per_group, channels_per_group)
|
| 811 |
+
x_covar = torch.matmul(x.transpose(1, 2), x)
|
| 812 |
+
x_covar_mean_diag = _diag(x_covar).mean()
|
| 813 |
+
# the following expression is what we'd get if we took the matrix product
|
| 814 |
+
# of each covariance and measured the mean of its trace, i.e.
|
| 815 |
+
# the same as _diag(torch.matmul(x_covar, x_covar)).mean().
|
| 816 |
+
x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
|
| 817 |
+
# this metric will be >= 1.0; the larger it is, the less 'white' the data was.
|
| 818 |
+
metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
|
| 819 |
+
return metric
|
| 820 |
+
|
| 821 |
+
|
| 822 |
+
class WhiteningPenaltyFunction(torch.autograd.Function):
|
| 823 |
+
@staticmethod
|
| 824 |
+
def forward(
|
| 825 |
+
ctx,
|
| 826 |
+
x: Tensor,
|
| 827 |
+
num_groups: int,
|
| 828 |
+
whitening_limit: float,
|
| 829 |
+
grad_scale: float,
|
| 830 |
+
) -> Tensor:
|
| 831 |
+
ctx.save_for_backward(x)
|
| 832 |
+
ctx.num_groups = num_groups
|
| 833 |
+
ctx.whitening_limit = whitening_limit
|
| 834 |
+
ctx.grad_scale = grad_scale
|
| 835 |
+
return x
|
| 836 |
+
|
| 837 |
+
@staticmethod
|
| 838 |
+
def backward(ctx, x_grad: Tensor):
|
| 839 |
+
(x_orig,) = ctx.saved_tensors
|
| 840 |
+
with torch.enable_grad():
|
| 841 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 842 |
+
x_detached = x_orig.to(torch.float32).detach()
|
| 843 |
+
x_detached.requires_grad = True
|
| 844 |
+
|
| 845 |
+
metric = _whitening_metric(x_detached, ctx.num_groups)
|
| 846 |
+
|
| 847 |
+
if random.random() < 0.005 or __name__ == "__main__":
|
| 848 |
+
logging.info(
|
| 849 |
+
f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, "
|
| 850 |
+
f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}"
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
(metric - ctx.whitening_limit).relu().backward()
|
| 854 |
+
penalty_grad = x_detached.grad
|
| 855 |
+
scale = ctx.grad_scale * (
|
| 856 |
+
x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20)
|
| 857 |
+
)
|
| 858 |
+
penalty_grad = penalty_grad * scale
|
| 859 |
+
return x_grad + penalty_grad.to(x_grad.dtype), None, None, None
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
class Whiten(nn.Module):
|
| 863 |
+
def __init__(
|
| 864 |
+
self,
|
| 865 |
+
num_groups: int,
|
| 866 |
+
whitening_limit: float,
|
| 867 |
+
prob: Union[float, Tuple[float, float]],
|
| 868 |
+
grad_scale: float,
|
| 869 |
+
):
|
| 870 |
+
"""
|
| 871 |
+
Args:
|
| 872 |
+
num_groups: the number of groups to divide the channel dim into before
|
| 873 |
+
whitening. We will attempt to make the feature covariance
|
| 874 |
+
within each group, after mean subtraction, as "white" as possible,
|
| 875 |
+
while having the same trace across all groups.
|
| 876 |
+
whitening_limit: a value greater than 1.0, that dictates how much
|
| 877 |
+
freedom we have to violate the constraints. 1.0 would mean perfectly
|
| 878 |
+
white, with exactly the same trace across groups; larger values
|
| 879 |
+
give more freedom. E.g. 2.0.
|
| 880 |
+
prob: the probability with which we apply the gradient modification
|
| 881 |
+
(also affects the grad scale). May be supplied as a float,
|
| 882 |
+
or as a pair (min_prob, max_prob)
|
| 883 |
+
|
| 884 |
+
grad_scale: determines the scale on the gradient term from this object,
|
| 885 |
+
relative to the rest of the gradient on the attention weights.
|
| 886 |
+
E.g. 0.02 (you may want to use smaller values than this if prob is large)
|
| 887 |
+
"""
|
| 888 |
+
super(Whiten, self).__init__()
|
| 889 |
+
assert num_groups >= 1
|
| 890 |
+
assert whitening_limit >= 1
|
| 891 |
+
assert grad_scale >= 0
|
| 892 |
+
self.num_groups = num_groups
|
| 893 |
+
self.whitening_limit = whitening_limit
|
| 894 |
+
if isinstance(prob, float):
|
| 895 |
+
assert 0 < prob <= 1
|
| 896 |
+
self.prob = prob
|
| 897 |
+
else:
|
| 898 |
+
(self.min_prob, self.max_prob) = prob
|
| 899 |
+
assert 0 < self.min_prob < self.max_prob <= 1
|
| 900 |
+
self.prob = self.max_prob
|
| 901 |
+
|
| 902 |
+
self.grad_scale = grad_scale
|
| 903 |
+
|
| 904 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 905 |
+
"""
|
| 906 |
+
In the forward pass, this function just returns the input unmodified.
|
| 907 |
+
In the backward pass, it will modify the gradients to ensure that the
|
| 908 |
+
distribution in each group has close to (lambda times I) as the covariance
|
| 909 |
+
after mean subtraction, with the same lambda across groups.
|
| 910 |
+
For whitening_limit > 1, there will be more freedom to violate this
|
| 911 |
+
constraint.
|
| 912 |
+
|
| 913 |
+
Args:
|
| 914 |
+
x: the input of shape (*, num_channels)
|
| 915 |
+
|
| 916 |
+
Returns:
|
| 917 |
+
x, unmodified. You should make sure
|
| 918 |
+
you use the returned value, or the graph will be freed
|
| 919 |
+
and nothing will happen in backprop.
|
| 920 |
+
"""
|
| 921 |
+
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
|
| 922 |
+
return _no_op(x)
|
| 923 |
+
else:
|
| 924 |
+
if hasattr(self, "min_prob") and random.random() < 0.25:
|
| 925 |
+
# occasionally switch between min_prob and max_prob, based on whether
|
| 926 |
+
# we are above or below the threshold.
|
| 927 |
+
if (
|
| 928 |
+
_whitening_metric(x.to(torch.float32), self.num_groups)
|
| 929 |
+
> self.whitening_limit
|
| 930 |
+
):
|
| 931 |
+
# there would be a change to the grad.
|
| 932 |
+
self.prob = self.max_prob
|
| 933 |
+
else:
|
| 934 |
+
self.prob = self.min_prob
|
| 935 |
+
|
| 936 |
+
return WhiteningPenaltyFunction.apply(
|
| 937 |
+
x, self.num_groups, self.whitening_limit, self.grad_scale
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
class WithLoss(torch.autograd.Function):
|
| 942 |
+
@staticmethod
|
| 943 |
+
def forward(ctx, x: Tensor, y: Tensor):
|
| 944 |
+
ctx.y_shape = y.shape
|
| 945 |
+
return x
|
| 946 |
+
|
| 947 |
+
@staticmethod
|
| 948 |
+
def backward(ctx, ans_grad: Tensor):
|
| 949 |
+
return ans_grad, torch.ones(
|
| 950 |
+
ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
|
| 951 |
+
)
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
def with_loss(x, y):
|
| 955 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 956 |
+
return x
|
| 957 |
+
# returns x but adds y.sum() to the loss function.
|
| 958 |
+
return WithLoss.apply(x, y)
|
| 959 |
+
|
| 960 |
+
|
| 961 |
+
def _no_op(x: Tensor) -> Tensor:
|
| 962 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 963 |
+
return x
|
| 964 |
+
else:
|
| 965 |
+
# a no-op function that will have a node in the autograd graph,
|
| 966 |
+
# to avoid certain bugs relating to backward hooks
|
| 967 |
+
return x.chunk(1, dim=-1)[0]
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
class Identity(torch.nn.Module):
|
| 971 |
+
def __init__(self):
|
| 972 |
+
super(Identity, self).__init__()
|
| 973 |
+
|
| 974 |
+
def forward(self, x):
|
| 975 |
+
return _no_op(x)
|
| 976 |
+
|
| 977 |
+
|
| 978 |
+
class MaxEig(torch.nn.Module):
|
| 979 |
+
"""
|
| 980 |
+
Modifies the backpropped derivatives of a function to try to discourage
|
| 981 |
+
that any given direction in activation space accounts for more than
|
| 982 |
+
a specified proportion of the covariance (e.g. 0.2).
|
| 983 |
+
|
| 984 |
+
|
| 985 |
+
Args:
|
| 986 |
+
num_channels: the number of channels
|
| 987 |
+
channel_dim: the dimension/axis corresponding to the channel, e.g.
|
| 988 |
+
-1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative.
|
| 989 |
+
max_var_per_eig: the maximum proportion of the variance of the
|
| 990 |
+
features/channels, after mean subtraction, that can come from
|
| 991 |
+
any given eigenvalue.
|
| 992 |
+
min_prob: the minimum probability with which we apply this during any invocation
|
| 993 |
+
of forward(), assuming last time we applied the constraint it was
|
| 994 |
+
not active; supplied for speed.
|
| 995 |
+
scale: determines the scale with which we modify the gradients, relative
|
| 996 |
+
to the existing / unmodified gradients
|
| 997 |
+
"""
|
| 998 |
+
|
| 999 |
+
def __init__(
|
| 1000 |
+
self,
|
| 1001 |
+
num_channels: int,
|
| 1002 |
+
channel_dim: int,
|
| 1003 |
+
max_var_per_eig: float = 0.2,
|
| 1004 |
+
min_prob: float = 0.01,
|
| 1005 |
+
scale: float = 0.01,
|
| 1006 |
+
):
|
| 1007 |
+
super(MaxEig, self).__init__()
|
| 1008 |
+
self.num_channels = num_channels
|
| 1009 |
+
self.channel_dim = channel_dim
|
| 1010 |
+
self.scale = scale
|
| 1011 |
+
assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels
|
| 1012 |
+
self.max_var_per_eig = max_var_per_eig
|
| 1013 |
+
|
| 1014 |
+
# we figure out the dominant direction using the power method: starting with
|
| 1015 |
+
# a random vector, keep multiplying by the covariance and renormalizing.
|
| 1016 |
+
with torch.no_grad():
|
| 1017 |
+
# arbitrary.. would use randn() but want to leave the rest of the model's
|
| 1018 |
+
# random parameters unchanged for comparison
|
| 1019 |
+
direction = torch.arange(num_channels).to(torch.float)
|
| 1020 |
+
direction = direction / direction.norm()
|
| 1021 |
+
self.register_buffer("max_eig_direction", direction)
|
| 1022 |
+
|
| 1023 |
+
self.min_prob = min_prob
|
| 1024 |
+
# cur_prob is the current probability we'll use to apply the ActivationBalancer.
|
| 1025 |
+
# We'll regress this towards prob, each time we try to apply it and it is not
|
| 1026 |
+
# active.
|
| 1027 |
+
self.cur_prob = 1.0
|
| 1028 |
+
|
| 1029 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1030 |
+
if (
|
| 1031 |
+
torch.jit.is_scripting()
|
| 1032 |
+
or self.max_var_per_eig <= 0
|
| 1033 |
+
or random.random() > self.cur_prob
|
| 1034 |
+
or torch.jit.is_tracing()
|
| 1035 |
+
):
|
| 1036 |
+
return _no_op(x)
|
| 1037 |
+
|
| 1038 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 1039 |
+
eps = 1.0e-20
|
| 1040 |
+
orig_x = x
|
| 1041 |
+
x = x.to(torch.float32)
|
| 1042 |
+
with torch.no_grad():
|
| 1043 |
+
x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels)
|
| 1044 |
+
x = x - x.mean(dim=0)
|
| 1045 |
+
new_direction, coeffs = self._find_direction_coeffs(
|
| 1046 |
+
x, self.max_eig_direction
|
| 1047 |
+
)
|
| 1048 |
+
x_var = (x**2).mean()
|
| 1049 |
+
x_residual = x - coeffs * new_direction
|
| 1050 |
+
x_residual_var = (x_residual**2).mean()
|
| 1051 |
+
|
| 1052 |
+
# `variance_proportion` is the proportion of the variance accounted for
|
| 1053 |
+
# by the top eigen-direction.
|
| 1054 |
+
variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20)
|
| 1055 |
+
|
| 1056 |
+
# ensure new direction is nonzero even if x == 0, by including `direction`.
|
| 1057 |
+
self._set_direction(0.1 * self.max_eig_direction + new_direction)
|
| 1058 |
+
|
| 1059 |
+
if random.random() < 0.01 or __name__ == "__main__":
|
| 1060 |
+
logging.info(
|
| 1061 |
+
f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}"
|
| 1062 |
+
)
|
| 1063 |
+
|
| 1064 |
+
if variance_proportion >= self.max_var_per_eig:
|
| 1065 |
+
# The constraint is active. Note, we should quite rarely
|
| 1066 |
+
# reach here, only near the beginning of training if we are
|
| 1067 |
+
# starting to diverge, should this constraint be active.
|
| 1068 |
+
cur_prob = self.cur_prob
|
| 1069 |
+
self.cur_prob = 1.0 # next time, do the update with probability 1.0.
|
| 1070 |
+
return MaxEigLimiterFunction.apply(
|
| 1071 |
+
orig_x, coeffs, new_direction, self.channel_dim, self.scale
|
| 1072 |
+
)
|
| 1073 |
+
else:
|
| 1074 |
+
# let self.cur_prob exponentially approach self.min_prob, as
|
| 1075 |
+
# long as the constraint is inactive.
|
| 1076 |
+
self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob
|
| 1077 |
+
return orig_x
|
| 1078 |
+
|
| 1079 |
+
def _set_direction(self, direction: Tensor):
|
| 1080 |
+
"""
|
| 1081 |
+
Sets self.max_eig_direction to a normalized version of `direction`
|
| 1082 |
+
"""
|
| 1083 |
+
direction = direction.detach()
|
| 1084 |
+
direction = direction / direction.norm()
|
| 1085 |
+
direction_sum = direction.sum().item()
|
| 1086 |
+
if direction_sum - direction_sum == 0: # no inf/nan
|
| 1087 |
+
self.max_eig_direction[:] = direction
|
| 1088 |
+
else:
|
| 1089 |
+
logging.info(
|
| 1090 |
+
f"Warning: sum of direction in MaxEig is {direction_sum}, "
|
| 1091 |
+
"num_channels={self.num_channels}, channel_dim={self.channel_dim}"
|
| 1092 |
+
)
|
| 1093 |
+
|
| 1094 |
+
def _find_direction_coeffs(
|
| 1095 |
+
self, x: Tensor, prev_direction: Tensor
|
| 1096 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
| 1097 |
+
"""
|
| 1098 |
+
Figure out (an approximation to) the proportion of the variance of a set of
|
| 1099 |
+
feature vectors that can be attributed to the top eigen-direction.
|
| 1100 |
+
Args:
|
| 1101 |
+
x: a Tensor of shape (num_frames, num_channels), with num_frames > 1.
|
| 1102 |
+
prev_direction: a Tensor of shape (num_channels,), that is our previous estimate
|
| 1103 |
+
of the top eigen-direction, or a random direction if this is the first
|
| 1104 |
+
iteration. Does not have to be normalized, but should be nonzero.
|
| 1105 |
+
|
| 1106 |
+
Returns: (cur_direction, coeffs), where:
|
| 1107 |
+
cur_direction: a Tensor of shape (num_channels,) that is the current
|
| 1108 |
+
estimate of the top eigen-direction.
|
| 1109 |
+
coeffs: a Tensor of shape (num_frames, 1) that minimizes, or
|
| 1110 |
+
approximately minimizes, (x - coeffs * cur_direction).norm()
|
| 1111 |
+
"""
|
| 1112 |
+
(num_frames, num_channels) = x.shape
|
| 1113 |
+
assert num_channels > 1 and num_frames > 1
|
| 1114 |
+
assert prev_direction.shape == (num_channels,)
|
| 1115 |
+
# `coeffs` are the coefficients of `prev_direction` in x.
|
| 1116 |
+
# actually represent the coeffs up to a constant positive factor.
|
| 1117 |
+
coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10
|
| 1118 |
+
cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20)
|
| 1119 |
+
return cur_direction, coeffs
|
| 1120 |
+
|
| 1121 |
+
|
| 1122 |
+
class DoubleSwishFunction(torch.autograd.Function):
|
| 1123 |
+
"""
|
| 1124 |
+
double_swish(x) = x * torch.sigmoid(x-1)
|
| 1125 |
+
This is a definition, originally motivated by its close numerical
|
| 1126 |
+
similarity to swish(swish(x)), where swish(x) = x * sigmoid(x).
|
| 1127 |
+
|
| 1128 |
+
Memory-efficient derivative computation:
|
| 1129 |
+
double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1)
|
| 1130 |
+
double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x).
|
| 1131 |
+
Now, s'(x) = s(x) * (1-s(x)).
|
| 1132 |
+
double_swish'(x) = x * s'(x) + s(x).
|
| 1133 |
+
= x * s(x) * (1-s(x)) + s(x).
|
| 1134 |
+
= double_swish(x) * (1-s(x)) + s(x)
|
| 1135 |
+
... so we just need to remember s(x) but not x itself.
|
| 1136 |
+
"""
|
| 1137 |
+
|
| 1138 |
+
@staticmethod
|
| 1139 |
+
def forward(ctx, x: Tensor) -> Tensor:
|
| 1140 |
+
requires_grad = x.requires_grad
|
| 1141 |
+
x_dtype = x.dtype
|
| 1142 |
+
if x.dtype == torch.float16:
|
| 1143 |
+
x = x.to(torch.float32)
|
| 1144 |
+
|
| 1145 |
+
s = torch.sigmoid(x - 1.0)
|
| 1146 |
+
y = x * s
|
| 1147 |
+
|
| 1148 |
+
if requires_grad:
|
| 1149 |
+
deriv = y * (1 - s) + s
|
| 1150 |
+
# notes on derivative of x * sigmoid(x - 1):
|
| 1151 |
+
# https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29
|
| 1152 |
+
# min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund
|
| 1153 |
+
# max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound.
|
| 1154 |
+
# the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which
|
| 1155 |
+
# floors), should be expectation-preserving.
|
| 1156 |
+
floor = -0.043637
|
| 1157 |
+
ceil = 1.2
|
| 1158 |
+
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
| 1159 |
+
deriv
|
| 1160 |
+
)
|
| 1161 |
+
if __name__ == "__main__":
|
| 1162 |
+
# for self-testing only.
|
| 1163 |
+
assert d_scaled.min() >= 0.0
|
| 1164 |
+
assert d_scaled.max() < 256.0
|
| 1165 |
+
d_int = d_scaled.to(torch.uint8)
|
| 1166 |
+
ctx.save_for_backward(d_int)
|
| 1167 |
+
if x.dtype == torch.float16 or torch.is_autocast_enabled():
|
| 1168 |
+
y = y.to(torch.float16)
|
| 1169 |
+
return y
|
| 1170 |
+
|
| 1171 |
+
@staticmethod
|
| 1172 |
+
def backward(ctx, y_grad: Tensor) -> Tensor:
|
| 1173 |
+
(d,) = ctx.saved_tensors
|
| 1174 |
+
# the same constants as used in forward pass.
|
| 1175 |
+
floor = -0.043637
|
| 1176 |
+
ceil = 1.2
|
| 1177 |
+
d = d * ((ceil - floor) / 255.0) + floor
|
| 1178 |
+
return y_grad * d
|
| 1179 |
+
|
| 1180 |
+
|
| 1181 |
+
class DoubleSwish(torch.nn.Module):
|
| 1182 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 1183 |
+
"""Return double-swish activation function which is an approximation to Swish(Swish(x)),
|
| 1184 |
+
that we approximate closely with x * sigmoid(x-1).
|
| 1185 |
+
"""
|
| 1186 |
+
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
| 1187 |
+
return x * torch.sigmoid(x - 1.0)
|
| 1188 |
+
return DoubleSwishFunction.apply(x)
|
| 1189 |
+
|
| 1190 |
+
|
| 1191 |
+
def BalancedDoubleSwish(
|
| 1192 |
+
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
|
| 1193 |
+
) -> nn.Sequential:
|
| 1194 |
+
"""
|
| 1195 |
+
ActivationBalancer -> DoubleSwish
|
| 1196 |
+
"""
|
| 1197 |
+
balancer = ActivationBalancer(
|
| 1198 |
+
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
|
| 1199 |
+
)
|
| 1200 |
+
return nn.Sequential(
|
| 1201 |
+
balancer,
|
| 1202 |
+
DoubleSwish(),
|
| 1203 |
+
)
|
| 1204 |
+
|
| 1205 |
+
|
| 1206 |
+
def _test_max_eig():
|
| 1207 |
+
for proportion in [0.1, 0.5, 10.0]:
|
| 1208 |
+
logging.info(f"proportion = {proportion}")
|
| 1209 |
+
x = torch.randn(100, 128)
|
| 1210 |
+
direction = torch.randn(128)
|
| 1211 |
+
coeffs = torch.randn(100, 1)
|
| 1212 |
+
x += proportion * direction * coeffs
|
| 1213 |
+
|
| 1214 |
+
x.requires_grad = True
|
| 1215 |
+
|
| 1216 |
+
num_channels = 128
|
| 1217 |
+
m = MaxEig(
|
| 1218 |
+
num_channels,
|
| 1219 |
+
1,
|
| 1220 |
+
0.5,
|
| 1221 |
+
scale=0.1, # channel_dim # max_var_per_eig
|
| 1222 |
+
) # grad_scale
|
| 1223 |
+
|
| 1224 |
+
for _ in range(4):
|
| 1225 |
+
y = m(x)
|
| 1226 |
+
|
| 1227 |
+
y_grad = torch.randn_like(x)
|
| 1228 |
+
y.backward(gradient=y_grad)
|
| 1229 |
+
|
| 1230 |
+
if proportion < 0.2:
|
| 1231 |
+
assert torch.allclose(x.grad, y_grad, atol=1.0e-02)
|
| 1232 |
+
elif proportion > 1.0:
|
| 1233 |
+
assert not torch.allclose(x.grad, y_grad)
|
| 1234 |
+
|
| 1235 |
+
|
| 1236 |
+
def _test_whiten():
|
| 1237 |
+
for proportion in [0.1, 0.5, 10.0]:
|
| 1238 |
+
logging.info(f"_test_whiten(): proportion = {proportion}")
|
| 1239 |
+
x = torch.randn(100, 128)
|
| 1240 |
+
direction = torch.randn(128)
|
| 1241 |
+
coeffs = torch.randn(100, 1)
|
| 1242 |
+
x += proportion * direction * coeffs
|
| 1243 |
+
|
| 1244 |
+
x.requires_grad = True
|
| 1245 |
+
|
| 1246 |
+
num_channels = 128
|
| 1247 |
+
m = Whiten(
|
| 1248 |
+
1,
|
| 1249 |
+
5.0,
|
| 1250 |
+
prob=1.0,
|
| 1251 |
+
grad_scale=0.1, # num_groups # whitening_limit,
|
| 1252 |
+
) # grad_scale
|
| 1253 |
+
|
| 1254 |
+
for _ in range(4):
|
| 1255 |
+
y = m(x)
|
| 1256 |
+
|
| 1257 |
+
y_grad = torch.randn_like(x)
|
| 1258 |
+
y.backward(gradient=y_grad)
|
| 1259 |
+
|
| 1260 |
+
if proportion < 0.2:
|
| 1261 |
+
assert torch.allclose(x.grad, y_grad)
|
| 1262 |
+
elif proportion > 1.0:
|
| 1263 |
+
assert not torch.allclose(x.grad, y_grad)
|
| 1264 |
+
|
| 1265 |
+
|
| 1266 |
+
def _test_activation_balancer_sign():
|
| 1267 |
+
probs = torch.arange(0, 1, 0.01)
|
| 1268 |
+
N = 1000
|
| 1269 |
+
x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
|
| 1270 |
+
x = x.detach()
|
| 1271 |
+
x.requires_grad = True
|
| 1272 |
+
m = ActivationBalancer(
|
| 1273 |
+
probs.numel(),
|
| 1274 |
+
channel_dim=0,
|
| 1275 |
+
min_positive=0.05,
|
| 1276 |
+
max_positive=0.95,
|
| 1277 |
+
max_factor=0.2,
|
| 1278 |
+
min_abs=0.0,
|
| 1279 |
+
)
|
| 1280 |
+
|
| 1281 |
+
y_grad = torch.sign(torch.randn(probs.numel(), N))
|
| 1282 |
+
|
| 1283 |
+
y = m(x)
|
| 1284 |
+
y.backward(gradient=y_grad)
|
| 1285 |
+
print("_test_activation_balancer_sign: x = ", x)
|
| 1286 |
+
print("_test_activation_balancer_sign: y grad = ", y_grad)
|
| 1287 |
+
print("_test_activation_balancer_sign: x grad = ", x.grad)
|
| 1288 |
+
|
| 1289 |
+
|
| 1290 |
+
def _test_activation_balancer_magnitude():
|
| 1291 |
+
magnitudes = torch.arange(0, 1, 0.01)
|
| 1292 |
+
N = 1000
|
| 1293 |
+
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
| 1294 |
+
x = x.detach()
|
| 1295 |
+
x.requires_grad = True
|
| 1296 |
+
m = ActivationBalancer(
|
| 1297 |
+
magnitudes.numel(),
|
| 1298 |
+
channel_dim=0,
|
| 1299 |
+
min_positive=0.0,
|
| 1300 |
+
max_positive=1.0,
|
| 1301 |
+
max_factor=0.2,
|
| 1302 |
+
min_abs=0.2,
|
| 1303 |
+
max_abs=0.8,
|
| 1304 |
+
min_prob=1.0,
|
| 1305 |
+
)
|
| 1306 |
+
|
| 1307 |
+
y_grad = torch.sign(torch.randn(magnitudes.numel(), N))
|
| 1308 |
+
|
| 1309 |
+
y = m(x)
|
| 1310 |
+
y.backward(gradient=y_grad)
|
| 1311 |
+
print("_test_activation_balancer_magnitude: x = ", x)
|
| 1312 |
+
print("_test_activation_balancer_magnitude: y grad = ", y_grad)
|
| 1313 |
+
print("_test_activation_balancer_magnitude: x grad = ", x.grad)
|
| 1314 |
+
|
| 1315 |
+
|
| 1316 |
+
def _test_basic_norm():
|
| 1317 |
+
num_channels = 128
|
| 1318 |
+
m = BasicNorm(num_channels=num_channels, channel_dim=1)
|
| 1319 |
+
|
| 1320 |
+
x = torch.randn(500, num_channels)
|
| 1321 |
+
|
| 1322 |
+
y = m(x)
|
| 1323 |
+
|
| 1324 |
+
assert y.shape == x.shape
|
| 1325 |
+
x_rms = (x**2).mean().sqrt()
|
| 1326 |
+
y_rms = (y**2).mean().sqrt()
|
| 1327 |
+
print("x rms = ", x_rms)
|
| 1328 |
+
print("y rms = ", y_rms)
|
| 1329 |
+
assert y_rms < x_rms
|
| 1330 |
+
assert y_rms > 0.5 * x_rms
|
| 1331 |
+
|
| 1332 |
+
|
| 1333 |
+
def _test_double_swish_deriv():
|
| 1334 |
+
x = torch.randn(10, 12, dtype=torch.double) * 3.0
|
| 1335 |
+
x.requires_grad = True
|
| 1336 |
+
m = DoubleSwish()
|
| 1337 |
+
|
| 1338 |
+
tol = (1.2 - (-0.043637)) / 255.0
|
| 1339 |
+
torch.autograd.gradcheck(m, x, atol=tol)
|
| 1340 |
+
|
| 1341 |
+
# for self-test.
|
| 1342 |
+
x = torch.randn(1000, 1000, dtype=torch.double) * 3.0
|
| 1343 |
+
x.requires_grad = True
|
| 1344 |
+
y = m(x)
|
| 1345 |
+
|
| 1346 |
+
|
| 1347 |
+
def _test_softmax():
|
| 1348 |
+
a = torch.randn(2, 10, dtype=torch.float64)
|
| 1349 |
+
b = a.clone()
|
| 1350 |
+
a.requires_grad = True
|
| 1351 |
+
b.requires_grad = True
|
| 1352 |
+
a.softmax(dim=1)[:, 0].sum().backward()
|
| 1353 |
+
print("a grad = ", a.grad)
|
| 1354 |
+
softmax(b, dim=1)[:, 0].sum().backward()
|
| 1355 |
+
print("b grad = ", b.grad)
|
| 1356 |
+
assert torch.allclose(a.grad, b.grad)
|
| 1357 |
+
|
| 1358 |
+
|
| 1359 |
+
if __name__ == "__main__":
|
| 1360 |
+
logging.getLogger().setLevel(logging.INFO)
|
| 1361 |
+
torch.set_num_threads(1)
|
| 1362 |
+
torch.set_num_interop_threads(1)
|
| 1363 |
+
_test_softmax()
|
| 1364 |
+
_test_whiten()
|
| 1365 |
+
_test_max_eig()
|
| 1366 |
+
_test_activation_balancer_sign()
|
| 1367 |
+
_test_activation_balancer_magnitude()
|
| 1368 |
+
_test_basic_norm()
|
| 1369 |
+
_test_double_swish_deriv()
|
apps/audio_cloning/vallex/modules/scheduler.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright 2023 (authors: Feiteng Li)
|
| 3 |
+
#
|
| 4 |
+
# See ../../../../LICENSE for clarification regarding multiple authors
|
| 5 |
+
#
|
| 6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 7 |
+
# you may not use this file except in compliance with the License.
|
| 8 |
+
# You may obtain a copy of the License at
|
| 9 |
+
#
|
| 10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 11 |
+
#
|
| 12 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 15 |
+
# See the License for the specific language governing permissions and
|
| 16 |
+
# limitations under the License.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from modules.optim import Eden
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def calc_lr(step, dim_embed, warmup_steps):
|
| 25 |
+
return dim_embed ** (-0.5) * min(
|
| 26 |
+
step ** (-0.5), step * warmup_steps ** (-1.5)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
base_lr: float,
|
| 34 |
+
optimizer: torch.optim.Optimizer,
|
| 35 |
+
dim_embed: int,
|
| 36 |
+
warmup_steps: int,
|
| 37 |
+
last_epoch: int = -1,
|
| 38 |
+
verbose: bool = False,
|
| 39 |
+
) -> None:
|
| 40 |
+
|
| 41 |
+
self.dim_embed = dim_embed
|
| 42 |
+
self.base_lr = base_lr
|
| 43 |
+
self.warmup_steps = warmup_steps
|
| 44 |
+
self.num_param_groups = len(optimizer.param_groups)
|
| 45 |
+
|
| 46 |
+
super().__init__(optimizer, last_epoch, verbose)
|
| 47 |
+
|
| 48 |
+
def get_lr(self) -> float:
|
| 49 |
+
lr = self.base_lr * calc_lr(
|
| 50 |
+
self._step_count, self.dim_embed, self.warmup_steps
|
| 51 |
+
)
|
| 52 |
+
return [lr] * self.num_param_groups
|
| 53 |
+
|
| 54 |
+
def set_step(self, step: int):
|
| 55 |
+
self._step_count = step
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_scheduler(params, optimizer):
|
| 59 |
+
if params.scheduler_name.lower() == "eden":
|
| 60 |
+
scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
|
| 61 |
+
elif params.scheduler_name.lower() == "noam":
|
| 62 |
+
scheduler = NoamScheduler(
|
| 63 |
+
params.base_lr,
|
| 64 |
+
optimizer,
|
| 65 |
+
params.decoder_dim,
|
| 66 |
+
warmup_steps=params.warmup_steps,
|
| 67 |
+
)
|
| 68 |
+
# scheduler.set_step(params.start_batch or params.batch_idx_train)
|
| 69 |
+
elif params.scheduler_name.lower() == "cosine":
|
| 70 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 71 |
+
params.warmup_steps,
|
| 72 |
+
optimizer,
|
| 73 |
+
eta_min=params.base_lr,
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
raise NotImplementedError(f"{params.scheduler_name}")
|
| 77 |
+
|
| 78 |
+
return scheduler
|
apps/audio_cloning/vallex/modules/transformer.py
ADDED
|
@@ -0,0 +1,683 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import numbers
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor, nn
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
|
| 10 |
+
from .activation import MultiheadAttention
|
| 11 |
+
from .scaling import ActivationBalancer, BalancedDoubleSwish
|
| 12 |
+
from .scaling import BasicNorm as _BasicNorm
|
| 13 |
+
|
| 14 |
+
_shape_t = Union[int, List[int], torch.Size]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class LayerNorm(nn.Module):
|
| 18 |
+
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
| 19 |
+
normalized_shape: Tuple[int, ...]
|
| 20 |
+
eps: float
|
| 21 |
+
elementwise_affine: bool
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
normalized_shape: _shape_t,
|
| 26 |
+
eps: float = 1e-5,
|
| 27 |
+
elementwise_affine: bool = True,
|
| 28 |
+
device=None,
|
| 29 |
+
dtype=None,
|
| 30 |
+
) -> None:
|
| 31 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 32 |
+
super(LayerNorm, self).__init__()
|
| 33 |
+
if isinstance(normalized_shape, numbers.Integral):
|
| 34 |
+
# mypy error: incompatible types in assignment
|
| 35 |
+
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
| 36 |
+
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
| 37 |
+
self.eps = eps
|
| 38 |
+
self.elementwise_affine = elementwise_affine
|
| 39 |
+
if self.elementwise_affine:
|
| 40 |
+
self.weight = nn.Parameter(
|
| 41 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
| 42 |
+
)
|
| 43 |
+
self.bias = nn.Parameter(
|
| 44 |
+
torch.empty(self.normalized_shape, **factory_kwargs)
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
self.register_parameter("weight", None)
|
| 48 |
+
self.register_parameter("bias", None)
|
| 49 |
+
|
| 50 |
+
self.reset_parameters()
|
| 51 |
+
|
| 52 |
+
def reset_parameters(self) -> None:
|
| 53 |
+
if self.elementwise_affine:
|
| 54 |
+
nn.init.ones_(self.weight)
|
| 55 |
+
nn.init.zeros_(self.bias)
|
| 56 |
+
|
| 57 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
| 58 |
+
if isinstance(input, tuple):
|
| 59 |
+
input, embedding = input
|
| 60 |
+
return (
|
| 61 |
+
F.layer_norm(
|
| 62 |
+
input,
|
| 63 |
+
self.normalized_shape,
|
| 64 |
+
self.weight,
|
| 65 |
+
self.bias,
|
| 66 |
+
self.eps,
|
| 67 |
+
),
|
| 68 |
+
embedding,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
assert embedding is None
|
| 72 |
+
return F.layer_norm(
|
| 73 |
+
input, self.normalized_shape, self.weight, self.bias, self.eps
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def extra_repr(self) -> str:
|
| 77 |
+
return (
|
| 78 |
+
"{normalized_shape}, eps={eps}, "
|
| 79 |
+
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class AdaptiveLayerNorm(nn.Module):
|
| 84 |
+
r"""Adaptive Layer Normalization"""
|
| 85 |
+
|
| 86 |
+
def __init__(self, d_model, norm) -> None:
|
| 87 |
+
super(AdaptiveLayerNorm, self).__init__()
|
| 88 |
+
self.project_layer = nn.Linear(d_model, 2 * d_model)
|
| 89 |
+
self.norm = norm
|
| 90 |
+
self.d_model = d_model
|
| 91 |
+
self.eps = self.norm.eps
|
| 92 |
+
|
| 93 |
+
def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
|
| 94 |
+
if isinstance(input, tuple):
|
| 95 |
+
input, embedding = input
|
| 96 |
+
weight, bias = torch.split(
|
| 97 |
+
self.project_layer(embedding),
|
| 98 |
+
split_size_or_sections=self.d_model,
|
| 99 |
+
dim=-1,
|
| 100 |
+
)
|
| 101 |
+
return (weight * self.norm(input) + bias, embedding)
|
| 102 |
+
|
| 103 |
+
weight, bias = torch.split(
|
| 104 |
+
self.project_layer(embedding),
|
| 105 |
+
split_size_or_sections=self.d_model,
|
| 106 |
+
dim=-1,
|
| 107 |
+
)
|
| 108 |
+
return weight * self.norm(input) + bias
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class BasicNorm(_BasicNorm):
|
| 112 |
+
def __init__(
|
| 113 |
+
self,
|
| 114 |
+
d_model: int,
|
| 115 |
+
eps: float = 1e-5,
|
| 116 |
+
device=None,
|
| 117 |
+
dtype=None,
|
| 118 |
+
):
|
| 119 |
+
super(BasicNorm, self).__init__(d_model, eps=eps)
|
| 120 |
+
|
| 121 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
| 122 |
+
if isinstance(input, tuple):
|
| 123 |
+
input, embedding = input
|
| 124 |
+
return (
|
| 125 |
+
super(BasicNorm, self).forward(input),
|
| 126 |
+
embedding,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
assert embedding is None
|
| 130 |
+
return super(BasicNorm, self).forward(input)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class BalancedBasicNorm(nn.Module):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
d_model: int,
|
| 137 |
+
eps: float = 1e-5,
|
| 138 |
+
device=None,
|
| 139 |
+
dtype=None,
|
| 140 |
+
):
|
| 141 |
+
super(BalancedBasicNorm, self).__init__()
|
| 142 |
+
self.balancer = ActivationBalancer(
|
| 143 |
+
d_model,
|
| 144 |
+
channel_dim=-1,
|
| 145 |
+
min_positive=0.45,
|
| 146 |
+
max_positive=0.55,
|
| 147 |
+
max_abs=6.0,
|
| 148 |
+
)
|
| 149 |
+
self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
|
| 150 |
+
|
| 151 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
| 152 |
+
if isinstance(input, tuple):
|
| 153 |
+
input, embedding = input
|
| 154 |
+
return self.norm((self.balancer(input), embedding))
|
| 155 |
+
|
| 156 |
+
assert embedding is None
|
| 157 |
+
return self.norm(self.balancer(input))
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class IdentityNorm(nn.Module):
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
d_model: int,
|
| 164 |
+
eps: float = 1e-5,
|
| 165 |
+
device=None,
|
| 166 |
+
dtype=None,
|
| 167 |
+
) -> None:
|
| 168 |
+
super(IdentityNorm, self).__init__()
|
| 169 |
+
|
| 170 |
+
def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
|
| 171 |
+
if isinstance(input, tuple):
|
| 172 |
+
return input
|
| 173 |
+
|
| 174 |
+
assert embedding is None
|
| 175 |
+
return input
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class TransformerEncoderLayer(nn.Module):
|
| 179 |
+
__constants__ = ["batch_first", "norm_first"]
|
| 180 |
+
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
d_model: int,
|
| 184 |
+
nhead: int,
|
| 185 |
+
dim_feedforward: int = 2048,
|
| 186 |
+
dropout: float = 0.1,
|
| 187 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
| 188 |
+
batch_first: bool = False,
|
| 189 |
+
norm_first: bool = False,
|
| 190 |
+
device=None,
|
| 191 |
+
dtype=None,
|
| 192 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
| 193 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
| 194 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
| 195 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
| 196 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
| 197 |
+
layer_norm_eps: float = 1e-5,
|
| 198 |
+
adaptive_layer_norm=False,
|
| 199 |
+
) -> None:
|
| 200 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 201 |
+
super(TransformerEncoderLayer, self).__init__()
|
| 202 |
+
self.self_attn = MultiheadAttention(
|
| 203 |
+
d_model,
|
| 204 |
+
nhead,
|
| 205 |
+
dropout=dropout,
|
| 206 |
+
batch_first=batch_first,
|
| 207 |
+
linear1_cls=linear1_self_attention_cls,
|
| 208 |
+
linear2_cls=linear2_self_attention_cls,
|
| 209 |
+
**factory_kwargs,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Implementation of Feedforward model
|
| 213 |
+
self.linear1 = linear1_feedforward_cls(
|
| 214 |
+
d_model, dim_feedforward, **factory_kwargs
|
| 215 |
+
)
|
| 216 |
+
self.dropout = nn.Dropout(dropout)
|
| 217 |
+
self.linear2 = linear2_feedforward_cls(
|
| 218 |
+
dim_feedforward, d_model, **factory_kwargs
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
self.norm_first = norm_first
|
| 222 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 223 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 224 |
+
|
| 225 |
+
# Legacy string support for activation function.
|
| 226 |
+
if isinstance(activation, str):
|
| 227 |
+
activation = _get_activation_fn(activation)
|
| 228 |
+
elif isinstance(activation, partial):
|
| 229 |
+
activation = activation(d_model)
|
| 230 |
+
elif activation == BalancedDoubleSwish:
|
| 231 |
+
activation = BalancedDoubleSwish(d_model)
|
| 232 |
+
|
| 233 |
+
# # We can't test self.activation in forward() in TorchScript,
|
| 234 |
+
# # so stash some information about it instead.
|
| 235 |
+
# if activation is F.relu or isinstance(activation, torch.nn.ReLU):
|
| 236 |
+
# self.activation_relu_or_gelu = 1
|
| 237 |
+
# elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
|
| 238 |
+
# self.activation_relu_or_gelu = 2
|
| 239 |
+
# else:
|
| 240 |
+
# self.activation_relu_or_gelu = 0
|
| 241 |
+
self.activation = activation
|
| 242 |
+
|
| 243 |
+
norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
|
| 244 |
+
if layer_norm_cls == IdentityNorm:
|
| 245 |
+
norm2 = BalancedBasicNorm(
|
| 246 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
| 247 |
+
)
|
| 248 |
+
else:
|
| 249 |
+
norm2 = layer_norm_cls(
|
| 250 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
if adaptive_layer_norm:
|
| 254 |
+
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
| 255 |
+
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
| 256 |
+
else:
|
| 257 |
+
self.norm1 = norm1
|
| 258 |
+
self.norm2 = norm2
|
| 259 |
+
|
| 260 |
+
def __setstate__(self, state):
|
| 261 |
+
super(TransformerEncoderLayer, self).__setstate__(state)
|
| 262 |
+
if not hasattr(self, "activation"):
|
| 263 |
+
self.activation = F.relu
|
| 264 |
+
|
| 265 |
+
def forward(
|
| 266 |
+
self,
|
| 267 |
+
src: Tensor,
|
| 268 |
+
src_mask: Optional[Tensor] = None,
|
| 269 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 270 |
+
) -> Tensor:
|
| 271 |
+
r"""Pass the input through the encoder layer.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
src: the sequence to the encoder layer (required).
|
| 275 |
+
src_mask: the mask for the src sequence (optional).
|
| 276 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
| 277 |
+
|
| 278 |
+
Shape:
|
| 279 |
+
see the docs in Transformer class.
|
| 280 |
+
"""
|
| 281 |
+
x, stage_embedding = src, None
|
| 282 |
+
is_src_tuple = False
|
| 283 |
+
if isinstance(src, tuple):
|
| 284 |
+
x, stage_embedding = src
|
| 285 |
+
is_src_tuple = True
|
| 286 |
+
|
| 287 |
+
if src_key_padding_mask is not None:
|
| 288 |
+
_skpm_dtype = src_key_padding_mask.dtype
|
| 289 |
+
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
| 290 |
+
src_key_padding_mask
|
| 291 |
+
):
|
| 292 |
+
raise AssertionError(
|
| 293 |
+
"only bool and floating types of key_padding_mask are supported"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
if self.norm_first:
|
| 297 |
+
x = x + self._sa_block(
|
| 298 |
+
self.norm1(x, stage_embedding),
|
| 299 |
+
src_mask,
|
| 300 |
+
src_key_padding_mask,
|
| 301 |
+
)
|
| 302 |
+
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
| 303 |
+
else:
|
| 304 |
+
x = self.norm1(
|
| 305 |
+
x + self._sa_block(x, src_mask, src_key_padding_mask),
|
| 306 |
+
stage_embedding,
|
| 307 |
+
)
|
| 308 |
+
x = self.norm2(x + self._ff_block(x), stage_embedding)
|
| 309 |
+
|
| 310 |
+
if is_src_tuple:
|
| 311 |
+
return (x, stage_embedding)
|
| 312 |
+
return x
|
| 313 |
+
|
| 314 |
+
def infer(
|
| 315 |
+
self,
|
| 316 |
+
src: Tensor,
|
| 317 |
+
src_mask: Optional[Tensor] = None,
|
| 318 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 319 |
+
past_kv: Optional[Tensor] = None,
|
| 320 |
+
use_cache: bool = False,
|
| 321 |
+
):
|
| 322 |
+
x, stage_embedding = src, None
|
| 323 |
+
is_src_tuple = False
|
| 324 |
+
if isinstance(src, tuple):
|
| 325 |
+
x, stage_embedding = src
|
| 326 |
+
is_src_tuple = True
|
| 327 |
+
|
| 328 |
+
if src_key_padding_mask is not None:
|
| 329 |
+
_skpm_dtype = src_key_padding_mask.dtype
|
| 330 |
+
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
| 331 |
+
src_key_padding_mask
|
| 332 |
+
):
|
| 333 |
+
raise AssertionError(
|
| 334 |
+
"only bool and floating types of key_padding_mask are supported"
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if self.norm_first:
|
| 338 |
+
x_attn_out, kv = self.self_attn.infer(
|
| 339 |
+
self.norm1(x, stage_embedding),
|
| 340 |
+
attn_mask=src_mask,
|
| 341 |
+
key_padding_mask=src_key_padding_mask,
|
| 342 |
+
need_weights=False,
|
| 343 |
+
past_kv=past_kv,
|
| 344 |
+
use_cache=use_cache,
|
| 345 |
+
)
|
| 346 |
+
x = x + x_attn_out
|
| 347 |
+
x = x + self._ff_block(self.norm2(x, stage_embedding))
|
| 348 |
+
|
| 349 |
+
if is_src_tuple:
|
| 350 |
+
return (x, stage_embedding)
|
| 351 |
+
return (x, kv)
|
| 352 |
+
|
| 353 |
+
# self-attention block
|
| 354 |
+
def _sa_block(
|
| 355 |
+
self,
|
| 356 |
+
x: Tensor,
|
| 357 |
+
attn_mask: Optional[Tensor],
|
| 358 |
+
key_padding_mask: Optional[Tensor],
|
| 359 |
+
) -> Tensor:
|
| 360 |
+
x = self.self_attn(
|
| 361 |
+
x,
|
| 362 |
+
x,
|
| 363 |
+
x,
|
| 364 |
+
attn_mask=attn_mask,
|
| 365 |
+
key_padding_mask=key_padding_mask,
|
| 366 |
+
need_weights=False,
|
| 367 |
+
)[0]
|
| 368 |
+
return self.dropout1(x)
|
| 369 |
+
|
| 370 |
+
# feed forward block
|
| 371 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
| 372 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 373 |
+
return self.dropout2(x)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class TransformerEncoder(nn.Module):
|
| 377 |
+
r"""TransformerEncoder is a stack of N encoder layers. Users can build the
|
| 378 |
+
BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
|
| 379 |
+
|
| 380 |
+
Args:
|
| 381 |
+
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
| 382 |
+
num_layers: the number of sub-encoder-layers in the encoder (required).
|
| 383 |
+
norm: the layer normalization component (optional).
|
| 384 |
+
enable_nested_tensor: if True, input will automatically convert to nested tensor
|
| 385 |
+
(and convert back on output). This will improve the overall performance of
|
| 386 |
+
TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
|
| 387 |
+
|
| 388 |
+
Examples::
|
| 389 |
+
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
|
| 390 |
+
>>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
|
| 391 |
+
>>> src = torch.rand(10, 32, 512)
|
| 392 |
+
>>> out = transformer_encoder(src)
|
| 393 |
+
"""
|
| 394 |
+
__constants__ = ["norm"]
|
| 395 |
+
|
| 396 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
| 397 |
+
super(TransformerEncoder, self).__init__()
|
| 398 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
| 399 |
+
self.num_layers = num_layers
|
| 400 |
+
self.norm = norm
|
| 401 |
+
|
| 402 |
+
def forward(
|
| 403 |
+
self,
|
| 404 |
+
src: Tensor,
|
| 405 |
+
mask: Optional[Tensor] = None,
|
| 406 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 407 |
+
return_layer_states: bool = False,
|
| 408 |
+
) -> Tensor:
|
| 409 |
+
r"""Pass the input through the encoder layers in turn.
|
| 410 |
+
|
| 411 |
+
Args:
|
| 412 |
+
src: the sequence to the encoder (required).
|
| 413 |
+
mask: the mask for the src sequence (optional).
|
| 414 |
+
src_key_padding_mask: the mask for the src keys per batch (optional).
|
| 415 |
+
return_layer_states: return layers' state (optional).
|
| 416 |
+
|
| 417 |
+
Shape:
|
| 418 |
+
see the docs in Transformer class.
|
| 419 |
+
"""
|
| 420 |
+
if return_layer_states:
|
| 421 |
+
layer_states = [] # layers' output
|
| 422 |
+
output = src
|
| 423 |
+
for mod in self.layers:
|
| 424 |
+
output = mod(
|
| 425 |
+
output,
|
| 426 |
+
src_mask=mask,
|
| 427 |
+
src_key_padding_mask=src_key_padding_mask,
|
| 428 |
+
)
|
| 429 |
+
layer_states.append(output[0])
|
| 430 |
+
|
| 431 |
+
if self.norm is not None:
|
| 432 |
+
output = self.norm(output)
|
| 433 |
+
|
| 434 |
+
return layer_states, output
|
| 435 |
+
|
| 436 |
+
output = src
|
| 437 |
+
for mod in self.layers:
|
| 438 |
+
output = mod(
|
| 439 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
if self.norm is not None:
|
| 443 |
+
output = self.norm(output)
|
| 444 |
+
|
| 445 |
+
return output
|
| 446 |
+
|
| 447 |
+
def infer(
|
| 448 |
+
self,
|
| 449 |
+
src: Tensor,
|
| 450 |
+
mask: Optional[Tensor] = None,
|
| 451 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
| 452 |
+
return_layer_states: bool = False,
|
| 453 |
+
past_kv: Optional[Tensor] = None,
|
| 454 |
+
use_cache: bool = False,
|
| 455 |
+
):
|
| 456 |
+
if past_kv is None:
|
| 457 |
+
past_length = 0
|
| 458 |
+
past_kv = tuple([None] * self.num_layers)
|
| 459 |
+
else:
|
| 460 |
+
past_length = past_kv[0][0].size(-2)
|
| 461 |
+
new_kv = () if use_cache else None
|
| 462 |
+
output = src
|
| 463 |
+
for mod, past_layer_kv in zip(self.layers, past_kv):
|
| 464 |
+
output, kv = mod.infer(
|
| 465 |
+
output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, past_kv=past_layer_kv, use_cache=use_cache
|
| 466 |
+
)
|
| 467 |
+
if use_cache:
|
| 468 |
+
new_kv = new_kv + (kv,)
|
| 469 |
+
|
| 470 |
+
if self.norm is not None:
|
| 471 |
+
output = self.norm(output)
|
| 472 |
+
|
| 473 |
+
return output, new_kv
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
class TransformerDecoderLayer(nn.Module):
|
| 477 |
+
__constants__ = ["batch_first", "norm_first"]
|
| 478 |
+
|
| 479 |
+
def __init__(
|
| 480 |
+
self,
|
| 481 |
+
d_model: int,
|
| 482 |
+
nhead: int,
|
| 483 |
+
dim_feedforward: int = 2048,
|
| 484 |
+
dropout: float = 0.1,
|
| 485 |
+
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
|
| 486 |
+
linear1_self_attention_cls: nn.Module = nn.Linear,
|
| 487 |
+
linear2_self_attention_cls: nn.Module = nn.Linear,
|
| 488 |
+
linear1_feedforward_cls: nn.Module = nn.Linear,
|
| 489 |
+
linear2_feedforward_cls: nn.Module = nn.Linear,
|
| 490 |
+
batch_first: bool = False,
|
| 491 |
+
norm_first: bool = False,
|
| 492 |
+
device=None,
|
| 493 |
+
dtype=None,
|
| 494 |
+
layer_norm_cls: nn.Module = LayerNorm,
|
| 495 |
+
layer_norm_eps: float = 1e-5,
|
| 496 |
+
adaptive_layer_norm=False,
|
| 497 |
+
) -> None:
|
| 498 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 499 |
+
super(TransformerDecoderLayer, self).__init__()
|
| 500 |
+
self.self_attn = MultiheadAttention(
|
| 501 |
+
d_model,
|
| 502 |
+
nhead,
|
| 503 |
+
dropout=dropout,
|
| 504 |
+
batch_first=batch_first,
|
| 505 |
+
linear1_cls=linear1_self_attention_cls,
|
| 506 |
+
linear2_cls=linear2_self_attention_cls,
|
| 507 |
+
**factory_kwargs,
|
| 508 |
+
)
|
| 509 |
+
self.multihead_attn = MultiheadAttention(
|
| 510 |
+
d_model,
|
| 511 |
+
nhead,
|
| 512 |
+
dropout=dropout,
|
| 513 |
+
batch_first=batch_first,
|
| 514 |
+
linear1_cls=linear1_self_attention_cls,
|
| 515 |
+
linear2_cls=linear2_self_attention_cls,
|
| 516 |
+
**factory_kwargs,
|
| 517 |
+
)
|
| 518 |
+
# Implementation of Feedforward model
|
| 519 |
+
self.linear1 = linear1_feedforward_cls(
|
| 520 |
+
d_model, dim_feedforward, **factory_kwargs
|
| 521 |
+
)
|
| 522 |
+
self.dropout = nn.Dropout(dropout)
|
| 523 |
+
self.linear2 = linear2_feedforward_cls(
|
| 524 |
+
dim_feedforward, d_model, **factory_kwargs
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
self.norm_first = norm_first
|
| 528 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 529 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 530 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 531 |
+
|
| 532 |
+
# Legacy string support for activation function.
|
| 533 |
+
if isinstance(activation, str):
|
| 534 |
+
self.activation = _get_activation_fn(activation)
|
| 535 |
+
elif isinstance(activation, partial):
|
| 536 |
+
self.activation = activation(d_model)
|
| 537 |
+
elif activation == BalancedDoubleSwish:
|
| 538 |
+
self.activation = BalancedDoubleSwish(d_model)
|
| 539 |
+
else:
|
| 540 |
+
self.activation = activation
|
| 541 |
+
|
| 542 |
+
if adaptive_layer_norm:
|
| 543 |
+
norm1 = layer_norm_cls(
|
| 544 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
| 545 |
+
)
|
| 546 |
+
norm2 = layer_norm_cls(
|
| 547 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
| 548 |
+
)
|
| 549 |
+
norm3 = layer_norm_cls(
|
| 550 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
self.norm1 = AdaptiveLayerNorm(d_model, norm1)
|
| 554 |
+
self.norm2 = AdaptiveLayerNorm(d_model, norm2)
|
| 555 |
+
self.norm3 = AdaptiveLayerNorm(d_model, norm3)
|
| 556 |
+
else:
|
| 557 |
+
self.norm1 = layer_norm_cls(
|
| 558 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
| 559 |
+
)
|
| 560 |
+
self.norm2 = layer_norm_cls(
|
| 561 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
| 562 |
+
)
|
| 563 |
+
if layer_norm_cls == IdentityNorm:
|
| 564 |
+
self.norm3 = BalancedBasicNorm(
|
| 565 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
| 566 |
+
)
|
| 567 |
+
else:
|
| 568 |
+
self.norm3 = layer_norm_cls(
|
| 569 |
+
d_model, eps=layer_norm_eps, **factory_kwargs
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
def forward(
|
| 573 |
+
self,
|
| 574 |
+
tgt: Tensor,
|
| 575 |
+
memory: Tensor,
|
| 576 |
+
tgt_mask: Optional[Tensor] = None,
|
| 577 |
+
memory_mask: Optional[Tensor] = None,
|
| 578 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
| 579 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
| 580 |
+
) -> Tensor:
|
| 581 |
+
r"""Pass the inputs (and mask) through the decoder layer.
|
| 582 |
+
|
| 583 |
+
Args:
|
| 584 |
+
tgt: the sequence to the decoder layer (required).
|
| 585 |
+
memory: the sequence from the last layer of the encoder (required).
|
| 586 |
+
tgt_mask: the mask for the tgt sequence (optional).
|
| 587 |
+
memory_mask: the mask for the memory sequence (optional).
|
| 588 |
+
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
| 589 |
+
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
| 590 |
+
|
| 591 |
+
Shape:
|
| 592 |
+
see the docs in Transformer class.
|
| 593 |
+
"""
|
| 594 |
+
tgt_is_tuple = False
|
| 595 |
+
if isinstance(tgt, tuple):
|
| 596 |
+
x, stage_embedding = tgt
|
| 597 |
+
tgt_is_tuple = True
|
| 598 |
+
else:
|
| 599 |
+
x, stage_embedding = tgt, None
|
| 600 |
+
|
| 601 |
+
if self.norm_first:
|
| 602 |
+
x = x + self._sa_block(
|
| 603 |
+
self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
|
| 604 |
+
)
|
| 605 |
+
x = x + self._mha_block(
|
| 606 |
+
self.norm2(x, stage_embedding),
|
| 607 |
+
memory,
|
| 608 |
+
memory_mask,
|
| 609 |
+
memory_key_padding_mask,
|
| 610 |
+
)
|
| 611 |
+
x = x + self._ff_block(self.norm3(x, stage_embedding))
|
| 612 |
+
else:
|
| 613 |
+
x = self.norm1(
|
| 614 |
+
x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
|
| 615 |
+
stage_embedding,
|
| 616 |
+
)
|
| 617 |
+
x = self.norm2(
|
| 618 |
+
x
|
| 619 |
+
+ self._mha_block(
|
| 620 |
+
x, memory, memory_mask, memory_key_padding_mask
|
| 621 |
+
),
|
| 622 |
+
stage_embedding,
|
| 623 |
+
)
|
| 624 |
+
x = self.norm3(x + self._ff_block(x), stage_embedding)
|
| 625 |
+
|
| 626 |
+
if tgt_is_tuple:
|
| 627 |
+
return (x, stage_embedding)
|
| 628 |
+
return x
|
| 629 |
+
|
| 630 |
+
# self-attention block
|
| 631 |
+
def _sa_block(
|
| 632 |
+
self,
|
| 633 |
+
x: Tensor,
|
| 634 |
+
attn_mask: Optional[Tensor],
|
| 635 |
+
key_padding_mask: Optional[Tensor],
|
| 636 |
+
) -> Tensor:
|
| 637 |
+
x = self.self_attn(
|
| 638 |
+
x,
|
| 639 |
+
x,
|
| 640 |
+
x,
|
| 641 |
+
attn_mask=attn_mask,
|
| 642 |
+
key_padding_mask=key_padding_mask,
|
| 643 |
+
need_weights=False,
|
| 644 |
+
)[0]
|
| 645 |
+
return self.dropout1(x)
|
| 646 |
+
|
| 647 |
+
# multihead attention block
|
| 648 |
+
def _mha_block(
|
| 649 |
+
self,
|
| 650 |
+
x: Tensor,
|
| 651 |
+
mem: Tensor,
|
| 652 |
+
attn_mask: Optional[Tensor],
|
| 653 |
+
key_padding_mask: Optional[Tensor],
|
| 654 |
+
) -> Tensor:
|
| 655 |
+
x = self.multihead_attn(
|
| 656 |
+
x,
|
| 657 |
+
mem,
|
| 658 |
+
mem,
|
| 659 |
+
attn_mask=attn_mask,
|
| 660 |
+
key_padding_mask=key_padding_mask,
|
| 661 |
+
need_weights=False,
|
| 662 |
+
)[0]
|
| 663 |
+
return self.dropout2(x)
|
| 664 |
+
|
| 665 |
+
# feed forward block
|
| 666 |
+
def _ff_block(self, x: Tensor) -> Tensor:
|
| 667 |
+
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 668 |
+
return self.dropout3(x)
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def _get_clones(module, N):
|
| 672 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
|
| 676 |
+
if activation == "relu":
|
| 677 |
+
return F.relu
|
| 678 |
+
elif activation == "gelu":
|
| 679 |
+
return F.gelu
|
| 680 |
+
|
| 681 |
+
raise RuntimeError(
|
| 682 |
+
"activation should be relu/gelu, not {}".format(activation)
|
| 683 |
+
)
|
apps/audio_cloning/vallex/presets/acou_1.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:470ce66fc24a2d14e162343381f7d93ef0a3af51edf5fd37240c21f492b4e769
|
| 3 |
+
size 15650
|
apps/audio_cloning/vallex/presets/acou_2.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ec1c5328751cadeed5356d4264759799ad96d33ea8dd4f8a3d0a80dd8ddb0e74
|
| 3 |
+
size 15426
|
apps/audio_cloning/vallex/presets/acou_3.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:03f241b094a32b3f542e74374183c6d15e8b70ae73ceeafb11bfd4ee6b8b4a3a
|
| 3 |
+
size 15410
|
apps/audio_cloning/vallex/presets/acou_4.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:52b96f32863f13f84cf7ac4a27d2bc95cea70c350a037f4d1890b20b8da9501e
|
| 3 |
+
size 15506
|
apps/audio_cloning/vallex/presets/alan.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:28838c3f0b2f9f315b34e9b940f30641306f0cadc5c527857cd1cc408547ed1c
|
| 3 |
+
size 50002
|
apps/audio_cloning/vallex/presets/amused.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:df3e882f3a62805b9aaf300d81822cd4eddeafee480503b7b78e32be2085fb11
|
| 3 |
+
size 20882
|
apps/audio_cloning/vallex/presets/anger.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:959cec6dc0b30219db0d70cdd165fe00bbdc098165cf9d67ccdd1ecf7a5da5be
|
| 3 |
+
size 22090
|
apps/audio_cloning/vallex/presets/babara.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8106b2a98c3f70587f23ab46ed5bf73b1c9a770481c3620ab140bd3256010376
|
| 3 |
+
size 11526
|
apps/audio_cloning/vallex/presets/bronya.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02eaada2c3d58866c813887ed9f871587ef5a7e976abc23382ce46a17b208001
|
| 3 |
+
size 18106
|
apps/audio_cloning/vallex/presets/cafe.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d78d96f5829da8f69c327ff25958da5b451305fdc9c308f7e67f13cf8d640fea
|
| 3 |
+
size 22442
|
apps/audio_cloning/vallex/presets/dingzhen.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4d19167c65eefef5e42dfaa1919ff5149ca0a93cb052396a47d1f42f9865f5f8
|
| 3 |
+
size 18154
|
apps/audio_cloning/vallex/presets/disgust.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4443f0a395072700f2ec6101dbf2ad9d28968aa3e5809e384ea131832f894d7f
|
| 3 |
+
size 39386
|