Spaces:
Sleeping
Sleeping
primepake
commited on
Commit
·
f768eb3
1
Parent(s):
7940474
add s3 tokenizer
Browse files- speech/tools/S3Tokenizer/.flake8 +28 -0
- speech/tools/S3Tokenizer/.github/workflows/python-publish.yml +37 -0
- speech/tools/S3Tokenizer/.github/workflows/unit_test_cpu.yaml +47 -0
- speech/tools/S3Tokenizer/.gitignore +162 -0
- speech/tools/S3Tokenizer/.pre-commit-config.yaml +14 -0
- speech/tools/S3Tokenizer/LICENSE +201 -0
- speech/tools/S3Tokenizer/MANIFEST.in +4 -0
- speech/tools/S3Tokenizer/README.md +150 -0
- speech/tools/S3Tokenizer/requirements.txt +7 -0
- speech/tools/S3Tokenizer/s3tokenizer/__init__.py +153 -0
- speech/tools/S3Tokenizer/s3tokenizer/assets/mel_filters.npz +0 -0
- speech/tools/S3Tokenizer/s3tokenizer/cli.py +212 -0
- speech/tools/S3Tokenizer/s3tokenizer/model.py +546 -0
- speech/tools/S3Tokenizer/s3tokenizer/model_v2.py +604 -0
- speech/tools/S3Tokenizer/s3tokenizer/utils.py +390 -0
- speech/tools/S3Tokenizer/setup.py +37 -0
- speech/tools/S3Tokenizer/test/test_batch_efficiency.py +272 -0
- speech/tools/S3Tokenizer/test/test_onnx.py +377 -0
speech/tools/S3Tokenizer/.flake8
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[flake8]
|
| 2 |
+
# Suggested config from pytorch that we can adapt
|
| 3 |
+
select = B,C,E,F,N,P,T4,W,B9,TOR0,TOR1,TOR2
|
| 4 |
+
max-line-length = 120
|
| 5 |
+
# C408 ignored because we like the dict keyword argument syntax
|
| 6 |
+
# E501 is not flexible enough, we're using B950 instead
|
| 7 |
+
# N812 ignored because import torch.nn.functional as F is PyTorch convention
|
| 8 |
+
# N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
|
| 9 |
+
# E731 allow usage of assigning lambda expressions
|
| 10 |
+
# N803,N806 allow caps and mixed case in function params. This is to work with Triton kernel coding style.
|
| 11 |
+
ignore =
|
| 12 |
+
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,N803,N806
|
| 13 |
+
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
|
| 14 |
+
# to line this up with executable bit
|
| 15 |
+
EXE001,
|
| 16 |
+
# these ignores are from flake8-bugbear; please fix!
|
| 17 |
+
B007,B008,
|
| 18 |
+
optional-ascii-coding = True
|
| 19 |
+
exclude =
|
| 20 |
+
./.git,
|
| 21 |
+
./docs
|
| 22 |
+
./build
|
| 23 |
+
./scripts,
|
| 24 |
+
./venv,
|
| 25 |
+
*.pyi
|
| 26 |
+
.pre-commit-config.yaml
|
| 27 |
+
*.md
|
| 28 |
+
.flake8
|
speech/tools/S3Tokenizer/.github/workflows/python-publish.yml
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Release
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- main
|
| 7 |
+
jobs:
|
| 8 |
+
deploy:
|
| 9 |
+
runs-on: ubuntu-latest
|
| 10 |
+
steps:
|
| 11 |
+
- uses: actions/checkout@v3
|
| 12 |
+
- uses: actions-ecosystem/action-regex-match@v2
|
| 13 |
+
id: regex-match
|
| 14 |
+
with:
|
| 15 |
+
text: ${{ github.event.head_commit.message }}
|
| 16 |
+
regex: '^Release ([^ ]+)'
|
| 17 |
+
- name: Set up Python
|
| 18 |
+
uses: actions/setup-python@v4
|
| 19 |
+
with:
|
| 20 |
+
python-version: '3.8'
|
| 21 |
+
- name: Install dependencies
|
| 22 |
+
run: |
|
| 23 |
+
python -m pip install --upgrade pip
|
| 24 |
+
pip install build twine
|
| 25 |
+
- name: Release
|
| 26 |
+
if: ${{ steps.regex-match.outputs.match != '' }}
|
| 27 |
+
uses: softprops/action-gh-release@v1
|
| 28 |
+
with:
|
| 29 |
+
tag_name: v${{ steps.regex-match.outputs.group1 }}
|
| 30 |
+
- name: Build and publish
|
| 31 |
+
if: ${{ steps.regex-match.outputs.match != '' }}
|
| 32 |
+
env:
|
| 33 |
+
TWINE_USERNAME: __token__
|
| 34 |
+
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
|
| 35 |
+
run: |
|
| 36 |
+
python -m build
|
| 37 |
+
twine upload dist/*
|
speech/tools/S3Tokenizer/.github/workflows/unit_test_cpu.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: CPU Unit Test
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [ main ]
|
| 6 |
+
pull_request:
|
| 7 |
+
|
| 8 |
+
concurrency:
|
| 9 |
+
group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
|
| 10 |
+
cancel-in-progress: true
|
| 11 |
+
|
| 12 |
+
jobs:
|
| 13 |
+
unit-test:
|
| 14 |
+
runs-on: ${{ matrix.os }}
|
| 15 |
+
strategy:
|
| 16 |
+
max-parallel: 20
|
| 17 |
+
matrix:
|
| 18 |
+
os: [ubuntu-22.04]
|
| 19 |
+
python-version: [3.10.16]
|
| 20 |
+
steps:
|
| 21 |
+
- name: Cache Python Packages
|
| 22 |
+
uses: actions/cache@v4
|
| 23 |
+
with:
|
| 24 |
+
path: ~/.cache/pip
|
| 25 |
+
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
|
| 26 |
+
- name: Setup Python
|
| 27 |
+
uses: actions/setup-python@v4
|
| 28 |
+
with:
|
| 29 |
+
python-version: ${{ matrix.python-version }}
|
| 30 |
+
architecture: x64
|
| 31 |
+
- name: Fetch S3Tokenizer
|
| 32 |
+
uses: actions/checkout@v4
|
| 33 |
+
with:
|
| 34 |
+
fetch-depth: 0
|
| 35 |
+
ref: ${{ github.event.pull_request.head.ref || github.ref }}
|
| 36 |
+
- name: Install S3Tokenizer Dependencies
|
| 37 |
+
run: |
|
| 38 |
+
set -eux
|
| 39 |
+
sudo apt update && sudo apt install -y ffmpeg libsox-dev libsndfile1
|
| 40 |
+
pip install -e .
|
| 41 |
+
- name: Run Pytest
|
| 42 |
+
run: |
|
| 43 |
+
set -eux
|
| 44 |
+
pip install pytest onnxruntime
|
| 45 |
+
pytest --version
|
| 46 |
+
PYTHONPATH="${PYTHONPATH:-}:$(pwd)" pytest test/ -q
|
| 47 |
+
if [ $? != 0 ]; then exit 1; fi
|
speech/tools/S3Tokenizer/.gitignore
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 110 |
+
.pdm.toml
|
| 111 |
+
.pdm-python
|
| 112 |
+
.pdm-build/
|
| 113 |
+
|
| 114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 115 |
+
__pypackages__/
|
| 116 |
+
|
| 117 |
+
# Celery stuff
|
| 118 |
+
celerybeat-schedule
|
| 119 |
+
celerybeat.pid
|
| 120 |
+
|
| 121 |
+
# SageMath parsed files
|
| 122 |
+
*.sage.py
|
| 123 |
+
|
| 124 |
+
# Environments
|
| 125 |
+
.env
|
| 126 |
+
.venv
|
| 127 |
+
env/
|
| 128 |
+
venv/
|
| 129 |
+
ENV/
|
| 130 |
+
env.bak/
|
| 131 |
+
venv.bak/
|
| 132 |
+
|
| 133 |
+
# Spyder project settings
|
| 134 |
+
.spyderproject
|
| 135 |
+
.spyproject
|
| 136 |
+
|
| 137 |
+
# Rope project settings
|
| 138 |
+
.ropeproject
|
| 139 |
+
|
| 140 |
+
# mkdocs documentation
|
| 141 |
+
/site
|
| 142 |
+
|
| 143 |
+
# mypy
|
| 144 |
+
.mypy_cache/
|
| 145 |
+
.dmypy.json
|
| 146 |
+
dmypy.json
|
| 147 |
+
|
| 148 |
+
# Pyre type checker
|
| 149 |
+
.pyre/
|
| 150 |
+
|
| 151 |
+
# pytype static type analyzer
|
| 152 |
+
.pytype/
|
| 153 |
+
|
| 154 |
+
# Cython debug symbols
|
| 155 |
+
cython_debug/
|
| 156 |
+
|
| 157 |
+
# PyCharm
|
| 158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
+
#.idea/
|
speech/tools/S3Tokenizer/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
| 3 |
+
rev: v4.5.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: trailing-whitespace
|
| 6 |
+
exclude: 's3tokenizer/assets/.*'
|
| 7 |
+
- repo: https://github.com/pre-commit/mirrors-yapf
|
| 8 |
+
rev: 'v0.32.0'
|
| 9 |
+
hooks:
|
| 10 |
+
- id: yapf
|
| 11 |
+
- repo: https://github.com/pycqa/flake8
|
| 12 |
+
rev: '3.8.2'
|
| 13 |
+
hooks:
|
| 14 |
+
- id: flake8
|
speech/tools/S3Tokenizer/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
speech/tools/S3Tokenizer/MANIFEST.in
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
include requirements.txt
|
| 2 |
+
include README.md
|
| 3 |
+
include LICENSE
|
| 4 |
+
include s3tokenizer/assets/*
|
speech/tools/S3Tokenizer/README.md
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reverse Engineering of S3Tokenizer
|
| 2 |
+
|
| 3 |
+
<div align="center">
|
| 4 |
+
<img src="https://arxiv.org/html/2407.04051v2/x1.png" alt="Description" width="35%" />
|
| 5 |
+
<p><em>Supervised Semantic Speech Tokenizer (S3Tokenizer)</em></p>
|
| 6 |
+
</div>
|
| 7 |
+
|
| 8 |
+
S3Tokenizer was initially introduced in CosyVoice [[Paper]](https://arxiv.org/abs/2407.04051v2) [[Repo]](https://github.com/FunAudioLLM/CosyVoice), it is a Supervised Semantic Speech Tokenizer based on the pre-trained SenseVoice-Large model, which enhances the semantic relationship of extracted tokens to textual and paralinguistic information, is robust to data noise, and reduces the reliance on clean data collection, thereby enabling the use of a broader range of data for model training.
|
| 9 |
+
|
| 10 |
+
However, as indicated in this [[issue]](https://github.com/FunAudioLLM/CosyVoice/issues/70), the authors have no intention to open-source the PyTorch implementation of the S3Tokenizer, and only plan to release an ONNX file. Additionally, users aiming to fine-tune CosyVoice must extract speech codes offline, with the batch size restricted to 1, a process that is notably time-consuming (refer to [[cosyvoice/tools/extract_speech_token.py]](https://github.com/FunAudioLLM/CosyVoice/blob/main/tools/extract_speech_token.py)).
|
| 11 |
+
|
| 12 |
+
This repository undertakes a reverse engineering of the S3Tokenizer, offering:
|
| 13 |
+
1. A pure PyTorch implementation of S3Tokenizer (see [[model.py]](https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/model.py)), compatible with initializing weights from the released ONNX file (see [[utils.py::onnx2torch()]](https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/utils.py)).
|
| 14 |
+
2. High-throughput (distributed) batch inference, achieving a ~790x speedup compared to the original inference pipeline in [[cosyvoice/tools/extract_speech_token.py]](https://github.com/FunAudioLLM/CosyVoice/blob/main/tools/extract_speech_token.py).
|
| 15 |
+
3. The capability to perform online speech code extraction during SpeechLLM training.
|
| 16 |
+
|
| 17 |
+
## Latest News 🎉
|
| 18 |
+
- [2025/07/07] S3Tokenizer now has built-in **long audio processing** capabilities, requiring no additional operations from users!
|
| 19 |
+
|
| 20 |
+
## Supported Models 🔥
|
| 21 |
+
- [x] Model: [S3Tokenizer V1 50hz](https://modelscope.cn/models/iic/CosyVoice-300M)
|
| 22 |
+
- [x] Model: [S3Tokenizer V1 25hz](https://modelscope.cn/models/iic/CosyVoice-300M-25Hz)
|
| 23 |
+
- [x] Model: [S3Tokenizer V2 25hz](https://modelscope.cn/models/iic/CosyVoice2-0.5B)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Setup
|
| 27 |
+
|
| 28 |
+
```sh
|
| 29 |
+
pip install s3tokenizer
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
# Usage-1: Offline batch inference
|
| 33 |
+
|
| 34 |
+
```py
|
| 35 |
+
import s3tokenizer
|
| 36 |
+
|
| 37 |
+
tokenizer = s3tokenizer.load_model("speech_tokenizer_v1").cuda() # or "speech_tokenizer_v1_25hz speech_tokenizer_v2_25hz"
|
| 38 |
+
|
| 39 |
+
mels = []
|
| 40 |
+
wav_paths = ["s3tokenizer/assets/BAC009S0764W0121.wav", "s3tokenizer/assets/BAC009S0764W0122.wav"]
|
| 41 |
+
for wav_path in wav_paths:
|
| 42 |
+
audio = s3tokenizer.load_audio(wav_path)
|
| 43 |
+
mels.append(s3tokenizer.log_mel_spectrogram(audio))
|
| 44 |
+
mels, mels_lens = s3tokenizer.padding(mels)
|
| 45 |
+
codes, codes_lens = tokenizer.quantize(mels.cuda(), mels_lens.cuda()) # Automatically handles long audio internally!
|
| 46 |
+
|
| 47 |
+
for i in range(len(wav_paths)):
|
| 48 |
+
print(codes[i, :codes_lens[i].item()])
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
# Usage-2: Distributed offline batch inference via command-line tools
|
| 52 |
+
|
| 53 |
+
## 2.1 CPU batch inference
|
| 54 |
+
|
| 55 |
+
```sh
|
| 56 |
+
s3tokenizer --wav_scp xxx.scp \
|
| 57 |
+
--device "cpu" \
|
| 58 |
+
--output_dir "./" \
|
| 59 |
+
--batch_size 32 \
|
| 60 |
+
--model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz speech_tokenizer_v2_25hz"
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
https://github.com/user-attachments/assets/d37d10fd-0e13-46a3-86b0-4cbec309086f
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
## 2.2 (Multi) GPU batch inference (a.k.a Distributed inference)
|
| 70 |
+
|
| 71 |
+
```sh
|
| 72 |
+
torchrun --nproc_per_node=8 --nnodes=1 \
|
| 73 |
+
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
| 74 |
+
`which s3tokenizer` --wav_scp xxx.scp \
|
| 75 |
+
--device "cuda" \
|
| 76 |
+
--output_dir "./" \
|
| 77 |
+
--batch_size 32 \
|
| 78 |
+
--model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz speech_tokenizer_v2_25hz"
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
https://github.com/user-attachments/assets/79a3fb11-7199-4ee2-8a35-9682a3b4d94a
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
## 2.3 Performance Benchmark
|
| 88 |
+
|
| 89 |
+
| Method | Time cost on Aishell Test Set | Relative speed up | Miss Rate |
|
| 90 |
+
|:------:|:----------:|:--------------:|:-----:|
|
| 91 |
+
| [[cosyvoice/tools/extract_speech_token.py]](https://github.com/FunAudioLLM/CosyVoice/blob/main/tools/extract_speech_token.py), cpu | 9 hours | ~ | ~ |
|
| 92 |
+
| cpu, batchsize 32 | 1.5h | ~6x | 0.00% |
|
| 93 |
+
| 4 gpus (3090), batchsize 32 per gpu | 41s | ~790x | 0.00% |
|
| 94 |
+
|
| 95 |
+
The miss rate represents the proportion of tokens that are inconsistent between the batch inference predictions and the ONNX (batch=1) inference predictions.
|
| 96 |
+
|
| 97 |
+
# Usage-3: Online speech code extraction
|
| 98 |
+
|
| 99 |
+
<table>
|
| 100 |
+
<tr>
|
| 101 |
+
<th>Before (extract code offline)</th>
|
| 102 |
+
<th>After (extract code online)</th>
|
| 103 |
+
</tr>
|
| 104 |
+
<tr>
|
| 105 |
+
<td>
|
| 106 |
+
<sub>
|
| 107 |
+
|
| 108 |
+
```py
|
| 109 |
+
|
| 110 |
+
class SpeechLLM(nn.Module):
|
| 111 |
+
...
|
| 112 |
+
def __init__(self, ...):
|
| 113 |
+
...
|
| 114 |
+
|
| 115 |
+
def forward(self, speech_codes: Tensor, text_ids: Tensor, ...):
|
| 116 |
+
...
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
</sub>
|
| 120 |
+
<td>
|
| 121 |
+
<sub>
|
| 122 |
+
|
| 123 |
+
```py
|
| 124 |
+
import s3tokenizer
|
| 125 |
+
|
| 126 |
+
class SpeechLLM(nn.Module):
|
| 127 |
+
...
|
| 128 |
+
def __init__(self, ...):
|
| 129 |
+
...
|
| 130 |
+
self.speech_tokenizer = s3tokenizer.load_model("speech_tokenizer_v1") # or "speech_tokenizer_v1_25hz"
|
| 131 |
+
self.speech_tokenizer.freeze()
|
| 132 |
+
|
| 133 |
+
def forward(self, speech: Tensor, speech_lens: Tensor, text_ids: Tensor, ...):
|
| 134 |
+
...
|
| 135 |
+
speech_codes, speech_codes_lens = self.speech_tokenizer.quantize(speech, speech_lens)
|
| 136 |
+
speech_codes = speech_codes.clone() # for backward compatbility
|
| 137 |
+
speech_codes_lens = speeech_codes_lens.clone() # for backward compatbility
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
</sub>
|
| 141 |
+
</td>
|
| 142 |
+
</tr>
|
| 143 |
+
</table>
|
| 144 |
+
|
| 145 |
+
# Usage-4: Long Audio Processing (Built-in Automatic Processing)
|
| 146 |
+
|
| 147 |
+
- **Automatic Detection**: Model automatically detects audio length (>30 seconds triggers long audio processing)
|
| 148 |
+
- **Sliding Window**: 30-second window with 4-second overlap, automatically segments long audio
|
| 149 |
+
- **Batch Processing**: Internal batch processing of multiple segments for improved efficiency
|
| 150 |
+
- **Complete Transparency**: User calling method is identical to short audio
|
speech/tools/S3Tokenizer/requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pre-commit
|
| 2 |
+
numpy
|
| 3 |
+
torch
|
| 4 |
+
onnx
|
| 5 |
+
tqdm
|
| 6 |
+
torchaudio
|
| 7 |
+
einops
|
speech/tools/S3Tokenizer/s3tokenizer/__init__.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 OpenAI. (authors: Whisper Team)
|
| 2 |
+
# 2024 Tsinghua Univ. (authors: Xingchen Song)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Modified from
|
| 16 |
+
https://github.com/openai/whisper/blob/main/whisper/__init__.py
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import hashlib
|
| 20 |
+
import os
|
| 21 |
+
import urllib
|
| 22 |
+
import warnings
|
| 23 |
+
from typing import List, Union
|
| 24 |
+
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
|
| 27 |
+
from s3tokenizer.model_v2 import S3TokenizerV2
|
| 28 |
+
|
| 29 |
+
from .model import S3Tokenizer
|
| 30 |
+
from .utils import (load_audio, log_mel_spectrogram, make_non_pad_mask,
|
| 31 |
+
mask_to_bias, onnx2torch, padding, merge_tokenized_segments)
|
| 32 |
+
|
| 33 |
+
__all__ = [
|
| 34 |
+
'load_audio', 'log_mel_spectrogram', 'make_non_pad_mask', 'mask_to_bias',
|
| 35 |
+
'onnx2torch', 'padding', 'merge_tokenized_segments'
|
| 36 |
+
]
|
| 37 |
+
_MODELS = {
|
| 38 |
+
"speech_tokenizer_v1":
|
| 39 |
+
"https://www.modelscope.cn/models/iic/cosyvoice-300m/"
|
| 40 |
+
"resolve/master/speech_tokenizer_v1.onnx",
|
| 41 |
+
"speech_tokenizer_v1_25hz":
|
| 42 |
+
"https://www.modelscope.cn/models/iic/CosyVoice-300M-25Hz/"
|
| 43 |
+
"resolve/master/speech_tokenizer_v1.onnx",
|
| 44 |
+
"speech_tokenizer_v2_25hz":
|
| 45 |
+
"https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/"
|
| 46 |
+
"resolve/master/speech_tokenizer_v2.onnx",
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
_SHA256S = {
|
| 50 |
+
"speech_tokenizer_v1":
|
| 51 |
+
"23b5a723ed9143aebfd9ffda14ac4c21231f31c35ef837b6a13bb9e5488abb1e",
|
| 52 |
+
"speech_tokenizer_v1_25hz":
|
| 53 |
+
"56285ddd4a83e883ee0cb9f8d69c1089b53a94b1f78ff7e4a0224a27eb4cb486",
|
| 54 |
+
"speech_tokenizer_v2_25hz":
|
| 55 |
+
"d43342aa12163a80bf07bffb94c9de2e120a8df2f9917cd2f642e7f4219c6f71",
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _download(name: str, root: str) -> Union[bytes, str]:
|
| 60 |
+
os.makedirs(root, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
expected_sha256 = _SHA256S[name]
|
| 63 |
+
url = _MODELS[name]
|
| 64 |
+
download_target = os.path.join(root, f"{name}.onnx")
|
| 65 |
+
|
| 66 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 67 |
+
raise RuntimeError(
|
| 68 |
+
f"{download_target} exists and is not a regular file")
|
| 69 |
+
|
| 70 |
+
if os.path.isfile(download_target):
|
| 71 |
+
with open(download_target, "rb") as f:
|
| 72 |
+
model_bytes = f.read()
|
| 73 |
+
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
| 74 |
+
return download_target
|
| 75 |
+
else:
|
| 76 |
+
warnings.warn(
|
| 77 |
+
f"{download_target} exists, but the SHA256 checksum does not"
|
| 78 |
+
" match; re-downloading the file")
|
| 79 |
+
|
| 80 |
+
with urllib.request.urlopen(url) as source, open(download_target,
|
| 81 |
+
"wb") as output:
|
| 82 |
+
with tqdm(
|
| 83 |
+
total=int(source.info().get("Content-Length")),
|
| 84 |
+
ncols=80,
|
| 85 |
+
unit="iB",
|
| 86 |
+
unit_scale=True,
|
| 87 |
+
unit_divisor=1024,
|
| 88 |
+
desc="Downloading onnx checkpoint",
|
| 89 |
+
) as loop:
|
| 90 |
+
while True:
|
| 91 |
+
buffer = source.read(8192)
|
| 92 |
+
if not buffer:
|
| 93 |
+
break
|
| 94 |
+
|
| 95 |
+
output.write(buffer)
|
| 96 |
+
loop.update(len(buffer))
|
| 97 |
+
|
| 98 |
+
model_bytes = open(download_target, "rb").read()
|
| 99 |
+
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
| 100 |
+
raise RuntimeError(
|
| 101 |
+
"Model has been downloaded but the SHA256 checksum does not not"
|
| 102 |
+
" match. Please retry loading the model.")
|
| 103 |
+
|
| 104 |
+
return download_target
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def available_models() -> List[str]:
|
| 108 |
+
"""Returns the names of available models"""
|
| 109 |
+
return list(_MODELS.keys())
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def load_model(
|
| 113 |
+
name: str,
|
| 114 |
+
download_root: str = None,
|
| 115 |
+
) -> S3Tokenizer:
|
| 116 |
+
"""
|
| 117 |
+
Load a S3Tokenizer ASR model
|
| 118 |
+
|
| 119 |
+
Parameters
|
| 120 |
+
----------
|
| 121 |
+
name : str
|
| 122 |
+
one of the official model names listed by
|
| 123 |
+
`s3tokenizer.available_models()`, or path to a model checkpoint
|
| 124 |
+
containing the model dimensions and the model state_dict.
|
| 125 |
+
download_root: str
|
| 126 |
+
path to download the model files; by default,
|
| 127 |
+
it uses "~/.cache/s3tokenizer"
|
| 128 |
+
|
| 129 |
+
Returns
|
| 130 |
+
-------
|
| 131 |
+
model : S3Tokenizer
|
| 132 |
+
The S3Tokenizer model instance
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
if download_root is None:
|
| 136 |
+
default = os.path.join(os.path.expanduser("~"), ".cache")
|
| 137 |
+
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default),
|
| 138 |
+
"s3tokenizer")
|
| 139 |
+
|
| 140 |
+
if name in _MODELS:
|
| 141 |
+
checkpoint_file = _download(name, download_root)
|
| 142 |
+
elif os.path.isfile(name):
|
| 143 |
+
checkpoint_file = name
|
| 144 |
+
else:
|
| 145 |
+
raise RuntimeError(
|
| 146 |
+
f"Model {name} not found; available models = {available_models()}")
|
| 147 |
+
if 'v2' in name:
|
| 148 |
+
model = S3TokenizerV2(name)
|
| 149 |
+
else:
|
| 150 |
+
model = S3Tokenizer(name)
|
| 151 |
+
model.init_from_onnx(checkpoint_file)
|
| 152 |
+
|
| 153 |
+
return model
|
speech/tools/S3Tokenizer/s3tokenizer/assets/mel_filters.npz
ADDED
|
Binary file (4.27 kB). View file
|
|
|
speech/tools/S3Tokenizer/s3tokenizer/cli.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
""" Example Usage
|
| 15 |
+
cpu:
|
| 16 |
+
|
| 17 |
+
s3tokenizer --root_path /path/to/audio/files \
|
| 18 |
+
--model speech_tokenizer_v2_25hz \
|
| 19 |
+
--device "cpu" \
|
| 20 |
+
--batch_size 32
|
| 21 |
+
|
| 22 |
+
gpu:
|
| 23 |
+
|
| 24 |
+
torchrun --nproc_per_node=1 --nnodes=1 \
|
| 25 |
+
--rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
| 26 |
+
`which s3tokenizer` --root_path /data/dataset \
|
| 27 |
+
--model speech_tokenizer_v2_25hz \
|
| 28 |
+
--device "cuda" \
|
| 29 |
+
--batch_size 64
|
| 30 |
+
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
import argparse
|
| 34 |
+
import os
|
| 35 |
+
from pathlib import Path
|
| 36 |
+
|
| 37 |
+
import torch
|
| 38 |
+
import torch.distributed as dist
|
| 39 |
+
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
| 40 |
+
from tqdm import tqdm
|
| 41 |
+
|
| 42 |
+
import s3tokenizer
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class AudioDataset(Dataset):
|
| 46 |
+
|
| 47 |
+
def __init__(self, root_path, extensions=['.wav', '.flac', '.mp3']):
|
| 48 |
+
self.data = []
|
| 49 |
+
|
| 50 |
+
# Recursively find all audio files
|
| 51 |
+
root = Path(root_path)
|
| 52 |
+
for ext in extensions:
|
| 53 |
+
self.data.extend(root.rglob(f'*{ext}'))
|
| 54 |
+
|
| 55 |
+
# Sort for consistent ordering
|
| 56 |
+
self.data.sort()
|
| 57 |
+
|
| 58 |
+
if len(self.data) == 0:
|
| 59 |
+
raise ValueError(f"No audio files found in {root_path}")
|
| 60 |
+
|
| 61 |
+
print(f"Found {len(self.data)} audio files")
|
| 62 |
+
|
| 63 |
+
def __len__(self):
|
| 64 |
+
return len(self.data)
|
| 65 |
+
|
| 66 |
+
def __getitem__(self, idx):
|
| 67 |
+
file_path = self.data[idx]
|
| 68 |
+
audio = s3tokenizer.load_audio(str(file_path))
|
| 69 |
+
mel = s3tokenizer.log_mel_spectrogram(audio)
|
| 70 |
+
return file_path, mel
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def collate_fn(batch):
|
| 74 |
+
file_paths = [item[0] for item in batch]
|
| 75 |
+
mels = [item[1] for item in batch]
|
| 76 |
+
mels, mels_lens = s3tokenizer.padding(mels)
|
| 77 |
+
return file_paths, mels, mels_lens
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def init_distributed():
|
| 81 |
+
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
| 82 |
+
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
| 83 |
+
rank = int(os.environ.get('RANK', 0))
|
| 84 |
+
print('Inference on multiple gpus, this gpu {}'.format(local_rank) +
|
| 85 |
+
', rank {}, world_size {}'.format(rank, world_size))
|
| 86 |
+
torch.cuda.set_device(local_rank)
|
| 87 |
+
dist.init_process_group("nccl")
|
| 88 |
+
return world_size, local_rank, rank
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_args():
|
| 92 |
+
parser = argparse.ArgumentParser(description='extract speech code')
|
| 93 |
+
parser.add_argument('--model',
|
| 94 |
+
required=True,
|
| 95 |
+
type=str,
|
| 96 |
+
choices=[
|
| 97 |
+
"speech_tokenizer_v1", "speech_tokenizer_v1_25hz",
|
| 98 |
+
"speech_tokenizer_v2_25hz"
|
| 99 |
+
],
|
| 100 |
+
help='model version')
|
| 101 |
+
parser.add_argument('--root_path',
|
| 102 |
+
required=True,
|
| 103 |
+
type=str,
|
| 104 |
+
help='root directory containing audio files')
|
| 105 |
+
parser.add_argument('--device',
|
| 106 |
+
required=True,
|
| 107 |
+
type=str,
|
| 108 |
+
choices=["cuda", "cpu"],
|
| 109 |
+
help='device for inference')
|
| 110 |
+
parser.add_argument('--batch_size',
|
| 111 |
+
required=True,
|
| 112 |
+
type=int,
|
| 113 |
+
help='batch size (per-device) for inference')
|
| 114 |
+
parser.add_argument('--num_workers',
|
| 115 |
+
type=int,
|
| 116 |
+
default=4,
|
| 117 |
+
help='workers for dataloader')
|
| 118 |
+
parser.add_argument('--prefetch',
|
| 119 |
+
type=int,
|
| 120 |
+
default=5,
|
| 121 |
+
help='prefetch for dataloader')
|
| 122 |
+
parser.add_argument('--extensions',
|
| 123 |
+
nargs='+',
|
| 124 |
+
default=['.wav', '.flac', '.mp3'],
|
| 125 |
+
help='audio file extensions to process')
|
| 126 |
+
args = parser.parse_args()
|
| 127 |
+
return args
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def save_tokens(file_path, codes, codes_len):
|
| 131 |
+
"""Save tokens as .pt file with _fsq suffix"""
|
| 132 |
+
# Remove extension and add _fsq.pt
|
| 133 |
+
output_path = file_path.with_suffix('').with_suffix('.pt')
|
| 134 |
+
output_path = output_path.parent / f"{output_path.stem}_fsq.pt"
|
| 135 |
+
|
| 136 |
+
# Extract only valid codes (up to codes_len)
|
| 137 |
+
valid_codes = codes[:codes_len]
|
| 138 |
+
# convert valid codes to list
|
| 139 |
+
valid_codes = valid_codes.tolist()
|
| 140 |
+
|
| 141 |
+
# Save as tensor
|
| 142 |
+
torch.save(valid_codes, output_path)
|
| 143 |
+
|
| 144 |
+
return output_path
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def main():
|
| 148 |
+
args = get_args()
|
| 149 |
+
|
| 150 |
+
if args.device == "cuda":
|
| 151 |
+
assert (torch.cuda.is_available())
|
| 152 |
+
world_size, local_rank, rank = init_distributed()
|
| 153 |
+
else:
|
| 154 |
+
world_size, local_rank, rank = 1, 0, 0
|
| 155 |
+
|
| 156 |
+
device = torch.device(args.device)
|
| 157 |
+
model = s3tokenizer.load_model(args.model).to(device)
|
| 158 |
+
dataset = AudioDataset(args.root_path, args.extensions)
|
| 159 |
+
|
| 160 |
+
if args.device == "cuda":
|
| 161 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
| 162 |
+
model, device_ids=[local_rank])
|
| 163 |
+
sampler = DistributedSampler(dataset,
|
| 164 |
+
num_replicas=world_size,
|
| 165 |
+
rank=rank)
|
| 166 |
+
else:
|
| 167 |
+
sampler = None
|
| 168 |
+
|
| 169 |
+
dataloader = DataLoader(dataset,
|
| 170 |
+
batch_size=args.batch_size,
|
| 171 |
+
sampler=sampler,
|
| 172 |
+
shuffle=False,
|
| 173 |
+
num_workers=args.num_workers,
|
| 174 |
+
prefetch_factor=args.prefetch,
|
| 175 |
+
collate_fn=collate_fn)
|
| 176 |
+
|
| 177 |
+
total_steps = len(dataset)
|
| 178 |
+
|
| 179 |
+
if rank == 0:
|
| 180 |
+
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
| 181 |
+
|
| 182 |
+
processed_count = 0
|
| 183 |
+
for file_paths, mels, mels_lens in dataloader:
|
| 184 |
+
codes, codes_lens = model(mels.to(device), mels_lens.to(device))
|
| 185 |
+
|
| 186 |
+
# Process each file in the batch
|
| 187 |
+
for i, file_path in enumerate(file_paths):
|
| 188 |
+
code = codes[i]
|
| 189 |
+
code_len = codes_lens[i].item()
|
| 190 |
+
|
| 191 |
+
# Save tokens as .pt file
|
| 192 |
+
output_path = save_tokens(file_path, code, code_len)
|
| 193 |
+
|
| 194 |
+
if rank == 0:
|
| 195 |
+
tqdm.write(f"Saved: {file_path} -> {output_path}")
|
| 196 |
+
|
| 197 |
+
processed_count += len(file_paths)
|
| 198 |
+
|
| 199 |
+
if rank == 0:
|
| 200 |
+
progress_bar.update(world_size * len(file_paths))
|
| 201 |
+
|
| 202 |
+
if rank == 0:
|
| 203 |
+
progress_bar.close()
|
| 204 |
+
print(f"\nProcessed {processed_count} files on rank {rank}")
|
| 205 |
+
|
| 206 |
+
if args.device == "cuda":
|
| 207 |
+
dist.barrier()
|
| 208 |
+
dist.destroy_process_group()
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
if __name__ == "__main__":
|
| 212 |
+
main()
|
speech/tools/S3Tokenizer/s3tokenizer/model.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 OpenAI. (authors: Whisper Team)
|
| 2 |
+
# 2024 Tsinghua Univ. (authors: Xingchen Song)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Modified from https://github.com/openai/whisper/blob/main/whisper/model.py
|
| 16 |
+
Add EuclideanCodebook & VectorQuantization
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Iterable, Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
from einops import rearrange
|
| 26 |
+
from torch import Tensor, nn
|
| 27 |
+
|
| 28 |
+
from .utils import make_non_pad_mask, mask_to_bias, onnx2torch, merge_tokenized_segments
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class ModelConfig:
|
| 33 |
+
n_mels: int = 128
|
| 34 |
+
n_audio_ctx: int = 1500
|
| 35 |
+
n_audio_state: int = 1280
|
| 36 |
+
n_audio_head: int = 20
|
| 37 |
+
n_audio_layer: int = 6
|
| 38 |
+
n_codebook_size: int = 4096
|
| 39 |
+
|
| 40 |
+
use_sdpa: bool = False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class LayerNorm(nn.LayerNorm):
|
| 44 |
+
|
| 45 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 46 |
+
return super().forward(x.float()).type(x.dtype)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Linear(nn.Linear):
|
| 50 |
+
|
| 51 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 52 |
+
return F.linear(
|
| 53 |
+
x,
|
| 54 |
+
self.weight.to(x.dtype),
|
| 55 |
+
None if self.bias is None else self.bias.to(x.dtype),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class Conv1d(nn.Conv1d):
|
| 60 |
+
|
| 61 |
+
def _conv_forward(self, x: Tensor, weight: Tensor,
|
| 62 |
+
bias: Optional[Tensor]) -> Tensor:
|
| 63 |
+
return super()._conv_forward(
|
| 64 |
+
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def sinusoids(length, channels, max_timescale=10000):
|
| 68 |
+
"""Returns sinusoids for positional embedding"""
|
| 69 |
+
assert channels % 2 == 0
|
| 70 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
| 71 |
+
inv_timescales = torch.exp(-log_timescale_increment *
|
| 72 |
+
torch.arange(channels // 2))
|
| 73 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[
|
| 74 |
+
np.newaxis, :]
|
| 75 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class MultiHeadAttention(nn.Module):
|
| 79 |
+
|
| 80 |
+
def __init__(self, n_state: int, n_head: int, use_sdpa: bool = False):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.n_head = n_head
|
| 83 |
+
self.query = Linear(n_state, n_state)
|
| 84 |
+
self.key = Linear(n_state, n_state, bias=False)
|
| 85 |
+
self.value = Linear(n_state, n_state)
|
| 86 |
+
self.out = Linear(n_state, n_state)
|
| 87 |
+
|
| 88 |
+
self.use_sdpa = use_sdpa
|
| 89 |
+
|
| 90 |
+
def forward(
|
| 91 |
+
self,
|
| 92 |
+
x: Tensor,
|
| 93 |
+
mask: Optional[Tensor] = None,
|
| 94 |
+
):
|
| 95 |
+
q = self.query(x)
|
| 96 |
+
k = self.key(x)
|
| 97 |
+
v = self.value(x)
|
| 98 |
+
|
| 99 |
+
wv, qk = self.qkv_attention(q, k, v, mask)
|
| 100 |
+
return self.out(wv), qk
|
| 101 |
+
|
| 102 |
+
def qkv_attention(self,
|
| 103 |
+
q: Tensor,
|
| 104 |
+
k: Tensor,
|
| 105 |
+
v: Tensor,
|
| 106 |
+
mask: Optional[Tensor] = None):
|
| 107 |
+
_, _, D = q.shape
|
| 108 |
+
scale = (D // self.n_head)**-0.25
|
| 109 |
+
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
| 110 |
+
k = k.view(*k.shape[:2], self.n_head, -1)
|
| 111 |
+
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
| 112 |
+
|
| 113 |
+
if not self.use_sdpa:
|
| 114 |
+
k = k.permute(0, 2, 3, 1) * scale
|
| 115 |
+
qk = q @ k # (B, n_head, T, T)
|
| 116 |
+
if mask is not None:
|
| 117 |
+
qk = qk + mask
|
| 118 |
+
qk = qk.float()
|
| 119 |
+
w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
|
| 120 |
+
return (w @ v).permute(0, 2, 1,
|
| 121 |
+
3).flatten(start_dim=2), qk.detach()
|
| 122 |
+
else:
|
| 123 |
+
k = k.permute(0, 2, 1, 3) * scale
|
| 124 |
+
assert mask is not None
|
| 125 |
+
output = torch.nn.functional.scaled_dot_product_attention(
|
| 126 |
+
q,
|
| 127 |
+
k,
|
| 128 |
+
v,
|
| 129 |
+
attn_mask=mask,
|
| 130 |
+
dropout_p=0.,
|
| 131 |
+
scale=1.,
|
| 132 |
+
)
|
| 133 |
+
output = (output.transpose(1,
|
| 134 |
+
2).contiguous().view(q.size(0), -1, D)
|
| 135 |
+
) # (batch, time1, d_model)
|
| 136 |
+
return output, None
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class ResidualAttentionBlock(nn.Module):
|
| 140 |
+
|
| 141 |
+
def __init__(self, n_state: int, n_head: int, use_sdpa: bool):
|
| 142 |
+
super().__init__()
|
| 143 |
+
|
| 144 |
+
self.attn = MultiHeadAttention(n_state, n_head, use_sdpa=use_sdpa)
|
| 145 |
+
self.attn_ln = LayerNorm(n_state)
|
| 146 |
+
|
| 147 |
+
n_mlp = n_state * 4
|
| 148 |
+
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(),
|
| 149 |
+
Linear(n_mlp, n_state))
|
| 150 |
+
self.mlp_ln = LayerNorm(n_state)
|
| 151 |
+
|
| 152 |
+
def forward(
|
| 153 |
+
self,
|
| 154 |
+
x: Tensor,
|
| 155 |
+
mask: Optional[Tensor] = None,
|
| 156 |
+
):
|
| 157 |
+
x = x + self.attn(self.attn_ln(x), mask=mask)[0]
|
| 158 |
+
x = x + self.mlp(self.mlp_ln(x))
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class AudioEncoder(nn.Module):
|
| 163 |
+
|
| 164 |
+
def __init__(
|
| 165 |
+
self,
|
| 166 |
+
n_mels: int,
|
| 167 |
+
n_ctx: int,
|
| 168 |
+
n_state: int,
|
| 169 |
+
n_head: int,
|
| 170 |
+
n_layer: int,
|
| 171 |
+
stride: int,
|
| 172 |
+
use_sdpa: bool,
|
| 173 |
+
):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.stride = stride
|
| 176 |
+
self.conv1 = Conv1d(n_mels,
|
| 177 |
+
n_state,
|
| 178 |
+
kernel_size=3,
|
| 179 |
+
stride=stride,
|
| 180 |
+
padding=1)
|
| 181 |
+
self.conv2 = Conv1d(n_state,
|
| 182 |
+
n_state,
|
| 183 |
+
kernel_size=3,
|
| 184 |
+
stride=2,
|
| 185 |
+
padding=1)
|
| 186 |
+
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
| 187 |
+
|
| 188 |
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([
|
| 189 |
+
ResidualAttentionBlock(n_state, n_head, use_sdpa=use_sdpa)
|
| 190 |
+
for _ in range(n_layer)
|
| 191 |
+
])
|
| 192 |
+
|
| 193 |
+
def forward(self, x: Tensor, x_len: Tensor) -> Tuple[Tensor, Tensor]:
|
| 194 |
+
"""
|
| 195 |
+
x : torch.Tensor, shape = (batch_size, n_mels, T)
|
| 196 |
+
the mel spectrogram of the audio
|
| 197 |
+
x_len: torch.Tensor, shape = (batch_size,)
|
| 198 |
+
length of each audio in x
|
| 199 |
+
"""
|
| 200 |
+
mask = make_non_pad_mask(x_len).unsqueeze(1)
|
| 201 |
+
x = F.gelu(self.conv1(x * mask))
|
| 202 |
+
x_len = (x_len + 2 - 1 * (3 - 1) - 1) // self.stride + 1
|
| 203 |
+
mask = make_non_pad_mask(x_len).unsqueeze(1)
|
| 204 |
+
x = F.gelu(self.conv2(x * mask))
|
| 205 |
+
x_len = (x_len + 2 - 1 * (3 - 1) - 1) // 2 + 1
|
| 206 |
+
mask = make_non_pad_mask(x_len).unsqueeze(1)
|
| 207 |
+
x = x.permute(0, 2, 1) # (B, T // 2, n_state)
|
| 208 |
+
|
| 209 |
+
mask = mask_to_bias(mask, x.dtype)
|
| 210 |
+
|
| 211 |
+
x = (x + self.positional_embedding[:x.shape[1], :]).to(x.dtype)
|
| 212 |
+
|
| 213 |
+
for block in self.blocks:
|
| 214 |
+
x = block(x, mask.unsqueeze(1))
|
| 215 |
+
|
| 216 |
+
return x, x_len
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class EuclideanCodebook(nn.Module):
|
| 220 |
+
"""Codebook with Euclidean distance (inference-only).
|
| 221 |
+
Args:
|
| 222 |
+
dim (int): Dimension.
|
| 223 |
+
codebook_size (int): Codebook size.
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
def __init__(self, dim: int, codebook_size: int):
|
| 227 |
+
super().__init__()
|
| 228 |
+
embed = torch.zeros(codebook_size, dim)
|
| 229 |
+
self.codebook_size = codebook_size
|
| 230 |
+
self.register_buffer("embed", embed)
|
| 231 |
+
|
| 232 |
+
@torch.inference_mode()
|
| 233 |
+
def preprocess(self, x: Tensor) -> Tensor:
|
| 234 |
+
x = rearrange(x, "... d -> (...) d")
|
| 235 |
+
return x
|
| 236 |
+
|
| 237 |
+
@torch.inference_mode()
|
| 238 |
+
def quantize(self, x: Tensor) -> Tensor:
|
| 239 |
+
embed = self.embed.t().to(x.dtype)
|
| 240 |
+
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed +
|
| 241 |
+
embed.pow(2).sum(0, keepdim=True))
|
| 242 |
+
embed_ind = dist.max(dim=-1).indices
|
| 243 |
+
return embed_ind
|
| 244 |
+
|
| 245 |
+
@torch.inference_mode()
|
| 246 |
+
def postprocess_emb(self, embed_ind, shape):
|
| 247 |
+
return embed_ind.view(*shape[:-1])
|
| 248 |
+
|
| 249 |
+
@torch.inference_mode()
|
| 250 |
+
def dequantize(self, embed_ind: Tensor) -> Tensor:
|
| 251 |
+
quantize = F.embedding(embed_ind, self.embed)
|
| 252 |
+
return quantize
|
| 253 |
+
|
| 254 |
+
@torch.inference_mode()
|
| 255 |
+
def encode(self, x: Tensor) -> Tensor:
|
| 256 |
+
shape = x.shape
|
| 257 |
+
# pre-process
|
| 258 |
+
x = self.preprocess(x)
|
| 259 |
+
# quantize
|
| 260 |
+
embed_ind = self.quantize(x)
|
| 261 |
+
# post-process
|
| 262 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
| 263 |
+
return embed_ind
|
| 264 |
+
|
| 265 |
+
@torch.inference_mode()
|
| 266 |
+
def decode(self, embed_ind: Tensor) -> Tensor:
|
| 267 |
+
quantize = self.dequantize(embed_ind)
|
| 268 |
+
return quantize
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class VectorQuantization(nn.Module):
|
| 272 |
+
"""Vector quantization implementation (inference-only).
|
| 273 |
+
Args:
|
| 274 |
+
dim (int): Dimension
|
| 275 |
+
codebook_size (int): Codebook size
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
def __init__(self, dim: int, codebook_size: int):
|
| 279 |
+
super().__init__()
|
| 280 |
+
self._codebook = EuclideanCodebook(dim=dim,
|
| 281 |
+
codebook_size=codebook_size)
|
| 282 |
+
self.codebook_size = codebook_size
|
| 283 |
+
|
| 284 |
+
@property
|
| 285 |
+
def codebook(self):
|
| 286 |
+
return self._codebook.embed
|
| 287 |
+
|
| 288 |
+
@torch.inference_mode()
|
| 289 |
+
def encode(self, x: Tensor) -> Tensor:
|
| 290 |
+
x = F.normalize(x.float(), p=2, dim=-1)
|
| 291 |
+
embed_in = self._codebook.encode(x)
|
| 292 |
+
return embed_in
|
| 293 |
+
|
| 294 |
+
@torch.inference_mode()
|
| 295 |
+
def decode(self, embed_ind: Tensor) -> Tensor:
|
| 296 |
+
quantize = self._codebook.decode(embed_ind)
|
| 297 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
| 298 |
+
return quantize
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class S3Tokenizer(nn.Module):
|
| 302 |
+
"""S3 tokenizer implementation (inference-only).
|
| 303 |
+
Args:
|
| 304 |
+
config (ModelConfig): Config
|
| 305 |
+
"""
|
| 306 |
+
|
| 307 |
+
def __init__(self, name: str, config: ModelConfig = ModelConfig()):
|
| 308 |
+
super().__init__()
|
| 309 |
+
self.name = name # Store model name for token_rate determination
|
| 310 |
+
self.config = config
|
| 311 |
+
self.encoder = AudioEncoder(
|
| 312 |
+
self.config.n_mels,
|
| 313 |
+
self.config.n_audio_ctx,
|
| 314 |
+
self.config.n_audio_state,
|
| 315 |
+
self.config.n_audio_head,
|
| 316 |
+
self.config.n_audio_layer,
|
| 317 |
+
2 if name == "speech_tokenizer_v1_25hz" else 1,
|
| 318 |
+
self.config.use_sdpa,
|
| 319 |
+
)
|
| 320 |
+
self.quantizer = VectorQuantization(self.config.n_audio_state,
|
| 321 |
+
self.config.n_codebook_size)
|
| 322 |
+
|
| 323 |
+
def forward(self, mel: Tensor, mel_len: Tensor) -> Tuple[Tensor, Tensor]:
|
| 324 |
+
return self.quantize(mel, mel_len)
|
| 325 |
+
|
| 326 |
+
@torch.inference_mode()
|
| 327 |
+
def quantize(self, mel: Tensor, mel_len: Tensor) -> Tuple[Tensor, Tensor]:
|
| 328 |
+
"""
|
| 329 |
+
Quantize mel spectrogram to tokens, with automatic long audio handling.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
|
| 333 |
+
mel_len: mel length tensor, shape (batch_size,)
|
| 334 |
+
|
| 335 |
+
Returns:
|
| 336 |
+
code: quantized tokens, shape (batch_size, T')
|
| 337 |
+
code_len: token length, shape (batch_size,)
|
| 338 |
+
"""
|
| 339 |
+
# Check if any audio in the batch exceeds 30 seconds
|
| 340 |
+
# Assuming 16kHz sample rate and hop_length=160, 30s = 30*16000/160 = 3000 frames
|
| 341 |
+
max_frames = 3000
|
| 342 |
+
|
| 343 |
+
# Check which samples are long audio
|
| 344 |
+
long_audio_mask = mel_len > max_frames
|
| 345 |
+
|
| 346 |
+
if long_audio_mask.any():
|
| 347 |
+
# Has long audio - need special processing
|
| 348 |
+
return self._quantize_mixed_batch(mel, mel_len, long_audio_mask,
|
| 349 |
+
max_frames)
|
| 350 |
+
else:
|
| 351 |
+
# All short audio - use original method
|
| 352 |
+
hidden, code_len = self.encoder(mel, mel_len)
|
| 353 |
+
code = self.quantizer.encode(hidden)
|
| 354 |
+
return code, code_len
|
| 355 |
+
|
| 356 |
+
@torch.inference_mode()
|
| 357 |
+
def _quantize_mixed_batch(self, mel: Tensor, mel_len: Tensor,
|
| 358 |
+
long_audio_mask: Tensor,
|
| 359 |
+
max_frames: int) -> Tuple[Tensor, Tensor]:
|
| 360 |
+
"""
|
| 361 |
+
Handle mixed batch with both short and long audio using unified batch processing.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
|
| 365 |
+
mel_len: mel length tensor, shape (batch_size,)
|
| 366 |
+
long_audio_mask: boolean mask for long audio, shape (batch_size,)
|
| 367 |
+
max_frames: maximum frames for short audio
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
code: quantized tokens, shape (batch_size, T')
|
| 371 |
+
code_len: token length, shape (batch_size,)
|
| 372 |
+
"""
|
| 373 |
+
batch_size = mel.size(0)
|
| 374 |
+
|
| 375 |
+
# Parameters for sliding window
|
| 376 |
+
sample_rate = 16000
|
| 377 |
+
hop_length = 160 # Default hop length for mel spectrogram
|
| 378 |
+
window_size = 30 # seconds
|
| 379 |
+
overlap = 4 # seconds
|
| 380 |
+
|
| 381 |
+
# Calculate frame-based parameters
|
| 382 |
+
frames_per_window = window_size * sample_rate // hop_length # 3000 frames
|
| 383 |
+
frames_per_overlap = overlap * sample_rate // hop_length # 400 frames
|
| 384 |
+
frames_per_stride = frames_per_window - frames_per_overlap # 2600 frames
|
| 385 |
+
|
| 386 |
+
# Collect all segments to process (including short and long audio segments)
|
| 387 |
+
all_segments = []
|
| 388 |
+
all_segments_len = []
|
| 389 |
+
segment_info = [
|
| 390 |
+
] # Record which audio each segment belongs to and whether it's long audio
|
| 391 |
+
|
| 392 |
+
# Process all audio in the batch
|
| 393 |
+
for batch_idx in range(batch_size):
|
| 394 |
+
audio_mel = mel[batch_idx]
|
| 395 |
+
audio_mel_len = mel_len[batch_idx]
|
| 396 |
+
is_long_audio = long_audio_mask[batch_idx].item()
|
| 397 |
+
|
| 398 |
+
if not is_long_audio:
|
| 399 |
+
# Short audio: process directly as a single segment
|
| 400 |
+
segment = audio_mel[:, :audio_mel_len]
|
| 401 |
+
seg_len = audio_mel_len.item()
|
| 402 |
+
|
| 403 |
+
# Pad to max_frames if necessary
|
| 404 |
+
if seg_len < frames_per_window:
|
| 405 |
+
pad_size = frames_per_window - seg_len
|
| 406 |
+
segment = F.pad(segment, (0, pad_size))
|
| 407 |
+
|
| 408 |
+
all_segments.append(segment)
|
| 409 |
+
all_segments_len.append(
|
| 410 |
+
torch.tensor(seg_len, device=mel.device))
|
| 411 |
+
segment_info.append({
|
| 412 |
+
'batch_idx': batch_idx,
|
| 413 |
+
'is_long_audio': False,
|
| 414 |
+
'segment_idx': 0,
|
| 415 |
+
'total_segments': 1
|
| 416 |
+
})
|
| 417 |
+
else:
|
| 418 |
+
# Long audio: split into multiple segments
|
| 419 |
+
start = 0
|
| 420 |
+
segment_idx = 0
|
| 421 |
+
while start < audio_mel_len:
|
| 422 |
+
end = min(start + frames_per_window, audio_mel_len)
|
| 423 |
+
segment = audio_mel[:, start:end]
|
| 424 |
+
|
| 425 |
+
seg_len = segment.size(1)
|
| 426 |
+
# Pad if necessary
|
| 427 |
+
if seg_len < frames_per_window:
|
| 428 |
+
pad_size = frames_per_window - seg_len
|
| 429 |
+
segment = F.pad(segment, (0, pad_size))
|
| 430 |
+
|
| 431 |
+
all_segments.append(segment)
|
| 432 |
+
all_segments_len.append(
|
| 433 |
+
torch.tensor(seg_len, device=mel.device))
|
| 434 |
+
segment_info.append({
|
| 435 |
+
'batch_idx': batch_idx,
|
| 436 |
+
'is_long_audio': True,
|
| 437 |
+
'segment_idx': segment_idx,
|
| 438 |
+
'total_segments': None # Will be filled later
|
| 439 |
+
})
|
| 440 |
+
|
| 441 |
+
segment_idx += 1
|
| 442 |
+
start += frames_per_stride
|
| 443 |
+
|
| 444 |
+
# Update total_segments info
|
| 445 |
+
total_segments = segment_idx
|
| 446 |
+
for info in segment_info:
|
| 447 |
+
if info['batch_idx'] == batch_idx and info['is_long_audio']:
|
| 448 |
+
info['total_segments'] = total_segments
|
| 449 |
+
|
| 450 |
+
if not all_segments:
|
| 451 |
+
# Fallback if no segments
|
| 452 |
+
return torch.zeros(batch_size,
|
| 453 |
+
0,
|
| 454 |
+
dtype=torch.long,
|
| 455 |
+
device=mel.device), torch.zeros(
|
| 456 |
+
batch_size,
|
| 457 |
+
dtype=torch.long,
|
| 458 |
+
device=mel.device)
|
| 459 |
+
|
| 460 |
+
# Unified batch processing for all segments
|
| 461 |
+
unified_batch_mel = torch.stack(all_segments)
|
| 462 |
+
unified_batch_lens = torch.stack(all_segments_len)
|
| 463 |
+
|
| 464 |
+
# Process all segments at once
|
| 465 |
+
hidden, code_len = self.encoder(unified_batch_mel, unified_batch_lens)
|
| 466 |
+
codes = self.quantizer.encode(hidden)
|
| 467 |
+
|
| 468 |
+
# Reorganize results based on segment_info
|
| 469 |
+
results = {} # batch_idx -> (code_tensor, code_len)
|
| 470 |
+
|
| 471 |
+
for seg_idx, info in enumerate(segment_info):
|
| 472 |
+
batch_idx = info['batch_idx']
|
| 473 |
+
is_long_audio = info['is_long_audio']
|
| 474 |
+
segment_idx = info['segment_idx']
|
| 475 |
+
|
| 476 |
+
# Get codes for current segment
|
| 477 |
+
segment_code = codes[
|
| 478 |
+
seg_idx, :code_len[seg_idx].item()].cpu().numpy().tolist()
|
| 479 |
+
|
| 480 |
+
if not is_long_audio:
|
| 481 |
+
# Short audio: use directly
|
| 482 |
+
code_tensor = torch.tensor(segment_code,
|
| 483 |
+
dtype=torch.long,
|
| 484 |
+
device=mel.device)
|
| 485 |
+
results[batch_idx] = (code_tensor, len(segment_code))
|
| 486 |
+
else:
|
| 487 |
+
# Long audio: collect all segments
|
| 488 |
+
if batch_idx not in results:
|
| 489 |
+
results[batch_idx] = []
|
| 490 |
+
results[batch_idx].append(segment_code)
|
| 491 |
+
|
| 492 |
+
# Process long audio segment merging
|
| 493 |
+
for batch_idx in range(batch_size):
|
| 494 |
+
if long_audio_mask[batch_idx].item():
|
| 495 |
+
# Merge long audio segments
|
| 496 |
+
audio_codes = results[batch_idx]
|
| 497 |
+
|
| 498 |
+
# Determine token rate based on model name
|
| 499 |
+
if hasattr(self,
|
| 500 |
+
'name') and self.name == "speech_tokenizer_v1":
|
| 501 |
+
token_rate = 50
|
| 502 |
+
else:
|
| 503 |
+
token_rate = 25
|
| 504 |
+
|
| 505 |
+
merged_codes = merge_tokenized_segments(audio_codes,
|
| 506 |
+
overlap=overlap,
|
| 507 |
+
token_rate=token_rate)
|
| 508 |
+
|
| 509 |
+
# Convert to tensor
|
| 510 |
+
merged_codes_tensor = torch.tensor(merged_codes,
|
| 511 |
+
dtype=torch.long,
|
| 512 |
+
device=mel.device)
|
| 513 |
+
results[batch_idx] = (merged_codes_tensor, len(merged_codes))
|
| 514 |
+
|
| 515 |
+
# Construct final output
|
| 516 |
+
max_code_len = max(code_info[1] for code_info in results.values())
|
| 517 |
+
|
| 518 |
+
output_codes = torch.zeros(batch_size,
|
| 519 |
+
max_code_len,
|
| 520 |
+
dtype=torch.long,
|
| 521 |
+
device=mel.device)
|
| 522 |
+
output_codes_len = torch.zeros(batch_size,
|
| 523 |
+
dtype=torch.long,
|
| 524 |
+
device=mel.device)
|
| 525 |
+
|
| 526 |
+
for batch_idx, (code_tensor, code_len) in results.items():
|
| 527 |
+
output_codes[batch_idx, :code_len] = code_tensor
|
| 528 |
+
output_codes_len[batch_idx] = code_len
|
| 529 |
+
|
| 530 |
+
return output_codes, output_codes_len
|
| 531 |
+
|
| 532 |
+
@property
|
| 533 |
+
def device(self):
|
| 534 |
+
return next(self.parameters()).device
|
| 535 |
+
|
| 536 |
+
def init_from_onnx(self, onnx_path: str):
|
| 537 |
+
ckpt = onnx2torch(onnx_path, None, False)
|
| 538 |
+
self.load_state_dict(ckpt, strict=True)
|
| 539 |
+
|
| 540 |
+
def init_from_pt(self, ckpt_path: str):
|
| 541 |
+
ckpt = torch.load(ckpt_path, map_location="cpu", mmap=True)
|
| 542 |
+
self.load_state_dict(ckpt, strict=True)
|
| 543 |
+
|
| 544 |
+
def freeze(self):
|
| 545 |
+
for _, param in self.named_parameters():
|
| 546 |
+
param.requires_grad = False
|
speech/tools/S3Tokenizer/s3tokenizer/model_v2.py
ADDED
|
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) (Mddct: Dinghao Zhou)
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
from typing import Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
|
| 21 |
+
from s3tokenizer.model import Conv1d, LayerNorm, Linear, MultiHeadAttention
|
| 22 |
+
from s3tokenizer.utils import make_non_pad_mask, mask_to_bias, onnx2torch, merge_tokenized_segments
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class ModelConfig:
|
| 27 |
+
n_mels: int = 128
|
| 28 |
+
n_audio_ctx: int = 1500
|
| 29 |
+
n_audio_state: int = 1280
|
| 30 |
+
n_audio_head: int = 20
|
| 31 |
+
n_audio_layer: int = 6
|
| 32 |
+
n_codebook_size: int = 3**8
|
| 33 |
+
|
| 34 |
+
use_sdpa: bool = False
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def precompute_freqs_cis(dim: int,
|
| 38 |
+
end: int,
|
| 39 |
+
theta: float = 10000.0,
|
| 40 |
+
scaling=None):
|
| 41 |
+
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 42 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
| 43 |
+
if scaling is not None:
|
| 44 |
+
t = t * scaling
|
| 45 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
| 46 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 47 |
+
|
| 48 |
+
return torch.cat((freqs_cis, freqs_cis), dim=-1)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def apply_rotary_emb(
|
| 52 |
+
xq: torch.Tensor,
|
| 53 |
+
xk: torch.Tensor,
|
| 54 |
+
freqs_cis: torch.Tensor,
|
| 55 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 56 |
+
real = torch.view_as_real(freqs_cis)
|
| 57 |
+
cos, sin = real[:, :, 0], real[:, :, 1]
|
| 58 |
+
cos = cos.unsqueeze(0).unsqueeze(2)
|
| 59 |
+
sin = sin.unsqueeze(0).unsqueeze(2)
|
| 60 |
+
|
| 61 |
+
D = xq.shape[-1]
|
| 62 |
+
half_l, half_r = xq[:, :, :, :D // 2], xq[:, :, :, D // 2:]
|
| 63 |
+
xq_r = torch.cat((-half_r, half_l), dim=-1)
|
| 64 |
+
|
| 65 |
+
D = xk.shape[-1]
|
| 66 |
+
|
| 67 |
+
half_l, half_r = xk[:, :, :, :D // 2], xk[:, :, :, D // 2:]
|
| 68 |
+
xk_r = torch.cat((-half_r, half_l), dim=-1)
|
| 69 |
+
|
| 70 |
+
return xq * cos + xq_r * sin, xk * cos + xk_r * sin
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
| 74 |
+
ndim = x.ndim
|
| 75 |
+
assert 0 <= 1 < ndim
|
| 76 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
| 77 |
+
shape = [
|
| 78 |
+
d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
|
| 79 |
+
]
|
| 80 |
+
return freqs_cis.view(*shape)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class FSQCodebook(torch.nn.Module):
|
| 84 |
+
|
| 85 |
+
def __init__(self, dim: int, level: int = 3):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.project_down = torch.nn.Linear(dim, 8)
|
| 88 |
+
self.level = level
|
| 89 |
+
self.embed = None
|
| 90 |
+
|
| 91 |
+
@torch.inference_mode()
|
| 92 |
+
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
| 93 |
+
x = rearrange(x, "... d -> (...) d")
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
@torch.inference_mode()
|
| 97 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 98 |
+
x_shape = x.shape
|
| 99 |
+
# pre-process
|
| 100 |
+
x = self.preprocess(x)
|
| 101 |
+
# quantize
|
| 102 |
+
h = self.project_down(x).float()
|
| 103 |
+
h = h.tanh()
|
| 104 |
+
h = h * 0.9990000128746033
|
| 105 |
+
h = h.round() + 1
|
| 106 |
+
# h = ((self.level - 1) * h).round() # range [-k, k]
|
| 107 |
+
powers = torch.pow(
|
| 108 |
+
self.level,
|
| 109 |
+
torch.arange(2**self.level, device=x.device, dtype=h.dtype))
|
| 110 |
+
mu = torch.sum(h * powers.unsqueeze(0), dim=-1)
|
| 111 |
+
ind = mu.reshape(x_shape[0], x_shape[1]).int()
|
| 112 |
+
return ind
|
| 113 |
+
|
| 114 |
+
@torch.inference_mode()
|
| 115 |
+
def decode(self, embed_ind: torch.Tensor) -> torch.Tensor:
|
| 116 |
+
raise NotImplementedError(
|
| 117 |
+
'There is no official up project component provided')
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class FSQVectorQuantization(torch.nn.Module):
|
| 121 |
+
"""Vector quantization implementation (inference-only).
|
| 122 |
+
Args:
|
| 123 |
+
dim (int): Dimension
|
| 124 |
+
codebook_size (int): Codebook size
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def __init__(
|
| 128 |
+
self,
|
| 129 |
+
dim: int,
|
| 130 |
+
codebook_size: int,
|
| 131 |
+
):
|
| 132 |
+
super().__init__()
|
| 133 |
+
assert 3**8 == codebook_size
|
| 134 |
+
self._codebook = FSQCodebook(dim=dim, level=3)
|
| 135 |
+
self.codebook_size = codebook_size
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def codebook(self):
|
| 139 |
+
return self._codebook.embed
|
| 140 |
+
|
| 141 |
+
@torch.inference_mode()
|
| 142 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
| 143 |
+
return self._codebook.encode(x)
|
| 144 |
+
|
| 145 |
+
@torch.inference_mode()
|
| 146 |
+
def decode(self, embed_ind: torch.Tensor) -> torch.Tensor:
|
| 147 |
+
quantize = self._codebook.decode(embed_ind)
|
| 148 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
| 149 |
+
return quantize
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class FSMNMultiHeadAttention(MultiHeadAttention):
|
| 153 |
+
|
| 154 |
+
def __init__(
|
| 155 |
+
self,
|
| 156 |
+
n_state: int,
|
| 157 |
+
n_head: int,
|
| 158 |
+
kernel_size: int = 31,
|
| 159 |
+
use_sdpa: bool = False,
|
| 160 |
+
):
|
| 161 |
+
super().__init__(n_state, n_head)
|
| 162 |
+
|
| 163 |
+
self.fsmn_block = torch.nn.Conv1d(n_state,
|
| 164 |
+
n_state,
|
| 165 |
+
kernel_size,
|
| 166 |
+
stride=1,
|
| 167 |
+
padding=0,
|
| 168 |
+
groups=n_state,
|
| 169 |
+
bias=False)
|
| 170 |
+
self.left_padding = (kernel_size - 1) // 2
|
| 171 |
+
self.right_padding = kernel_size - 1 - self.left_padding
|
| 172 |
+
self.pad_fn = torch.nn.ConstantPad1d(
|
| 173 |
+
(self.left_padding, self.right_padding), 0.0)
|
| 174 |
+
|
| 175 |
+
self.use_sdpa = use_sdpa
|
| 176 |
+
|
| 177 |
+
def forward_fsmn(self,
|
| 178 |
+
inputs: torch.Tensor,
|
| 179 |
+
mask: Optional[torch.Tensor] = None):
|
| 180 |
+
b, t, _, _ = inputs.size()
|
| 181 |
+
inputs = inputs.view(b, t, -1)
|
| 182 |
+
if mask is not None and mask.size(2) > 0: # time2 > 0
|
| 183 |
+
inputs = inputs * mask
|
| 184 |
+
x = inputs.transpose(1, 2)
|
| 185 |
+
x = self.pad_fn(x)
|
| 186 |
+
x = self.fsmn_block(x)
|
| 187 |
+
x = x.transpose(1, 2)
|
| 188 |
+
x += inputs
|
| 189 |
+
return x * mask
|
| 190 |
+
|
| 191 |
+
def qkv_attention(self,
|
| 192 |
+
q: torch.Tensor,
|
| 193 |
+
k: torch.Tensor,
|
| 194 |
+
v: torch.Tensor,
|
| 195 |
+
mask: Optional[torch.Tensor] = None,
|
| 196 |
+
mask_pad: Optional[torch.Tensor] = None,
|
| 197 |
+
freqs_cis: Optional[torch.Tensor] = None):
|
| 198 |
+
_, _, D = q.shape
|
| 199 |
+
scale = (D // self.n_head)**-0.25
|
| 200 |
+
q = q.view(*q.shape[:2], self.n_head, -1)
|
| 201 |
+
k = k.view(*k.shape[:2], self.n_head, -1)
|
| 202 |
+
v = v.view(*v.shape[:2], self.n_head, -1)
|
| 203 |
+
|
| 204 |
+
if freqs_cis is not None:
|
| 205 |
+
q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
|
| 206 |
+
|
| 207 |
+
fsm_memory = self.forward_fsmn(v, mask_pad)
|
| 208 |
+
|
| 209 |
+
q = q.permute(0, 2, 1, 3) * scale
|
| 210 |
+
v = v.permute(0, 2, 1, 3)
|
| 211 |
+
|
| 212 |
+
if not self.use_sdpa:
|
| 213 |
+
k = k.permute(0, 2, 3, 1) * scale
|
| 214 |
+
qk = q @ k # (B, n_head, T, T)
|
| 215 |
+
if mask is not None:
|
| 216 |
+
qk = qk + mask
|
| 217 |
+
qk = qk.float()
|
| 218 |
+
w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
|
| 219 |
+
return (w @ v).permute(
|
| 220 |
+
0, 2, 1, 3).flatten(start_dim=2), qk.detach(), fsm_memory
|
| 221 |
+
else:
|
| 222 |
+
k = k.permute(0, 2, 1, 3) * scale
|
| 223 |
+
assert mask is not None
|
| 224 |
+
output = torch.nn.functional.scaled_dot_product_attention(
|
| 225 |
+
q,
|
| 226 |
+
k,
|
| 227 |
+
v,
|
| 228 |
+
attn_mask=mask,
|
| 229 |
+
dropout_p=0.,
|
| 230 |
+
scale=1.,
|
| 231 |
+
)
|
| 232 |
+
output = (output.transpose(1,
|
| 233 |
+
2).contiguous().view(q.size(0), -1, D)
|
| 234 |
+
) # (batch, time1, d_model)
|
| 235 |
+
return output, None, fsm_memory
|
| 236 |
+
|
| 237 |
+
def forward(self,
|
| 238 |
+
x: torch.Tensor,
|
| 239 |
+
mask: Optional[torch.Tensor] = None,
|
| 240 |
+
mask_pad: Optional[torch.Tensor] = None,
|
| 241 |
+
freqs_cis: Optional[torch.Tensor] = None):
|
| 242 |
+
|
| 243 |
+
q = self.query(x)
|
| 244 |
+
k = self.key(x)
|
| 245 |
+
v = self.value(x)
|
| 246 |
+
|
| 247 |
+
wv, qk, fsm_memory = self.qkv_attention(q, k, v, mask, mask_pad,
|
| 248 |
+
freqs_cis)
|
| 249 |
+
return self.out(wv) + fsm_memory, qk
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class ResidualAttentionBlock(torch.nn.Module):
|
| 253 |
+
|
| 254 |
+
def __init__(
|
| 255 |
+
self,
|
| 256 |
+
n_state: int,
|
| 257 |
+
n_head: int,
|
| 258 |
+
kernel_size: int = 31,
|
| 259 |
+
use_sdpa: bool = False,
|
| 260 |
+
):
|
| 261 |
+
super().__init__()
|
| 262 |
+
|
| 263 |
+
self.attn = FSMNMultiHeadAttention(n_state,
|
| 264 |
+
n_head,
|
| 265 |
+
kernel_size,
|
| 266 |
+
use_sdpa=use_sdpa)
|
| 267 |
+
self.attn_ln = LayerNorm(n_state, eps=1e-6)
|
| 268 |
+
|
| 269 |
+
n_mlp = n_state * 4
|
| 270 |
+
|
| 271 |
+
self.mlp = torch.nn.Sequential(Linear(n_state, n_mlp), torch.nn.GELU(),
|
| 272 |
+
Linear(n_mlp, n_state))
|
| 273 |
+
self.mlp_ln = LayerNorm(n_state)
|
| 274 |
+
|
| 275 |
+
def forward(
|
| 276 |
+
self,
|
| 277 |
+
x: torch.Tensor,
|
| 278 |
+
mask: Optional[torch.Tensor] = None,
|
| 279 |
+
mask_pad: Optional[torch.Tensor] = None,
|
| 280 |
+
freqs_cis: Optional[torch.Tensor] = None,
|
| 281 |
+
):
|
| 282 |
+
x = x + self.attn(
|
| 283 |
+
self.attn_ln(x), mask=mask, mask_pad=mask_pad,
|
| 284 |
+
freqs_cis=freqs_cis)[0]
|
| 285 |
+
|
| 286 |
+
x = x + self.mlp(self.mlp_ln(x))
|
| 287 |
+
return x
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class AudioEncoderV2(torch.nn.Module):
|
| 291 |
+
|
| 292 |
+
def __init__(
|
| 293 |
+
self,
|
| 294 |
+
n_mels: int,
|
| 295 |
+
n_state: int,
|
| 296 |
+
n_head: int,
|
| 297 |
+
n_layer: int,
|
| 298 |
+
stride: int,
|
| 299 |
+
use_sdpa: bool,
|
| 300 |
+
):
|
| 301 |
+
super().__init__()
|
| 302 |
+
self.stride = stride
|
| 303 |
+
|
| 304 |
+
self.conv1 = Conv1d(n_mels,
|
| 305 |
+
n_state,
|
| 306 |
+
kernel_size=3,
|
| 307 |
+
stride=stride,
|
| 308 |
+
padding=1)
|
| 309 |
+
self.conv2 = Conv1d(n_state,
|
| 310 |
+
n_state,
|
| 311 |
+
kernel_size=3,
|
| 312 |
+
stride=2,
|
| 313 |
+
padding=1)
|
| 314 |
+
self.freqs_cis = precompute_freqs_cis(64, 1024 * 2)
|
| 315 |
+
self.blocks = torch.nn.ModuleList([
|
| 316 |
+
ResidualAttentionBlock(n_state, n_head, use_sdpa=use_sdpa)
|
| 317 |
+
for _ in range(n_layer)
|
| 318 |
+
])
|
| 319 |
+
|
| 320 |
+
def forward(self, x: torch.Tensor,
|
| 321 |
+
x_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 322 |
+
"""
|
| 323 |
+
x : torch.Tensor, shape = (batch_size, n_mels, T)
|
| 324 |
+
the mel spectrogram of the audio
|
| 325 |
+
x_len: torch.Tensor, shape = (batch_size,)
|
| 326 |
+
length of each audio in x
|
| 327 |
+
"""
|
| 328 |
+
mask = make_non_pad_mask(x_len).unsqueeze(1)
|
| 329 |
+
x = torch.nn.functional.gelu(self.conv1(x * mask))
|
| 330 |
+
x_len = (x_len + 2 - 1 * (3 - 1) - 1) // self.stride + 1
|
| 331 |
+
mask = make_non_pad_mask(x_len).unsqueeze(1)
|
| 332 |
+
x = torch.nn.functional.gelu(self.conv2(x * mask))
|
| 333 |
+
x_len = (x_len + 2 - 1 * (3 - 1) - 1) // 2 + 1
|
| 334 |
+
mask = make_non_pad_mask(x_len).unsqueeze(1)
|
| 335 |
+
x = x.permute(0, 2, 1) # (B, T // 2, n_state)
|
| 336 |
+
freqs_cis = self.freqs_cis.to(x.device)
|
| 337 |
+
mask_pad = mask.transpose(1, 2)
|
| 338 |
+
mask = mask_to_bias(mask, x.dtype)
|
| 339 |
+
|
| 340 |
+
tmp = torch.view_as_real(freqs_cis)
|
| 341 |
+
cos, sin = tmp[:, :, 0], tmp[:, :, 1]
|
| 342 |
+
|
| 343 |
+
cos = torch.cat((cos, cos), dim=-1)
|
| 344 |
+
sin = torch.cat((sin, sin), dim=-1)
|
| 345 |
+
cos = cos.unsqueeze(0).unsqueeze(2)
|
| 346 |
+
sin = sin.unsqueeze(0).unsqueeze(2)
|
| 347 |
+
|
| 348 |
+
for block in self.blocks:
|
| 349 |
+
x = block(x, mask.unsqueeze(1), mask_pad, freqs_cis[:x.size(1)])
|
| 350 |
+
|
| 351 |
+
return x, x_len
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class S3TokenizerV2(torch.nn.Module):
|
| 355 |
+
"""S3 tokenizer v2 implementation (inference-only).
|
| 356 |
+
Args:
|
| 357 |
+
config (ModelConfig): Config
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
def __init__(self, name: str, config: ModelConfig = ModelConfig()):
|
| 361 |
+
super().__init__()
|
| 362 |
+
self.name = name # Store model name for token_rate determination
|
| 363 |
+
if 'v1' not in name:
|
| 364 |
+
assert 'v2' in name
|
| 365 |
+
# TODO(Mddct): make it configureable
|
| 366 |
+
config.n_codebook_size = 3**8
|
| 367 |
+
self.config = config
|
| 368 |
+
self.encoder = AudioEncoderV2(
|
| 369 |
+
self.config.n_mels,
|
| 370 |
+
self.config.n_audio_state,
|
| 371 |
+
self.config.n_audio_head,
|
| 372 |
+
self.config.n_audio_layer,
|
| 373 |
+
2,
|
| 374 |
+
self.config.use_sdpa,
|
| 375 |
+
)
|
| 376 |
+
self.quantizer = FSQVectorQuantization(
|
| 377 |
+
self.config.n_audio_state,
|
| 378 |
+
self.config.n_codebook_size,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
def forward(self, mel: torch.Tensor,
|
| 382 |
+
mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 383 |
+
return self.quantize(mel, mel_len)
|
| 384 |
+
|
| 385 |
+
@torch.inference_mode()
|
| 386 |
+
def quantize(self, mel: torch.Tensor,
|
| 387 |
+
mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 388 |
+
"""
|
| 389 |
+
Quantize mel spectrogram to tokens, with automatic long audio handling.
|
| 390 |
+
|
| 391 |
+
Args:
|
| 392 |
+
mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
|
| 393 |
+
mel_len: mel length tensor, shape (batch_size,)
|
| 394 |
+
|
| 395 |
+
Returns:
|
| 396 |
+
code: quantized tokens, shape (batch_size, T')
|
| 397 |
+
code_len: token length, shape (batch_size,)
|
| 398 |
+
"""
|
| 399 |
+
# Check if any audio in the batch exceeds 30 seconds
|
| 400 |
+
# Assuming 16kHz sample rate and hop_length=160, 30s = 30*16000/160 = 3000 frames
|
| 401 |
+
max_frames = 3000
|
| 402 |
+
|
| 403 |
+
# Check which samples are long audio
|
| 404 |
+
long_audio_mask = mel_len > max_frames
|
| 405 |
+
|
| 406 |
+
if long_audio_mask.any():
|
| 407 |
+
# Has long audio - need special processing
|
| 408 |
+
return self._quantize_mixed_batch(mel, mel_len, long_audio_mask,
|
| 409 |
+
max_frames)
|
| 410 |
+
else:
|
| 411 |
+
# All short audio - use original method
|
| 412 |
+
hidden, code_len = self.encoder(mel, mel_len)
|
| 413 |
+
code = self.quantizer.encode(hidden)
|
| 414 |
+
return code, code_len
|
| 415 |
+
|
| 416 |
+
@torch.inference_mode()
|
| 417 |
+
def _quantize_mixed_batch(
|
| 418 |
+
self, mel: torch.Tensor, mel_len: torch.Tensor,
|
| 419 |
+
long_audio_mask: torch.Tensor,
|
| 420 |
+
max_frames: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 421 |
+
"""
|
| 422 |
+
Handle mixed batch with both short and long audio using unified batch processing.
|
| 423 |
+
|
| 424 |
+
Args:
|
| 425 |
+
mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
|
| 426 |
+
mel_len: mel length tensor, shape (batch_size,)
|
| 427 |
+
long_audio_mask: boolean mask for long audio, shape (batch_size,)
|
| 428 |
+
max_frames: maximum frames for short audio
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
code: quantized tokens, shape (batch_size, T')
|
| 432 |
+
code_len: token length, shape (batch_size,)
|
| 433 |
+
"""
|
| 434 |
+
batch_size = mel.size(0)
|
| 435 |
+
|
| 436 |
+
# Parameters for sliding window
|
| 437 |
+
sample_rate = 16000
|
| 438 |
+
hop_length = 160 # Default hop length for mel spectrogram
|
| 439 |
+
window_size = 30 # seconds
|
| 440 |
+
overlap = 4 # seconds
|
| 441 |
+
|
| 442 |
+
# Calculate frame-based parameters
|
| 443 |
+
frames_per_window = window_size * sample_rate // hop_length # 3000 frames
|
| 444 |
+
frames_per_overlap = overlap * sample_rate // hop_length # 400 frames
|
| 445 |
+
frames_per_stride = frames_per_window - frames_per_overlap # 2600 frames
|
| 446 |
+
|
| 447 |
+
# Collect all segments to process (including short and long audio segments)
|
| 448 |
+
all_segments = []
|
| 449 |
+
all_segments_len = []
|
| 450 |
+
segment_info = [
|
| 451 |
+
] # Record which audio each segment belongs to and whether it's long audio
|
| 452 |
+
|
| 453 |
+
# Process all audio in the batch
|
| 454 |
+
for batch_idx in range(batch_size):
|
| 455 |
+
audio_mel = mel[batch_idx]
|
| 456 |
+
audio_mel_len = mel_len[batch_idx]
|
| 457 |
+
is_long_audio = long_audio_mask[batch_idx].item()
|
| 458 |
+
|
| 459 |
+
if not is_long_audio:
|
| 460 |
+
# Short audio: process directly as a single segment
|
| 461 |
+
segment = audio_mel[:, :audio_mel_len]
|
| 462 |
+
seg_len = audio_mel_len.item()
|
| 463 |
+
|
| 464 |
+
# Pad to max_frames if necessary
|
| 465 |
+
if seg_len < frames_per_window:
|
| 466 |
+
pad_size = frames_per_window - seg_len
|
| 467 |
+
segment = torch.nn.functional.pad(segment, (0, pad_size))
|
| 468 |
+
|
| 469 |
+
all_segments.append(segment)
|
| 470 |
+
all_segments_len.append(
|
| 471 |
+
torch.tensor(seg_len, device=mel.device))
|
| 472 |
+
segment_info.append({
|
| 473 |
+
'batch_idx': batch_idx,
|
| 474 |
+
'is_long_audio': False,
|
| 475 |
+
'segment_idx': 0,
|
| 476 |
+
'total_segments': 1
|
| 477 |
+
})
|
| 478 |
+
else:
|
| 479 |
+
# Long audio: split into multiple segments
|
| 480 |
+
start = 0
|
| 481 |
+
segment_idx = 0
|
| 482 |
+
while start < audio_mel_len:
|
| 483 |
+
end = min(start + frames_per_window, audio_mel_len)
|
| 484 |
+
segment = audio_mel[:, start:end]
|
| 485 |
+
|
| 486 |
+
seg_len = segment.size(1)
|
| 487 |
+
# Pad if necessary
|
| 488 |
+
if seg_len < frames_per_window:
|
| 489 |
+
pad_size = frames_per_window - seg_len
|
| 490 |
+
segment = torch.nn.functional.pad(
|
| 491 |
+
segment, (0, pad_size))
|
| 492 |
+
|
| 493 |
+
all_segments.append(segment)
|
| 494 |
+
all_segments_len.append(
|
| 495 |
+
torch.tensor(seg_len, device=mel.device))
|
| 496 |
+
segment_info.append({
|
| 497 |
+
'batch_idx': batch_idx,
|
| 498 |
+
'is_long_audio': True,
|
| 499 |
+
'segment_idx': segment_idx,
|
| 500 |
+
'total_segments': None # Will be filled later
|
| 501 |
+
})
|
| 502 |
+
|
| 503 |
+
segment_idx += 1
|
| 504 |
+
start += frames_per_stride
|
| 505 |
+
|
| 506 |
+
# Update total_segments info
|
| 507 |
+
total_segments = segment_idx
|
| 508 |
+
for info in segment_info:
|
| 509 |
+
if info['batch_idx'] == batch_idx and info['is_long_audio']:
|
| 510 |
+
info['total_segments'] = total_segments
|
| 511 |
+
|
| 512 |
+
if not all_segments:
|
| 513 |
+
# Fallback if no segments
|
| 514 |
+
return torch.zeros(batch_size,
|
| 515 |
+
0,
|
| 516 |
+
dtype=torch.long,
|
| 517 |
+
device=mel.device), torch.zeros(
|
| 518 |
+
batch_size,
|
| 519 |
+
dtype=torch.long,
|
| 520 |
+
device=mel.device)
|
| 521 |
+
|
| 522 |
+
# Unified batch processing for all segments
|
| 523 |
+
unified_batch_mel = torch.stack(all_segments)
|
| 524 |
+
unified_batch_lens = torch.stack(all_segments_len)
|
| 525 |
+
|
| 526 |
+
# Process all segments at once
|
| 527 |
+
hidden, code_len = self.encoder(unified_batch_mel, unified_batch_lens)
|
| 528 |
+
codes = self.quantizer.encode(hidden)
|
| 529 |
+
|
| 530 |
+
# Reorganize results based on segment_info
|
| 531 |
+
results = {} # batch_idx -> (code_tensor, code_len)
|
| 532 |
+
|
| 533 |
+
for seg_idx, info in enumerate(segment_info):
|
| 534 |
+
batch_idx = info['batch_idx']
|
| 535 |
+
is_long_audio = info['is_long_audio']
|
| 536 |
+
segment_idx = info['segment_idx']
|
| 537 |
+
|
| 538 |
+
# Get codes for current segment
|
| 539 |
+
segment_code = codes[
|
| 540 |
+
seg_idx, :code_len[seg_idx].item()].cpu().numpy().tolist()
|
| 541 |
+
|
| 542 |
+
if not is_long_audio:
|
| 543 |
+
# Short audio: use directly
|
| 544 |
+
code_tensor = torch.tensor(segment_code,
|
| 545 |
+
dtype=torch.long,
|
| 546 |
+
device=mel.device)
|
| 547 |
+
results[batch_idx] = (code_tensor, len(segment_code))
|
| 548 |
+
else:
|
| 549 |
+
# Long audio: collect all segments
|
| 550 |
+
if batch_idx not in results:
|
| 551 |
+
results[batch_idx] = []
|
| 552 |
+
results[batch_idx].append(segment_code)
|
| 553 |
+
|
| 554 |
+
# Process long audio segment merging
|
| 555 |
+
for batch_idx in range(batch_size):
|
| 556 |
+
if long_audio_mask[batch_idx].item():
|
| 557 |
+
# Merge long audio segments
|
| 558 |
+
audio_codes = results[batch_idx]
|
| 559 |
+
|
| 560 |
+
# V2 models use 25Hz token rate
|
| 561 |
+
token_rate = 25
|
| 562 |
+
|
| 563 |
+
merged_codes = merge_tokenized_segments(audio_codes,
|
| 564 |
+
overlap=overlap,
|
| 565 |
+
token_rate=token_rate)
|
| 566 |
+
|
| 567 |
+
# Convert to tensor
|
| 568 |
+
merged_codes_tensor = torch.tensor(merged_codes,
|
| 569 |
+
dtype=torch.long,
|
| 570 |
+
device=mel.device)
|
| 571 |
+
results[batch_idx] = (merged_codes_tensor, len(merged_codes))
|
| 572 |
+
|
| 573 |
+
# Construct final output
|
| 574 |
+
max_code_len = max(code_info[1] for code_info in results.values())
|
| 575 |
+
|
| 576 |
+
output_codes = torch.zeros(batch_size,
|
| 577 |
+
max_code_len,
|
| 578 |
+
dtype=torch.long,
|
| 579 |
+
device=mel.device)
|
| 580 |
+
output_codes_len = torch.zeros(batch_size,
|
| 581 |
+
dtype=torch.long,
|
| 582 |
+
device=mel.device)
|
| 583 |
+
|
| 584 |
+
for batch_idx, (code_tensor, code_len) in results.items():
|
| 585 |
+
output_codes[batch_idx, :code_len] = code_tensor
|
| 586 |
+
output_codes_len[batch_idx] = code_len
|
| 587 |
+
|
| 588 |
+
return output_codes, output_codes_len
|
| 589 |
+
|
| 590 |
+
@property
|
| 591 |
+
def device(self):
|
| 592 |
+
return next(self.parameters()).device
|
| 593 |
+
|
| 594 |
+
def init_from_onnx(self, onnx_path: str):
|
| 595 |
+
ckpt = onnx2torch(onnx_path, None, False)
|
| 596 |
+
self.load_state_dict(ckpt, strict=True)
|
| 597 |
+
|
| 598 |
+
def init_from_pt(self, ckpt_path: str):
|
| 599 |
+
ckpt = torch.load(ckpt_path, map_location="cpu", mmap=True)
|
| 600 |
+
self.load_state_dict(ckpt, strict=True)
|
| 601 |
+
|
| 602 |
+
def freeze(self):
|
| 603 |
+
for _, param in self.named_parameters():
|
| 604 |
+
param.requires_grad = False
|
speech/tools/S3Tokenizer/s3tokenizer/utils.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2023 OpenAI. (authors: Whisper Team)
|
| 2 |
+
# 2024 Tsinghua Univ. (authors: Xingchen Song)
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py
|
| 16 |
+
Add rename_weights() & onnx2torch() & make_non_pad_mask() & mask_to_bias()
|
| 17 |
+
Copy merge_tokenized_segments() from https://github.com/Mddct/s3tokenizer-long/blob/main/example.py
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
from functools import lru_cache
|
| 22 |
+
from typing import List, Optional, Union
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import onnx
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
import torchaudio
|
| 29 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _rename_weights(weights_dict: dict):
|
| 33 |
+
"""
|
| 34 |
+
Rename onnx weights to pytorch format.
|
| 35 |
+
|
| 36 |
+
Parameters
|
| 37 |
+
----------
|
| 38 |
+
weight_dict: dict
|
| 39 |
+
The dict containing weights in onnx format
|
| 40 |
+
|
| 41 |
+
Returns
|
| 42 |
+
-------
|
| 43 |
+
A new weight dict containing the weights in pytorch format.
|
| 44 |
+
"""
|
| 45 |
+
new_weight_dict = {}
|
| 46 |
+
for k in weights_dict.keys():
|
| 47 |
+
if "quantizer" in k: # vq or fsq
|
| 48 |
+
if k == "/quantizer/rq/model/layers.0/_codebook/Pow_1":
|
| 49 |
+
new_weight_dict["quantizer._codebook.embed"] = weights_dict[k]
|
| 50 |
+
elif 'project_down' in k: # v2
|
| 51 |
+
new_weight_dict[k] = weights_dict[k]
|
| 52 |
+
elif "positional_embedding" in k: # positional emb
|
| 53 |
+
new_weight_dict[k] = weights_dict[k]
|
| 54 |
+
elif "conv" in k: # 1/2 or 1/4 subsample
|
| 55 |
+
new_weight_dict[k] = weights_dict[k]
|
| 56 |
+
else: # transformer blocks
|
| 57 |
+
assert "blocks" in k
|
| 58 |
+
new_k = (k[1:].replace('/', '.').replace(
|
| 59 |
+
'MatMul', 'weight').replace('Add_1', 'bias').replace(
|
| 60 |
+
'Mul', 'weight').replace('Add', 'bias').replace(
|
| 61 |
+
'mlp.mlp', 'mlp')).replace('fsmn_block.Conv',
|
| 62 |
+
'fsmn_block.weight')
|
| 63 |
+
|
| 64 |
+
new_weight_dict[f"encoder.{new_k}"] = weights_dict[k]
|
| 65 |
+
return new_weight_dict
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def onnx2torch(onnx_path: str, torch_path: str = None, verbose: bool = False):
|
| 69 |
+
"""
|
| 70 |
+
Open an onnx file and convert to pytorch format.
|
| 71 |
+
|
| 72 |
+
Parameters
|
| 73 |
+
----------
|
| 74 |
+
onnx_path: str
|
| 75 |
+
The onnx file to open, typically `speech_tokenizer_v1.onnx`
|
| 76 |
+
|
| 77 |
+
torch_path: str
|
| 78 |
+
The path to save the torch-formated checkpoint.
|
| 79 |
+
|
| 80 |
+
verbose: bool
|
| 81 |
+
Logging info or not.
|
| 82 |
+
|
| 83 |
+
Returns
|
| 84 |
+
-------
|
| 85 |
+
A checkpoint dict containing the weights and their names, if torch_path is
|
| 86 |
+
None. Otherwise save checkpoint dict to the desired path.
|
| 87 |
+
"""
|
| 88 |
+
onnx_model = onnx.load(onnx_path)
|
| 89 |
+
weights_dict = {}
|
| 90 |
+
initializer_map = {
|
| 91 |
+
initializer.name: initializer
|
| 92 |
+
for initializer in onnx_model.graph.initializer
|
| 93 |
+
}
|
| 94 |
+
for node in onnx_model.graph.node:
|
| 95 |
+
for input_name in node.input:
|
| 96 |
+
if input_name in initializer_map:
|
| 97 |
+
ln_bias_name, ln_weight_name = None, None # for v2 ln
|
| 98 |
+
initializer = initializer_map[input_name]
|
| 99 |
+
if input_name in [
|
| 100 |
+
"onnx::Conv_1519",
|
| 101 |
+
"encoders.conv1.weight",
|
| 102 |
+
"onnx::Conv_2216",
|
| 103 |
+
]: # v1_50hz, v1_25hz, v2_25hz
|
| 104 |
+
weight_name = "encoder.conv1.weight"
|
| 105 |
+
elif input_name in [
|
| 106 |
+
"onnx::Conv_1520",
|
| 107 |
+
"encoders.conv1.bias",
|
| 108 |
+
"onnx::Conv_2217",
|
| 109 |
+
]: # v1_50hz, v1_25hz, v2_25hz
|
| 110 |
+
weight_name = "encoder.conv1.bias"
|
| 111 |
+
elif input_name in [
|
| 112 |
+
"onnx::Conv_1521",
|
| 113 |
+
"encoders.conv2.weight",
|
| 114 |
+
"onnx::Conv_2218",
|
| 115 |
+
]:
|
| 116 |
+
weight_name = "encoder.conv2.weight"
|
| 117 |
+
elif input_name in [
|
| 118 |
+
"onnx::Conv_1522",
|
| 119 |
+
"encoders.conv2.bias",
|
| 120 |
+
"onnx::Conv_2219",
|
| 121 |
+
]:
|
| 122 |
+
weight_name = "encoder.conv2.bias"
|
| 123 |
+
elif input_name == "encoders.positional_embedding":
|
| 124 |
+
weight_name = "encoder.positional_embedding"
|
| 125 |
+
elif input_name == 'quantizer.project_in.bias':
|
| 126 |
+
weight_name = "quantizer._codebook.project_down.bias"
|
| 127 |
+
elif input_name == 'onnx::MatMul_2536':
|
| 128 |
+
weight_name = "quantizer._codebook.project_down.weight"
|
| 129 |
+
else:
|
| 130 |
+
if node.op_type == 'LayerNormalization': # in input_name:
|
| 131 |
+
ln_name = node.name.replace('/LayerNormalization', '')
|
| 132 |
+
ln_weight_name = ln_name + '.weight'
|
| 133 |
+
ln_bias_name = ln_name + '.bias'
|
| 134 |
+
else:
|
| 135 |
+
weight_name = node.name
|
| 136 |
+
if ln_weight_name is not None and ln_bias_name is not None:
|
| 137 |
+
ln_inputs = node.input
|
| 138 |
+
scale_name = ln_inputs[1]
|
| 139 |
+
bias_name = ln_inputs[2]
|
| 140 |
+
scale = onnx.numpy_helper.to_array(
|
| 141 |
+
initializer_map[scale_name]).copy(
|
| 142 |
+
) if scale_name in initializer_map else None
|
| 143 |
+
bias = onnx.numpy_helper.to_array(
|
| 144 |
+
initializer_map[bias_name]).copy(
|
| 145 |
+
) if bias_name in initializer_map else None
|
| 146 |
+
scale.flags.writeable = True
|
| 147 |
+
bias.flags.writeable = True
|
| 148 |
+
weight_tensor = torch.from_numpy(scale)
|
| 149 |
+
bias_tensor = torch.from_numpy(bias)
|
| 150 |
+
|
| 151 |
+
weights_dict[ln_bias_name] = bias_tensor
|
| 152 |
+
weights_dict[ln_weight_name] = weight_tensor
|
| 153 |
+
else:
|
| 154 |
+
weight_array = onnx.numpy_helper.to_array(
|
| 155 |
+
initializer).copy()
|
| 156 |
+
weight_array.flags.writeable = True
|
| 157 |
+
weight_tensor = torch.from_numpy(weight_array)
|
| 158 |
+
if len(weight_tensor.shape) > 2 or weight_name in [
|
| 159 |
+
"encoder.positional_embedding"
|
| 160 |
+
]:
|
| 161 |
+
weights_dict[weight_name] = weight_tensor
|
| 162 |
+
else:
|
| 163 |
+
weights_dict[weight_name] = weight_tensor.t()
|
| 164 |
+
|
| 165 |
+
new_weights_dict = _rename_weights(weights_dict)
|
| 166 |
+
if verbose:
|
| 167 |
+
for k, v in new_weights_dict.items():
|
| 168 |
+
print(f"{k} : {v.shape} {v.dtype}")
|
| 169 |
+
print(f"PyTorch weights saved to {torch_path}")
|
| 170 |
+
del weights_dict, onnx_model
|
| 171 |
+
if torch_path:
|
| 172 |
+
torch.save(new_weights_dict, torch_path)
|
| 173 |
+
else:
|
| 174 |
+
return new_weights_dict
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def load_audio(file: str, sr: int = 16000):
|
| 178 |
+
"""
|
| 179 |
+
Open an audio file and read as mono waveform, resampling as necessary
|
| 180 |
+
|
| 181 |
+
Parameters
|
| 182 |
+
----------
|
| 183 |
+
file: str
|
| 184 |
+
The audio file to open
|
| 185 |
+
|
| 186 |
+
sr: int
|
| 187 |
+
The sample rate to resample the audio if necessary
|
| 188 |
+
|
| 189 |
+
Returns
|
| 190 |
+
-------
|
| 191 |
+
A torch.Tensor containing the audio waveform, in float32 dtype.
|
| 192 |
+
"""
|
| 193 |
+
audio, sample_rate = torchaudio.load(file)
|
| 194 |
+
if sample_rate != sr:
|
| 195 |
+
audio = torchaudio.transforms.Resample(sample_rate, sr)(audio)
|
| 196 |
+
audio = audio[0] # get the first channel
|
| 197 |
+
return audio
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@lru_cache(maxsize=None)
|
| 201 |
+
def _mel_filters(device, n_mels: int) -> torch.Tensor:
|
| 202 |
+
"""
|
| 203 |
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
| 204 |
+
Allows decoupling librosa dependency; saved using:
|
| 205 |
+
|
| 206 |
+
np.savez_compressed(
|
| 207 |
+
"mel_filters.npz",
|
| 208 |
+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
| 209 |
+
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
| 210 |
+
)
|
| 211 |
+
"""
|
| 212 |
+
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
| 213 |
+
|
| 214 |
+
filters_path = os.path.join(os.path.dirname(__file__), "assets",
|
| 215 |
+
"mel_filters.npz")
|
| 216 |
+
with np.load(filters_path, allow_pickle=False) as f:
|
| 217 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def log_mel_spectrogram(
|
| 221 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
| 222 |
+
n_mels: int = 128,
|
| 223 |
+
padding: int = 0,
|
| 224 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 225 |
+
):
|
| 226 |
+
"""
|
| 227 |
+
Compute the log-Mel spectrogram of
|
| 228 |
+
|
| 229 |
+
Parameters
|
| 230 |
+
----------
|
| 231 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
| 232 |
+
The path to audio or either a NumPy array or Tensor containing the
|
| 233 |
+
audio waveform in 16 kHz
|
| 234 |
+
|
| 235 |
+
n_mels: int
|
| 236 |
+
The number of Mel-frequency filters, only 80 is supported
|
| 237 |
+
|
| 238 |
+
padding: int
|
| 239 |
+
Number of zero samples to pad to the right
|
| 240 |
+
|
| 241 |
+
device: Optional[Union[str, torch.device]]
|
| 242 |
+
If given, the audio tensor is moved to this device before STFT
|
| 243 |
+
|
| 244 |
+
Returns
|
| 245 |
+
-------
|
| 246 |
+
torch.Tensor, shape = (128, n_frames)
|
| 247 |
+
A Tensor that contains the Mel spectrogram
|
| 248 |
+
"""
|
| 249 |
+
if not torch.is_tensor(audio):
|
| 250 |
+
if isinstance(audio, str):
|
| 251 |
+
audio = load_audio(audio)
|
| 252 |
+
|
| 253 |
+
if device is not None:
|
| 254 |
+
audio = audio.to(device)
|
| 255 |
+
if padding > 0:
|
| 256 |
+
audio = F.pad(audio, (0, padding))
|
| 257 |
+
window = torch.hann_window(400).to(audio.device)
|
| 258 |
+
stft = torch.stft(audio, 400, 160, window=window, return_complex=True)
|
| 259 |
+
magnitudes = stft[..., :-1].abs()**2
|
| 260 |
+
|
| 261 |
+
filters = _mel_filters(audio.device, n_mels)
|
| 262 |
+
mel_spec = filters @ magnitudes
|
| 263 |
+
|
| 264 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 265 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| 266 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 267 |
+
return log_spec
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def make_non_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
| 271 |
+
"""Make mask tensor containing indices of non-padded part.
|
| 272 |
+
|
| 273 |
+
The sequences in a batch may have different lengths. To enable
|
| 274 |
+
batch computing, padding is need to make all sequence in same
|
| 275 |
+
size. To avoid the padding part pass value to context dependent
|
| 276 |
+
block such as attention or convolution , this padding part is
|
| 277 |
+
masked.
|
| 278 |
+
|
| 279 |
+
1 for non-padded part and 0 for padded part.
|
| 280 |
+
|
| 281 |
+
Parameters
|
| 282 |
+
----------
|
| 283 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
-------
|
| 287 |
+
torch.Tensor: Mask tensor containing indices of padded part (B, max_T).
|
| 288 |
+
|
| 289 |
+
Examples:
|
| 290 |
+
>>> import torch
|
| 291 |
+
>>> import s3tokenizer
|
| 292 |
+
>>> lengths = torch.tensor([5, 3, 2])
|
| 293 |
+
>>> masks = s3tokenizer.make_non_pad_mask(lengths)
|
| 294 |
+
masks = [[1, 1, 1, 1, 1],
|
| 295 |
+
[1, 1, 1, 0, 0],
|
| 296 |
+
[1, 1, 0, 0, 0]]
|
| 297 |
+
"""
|
| 298 |
+
batch_size = lengths.size(0)
|
| 299 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
| 300 |
+
seq_range = torch.arange(0,
|
| 301 |
+
max_len,
|
| 302 |
+
dtype=torch.int64,
|
| 303 |
+
device=lengths.device)
|
| 304 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
| 305 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
| 306 |
+
mask = seq_range_expand >= seq_length_expand
|
| 307 |
+
return ~mask
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 311 |
+
"""Convert bool-tensor to float-tensor for flash attention.
|
| 312 |
+
|
| 313 |
+
Parameters
|
| 314 |
+
----------
|
| 315 |
+
lengths (torch.Tensor): Batch of lengths (B, ?).
|
| 316 |
+
|
| 317 |
+
Returns:
|
| 318 |
+
-------
|
| 319 |
+
torch.Tensor: Mask tensor containing indices of padded part (B, ?).
|
| 320 |
+
|
| 321 |
+
Examples:
|
| 322 |
+
>>> import torch
|
| 323 |
+
>>> import s3tokenizer
|
| 324 |
+
>>> lengths = torch.tensor([5, 3, 2])
|
| 325 |
+
>>> masks = s3tokenizer.make_non_pad_mask(lengths)
|
| 326 |
+
masks = [[1, 1, 1, 1, 1],
|
| 327 |
+
[1, 1, 1, 0, 0],
|
| 328 |
+
[1, 1, 0, 0, 0]]
|
| 329 |
+
>>> new_masks = s3tokenizer.mask_to_bias(masks, torch.float32)
|
| 330 |
+
new_masks =
|
| 331 |
+
[[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
|
| 332 |
+
[-0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10],
|
| 333 |
+
[-0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10, -1.0000e+10]]
|
| 334 |
+
"""
|
| 335 |
+
assert mask.dtype == torch.bool
|
| 336 |
+
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
| 337 |
+
mask = mask.to(dtype)
|
| 338 |
+
|
| 339 |
+
# attention mask bias
|
| 340 |
+
# NOTE(Mddct): torch.finfo jit issues
|
| 341 |
+
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
| 342 |
+
mask = (1.0 - mask) * -1.0e+10
|
| 343 |
+
return mask
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def padding(data: List[torch.Tensor]):
|
| 347 |
+
""" Padding the data into batch data
|
| 348 |
+
|
| 349 |
+
Parameters
|
| 350 |
+
----------
|
| 351 |
+
data: List[Tensor], shape of Tensor (128, T)
|
| 352 |
+
|
| 353 |
+
Returns:
|
| 354 |
+
-------
|
| 355 |
+
feats [B, 128, T_max], feats lengths [B]
|
| 356 |
+
"""
|
| 357 |
+
sample = data
|
| 358 |
+
assert isinstance(sample, list)
|
| 359 |
+
feats_lengths = torch.tensor([s.size(1) for s in sample],
|
| 360 |
+
dtype=torch.int32)
|
| 361 |
+
feats = [s.t() for s in sample]
|
| 362 |
+
padded_feats = pad_sequence(feats, batch_first=True, padding_value=0)
|
| 363 |
+
|
| 364 |
+
return padded_feats.transpose(1, 2), feats_lengths
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def merge_tokenized_segments(tokenized_segments, overlap, token_rate):
|
| 368 |
+
"""
|
| 369 |
+
Merges tokenized outputs by keeping the middle and dropping half of the overlapped tokens.
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
- tokenized_segments (List[List[int]]): List of tokenized sequences.
|
| 373 |
+
- overlap (int): Overlapping duration in seconds (default: 4s).
|
| 374 |
+
- token_rate (int): Number of tokens per second.
|
| 375 |
+
|
| 376 |
+
Returns:
|
| 377 |
+
- List[int]: A single merged token sequence.
|
| 378 |
+
"""
|
| 379 |
+
merged_tokens = []
|
| 380 |
+
overlap_tokens = (
|
| 381 |
+
overlap //
|
| 382 |
+
2) * token_rate # Tokens corresponding to half of the overlap duration
|
| 383 |
+
|
| 384 |
+
for i, tokens in enumerate(tokenized_segments):
|
| 385 |
+
l = 0 if i == 0 else overlap_tokens
|
| 386 |
+
r = -overlap_tokens if i != len(tokenized_segments) - 1 else len(tokens)
|
| 387 |
+
# Keep only the middle part (drop overlap / 2 from both sides)
|
| 388 |
+
merged_tokens.extend(tokens[l:r])
|
| 389 |
+
|
| 390 |
+
return merged_tokens
|
speech/tools/S3Tokenizer/setup.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
from setuptools import find_packages, setup
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def parse_requirements(filename):
|
| 7 |
+
"""Load requirements from a pip requirements file."""
|
| 8 |
+
with open(filename, 'r') as file:
|
| 9 |
+
lines = (line.strip() for line in file)
|
| 10 |
+
return [line for line in lines if line and not line.startswith('#')]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
setup(
|
| 14 |
+
name="s3tokenizer",
|
| 15 |
+
version="0.2.0",
|
| 16 |
+
description=\
|
| 17 |
+
"Reverse Engineering of Supervised Semantic Speech Tokenizer (S3Tokenizer) proposed in CosyVoice", # noqa
|
| 18 |
+
long_description=open("README.md", encoding="utf-8").read(),
|
| 19 |
+
long_description_content_type="text/markdown",
|
| 20 |
+
python_requires=">=3.8",
|
| 21 |
+
author="xingchensong",
|
| 22 |
+
url="https://github.com/xingchensong/S3Tokenizer",
|
| 23 |
+
license="Apache2.0",
|
| 24 |
+
packages=find_packages(),
|
| 25 |
+
install_requires=parse_requirements(
|
| 26 |
+
Path(__file__).with_name("requirements.txt")),
|
| 27 |
+
entry_points={
|
| 28 |
+
"console_scripts": ["s3tokenizer=s3tokenizer.cli:main"],
|
| 29 |
+
},
|
| 30 |
+
include_package_data=True,
|
| 31 |
+
extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
|
| 32 |
+
classifiers=[
|
| 33 |
+
"Programming Language :: Python :: 3",
|
| 34 |
+
"Operating System :: OS Independent",
|
| 35 |
+
"Topic :: Scientific/Engineering",
|
| 36 |
+
],
|
| 37 |
+
)
|
speech/tools/S3Tokenizer/test/test_batch_efficiency.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Batch processing efficiency test
|
| 4 |
+
Test the efficiency improvement of new batch processing functionality for mixed long and short audio
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
import torch
|
| 9 |
+
import pytest
|
| 10 |
+
import s3tokenizer
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def create_test_audio(duration_seconds=20, sample_rate=16000):
|
| 14 |
+
"""Create test audio"""
|
| 15 |
+
length = int(duration_seconds * sample_rate)
|
| 16 |
+
# Create meaningful audio signal (sine wave mixture)
|
| 17 |
+
t = torch.linspace(0, duration_seconds, length)
|
| 18 |
+
audio = 0.5 * torch.sin(2 * torch.pi * 440 * t) # 440Hz fundamental
|
| 19 |
+
audio += 0.3 * torch.sin(2 * torch.pi * 880 * t) # 880Hz second harmonic
|
| 20 |
+
audio += 0.1 * torch.randn(length) # Add some noise
|
| 21 |
+
return audio
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@pytest.fixture
|
| 25 |
+
def test_audios():
|
| 26 |
+
"""Create test audio dataset"""
|
| 27 |
+
return [
|
| 28 |
+
create_test_audio(10), # Short audio
|
| 29 |
+
create_test_audio(20), # Medium audio
|
| 30 |
+
create_test_audio(40), # Long audio
|
| 31 |
+
create_test_audio(60), # Long audio
|
| 32 |
+
create_test_audio(15), # Short audio
|
| 33 |
+
create_test_audio(35), # Long audio
|
| 34 |
+
create_test_audio(25), # Medium audio
|
| 35 |
+
create_test_audio(50), # Long audio
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@pytest.fixture
|
| 40 |
+
def long_audios():
|
| 41 |
+
"""Create long audio dataset"""
|
| 42 |
+
return [
|
| 43 |
+
create_test_audio(45.5),
|
| 44 |
+
create_test_audio(60),
|
| 45 |
+
create_test_audio(91.2),
|
| 46 |
+
create_test_audio(120),
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@pytest.mark.parametrize("model_name", [
|
| 51 |
+
"speech_tokenizer_v1_25hz", "speech_tokenizer_v1",
|
| 52 |
+
"speech_tokenizer_v2_25hz"
|
| 53 |
+
])
|
| 54 |
+
def test_batch_efficiency(test_audios, model_name):
|
| 55 |
+
"""Test batch processing efficiency for different models"""
|
| 56 |
+
print(f"\n=== Batch Processing Efficiency Test for {model_name} ===")
|
| 57 |
+
|
| 58 |
+
# Load model
|
| 59 |
+
model = s3tokenizer.load_model(model_name)
|
| 60 |
+
model.eval()
|
| 61 |
+
|
| 62 |
+
# Method 1: Individual processing
|
| 63 |
+
print(f"\n--- Method 1: Individual Processing ({model_name}) ---")
|
| 64 |
+
start_time = time.time()
|
| 65 |
+
individual_results = []
|
| 66 |
+
|
| 67 |
+
for i, audio in enumerate(test_audios):
|
| 68 |
+
mel = s3tokenizer.log_mel_spectrogram(audio)
|
| 69 |
+
mels = mel.unsqueeze(0)
|
| 70 |
+
mels_lens = torch.tensor([mel.size(1)])
|
| 71 |
+
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
codes, codes_lens = model.quantize(mels, mels_lens)
|
| 74 |
+
|
| 75 |
+
final_codes = codes[0, :codes_lens[0].item()].tolist()
|
| 76 |
+
individual_results.append(final_codes)
|
| 77 |
+
|
| 78 |
+
duration = audio.shape[0] / 16000
|
| 79 |
+
processing_type = "Long audio" if duration > 30 else "Short audio"
|
| 80 |
+
print(
|
| 81 |
+
f"Audio {i+1}: {duration:.1f}s, {len(final_codes)} tokens, {processing_type}"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
individual_time = time.time() - start_time
|
| 85 |
+
print(f"Individual processing total time: {individual_time:.2f}s")
|
| 86 |
+
|
| 87 |
+
# Method 2: Batch processing
|
| 88 |
+
print(f"\n--- Method 2: Batch Processing ({model_name}) ---")
|
| 89 |
+
start_time = time.time()
|
| 90 |
+
|
| 91 |
+
# Prepare batch input
|
| 92 |
+
mels = []
|
| 93 |
+
for audio in test_audios:
|
| 94 |
+
mel = s3tokenizer.log_mel_spectrogram(audio)
|
| 95 |
+
mels.append(mel)
|
| 96 |
+
|
| 97 |
+
# Use padding to handle different lengths of mel
|
| 98 |
+
mels, mels_lens = s3tokenizer.padding(mels)
|
| 99 |
+
|
| 100 |
+
# Batch processing
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
codes, codes_lens = model.quantize(mels, mels_lens)
|
| 103 |
+
|
| 104 |
+
# Process results
|
| 105 |
+
batch_results = []
|
| 106 |
+
for i in range(len(test_audios)):
|
| 107 |
+
final_codes = codes[i, :codes_lens[i].item()].tolist()
|
| 108 |
+
batch_results.append(final_codes)
|
| 109 |
+
|
| 110 |
+
duration = test_audios[i].shape[0] / 16000
|
| 111 |
+
processing_type = "Long audio" if duration > 30 else "Short audio"
|
| 112 |
+
print(
|
| 113 |
+
f"Audio {i+1}: {duration:.1f}s, {len(final_codes)} tokens, {processing_type}"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
batch_time = time.time() - start_time
|
| 117 |
+
print(f"Batch processing total time: {batch_time:.2f}s")
|
| 118 |
+
|
| 119 |
+
# Verify result consistency
|
| 120 |
+
print(f"\n--- Result Verification for {model_name} ---")
|
| 121 |
+
all_ok = True
|
| 122 |
+
for i in range(len(test_audios)):
|
| 123 |
+
individual_tokens = individual_results[i]
|
| 124 |
+
batch_tokens = batch_results[i]
|
| 125 |
+
|
| 126 |
+
# Calculate miss rate
|
| 127 |
+
if len(individual_tokens) != len(batch_tokens):
|
| 128 |
+
print(
|
| 129 |
+
f"❌ Audio {i+1} length mismatch: individual={len(individual_tokens)}, batch={len(batch_tokens)}"
|
| 130 |
+
)
|
| 131 |
+
all_ok = False
|
| 132 |
+
else:
|
| 133 |
+
mismatches = sum(1 for a, b in zip(individual_tokens, batch_tokens)
|
| 134 |
+
if a != b)
|
| 135 |
+
miss_rate = mismatches / len(individual_tokens) * 100 if len(
|
| 136 |
+
individual_tokens) > 0 else 0
|
| 137 |
+
|
| 138 |
+
if miss_rate < 0.2: # Less than 0.2% is considered OK
|
| 139 |
+
print(f"✅ Audio {i+1} miss rate: {miss_rate:.4f}% (OK)")
|
| 140 |
+
else:
|
| 141 |
+
print(f"❌ Audio {i+1} miss rate: {miss_rate:.4f}% (Too high)")
|
| 142 |
+
all_ok = False
|
| 143 |
+
|
| 144 |
+
# Efficiency improvement
|
| 145 |
+
speedup = individual_time / batch_time
|
| 146 |
+
print(f"\n--- Efficiency Improvement for {model_name} ---")
|
| 147 |
+
print(f"Batch processing speedup: {speedup:.2f}x")
|
| 148 |
+
if speedup > 1:
|
| 149 |
+
print("✅ Batch processing indeed improves efficiency!")
|
| 150 |
+
else:
|
| 151 |
+
print("⚠️ Batch processing doesn't significantly improve efficiency")
|
| 152 |
+
|
| 153 |
+
# Assertions for pytest
|
| 154 |
+
assert all_ok, f"Results don't match for model {model_name}"
|
| 155 |
+
assert len(individual_results) == len(
|
| 156 |
+
batch_results), "Number of results don't match"
|
| 157 |
+
assert all(
|
| 158 |
+
len(individual_results[i]) == len(batch_results[i])
|
| 159 |
+
for i in range(len(test_audios))), "Token counts don't match"
|
| 160 |
+
|
| 161 |
+
# Performance assertion - batch should be at least as fast as individual (allowing for some variance)
|
| 162 |
+
# assert batch_time <= individual_time * 1.1, f"Batch processing should not be significantly slower than individual processing for {model_name}"
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@pytest.mark.parametrize("model_name", [
|
| 166 |
+
"speech_tokenizer_v1_25hz", "speech_tokenizer_v1",
|
| 167 |
+
"speech_tokenizer_v2_25hz"
|
| 168 |
+
])
|
| 169 |
+
def test_pure_long_audio_batch(long_audios, model_name):
|
| 170 |
+
"""Test pure long audio batch processing for different models"""
|
| 171 |
+
print(f"\n=== Pure Long Audio Batch Processing Test for {model_name} ===")
|
| 172 |
+
|
| 173 |
+
model = s3tokenizer.load_model(model_name)
|
| 174 |
+
model.eval()
|
| 175 |
+
|
| 176 |
+
# Prepare batch input
|
| 177 |
+
mels = []
|
| 178 |
+
for audio in long_audios:
|
| 179 |
+
mel = s3tokenizer.log_mel_spectrogram(audio)
|
| 180 |
+
mels.append(mel)
|
| 181 |
+
|
| 182 |
+
mels, mels_lens = s3tokenizer.padding(mels)
|
| 183 |
+
|
| 184 |
+
# Batch process long audio
|
| 185 |
+
start_time = time.time()
|
| 186 |
+
with torch.no_grad():
|
| 187 |
+
codes, codes_lens = model.quantize(mels, mels_lens)
|
| 188 |
+
processing_time = time.time() - start_time
|
| 189 |
+
|
| 190 |
+
print(
|
| 191 |
+
f"Batch processing {len(long_audios)} long audios took: {processing_time:.2f}s"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
results = []
|
| 195 |
+
for i in range(len(long_audios)):
|
| 196 |
+
duration = long_audios[i].shape[0] / 16000
|
| 197 |
+
tokens_count = codes_lens[i].item()
|
| 198 |
+
results.append((duration, tokens_count))
|
| 199 |
+
print(f"Long audio {i+1}: {duration:.1f}s → {tokens_count} tokens")
|
| 200 |
+
|
| 201 |
+
print(
|
| 202 |
+
f"✅ Pure long audio batch processing test completed for {model_name}")
|
| 203 |
+
|
| 204 |
+
# Assertions for pytest
|
| 205 |
+
assert codes is not None, f"Codes should not be None for model {model_name}"
|
| 206 |
+
assert codes_lens is not None, f"Codes lengths should not be None for model {model_name}"
|
| 207 |
+
assert len(results) == len(
|
| 208 |
+
long_audios), "Number of results should match number of input audios"
|
| 209 |
+
assert all(
|
| 210 |
+
tokens_count > 0
|
| 211 |
+
for _, tokens_count in results), "All audio should produce tokens"
|
| 212 |
+
assert processing_time > 0, "Processing time should be positive"
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
@pytest.mark.parametrize("model_name", [
|
| 216 |
+
"speech_tokenizer_v1_25hz", "speech_tokenizer_v1",
|
| 217 |
+
"speech_tokenizer_v2_25hz"
|
| 218 |
+
])
|
| 219 |
+
def test_model_loading(model_name):
|
| 220 |
+
"""Test that all models can be loaded successfully"""
|
| 221 |
+
print(f"\n=== Model Loading Test for {model_name} ===")
|
| 222 |
+
|
| 223 |
+
model = s3tokenizer.load_model(model_name)
|
| 224 |
+
assert model is not None, f"Model {model_name} should load successfully"
|
| 225 |
+
|
| 226 |
+
# Test model can be set to eval mode
|
| 227 |
+
model.eval()
|
| 228 |
+
print(f"✅ Model {model_name} loaded and set to eval mode successfully")
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
@pytest.mark.parametrize("model_name", [
|
| 232 |
+
"speech_tokenizer_v1_25hz", "speech_tokenizer_v1",
|
| 233 |
+
"speech_tokenizer_v2_25hz"
|
| 234 |
+
])
|
| 235 |
+
def test_single_audio_processing(model_name):
|
| 236 |
+
"""Test single audio processing for different models"""
|
| 237 |
+
print(f"\n=== Single Audio Processing Test for {model_name} ===")
|
| 238 |
+
|
| 239 |
+
# Create a single test audio
|
| 240 |
+
audio = create_test_audio(30) # 30 second audio
|
| 241 |
+
|
| 242 |
+
model = s3tokenizer.load_model(model_name)
|
| 243 |
+
model.eval()
|
| 244 |
+
|
| 245 |
+
# Process the audio
|
| 246 |
+
mel = s3tokenizer.log_mel_spectrogram(audio)
|
| 247 |
+
mels = mel.unsqueeze(0)
|
| 248 |
+
mels_lens = torch.tensor([mel.size(1)])
|
| 249 |
+
|
| 250 |
+
with torch.no_grad():
|
| 251 |
+
codes, codes_lens = model.quantize(mels, mels_lens)
|
| 252 |
+
|
| 253 |
+
final_codes = codes[0, :codes_lens[0].item()].tolist()
|
| 254 |
+
|
| 255 |
+
# Assertions
|
| 256 |
+
assert codes is not None, f"Codes should not be None for model {model_name}"
|
| 257 |
+
assert codes_lens is not None, f"Codes lengths should not be None for model {model_name}"
|
| 258 |
+
assert len(
|
| 259 |
+
final_codes) > 0, f"Should produce tokens for model {model_name}"
|
| 260 |
+
assert codes_lens[0].item() == len(
|
| 261 |
+
final_codes
|
| 262 |
+
), f"Codes length should match actual codes for model {model_name}"
|
| 263 |
+
|
| 264 |
+
duration = audio.shape[0] / 16000
|
| 265 |
+
print(
|
| 266 |
+
f"✅ Single audio processing test completed for {model_name}: {duration:.1f}s → {len(final_codes)} tokens"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
if __name__ == "__main__":
|
| 271 |
+
# Run tests with pytest
|
| 272 |
+
pytest.main([__file__, "-v"])
|
speech/tools/S3Tokenizer/test/test_onnx.py
ADDED
|
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# Copyright [2024-09-27] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import onnxruntime
|
| 11 |
+
import pytest
|
| 12 |
+
import s3tokenizer
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def create_test_audio(duration_seconds: float = 20,
|
| 17 |
+
sample_rate: int = 16000) -> torch.Tensor:
|
| 18 |
+
"""Create synthetic test audio"""
|
| 19 |
+
length = int(duration_seconds * sample_rate)
|
| 20 |
+
# Create sinusoidal mixed audio
|
| 21 |
+
t = torch.linspace(0, duration_seconds, length)
|
| 22 |
+
audio = 0.5 * torch.sin(2 * torch.pi * 440 * t) # 440Hz fundamental
|
| 23 |
+
audio += 0.3 * torch.sin(2 * torch.pi * 880 * t) # 880Hz second harmonic
|
| 24 |
+
audio += 0.1 * torch.randn(length) # Add noise
|
| 25 |
+
return audio
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@pytest.fixture
|
| 29 |
+
def test_audio_suite():
|
| 30 |
+
"""Create a suite of test audios with different lengths"""
|
| 31 |
+
return {
|
| 32 |
+
"short_audio_1": create_test_audio(5.0), # 5 seconds
|
| 33 |
+
"short_audio_2": create_test_audio(15.0), # 15 seconds
|
| 34 |
+
"medium_audio": create_test_audio(25.0), # 25 seconds
|
| 35 |
+
"medium_audio_2": create_test_audio(30.0), # 30 seconds
|
| 36 |
+
"long_audio": create_test_audio(
|
| 37 |
+
35.0), # 35 seconds - for torch and onnx, 2 segments with padding
|
| 38 |
+
"long_audio_2": create_test_audio(
|
| 39 |
+
56.0
|
| 40 |
+
), # 56 seconds - for torch and onnx, exactly 2 segments without padding
|
| 41 |
+
"very_long_audio": create_test_audio(
|
| 42 |
+
60.0), # 60 seconds - for torch and onnx, 3 segments with padding
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def onnx_inference_short_audio(model_name: str, mel: torch.Tensor,
|
| 47 |
+
mel_len: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
"""
|
| 49 |
+
ONNX inference for short audio (<=30s)
|
| 50 |
+
"""
|
| 51 |
+
# Load ONNX model
|
| 52 |
+
default = os.path.join(os.path.expanduser("~"), ".cache")
|
| 53 |
+
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default),
|
| 54 |
+
"s3tokenizer")
|
| 55 |
+
|
| 56 |
+
option = onnxruntime.SessionOptions()
|
| 57 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 58 |
+
option.intra_op_num_threads = 1
|
| 59 |
+
providers = ["CPUExecutionProvider"]
|
| 60 |
+
|
| 61 |
+
ort_session = onnxruntime.InferenceSession(
|
| 62 |
+
f"{download_root}/{model_name}.onnx",
|
| 63 |
+
sess_options=option,
|
| 64 |
+
providers=providers)
|
| 65 |
+
|
| 66 |
+
# Direct inference for short audio
|
| 67 |
+
onnx_output = ort_session.run(
|
| 68 |
+
None, {
|
| 69 |
+
ort_session.get_inputs()[0].name:
|
| 70 |
+
mel[:, :mel_len.item()].unsqueeze(0).detach().cpu().numpy(),
|
| 71 |
+
ort_session.get_inputs()[1].name:
|
| 72 |
+
np.array([mel_len.item()], dtype=np.int32)
|
| 73 |
+
})[0]
|
| 74 |
+
|
| 75 |
+
# Convert to numpy array to fix linter issues
|
| 76 |
+
onnx_output = np.array(onnx_output)
|
| 77 |
+
|
| 78 |
+
# Handle different output formats
|
| 79 |
+
if onnx_output.ndim == 2:
|
| 80 |
+
onnx_output = onnx_output[0, :]
|
| 81 |
+
elif onnx_output.ndim == 3:
|
| 82 |
+
onnx_output = onnx_output[0, 0, :]
|
| 83 |
+
|
| 84 |
+
return torch.tensor(onnx_output, dtype=torch.long)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def onnx_inference_long_audio(model_name: str, mel: torch.Tensor,
|
| 88 |
+
mel_len: torch.Tensor) -> torch.Tensor:
|
| 89 |
+
"""
|
| 90 |
+
ONNX inference for long audio (>30s) using sliding window approach
|
| 91 |
+
Based on _quantize_mixed_batch logic
|
| 92 |
+
|
| 93 |
+
Note: This may fail due to ONNX model limitations with dynamic lengths
|
| 94 |
+
"""
|
| 95 |
+
# Load ONNX model
|
| 96 |
+
default = os.path.join(os.path.expanduser("~"), ".cache")
|
| 97 |
+
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default),
|
| 98 |
+
"s3tokenizer")
|
| 99 |
+
|
| 100 |
+
option = onnxruntime.SessionOptions()
|
| 101 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 102 |
+
option.intra_op_num_threads = 1
|
| 103 |
+
providers = ["CPUExecutionProvider"]
|
| 104 |
+
|
| 105 |
+
ort_session = onnxruntime.InferenceSession(
|
| 106 |
+
f"{download_root}/{model_name}.onnx",
|
| 107 |
+
sess_options=option,
|
| 108 |
+
providers=providers)
|
| 109 |
+
|
| 110 |
+
# Parameters for sliding window (same as _quantize_mixed_batch)
|
| 111 |
+
sample_rate = 16000
|
| 112 |
+
hop_length = 160
|
| 113 |
+
window_size = 30 # seconds
|
| 114 |
+
overlap = 4 # seconds
|
| 115 |
+
|
| 116 |
+
# Calculate frame-based parameters
|
| 117 |
+
frames_per_window = window_size * sample_rate // hop_length # 3000 frames
|
| 118 |
+
frames_per_overlap = overlap * sample_rate // hop_length # 400 frames
|
| 119 |
+
frames_per_stride = frames_per_window - frames_per_overlap # 2600 frames
|
| 120 |
+
|
| 121 |
+
# Split into segments
|
| 122 |
+
segments = []
|
| 123 |
+
segments_len = []
|
| 124 |
+
start = 0
|
| 125 |
+
|
| 126 |
+
while start < mel_len.item():
|
| 127 |
+
end = min(start + frames_per_window, mel_len.item())
|
| 128 |
+
segment = mel[:, start:end]
|
| 129 |
+
|
| 130 |
+
if segment.size(1) < frames_per_window:
|
| 131 |
+
break
|
| 132 |
+
|
| 133 |
+
seg_len = segment.size(1)
|
| 134 |
+
segments.append(segment)
|
| 135 |
+
segments_len.append(seg_len)
|
| 136 |
+
|
| 137 |
+
start += frames_per_stride
|
| 138 |
+
|
| 139 |
+
if not segments:
|
| 140 |
+
raise ValueError("No valid segments for ONNX processing")
|
| 141 |
+
|
| 142 |
+
# Process each segment with ONNX
|
| 143 |
+
segment_results = []
|
| 144 |
+
for i, (segment, seg_len) in enumerate(zip(segments, segments_len)):
|
| 145 |
+
try:
|
| 146 |
+
onnx_output = ort_session.run(
|
| 147 |
+
None, {
|
| 148 |
+
ort_session.get_inputs()[0].name:
|
| 149 |
+
segment.unsqueeze(0).detach().cpu().numpy(),
|
| 150 |
+
ort_session.get_inputs()[1].name:
|
| 151 |
+
np.array([seg_len], dtype=np.int32)
|
| 152 |
+
})[0]
|
| 153 |
+
|
| 154 |
+
# Convert to numpy array to fix linter issues
|
| 155 |
+
onnx_output = np.array(onnx_output)
|
| 156 |
+
|
| 157 |
+
# Handle different output formats
|
| 158 |
+
if onnx_output.ndim == 2:
|
| 159 |
+
segment_codes = onnx_output[0, :].tolist()
|
| 160 |
+
elif onnx_output.ndim == 3:
|
| 161 |
+
segment_codes = onnx_output[0, 0, :].tolist()
|
| 162 |
+
else:
|
| 163 |
+
segment_codes = onnx_output.tolist()
|
| 164 |
+
|
| 165 |
+
segment_results.append(segment_codes)
|
| 166 |
+
|
| 167 |
+
except Exception as e:
|
| 168 |
+
print(f" ONNX error on segment {i+1}: {str(e)[:100]}...")
|
| 169 |
+
raise Exception(
|
| 170 |
+
f"ONNX inference failed on segment {i+1}: {str(e)}")
|
| 171 |
+
|
| 172 |
+
if not segment_results:
|
| 173 |
+
raise ValueError("All ONNX segments failed to process")
|
| 174 |
+
|
| 175 |
+
# Merge segments using the same logic as _quantize_mixed_batch
|
| 176 |
+
# Determine token rate based on model name
|
| 177 |
+
if model_name == "speech_tokenizer_v1":
|
| 178 |
+
token_rate = 50
|
| 179 |
+
else:
|
| 180 |
+
token_rate = 25
|
| 181 |
+
|
| 182 |
+
merged_codes = s3tokenizer.merge_tokenized_segments(
|
| 183 |
+
segment_results, overlap=overlap, token_rate=token_rate
|
| 184 |
+
)[:-overlap * token_rate] # NOTE(xcsong): drop the last overlap part.
|
| 185 |
+
return torch.tensor(merged_codes, dtype=torch.long)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def onnx_inference_with_long_audio_support(
|
| 189 |
+
model_name: str, mel: torch.Tensor,
|
| 190 |
+
mel_len: torch.Tensor) -> torch.Tensor:
|
| 191 |
+
"""
|
| 192 |
+
ONNX inference with automatic long audio support
|
| 193 |
+
"""
|
| 194 |
+
max_frames = 3000 # 30s * 16000 / 160 = 3000 frames
|
| 195 |
+
|
| 196 |
+
if mel_len.item() <= max_frames:
|
| 197 |
+
# Short audio - use direct inference
|
| 198 |
+
return onnx_inference_short_audio(model_name, mel, mel_len)
|
| 199 |
+
else:
|
| 200 |
+
# Long audio - use sliding window approach
|
| 201 |
+
return onnx_inference_long_audio(model_name, mel, mel_len)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def compare_torch_vs_onnx_single(model_name: str, audio: torch.Tensor,
|
| 205 |
+
audio_name: str) -> Dict[str, Any]:
|
| 206 |
+
"""Test single audio with both torch and onnx versions"""
|
| 207 |
+
duration = audio.shape[0] / 16000
|
| 208 |
+
|
| 209 |
+
# Load torch model
|
| 210 |
+
tokenizer = s3tokenizer.load_model(model_name)
|
| 211 |
+
tokenizer.eval()
|
| 212 |
+
|
| 213 |
+
# Prepare input
|
| 214 |
+
mel = s3tokenizer.log_mel_spectrogram(audio)
|
| 215 |
+
mels = mel.unsqueeze(0)
|
| 216 |
+
mels_lens = torch.tensor([mel.size(1)])
|
| 217 |
+
|
| 218 |
+
# Test torch version
|
| 219 |
+
start_time = time.time()
|
| 220 |
+
with torch.no_grad():
|
| 221 |
+
torch_codes, torch_codes_lens = tokenizer.quantize(mels, mels_lens)
|
| 222 |
+
torch_time = time.time() - start_time
|
| 223 |
+
|
| 224 |
+
torch_result = torch_codes[0, :torch_codes_lens[0].item()]
|
| 225 |
+
|
| 226 |
+
# Test onnx version with long audio support
|
| 227 |
+
try:
|
| 228 |
+
start_time = time.time()
|
| 229 |
+
onnx_result = onnx_inference_with_long_audio_support(
|
| 230 |
+
model_name, mel, mels_lens[0])
|
| 231 |
+
onnx_time = time.time() - start_time
|
| 232 |
+
|
| 233 |
+
# Compare results
|
| 234 |
+
min_len = min(len(torch_result), len(onnx_result))
|
| 235 |
+
torch_truncated = torch_result[:min_len]
|
| 236 |
+
onnx_truncated = onnx_result[:min_len]
|
| 237 |
+
|
| 238 |
+
are_equal = torch.equal(torch_truncated, onnx_truncated)
|
| 239 |
+
miss_rate = 0.0
|
| 240 |
+
|
| 241 |
+
if not are_equal:
|
| 242 |
+
miss_num = torch.sum(~(torch_truncated == onnx_truncated))
|
| 243 |
+
miss_rate = miss_num.item() * 100.0 / min_len
|
| 244 |
+
|
| 245 |
+
return {
|
| 246 |
+
"audio_name": audio_name,
|
| 247 |
+
"model_name": model_name,
|
| 248 |
+
"duration": duration,
|
| 249 |
+
"torch_tokens": torch_truncated,
|
| 250 |
+
"onnx_tokens": onnx_truncated,
|
| 251 |
+
"torch_time": torch_time,
|
| 252 |
+
"onnx_time": onnx_time,
|
| 253 |
+
"results_match": are_equal,
|
| 254 |
+
"miss_rate": miss_rate
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
except Exception as e:
|
| 258 |
+
return {
|
| 259 |
+
"audio_name": audio_name,
|
| 260 |
+
"model_name": model_name,
|
| 261 |
+
"duration": duration,
|
| 262 |
+
"torch_tokens": torch_result,
|
| 263 |
+
"onnx_tokens": [],
|
| 264 |
+
"torch_time": torch_time,
|
| 265 |
+
"onnx_time": 0.0,
|
| 266 |
+
"results_match": False,
|
| 267 |
+
"miss_rate": 100.0,
|
| 268 |
+
"error": str(e)
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
@pytest.mark.parametrize("model_name", [
|
| 273 |
+
"speech_tokenizer_v1", "speech_tokenizer_v1_25hz",
|
| 274 |
+
"speech_tokenizer_v2_25hz"
|
| 275 |
+
])
|
| 276 |
+
def test_torch_vs_onnx_short_audio(model_name, test_audio_suite):
|
| 277 |
+
"""Test torch vs onnx for short audio (<=30s)"""
|
| 278 |
+
print(f"\n=== Testing {model_name} on Short Audio ===")
|
| 279 |
+
|
| 280 |
+
short_audios = {
|
| 281 |
+
k: v
|
| 282 |
+
for k, v in test_audio_suite.items() if v.shape[0] / 16000 <= 30
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
results = []
|
| 286 |
+
for audio_name, audio in short_audios.items():
|
| 287 |
+
result = compare_torch_vs_onnx_single(model_name, audio, audio_name)
|
| 288 |
+
results.append(result)
|
| 289 |
+
|
| 290 |
+
duration = result["duration"]
|
| 291 |
+
torch_tokens = result["torch_tokens"]
|
| 292 |
+
onnx_tokens = result["onnx_tokens"]
|
| 293 |
+
match_status = "✅" if result["results_match"] else "❌"
|
| 294 |
+
|
| 295 |
+
print(
|
| 296 |
+
f"{match_status} {audio_name}: {duration:.1f}s → torch:{len(torch_tokens)}, onnx:{len(onnx_tokens)}"
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
if not result["results_match"] and "error" not in result:
|
| 300 |
+
print(f" Miss rate: {result['miss_rate']:.2f}%")
|
| 301 |
+
print(
|
| 302 |
+
f" torch_tokens:\n{torch_tokens}\nonnx_tokens:\n{onnx_tokens}"
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# Assertions
|
| 306 |
+
successful_tests = [r for r in results if "error" not in r]
|
| 307 |
+
assert len(successful_tests) == len(
|
| 308 |
+
short_audios
|
| 309 |
+
), f"successful tests ({len(successful_tests)}) for {model_name} should be equal to number of short audios ({len(short_audios)})" # noqa
|
| 310 |
+
|
| 311 |
+
# For short audio, we expect reasonable match rate
|
| 312 |
+
for r in results:
|
| 313 |
+
assert r[
|
| 314 |
+
'miss_rate'] < 0.5, f"Miss rate too high for {model_name}: {r['miss_rate']:.2f}%"
|
| 315 |
+
|
| 316 |
+
print(f"\n{model_name} Short Audio Summary:")
|
| 317 |
+
print(f" Successful tests: {len(successful_tests)}/{len(results)}")
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@pytest.mark.parametrize("model_name", [
|
| 321 |
+
"speech_tokenizer_v1", "speech_tokenizer_v1_25hz",
|
| 322 |
+
"speech_tokenizer_v2_25hz"
|
| 323 |
+
])
|
| 324 |
+
def test_torch_vs_onnx_long_audio(model_name, test_audio_suite):
|
| 325 |
+
"""Test torch vs onnx for long audio (>30s) with ONNX sliding window implementation"""
|
| 326 |
+
print(
|
| 327 |
+
f"\n=== Testing {model_name} on Long Audio (ONNX Sliding Window) ===")
|
| 328 |
+
|
| 329 |
+
long_audios = {
|
| 330 |
+
k: v
|
| 331 |
+
for k, v in test_audio_suite.items() if v.shape[0] / 16000 > 30
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
results = []
|
| 335 |
+
for audio_name, audio in long_audios.items():
|
| 336 |
+
result = compare_torch_vs_onnx_single(model_name, audio, audio_name)
|
| 337 |
+
results.append(result)
|
| 338 |
+
|
| 339 |
+
duration = result["duration"]
|
| 340 |
+
torch_tokens = result["torch_tokens"]
|
| 341 |
+
onnx_tokens = result["onnx_tokens"]
|
| 342 |
+
match_status = "✅" if result["results_match"] else "❌"
|
| 343 |
+
|
| 344 |
+
print(
|
| 345 |
+
f"{match_status} {audio_name}: {duration:.1f}s → torch:{len(torch_tokens)}, onnx:{len(onnx_tokens)}"
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
if not result["results_match"] and "error" not in result:
|
| 349 |
+
print(f" Miss rate: {result['miss_rate']:.2f}%")
|
| 350 |
+
print(
|
| 351 |
+
f" torch_tokens:\n{torch_tokens}\nonnx_tokens:\n{onnx_tokens}"
|
| 352 |
+
)
|
| 353 |
+
elif "error" in result:
|
| 354 |
+
print(f" Error: {result['error'][:100]}...")
|
| 355 |
+
|
| 356 |
+
# For long audio with ONNX, we document the current limitations
|
| 357 |
+
successful_tests = [r for r in results if "error" not in r]
|
| 358 |
+
assert len(successful_tests) == len(
|
| 359 |
+
long_audios
|
| 360 |
+
), f"successful tests ({len(successful_tests)}) for {model_name} should be equal to number of long audios ({len(long_audios)})" # noqa
|
| 361 |
+
|
| 362 |
+
print(f"\n{model_name} Long Audio Results:")
|
| 363 |
+
print(f" Total tests: {len(results)}")
|
| 364 |
+
print(f" Successful ONNX tests: {len(successful_tests)}")
|
| 365 |
+
|
| 366 |
+
for r in results:
|
| 367 |
+
# NOTE(xcsong): 0.5% is a reasonable miss rate for long audio, since we drop the last overlap part.
|
| 368 |
+
assert r[
|
| 369 |
+
'miss_rate'] < 0.5, f"Miss rate too high for {model_name}: {r['miss_rate']}%"
|
| 370 |
+
|
| 371 |
+
# The main requirement is that Torch always works
|
| 372 |
+
print(" ✅ Torch processing works reliably for all long audio")
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
if __name__ == "__main__":
|
| 376 |
+
# Run tests with pytest
|
| 377 |
+
pytest.main([__file__, "-v"])
|