primepake commited on
Commit
f768eb3
·
1 Parent(s): 7940474

add s3 tokenizer

Browse files
speech/tools/S3Tokenizer/.flake8 ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ # Suggested config from pytorch that we can adapt
3
+ select = B,C,E,F,N,P,T4,W,B9,TOR0,TOR1,TOR2
4
+ max-line-length = 120
5
+ # C408 ignored because we like the dict keyword argument syntax
6
+ # E501 is not flexible enough, we're using B950 instead
7
+ # N812 ignored because import torch.nn.functional as F is PyTorch convention
8
+ # N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
9
+ # E731 allow usage of assigning lambda expressions
10
+ # N803,N806 allow caps and mixed case in function params. This is to work with Triton kernel coding style.
11
+ ignore =
12
+ E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,N803,N806
13
+ # shebang has extra meaning in fbcode lints, so I think it's not worth trying
14
+ # to line this up with executable bit
15
+ EXE001,
16
+ # these ignores are from flake8-bugbear; please fix!
17
+ B007,B008,
18
+ optional-ascii-coding = True
19
+ exclude =
20
+ ./.git,
21
+ ./docs
22
+ ./build
23
+ ./scripts,
24
+ ./venv,
25
+ *.pyi
26
+ .pre-commit-config.yaml
27
+ *.md
28
+ .flake8
speech/tools/S3Tokenizer/.github/workflows/python-publish.yml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Release
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ jobs:
8
+ deploy:
9
+ runs-on: ubuntu-latest
10
+ steps:
11
+ - uses: actions/checkout@v3
12
+ - uses: actions-ecosystem/action-regex-match@v2
13
+ id: regex-match
14
+ with:
15
+ text: ${{ github.event.head_commit.message }}
16
+ regex: '^Release ([^ ]+)'
17
+ - name: Set up Python
18
+ uses: actions/setup-python@v4
19
+ with:
20
+ python-version: '3.8'
21
+ - name: Install dependencies
22
+ run: |
23
+ python -m pip install --upgrade pip
24
+ pip install build twine
25
+ - name: Release
26
+ if: ${{ steps.regex-match.outputs.match != '' }}
27
+ uses: softprops/action-gh-release@v1
28
+ with:
29
+ tag_name: v${{ steps.regex-match.outputs.group1 }}
30
+ - name: Build and publish
31
+ if: ${{ steps.regex-match.outputs.match != '' }}
32
+ env:
33
+ TWINE_USERNAME: __token__
34
+ TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
35
+ run: |
36
+ python -m build
37
+ twine upload dist/*
speech/tools/S3Tokenizer/.github/workflows/unit_test_cpu.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: CPU Unit Test
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+ pull_request:
7
+
8
+ concurrency:
9
+ group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
10
+ cancel-in-progress: true
11
+
12
+ jobs:
13
+ unit-test:
14
+ runs-on: ${{ matrix.os }}
15
+ strategy:
16
+ max-parallel: 20
17
+ matrix:
18
+ os: [ubuntu-22.04]
19
+ python-version: [3.10.16]
20
+ steps:
21
+ - name: Cache Python Packages
22
+ uses: actions/cache@v4
23
+ with:
24
+ path: ~/.cache/pip
25
+ key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
26
+ - name: Setup Python
27
+ uses: actions/setup-python@v4
28
+ with:
29
+ python-version: ${{ matrix.python-version }}
30
+ architecture: x64
31
+ - name: Fetch S3Tokenizer
32
+ uses: actions/checkout@v4
33
+ with:
34
+ fetch-depth: 0
35
+ ref: ${{ github.event.pull_request.head.ref || github.ref }}
36
+ - name: Install S3Tokenizer Dependencies
37
+ run: |
38
+ set -eux
39
+ sudo apt update && sudo apt install -y ffmpeg libsox-dev libsndfile1
40
+ pip install -e .
41
+ - name: Run Pytest
42
+ run: |
43
+ set -eux
44
+ pip install pytest onnxruntime
45
+ pytest --version
46
+ PYTHONPATH="${PYTHONPATH:-}:$(pwd)" pytest test/ -q
47
+ if [ $? != 0 ]; then exit 1; fi
speech/tools/S3Tokenizer/.gitignore ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
speech/tools/S3Tokenizer/.pre-commit-config.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.5.0
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ exclude: 's3tokenizer/assets/.*'
7
+ - repo: https://github.com/pre-commit/mirrors-yapf
8
+ rev: 'v0.32.0'
9
+ hooks:
10
+ - id: yapf
11
+ - repo: https://github.com/pycqa/flake8
12
+ rev: '3.8.2'
13
+ hooks:
14
+ - id: flake8
speech/tools/S3Tokenizer/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
speech/tools/S3Tokenizer/MANIFEST.in ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ include requirements.txt
2
+ include README.md
3
+ include LICENSE
4
+ include s3tokenizer/assets/*
speech/tools/S3Tokenizer/README.md ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reverse Engineering of S3Tokenizer
2
+
3
+ <div align="center">
4
+ <img src="https://arxiv.org/html/2407.04051v2/x1.png" alt="Description" width="35%" />
5
+ <p><em>Supervised Semantic Speech Tokenizer (S3Tokenizer)</em></p>
6
+ </div>
7
+
8
+ S3Tokenizer was initially introduced in CosyVoice [[Paper]](https://arxiv.org/abs/2407.04051v2) [[Repo]](https://github.com/FunAudioLLM/CosyVoice), it is a Supervised Semantic Speech Tokenizer based on the pre-trained SenseVoice-Large model, which enhances the semantic relationship of extracted tokens to textual and paralinguistic information, is robust to data noise, and reduces the reliance on clean data collection, thereby enabling the use of a broader range of data for model training.
9
+
10
+ However, as indicated in this [[issue]](https://github.com/FunAudioLLM/CosyVoice/issues/70), the authors have no intention to open-source the PyTorch implementation of the S3Tokenizer, and only plan to release an ONNX file. Additionally, users aiming to fine-tune CosyVoice must extract speech codes offline, with the batch size restricted to 1, a process that is notably time-consuming (refer to [[cosyvoice/tools/extract_speech_token.py]](https://github.com/FunAudioLLM/CosyVoice/blob/main/tools/extract_speech_token.py)).
11
+
12
+ This repository undertakes a reverse engineering of the S3Tokenizer, offering:
13
+ 1. A pure PyTorch implementation of S3Tokenizer (see [[model.py]](https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/model.py)), compatible with initializing weights from the released ONNX file (see [[utils.py::onnx2torch()]](https://github.com/xingchensong/S3Tokenizer/blob/main/s3tokenizer/utils.py)).
14
+ 2. High-throughput (distributed) batch inference, achieving a ~790x speedup compared to the original inference pipeline in [[cosyvoice/tools/extract_speech_token.py]](https://github.com/FunAudioLLM/CosyVoice/blob/main/tools/extract_speech_token.py).
15
+ 3. The capability to perform online speech code extraction during SpeechLLM training.
16
+
17
+ ## Latest News 🎉
18
+ - [2025/07/07] S3Tokenizer now has built-in **long audio processing** capabilities, requiring no additional operations from users!
19
+
20
+ ## Supported Models 🔥
21
+ - [x] Model: [S3Tokenizer V1 50hz](https://modelscope.cn/models/iic/CosyVoice-300M)
22
+ - [x] Model: [S3Tokenizer V1 25hz](https://modelscope.cn/models/iic/CosyVoice-300M-25Hz)
23
+ - [x] Model: [S3Tokenizer V2 25hz](https://modelscope.cn/models/iic/CosyVoice2-0.5B)
24
+
25
+
26
+ # Setup
27
+
28
+ ```sh
29
+ pip install s3tokenizer
30
+ ```
31
+
32
+ # Usage-1: Offline batch inference
33
+
34
+ ```py
35
+ import s3tokenizer
36
+
37
+ tokenizer = s3tokenizer.load_model("speech_tokenizer_v1").cuda() # or "speech_tokenizer_v1_25hz speech_tokenizer_v2_25hz"
38
+
39
+ mels = []
40
+ wav_paths = ["s3tokenizer/assets/BAC009S0764W0121.wav", "s3tokenizer/assets/BAC009S0764W0122.wav"]
41
+ for wav_path in wav_paths:
42
+ audio = s3tokenizer.load_audio(wav_path)
43
+ mels.append(s3tokenizer.log_mel_spectrogram(audio))
44
+ mels, mels_lens = s3tokenizer.padding(mels)
45
+ codes, codes_lens = tokenizer.quantize(mels.cuda(), mels_lens.cuda()) # Automatically handles long audio internally!
46
+
47
+ for i in range(len(wav_paths)):
48
+ print(codes[i, :codes_lens[i].item()])
49
+ ```
50
+
51
+ # Usage-2: Distributed offline batch inference via command-line tools
52
+
53
+ ## 2.1 CPU batch inference
54
+
55
+ ```sh
56
+ s3tokenizer --wav_scp xxx.scp \
57
+ --device "cpu" \
58
+ --output_dir "./" \
59
+ --batch_size 32 \
60
+ --model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz speech_tokenizer_v2_25hz"
61
+ ```
62
+
63
+
64
+
65
+ https://github.com/user-attachments/assets/d37d10fd-0e13-46a3-86b0-4cbec309086f
66
+
67
+
68
+
69
+ ## 2.2 (Multi) GPU batch inference (a.k.a Distributed inference)
70
+
71
+ ```sh
72
+ torchrun --nproc_per_node=8 --nnodes=1 \
73
+ --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
74
+ `which s3tokenizer` --wav_scp xxx.scp \
75
+ --device "cuda" \
76
+ --output_dir "./" \
77
+ --batch_size 32 \
78
+ --model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz speech_tokenizer_v2_25hz"
79
+ ```
80
+
81
+
82
+
83
+ https://github.com/user-attachments/assets/79a3fb11-7199-4ee2-8a35-9682a3b4d94a
84
+
85
+
86
+
87
+ ## 2.3 Performance Benchmark
88
+
89
+ | Method | Time cost on Aishell Test Set | Relative speed up | Miss Rate |
90
+ |:------:|:----------:|:--------------:|:-----:|
91
+ | [[cosyvoice/tools/extract_speech_token.py]](https://github.com/FunAudioLLM/CosyVoice/blob/main/tools/extract_speech_token.py), cpu | 9 hours | ~ | ~ |
92
+ | cpu, batchsize 32 | 1.5h | ~6x | 0.00% |
93
+ | 4 gpus (3090), batchsize 32 per gpu | 41s | ~790x | 0.00% |
94
+
95
+ The miss rate represents the proportion of tokens that are inconsistent between the batch inference predictions and the ONNX (batch=1) inference predictions.
96
+
97
+ # Usage-3: Online speech code extraction
98
+
99
+ <table>
100
+ <tr>
101
+ <th>Before (extract code offline)</th>
102
+ <th>After (extract code online)</th>
103
+ </tr>
104
+ <tr>
105
+ <td>
106
+ <sub>
107
+
108
+ ```py
109
+
110
+ class SpeechLLM(nn.Module):
111
+ ...
112
+ def __init__(self, ...):
113
+ ...
114
+
115
+ def forward(self, speech_codes: Tensor, text_ids: Tensor, ...):
116
+ ...
117
+ ```
118
+
119
+ </sub>
120
+ <td>
121
+ <sub>
122
+
123
+ ```py
124
+ import s3tokenizer
125
+
126
+ class SpeechLLM(nn.Module):
127
+ ...
128
+ def __init__(self, ...):
129
+ ...
130
+ self.speech_tokenizer = s3tokenizer.load_model("speech_tokenizer_v1") # or "speech_tokenizer_v1_25hz"
131
+ self.speech_tokenizer.freeze()
132
+
133
+ def forward(self, speech: Tensor, speech_lens: Tensor, text_ids: Tensor, ...):
134
+ ...
135
+ speech_codes, speech_codes_lens = self.speech_tokenizer.quantize(speech, speech_lens)
136
+ speech_codes = speech_codes.clone() # for backward compatbility
137
+ speech_codes_lens = speeech_codes_lens.clone() # for backward compatbility
138
+ ```
139
+
140
+ </sub>
141
+ </td>
142
+ </tr>
143
+ </table>
144
+
145
+ # Usage-4: Long Audio Processing (Built-in Automatic Processing)
146
+
147
+ - **Automatic Detection**: Model automatically detects audio length (>30 seconds triggers long audio processing)
148
+ - **Sliding Window**: 30-second window with 4-second overlap, automatically segments long audio
149
+ - **Batch Processing**: Internal batch processing of multiple segments for improved efficiency
150
+ - **Complete Transparency**: User calling method is identical to short audio
speech/tools/S3Tokenizer/requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ pre-commit
2
+ numpy
3
+ torch
4
+ onnx
5
+ tqdm
6
+ torchaudio
7
+ einops
speech/tools/S3Tokenizer/s3tokenizer/__init__.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 OpenAI. (authors: Whisper Team)
2
+ # 2024 Tsinghua Univ. (authors: Xingchen Song)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Modified from
16
+ https://github.com/openai/whisper/blob/main/whisper/__init__.py
17
+ """
18
+
19
+ import hashlib
20
+ import os
21
+ import urllib
22
+ import warnings
23
+ from typing import List, Union
24
+
25
+ from tqdm import tqdm
26
+
27
+ from s3tokenizer.model_v2 import S3TokenizerV2
28
+
29
+ from .model import S3Tokenizer
30
+ from .utils import (load_audio, log_mel_spectrogram, make_non_pad_mask,
31
+ mask_to_bias, onnx2torch, padding, merge_tokenized_segments)
32
+
33
+ __all__ = [
34
+ 'load_audio', 'log_mel_spectrogram', 'make_non_pad_mask', 'mask_to_bias',
35
+ 'onnx2torch', 'padding', 'merge_tokenized_segments'
36
+ ]
37
+ _MODELS = {
38
+ "speech_tokenizer_v1":
39
+ "https://www.modelscope.cn/models/iic/cosyvoice-300m/"
40
+ "resolve/master/speech_tokenizer_v1.onnx",
41
+ "speech_tokenizer_v1_25hz":
42
+ "https://www.modelscope.cn/models/iic/CosyVoice-300M-25Hz/"
43
+ "resolve/master/speech_tokenizer_v1.onnx",
44
+ "speech_tokenizer_v2_25hz":
45
+ "https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/"
46
+ "resolve/master/speech_tokenizer_v2.onnx",
47
+ }
48
+
49
+ _SHA256S = {
50
+ "speech_tokenizer_v1":
51
+ "23b5a723ed9143aebfd9ffda14ac4c21231f31c35ef837b6a13bb9e5488abb1e",
52
+ "speech_tokenizer_v1_25hz":
53
+ "56285ddd4a83e883ee0cb9f8d69c1089b53a94b1f78ff7e4a0224a27eb4cb486",
54
+ "speech_tokenizer_v2_25hz":
55
+ "d43342aa12163a80bf07bffb94c9de2e120a8df2f9917cd2f642e7f4219c6f71",
56
+ }
57
+
58
+
59
+ def _download(name: str, root: str) -> Union[bytes, str]:
60
+ os.makedirs(root, exist_ok=True)
61
+
62
+ expected_sha256 = _SHA256S[name]
63
+ url = _MODELS[name]
64
+ download_target = os.path.join(root, f"{name}.onnx")
65
+
66
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
67
+ raise RuntimeError(
68
+ f"{download_target} exists and is not a regular file")
69
+
70
+ if os.path.isfile(download_target):
71
+ with open(download_target, "rb") as f:
72
+ model_bytes = f.read()
73
+ if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
74
+ return download_target
75
+ else:
76
+ warnings.warn(
77
+ f"{download_target} exists, but the SHA256 checksum does not"
78
+ " match; re-downloading the file")
79
+
80
+ with urllib.request.urlopen(url) as source, open(download_target,
81
+ "wb") as output:
82
+ with tqdm(
83
+ total=int(source.info().get("Content-Length")),
84
+ ncols=80,
85
+ unit="iB",
86
+ unit_scale=True,
87
+ unit_divisor=1024,
88
+ desc="Downloading onnx checkpoint",
89
+ ) as loop:
90
+ while True:
91
+ buffer = source.read(8192)
92
+ if not buffer:
93
+ break
94
+
95
+ output.write(buffer)
96
+ loop.update(len(buffer))
97
+
98
+ model_bytes = open(download_target, "rb").read()
99
+ if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
100
+ raise RuntimeError(
101
+ "Model has been downloaded but the SHA256 checksum does not not"
102
+ " match. Please retry loading the model.")
103
+
104
+ return download_target
105
+
106
+
107
+ def available_models() -> List[str]:
108
+ """Returns the names of available models"""
109
+ return list(_MODELS.keys())
110
+
111
+
112
+ def load_model(
113
+ name: str,
114
+ download_root: str = None,
115
+ ) -> S3Tokenizer:
116
+ """
117
+ Load a S3Tokenizer ASR model
118
+
119
+ Parameters
120
+ ----------
121
+ name : str
122
+ one of the official model names listed by
123
+ `s3tokenizer.available_models()`, or path to a model checkpoint
124
+ containing the model dimensions and the model state_dict.
125
+ download_root: str
126
+ path to download the model files; by default,
127
+ it uses "~/.cache/s3tokenizer"
128
+
129
+ Returns
130
+ -------
131
+ model : S3Tokenizer
132
+ The S3Tokenizer model instance
133
+ """
134
+
135
+ if download_root is None:
136
+ default = os.path.join(os.path.expanduser("~"), ".cache")
137
+ download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default),
138
+ "s3tokenizer")
139
+
140
+ if name in _MODELS:
141
+ checkpoint_file = _download(name, download_root)
142
+ elif os.path.isfile(name):
143
+ checkpoint_file = name
144
+ else:
145
+ raise RuntimeError(
146
+ f"Model {name} not found; available models = {available_models()}")
147
+ if 'v2' in name:
148
+ model = S3TokenizerV2(name)
149
+ else:
150
+ model = S3Tokenizer(name)
151
+ model.init_from_onnx(checkpoint_file)
152
+
153
+ return model
speech/tools/S3Tokenizer/s3tokenizer/assets/mel_filters.npz ADDED
Binary file (4.27 kB). View file
 
speech/tools/S3Tokenizer/s3tokenizer/cli.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Tsinghua Univ. (authors: Xingchen Song)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ Example Usage
15
+ cpu:
16
+
17
+ s3tokenizer --root_path /path/to/audio/files \
18
+ --model speech_tokenizer_v2_25hz \
19
+ --device "cpu" \
20
+ --batch_size 32
21
+
22
+ gpu:
23
+
24
+ torchrun --nproc_per_node=1 --nnodes=1 \
25
+ --rdzv_id=2024 --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
26
+ `which s3tokenizer` --root_path /data/dataset \
27
+ --model speech_tokenizer_v2_25hz \
28
+ --device "cuda" \
29
+ --batch_size 64
30
+
31
+ """
32
+
33
+ import argparse
34
+ import os
35
+ from pathlib import Path
36
+
37
+ import torch
38
+ import torch.distributed as dist
39
+ from torch.utils.data import DataLoader, Dataset, DistributedSampler
40
+ from tqdm import tqdm
41
+
42
+ import s3tokenizer
43
+
44
+
45
+ class AudioDataset(Dataset):
46
+
47
+ def __init__(self, root_path, extensions=['.wav', '.flac', '.mp3']):
48
+ self.data = []
49
+
50
+ # Recursively find all audio files
51
+ root = Path(root_path)
52
+ for ext in extensions:
53
+ self.data.extend(root.rglob(f'*{ext}'))
54
+
55
+ # Sort for consistent ordering
56
+ self.data.sort()
57
+
58
+ if len(self.data) == 0:
59
+ raise ValueError(f"No audio files found in {root_path}")
60
+
61
+ print(f"Found {len(self.data)} audio files")
62
+
63
+ def __len__(self):
64
+ return len(self.data)
65
+
66
+ def __getitem__(self, idx):
67
+ file_path = self.data[idx]
68
+ audio = s3tokenizer.load_audio(str(file_path))
69
+ mel = s3tokenizer.log_mel_spectrogram(audio)
70
+ return file_path, mel
71
+
72
+
73
+ def collate_fn(batch):
74
+ file_paths = [item[0] for item in batch]
75
+ mels = [item[1] for item in batch]
76
+ mels, mels_lens = s3tokenizer.padding(mels)
77
+ return file_paths, mels, mels_lens
78
+
79
+
80
+ def init_distributed():
81
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
82
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
83
+ rank = int(os.environ.get('RANK', 0))
84
+ print('Inference on multiple gpus, this gpu {}'.format(local_rank) +
85
+ ', rank {}, world_size {}'.format(rank, world_size))
86
+ torch.cuda.set_device(local_rank)
87
+ dist.init_process_group("nccl")
88
+ return world_size, local_rank, rank
89
+
90
+
91
+ def get_args():
92
+ parser = argparse.ArgumentParser(description='extract speech code')
93
+ parser.add_argument('--model',
94
+ required=True,
95
+ type=str,
96
+ choices=[
97
+ "speech_tokenizer_v1", "speech_tokenizer_v1_25hz",
98
+ "speech_tokenizer_v2_25hz"
99
+ ],
100
+ help='model version')
101
+ parser.add_argument('--root_path',
102
+ required=True,
103
+ type=str,
104
+ help='root directory containing audio files')
105
+ parser.add_argument('--device',
106
+ required=True,
107
+ type=str,
108
+ choices=["cuda", "cpu"],
109
+ help='device for inference')
110
+ parser.add_argument('--batch_size',
111
+ required=True,
112
+ type=int,
113
+ help='batch size (per-device) for inference')
114
+ parser.add_argument('--num_workers',
115
+ type=int,
116
+ default=4,
117
+ help='workers for dataloader')
118
+ parser.add_argument('--prefetch',
119
+ type=int,
120
+ default=5,
121
+ help='prefetch for dataloader')
122
+ parser.add_argument('--extensions',
123
+ nargs='+',
124
+ default=['.wav', '.flac', '.mp3'],
125
+ help='audio file extensions to process')
126
+ args = parser.parse_args()
127
+ return args
128
+
129
+
130
+ def save_tokens(file_path, codes, codes_len):
131
+ """Save tokens as .pt file with _fsq suffix"""
132
+ # Remove extension and add _fsq.pt
133
+ output_path = file_path.with_suffix('').with_suffix('.pt')
134
+ output_path = output_path.parent / f"{output_path.stem}_fsq.pt"
135
+
136
+ # Extract only valid codes (up to codes_len)
137
+ valid_codes = codes[:codes_len]
138
+ # convert valid codes to list
139
+ valid_codes = valid_codes.tolist()
140
+
141
+ # Save as tensor
142
+ torch.save(valid_codes, output_path)
143
+
144
+ return output_path
145
+
146
+
147
+ def main():
148
+ args = get_args()
149
+
150
+ if args.device == "cuda":
151
+ assert (torch.cuda.is_available())
152
+ world_size, local_rank, rank = init_distributed()
153
+ else:
154
+ world_size, local_rank, rank = 1, 0, 0
155
+
156
+ device = torch.device(args.device)
157
+ model = s3tokenizer.load_model(args.model).to(device)
158
+ dataset = AudioDataset(args.root_path, args.extensions)
159
+
160
+ if args.device == "cuda":
161
+ model = torch.nn.parallel.DistributedDataParallel(
162
+ model, device_ids=[local_rank])
163
+ sampler = DistributedSampler(dataset,
164
+ num_replicas=world_size,
165
+ rank=rank)
166
+ else:
167
+ sampler = None
168
+
169
+ dataloader = DataLoader(dataset,
170
+ batch_size=args.batch_size,
171
+ sampler=sampler,
172
+ shuffle=False,
173
+ num_workers=args.num_workers,
174
+ prefetch_factor=args.prefetch,
175
+ collate_fn=collate_fn)
176
+
177
+ total_steps = len(dataset)
178
+
179
+ if rank == 0:
180
+ progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
181
+
182
+ processed_count = 0
183
+ for file_paths, mels, mels_lens in dataloader:
184
+ codes, codes_lens = model(mels.to(device), mels_lens.to(device))
185
+
186
+ # Process each file in the batch
187
+ for i, file_path in enumerate(file_paths):
188
+ code = codes[i]
189
+ code_len = codes_lens[i].item()
190
+
191
+ # Save tokens as .pt file
192
+ output_path = save_tokens(file_path, code, code_len)
193
+
194
+ if rank == 0:
195
+ tqdm.write(f"Saved: {file_path} -> {output_path}")
196
+
197
+ processed_count += len(file_paths)
198
+
199
+ if rank == 0:
200
+ progress_bar.update(world_size * len(file_paths))
201
+
202
+ if rank == 0:
203
+ progress_bar.close()
204
+ print(f"\nProcessed {processed_count} files on rank {rank}")
205
+
206
+ if args.device == "cuda":
207
+ dist.barrier()
208
+ dist.destroy_process_group()
209
+
210
+
211
+ if __name__ == "__main__":
212
+ main()
speech/tools/S3Tokenizer/s3tokenizer/model.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 OpenAI. (authors: Whisper Team)
2
+ # 2024 Tsinghua Univ. (authors: Xingchen Song)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Modified from https://github.com/openai/whisper/blob/main/whisper/model.py
16
+ Add EuclideanCodebook & VectorQuantization
17
+ """
18
+
19
+ from dataclasses import dataclass
20
+ from typing import Iterable, Optional, Tuple
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from einops import rearrange
26
+ from torch import Tensor, nn
27
+
28
+ from .utils import make_non_pad_mask, mask_to_bias, onnx2torch, merge_tokenized_segments
29
+
30
+
31
+ @dataclass
32
+ class ModelConfig:
33
+ n_mels: int = 128
34
+ n_audio_ctx: int = 1500
35
+ n_audio_state: int = 1280
36
+ n_audio_head: int = 20
37
+ n_audio_layer: int = 6
38
+ n_codebook_size: int = 4096
39
+
40
+ use_sdpa: bool = False
41
+
42
+
43
+ class LayerNorm(nn.LayerNorm):
44
+
45
+ def forward(self, x: Tensor) -> Tensor:
46
+ return super().forward(x.float()).type(x.dtype)
47
+
48
+
49
+ class Linear(nn.Linear):
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return F.linear(
53
+ x,
54
+ self.weight.to(x.dtype),
55
+ None if self.bias is None else self.bias.to(x.dtype),
56
+ )
57
+
58
+
59
+ class Conv1d(nn.Conv1d):
60
+
61
+ def _conv_forward(self, x: Tensor, weight: Tensor,
62
+ bias: Optional[Tensor]) -> Tensor:
63
+ return super()._conv_forward(
64
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
65
+
66
+
67
+ def sinusoids(length, channels, max_timescale=10000):
68
+ """Returns sinusoids for positional embedding"""
69
+ assert channels % 2 == 0
70
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
71
+ inv_timescales = torch.exp(-log_timescale_increment *
72
+ torch.arange(channels // 2))
73
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[
74
+ np.newaxis, :]
75
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
76
+
77
+
78
+ class MultiHeadAttention(nn.Module):
79
+
80
+ def __init__(self, n_state: int, n_head: int, use_sdpa: bool = False):
81
+ super().__init__()
82
+ self.n_head = n_head
83
+ self.query = Linear(n_state, n_state)
84
+ self.key = Linear(n_state, n_state, bias=False)
85
+ self.value = Linear(n_state, n_state)
86
+ self.out = Linear(n_state, n_state)
87
+
88
+ self.use_sdpa = use_sdpa
89
+
90
+ def forward(
91
+ self,
92
+ x: Tensor,
93
+ mask: Optional[Tensor] = None,
94
+ ):
95
+ q = self.query(x)
96
+ k = self.key(x)
97
+ v = self.value(x)
98
+
99
+ wv, qk = self.qkv_attention(q, k, v, mask)
100
+ return self.out(wv), qk
101
+
102
+ def qkv_attention(self,
103
+ q: Tensor,
104
+ k: Tensor,
105
+ v: Tensor,
106
+ mask: Optional[Tensor] = None):
107
+ _, _, D = q.shape
108
+ scale = (D // self.n_head)**-0.25
109
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
110
+ k = k.view(*k.shape[:2], self.n_head, -1)
111
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
112
+
113
+ if not self.use_sdpa:
114
+ k = k.permute(0, 2, 3, 1) * scale
115
+ qk = q @ k # (B, n_head, T, T)
116
+ if mask is not None:
117
+ qk = qk + mask
118
+ qk = qk.float()
119
+ w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
120
+ return (w @ v).permute(0, 2, 1,
121
+ 3).flatten(start_dim=2), qk.detach()
122
+ else:
123
+ k = k.permute(0, 2, 1, 3) * scale
124
+ assert mask is not None
125
+ output = torch.nn.functional.scaled_dot_product_attention(
126
+ q,
127
+ k,
128
+ v,
129
+ attn_mask=mask,
130
+ dropout_p=0.,
131
+ scale=1.,
132
+ )
133
+ output = (output.transpose(1,
134
+ 2).contiguous().view(q.size(0), -1, D)
135
+ ) # (batch, time1, d_model)
136
+ return output, None
137
+
138
+
139
+ class ResidualAttentionBlock(nn.Module):
140
+
141
+ def __init__(self, n_state: int, n_head: int, use_sdpa: bool):
142
+ super().__init__()
143
+
144
+ self.attn = MultiHeadAttention(n_state, n_head, use_sdpa=use_sdpa)
145
+ self.attn_ln = LayerNorm(n_state)
146
+
147
+ n_mlp = n_state * 4
148
+ self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(),
149
+ Linear(n_mlp, n_state))
150
+ self.mlp_ln = LayerNorm(n_state)
151
+
152
+ def forward(
153
+ self,
154
+ x: Tensor,
155
+ mask: Optional[Tensor] = None,
156
+ ):
157
+ x = x + self.attn(self.attn_ln(x), mask=mask)[0]
158
+ x = x + self.mlp(self.mlp_ln(x))
159
+ return x
160
+
161
+
162
+ class AudioEncoder(nn.Module):
163
+
164
+ def __init__(
165
+ self,
166
+ n_mels: int,
167
+ n_ctx: int,
168
+ n_state: int,
169
+ n_head: int,
170
+ n_layer: int,
171
+ stride: int,
172
+ use_sdpa: bool,
173
+ ):
174
+ super().__init__()
175
+ self.stride = stride
176
+ self.conv1 = Conv1d(n_mels,
177
+ n_state,
178
+ kernel_size=3,
179
+ stride=stride,
180
+ padding=1)
181
+ self.conv2 = Conv1d(n_state,
182
+ n_state,
183
+ kernel_size=3,
184
+ stride=2,
185
+ padding=1)
186
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
187
+
188
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([
189
+ ResidualAttentionBlock(n_state, n_head, use_sdpa=use_sdpa)
190
+ for _ in range(n_layer)
191
+ ])
192
+
193
+ def forward(self, x: Tensor, x_len: Tensor) -> Tuple[Tensor, Tensor]:
194
+ """
195
+ x : torch.Tensor, shape = (batch_size, n_mels, T)
196
+ the mel spectrogram of the audio
197
+ x_len: torch.Tensor, shape = (batch_size,)
198
+ length of each audio in x
199
+ """
200
+ mask = make_non_pad_mask(x_len).unsqueeze(1)
201
+ x = F.gelu(self.conv1(x * mask))
202
+ x_len = (x_len + 2 - 1 * (3 - 1) - 1) // self.stride + 1
203
+ mask = make_non_pad_mask(x_len).unsqueeze(1)
204
+ x = F.gelu(self.conv2(x * mask))
205
+ x_len = (x_len + 2 - 1 * (3 - 1) - 1) // 2 + 1
206
+ mask = make_non_pad_mask(x_len).unsqueeze(1)
207
+ x = x.permute(0, 2, 1) # (B, T // 2, n_state)
208
+
209
+ mask = mask_to_bias(mask, x.dtype)
210
+
211
+ x = (x + self.positional_embedding[:x.shape[1], :]).to(x.dtype)
212
+
213
+ for block in self.blocks:
214
+ x = block(x, mask.unsqueeze(1))
215
+
216
+ return x, x_len
217
+
218
+
219
+ class EuclideanCodebook(nn.Module):
220
+ """Codebook with Euclidean distance (inference-only).
221
+ Args:
222
+ dim (int): Dimension.
223
+ codebook_size (int): Codebook size.
224
+ """
225
+
226
+ def __init__(self, dim: int, codebook_size: int):
227
+ super().__init__()
228
+ embed = torch.zeros(codebook_size, dim)
229
+ self.codebook_size = codebook_size
230
+ self.register_buffer("embed", embed)
231
+
232
+ @torch.inference_mode()
233
+ def preprocess(self, x: Tensor) -> Tensor:
234
+ x = rearrange(x, "... d -> (...) d")
235
+ return x
236
+
237
+ @torch.inference_mode()
238
+ def quantize(self, x: Tensor) -> Tensor:
239
+ embed = self.embed.t().to(x.dtype)
240
+ dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed +
241
+ embed.pow(2).sum(0, keepdim=True))
242
+ embed_ind = dist.max(dim=-1).indices
243
+ return embed_ind
244
+
245
+ @torch.inference_mode()
246
+ def postprocess_emb(self, embed_ind, shape):
247
+ return embed_ind.view(*shape[:-1])
248
+
249
+ @torch.inference_mode()
250
+ def dequantize(self, embed_ind: Tensor) -> Tensor:
251
+ quantize = F.embedding(embed_ind, self.embed)
252
+ return quantize
253
+
254
+ @torch.inference_mode()
255
+ def encode(self, x: Tensor) -> Tensor:
256
+ shape = x.shape
257
+ # pre-process
258
+ x = self.preprocess(x)
259
+ # quantize
260
+ embed_ind = self.quantize(x)
261
+ # post-process
262
+ embed_ind = self.postprocess_emb(embed_ind, shape)
263
+ return embed_ind
264
+
265
+ @torch.inference_mode()
266
+ def decode(self, embed_ind: Tensor) -> Tensor:
267
+ quantize = self.dequantize(embed_ind)
268
+ return quantize
269
+
270
+
271
+ class VectorQuantization(nn.Module):
272
+ """Vector quantization implementation (inference-only).
273
+ Args:
274
+ dim (int): Dimension
275
+ codebook_size (int): Codebook size
276
+ """
277
+
278
+ def __init__(self, dim: int, codebook_size: int):
279
+ super().__init__()
280
+ self._codebook = EuclideanCodebook(dim=dim,
281
+ codebook_size=codebook_size)
282
+ self.codebook_size = codebook_size
283
+
284
+ @property
285
+ def codebook(self):
286
+ return self._codebook.embed
287
+
288
+ @torch.inference_mode()
289
+ def encode(self, x: Tensor) -> Tensor:
290
+ x = F.normalize(x.float(), p=2, dim=-1)
291
+ embed_in = self._codebook.encode(x)
292
+ return embed_in
293
+
294
+ @torch.inference_mode()
295
+ def decode(self, embed_ind: Tensor) -> Tensor:
296
+ quantize = self._codebook.decode(embed_ind)
297
+ quantize = rearrange(quantize, "b n d -> b d n")
298
+ return quantize
299
+
300
+
301
+ class S3Tokenizer(nn.Module):
302
+ """S3 tokenizer implementation (inference-only).
303
+ Args:
304
+ config (ModelConfig): Config
305
+ """
306
+
307
+ def __init__(self, name: str, config: ModelConfig = ModelConfig()):
308
+ super().__init__()
309
+ self.name = name # Store model name for token_rate determination
310
+ self.config = config
311
+ self.encoder = AudioEncoder(
312
+ self.config.n_mels,
313
+ self.config.n_audio_ctx,
314
+ self.config.n_audio_state,
315
+ self.config.n_audio_head,
316
+ self.config.n_audio_layer,
317
+ 2 if name == "speech_tokenizer_v1_25hz" else 1,
318
+ self.config.use_sdpa,
319
+ )
320
+ self.quantizer = VectorQuantization(self.config.n_audio_state,
321
+ self.config.n_codebook_size)
322
+
323
+ def forward(self, mel: Tensor, mel_len: Tensor) -> Tuple[Tensor, Tensor]:
324
+ return self.quantize(mel, mel_len)
325
+
326
+ @torch.inference_mode()
327
+ def quantize(self, mel: Tensor, mel_len: Tensor) -> Tuple[Tensor, Tensor]:
328
+ """
329
+ Quantize mel spectrogram to tokens, with automatic long audio handling.
330
+
331
+ Args:
332
+ mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
333
+ mel_len: mel length tensor, shape (batch_size,)
334
+
335
+ Returns:
336
+ code: quantized tokens, shape (batch_size, T')
337
+ code_len: token length, shape (batch_size,)
338
+ """
339
+ # Check if any audio in the batch exceeds 30 seconds
340
+ # Assuming 16kHz sample rate and hop_length=160, 30s = 30*16000/160 = 3000 frames
341
+ max_frames = 3000
342
+
343
+ # Check which samples are long audio
344
+ long_audio_mask = mel_len > max_frames
345
+
346
+ if long_audio_mask.any():
347
+ # Has long audio - need special processing
348
+ return self._quantize_mixed_batch(mel, mel_len, long_audio_mask,
349
+ max_frames)
350
+ else:
351
+ # All short audio - use original method
352
+ hidden, code_len = self.encoder(mel, mel_len)
353
+ code = self.quantizer.encode(hidden)
354
+ return code, code_len
355
+
356
+ @torch.inference_mode()
357
+ def _quantize_mixed_batch(self, mel: Tensor, mel_len: Tensor,
358
+ long_audio_mask: Tensor,
359
+ max_frames: int) -> Tuple[Tensor, Tensor]:
360
+ """
361
+ Handle mixed batch with both short and long audio using unified batch processing.
362
+
363
+ Args:
364
+ mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
365
+ mel_len: mel length tensor, shape (batch_size,)
366
+ long_audio_mask: boolean mask for long audio, shape (batch_size,)
367
+ max_frames: maximum frames for short audio
368
+
369
+ Returns:
370
+ code: quantized tokens, shape (batch_size, T')
371
+ code_len: token length, shape (batch_size,)
372
+ """
373
+ batch_size = mel.size(0)
374
+
375
+ # Parameters for sliding window
376
+ sample_rate = 16000
377
+ hop_length = 160 # Default hop length for mel spectrogram
378
+ window_size = 30 # seconds
379
+ overlap = 4 # seconds
380
+
381
+ # Calculate frame-based parameters
382
+ frames_per_window = window_size * sample_rate // hop_length # 3000 frames
383
+ frames_per_overlap = overlap * sample_rate // hop_length # 400 frames
384
+ frames_per_stride = frames_per_window - frames_per_overlap # 2600 frames
385
+
386
+ # Collect all segments to process (including short and long audio segments)
387
+ all_segments = []
388
+ all_segments_len = []
389
+ segment_info = [
390
+ ] # Record which audio each segment belongs to and whether it's long audio
391
+
392
+ # Process all audio in the batch
393
+ for batch_idx in range(batch_size):
394
+ audio_mel = mel[batch_idx]
395
+ audio_mel_len = mel_len[batch_idx]
396
+ is_long_audio = long_audio_mask[batch_idx].item()
397
+
398
+ if not is_long_audio:
399
+ # Short audio: process directly as a single segment
400
+ segment = audio_mel[:, :audio_mel_len]
401
+ seg_len = audio_mel_len.item()
402
+
403
+ # Pad to max_frames if necessary
404
+ if seg_len < frames_per_window:
405
+ pad_size = frames_per_window - seg_len
406
+ segment = F.pad(segment, (0, pad_size))
407
+
408
+ all_segments.append(segment)
409
+ all_segments_len.append(
410
+ torch.tensor(seg_len, device=mel.device))
411
+ segment_info.append({
412
+ 'batch_idx': batch_idx,
413
+ 'is_long_audio': False,
414
+ 'segment_idx': 0,
415
+ 'total_segments': 1
416
+ })
417
+ else:
418
+ # Long audio: split into multiple segments
419
+ start = 0
420
+ segment_idx = 0
421
+ while start < audio_mel_len:
422
+ end = min(start + frames_per_window, audio_mel_len)
423
+ segment = audio_mel[:, start:end]
424
+
425
+ seg_len = segment.size(1)
426
+ # Pad if necessary
427
+ if seg_len < frames_per_window:
428
+ pad_size = frames_per_window - seg_len
429
+ segment = F.pad(segment, (0, pad_size))
430
+
431
+ all_segments.append(segment)
432
+ all_segments_len.append(
433
+ torch.tensor(seg_len, device=mel.device))
434
+ segment_info.append({
435
+ 'batch_idx': batch_idx,
436
+ 'is_long_audio': True,
437
+ 'segment_idx': segment_idx,
438
+ 'total_segments': None # Will be filled later
439
+ })
440
+
441
+ segment_idx += 1
442
+ start += frames_per_stride
443
+
444
+ # Update total_segments info
445
+ total_segments = segment_idx
446
+ for info in segment_info:
447
+ if info['batch_idx'] == batch_idx and info['is_long_audio']:
448
+ info['total_segments'] = total_segments
449
+
450
+ if not all_segments:
451
+ # Fallback if no segments
452
+ return torch.zeros(batch_size,
453
+ 0,
454
+ dtype=torch.long,
455
+ device=mel.device), torch.zeros(
456
+ batch_size,
457
+ dtype=torch.long,
458
+ device=mel.device)
459
+
460
+ # Unified batch processing for all segments
461
+ unified_batch_mel = torch.stack(all_segments)
462
+ unified_batch_lens = torch.stack(all_segments_len)
463
+
464
+ # Process all segments at once
465
+ hidden, code_len = self.encoder(unified_batch_mel, unified_batch_lens)
466
+ codes = self.quantizer.encode(hidden)
467
+
468
+ # Reorganize results based on segment_info
469
+ results = {} # batch_idx -> (code_tensor, code_len)
470
+
471
+ for seg_idx, info in enumerate(segment_info):
472
+ batch_idx = info['batch_idx']
473
+ is_long_audio = info['is_long_audio']
474
+ segment_idx = info['segment_idx']
475
+
476
+ # Get codes for current segment
477
+ segment_code = codes[
478
+ seg_idx, :code_len[seg_idx].item()].cpu().numpy().tolist()
479
+
480
+ if not is_long_audio:
481
+ # Short audio: use directly
482
+ code_tensor = torch.tensor(segment_code,
483
+ dtype=torch.long,
484
+ device=mel.device)
485
+ results[batch_idx] = (code_tensor, len(segment_code))
486
+ else:
487
+ # Long audio: collect all segments
488
+ if batch_idx not in results:
489
+ results[batch_idx] = []
490
+ results[batch_idx].append(segment_code)
491
+
492
+ # Process long audio segment merging
493
+ for batch_idx in range(batch_size):
494
+ if long_audio_mask[batch_idx].item():
495
+ # Merge long audio segments
496
+ audio_codes = results[batch_idx]
497
+
498
+ # Determine token rate based on model name
499
+ if hasattr(self,
500
+ 'name') and self.name == "speech_tokenizer_v1":
501
+ token_rate = 50
502
+ else:
503
+ token_rate = 25
504
+
505
+ merged_codes = merge_tokenized_segments(audio_codes,
506
+ overlap=overlap,
507
+ token_rate=token_rate)
508
+
509
+ # Convert to tensor
510
+ merged_codes_tensor = torch.tensor(merged_codes,
511
+ dtype=torch.long,
512
+ device=mel.device)
513
+ results[batch_idx] = (merged_codes_tensor, len(merged_codes))
514
+
515
+ # Construct final output
516
+ max_code_len = max(code_info[1] for code_info in results.values())
517
+
518
+ output_codes = torch.zeros(batch_size,
519
+ max_code_len,
520
+ dtype=torch.long,
521
+ device=mel.device)
522
+ output_codes_len = torch.zeros(batch_size,
523
+ dtype=torch.long,
524
+ device=mel.device)
525
+
526
+ for batch_idx, (code_tensor, code_len) in results.items():
527
+ output_codes[batch_idx, :code_len] = code_tensor
528
+ output_codes_len[batch_idx] = code_len
529
+
530
+ return output_codes, output_codes_len
531
+
532
+ @property
533
+ def device(self):
534
+ return next(self.parameters()).device
535
+
536
+ def init_from_onnx(self, onnx_path: str):
537
+ ckpt = onnx2torch(onnx_path, None, False)
538
+ self.load_state_dict(ckpt, strict=True)
539
+
540
+ def init_from_pt(self, ckpt_path: str):
541
+ ckpt = torch.load(ckpt_path, map_location="cpu", mmap=True)
542
+ self.load_state_dict(ckpt, strict=True)
543
+
544
+ def freeze(self):
545
+ for _, param in self.named_parameters():
546
+ param.requires_grad = False
speech/tools/S3Tokenizer/s3tokenizer/model_v2.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) (Mddct: Dinghao Zhou)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ from einops import rearrange
20
+
21
+ from s3tokenizer.model import Conv1d, LayerNorm, Linear, MultiHeadAttention
22
+ from s3tokenizer.utils import make_non_pad_mask, mask_to_bias, onnx2torch, merge_tokenized_segments
23
+
24
+
25
+ @dataclass
26
+ class ModelConfig:
27
+ n_mels: int = 128
28
+ n_audio_ctx: int = 1500
29
+ n_audio_state: int = 1280
30
+ n_audio_head: int = 20
31
+ n_audio_layer: int = 6
32
+ n_codebook_size: int = 3**8
33
+
34
+ use_sdpa: bool = False
35
+
36
+
37
+ def precompute_freqs_cis(dim: int,
38
+ end: int,
39
+ theta: float = 10000.0,
40
+ scaling=None):
41
+ freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
42
+ t = torch.arange(end, device=freqs.device) # type: ignore
43
+ if scaling is not None:
44
+ t = t * scaling
45
+ freqs = torch.outer(t, freqs).float() # type: ignore
46
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
47
+
48
+ return torch.cat((freqs_cis, freqs_cis), dim=-1)
49
+
50
+
51
+ def apply_rotary_emb(
52
+ xq: torch.Tensor,
53
+ xk: torch.Tensor,
54
+ freqs_cis: torch.Tensor,
55
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
56
+ real = torch.view_as_real(freqs_cis)
57
+ cos, sin = real[:, :, 0], real[:, :, 1]
58
+ cos = cos.unsqueeze(0).unsqueeze(2)
59
+ sin = sin.unsqueeze(0).unsqueeze(2)
60
+
61
+ D = xq.shape[-1]
62
+ half_l, half_r = xq[:, :, :, :D // 2], xq[:, :, :, D // 2:]
63
+ xq_r = torch.cat((-half_r, half_l), dim=-1)
64
+
65
+ D = xk.shape[-1]
66
+
67
+ half_l, half_r = xk[:, :, :, :D // 2], xk[:, :, :, D // 2:]
68
+ xk_r = torch.cat((-half_r, half_l), dim=-1)
69
+
70
+ return xq * cos + xq_r * sin, xk * cos + xk_r * sin
71
+
72
+
73
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
74
+ ndim = x.ndim
75
+ assert 0 <= 1 < ndim
76
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
77
+ shape = [
78
+ d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)
79
+ ]
80
+ return freqs_cis.view(*shape)
81
+
82
+
83
+ class FSQCodebook(torch.nn.Module):
84
+
85
+ def __init__(self, dim: int, level: int = 3):
86
+ super().__init__()
87
+ self.project_down = torch.nn.Linear(dim, 8)
88
+ self.level = level
89
+ self.embed = None
90
+
91
+ @torch.inference_mode()
92
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
93
+ x = rearrange(x, "... d -> (...) d")
94
+ return x
95
+
96
+ @torch.inference_mode()
97
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
98
+ x_shape = x.shape
99
+ # pre-process
100
+ x = self.preprocess(x)
101
+ # quantize
102
+ h = self.project_down(x).float()
103
+ h = h.tanh()
104
+ h = h * 0.9990000128746033
105
+ h = h.round() + 1
106
+ # h = ((self.level - 1) * h).round() # range [-k, k]
107
+ powers = torch.pow(
108
+ self.level,
109
+ torch.arange(2**self.level, device=x.device, dtype=h.dtype))
110
+ mu = torch.sum(h * powers.unsqueeze(0), dim=-1)
111
+ ind = mu.reshape(x_shape[0], x_shape[1]).int()
112
+ return ind
113
+
114
+ @torch.inference_mode()
115
+ def decode(self, embed_ind: torch.Tensor) -> torch.Tensor:
116
+ raise NotImplementedError(
117
+ 'There is no official up project component provided')
118
+
119
+
120
+ class FSQVectorQuantization(torch.nn.Module):
121
+ """Vector quantization implementation (inference-only).
122
+ Args:
123
+ dim (int): Dimension
124
+ codebook_size (int): Codebook size
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ dim: int,
130
+ codebook_size: int,
131
+ ):
132
+ super().__init__()
133
+ assert 3**8 == codebook_size
134
+ self._codebook = FSQCodebook(dim=dim, level=3)
135
+ self.codebook_size = codebook_size
136
+
137
+ @property
138
+ def codebook(self):
139
+ return self._codebook.embed
140
+
141
+ @torch.inference_mode()
142
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
143
+ return self._codebook.encode(x)
144
+
145
+ @torch.inference_mode()
146
+ def decode(self, embed_ind: torch.Tensor) -> torch.Tensor:
147
+ quantize = self._codebook.decode(embed_ind)
148
+ quantize = rearrange(quantize, "b n d -> b d n")
149
+ return quantize
150
+
151
+
152
+ class FSMNMultiHeadAttention(MultiHeadAttention):
153
+
154
+ def __init__(
155
+ self,
156
+ n_state: int,
157
+ n_head: int,
158
+ kernel_size: int = 31,
159
+ use_sdpa: bool = False,
160
+ ):
161
+ super().__init__(n_state, n_head)
162
+
163
+ self.fsmn_block = torch.nn.Conv1d(n_state,
164
+ n_state,
165
+ kernel_size,
166
+ stride=1,
167
+ padding=0,
168
+ groups=n_state,
169
+ bias=False)
170
+ self.left_padding = (kernel_size - 1) // 2
171
+ self.right_padding = kernel_size - 1 - self.left_padding
172
+ self.pad_fn = torch.nn.ConstantPad1d(
173
+ (self.left_padding, self.right_padding), 0.0)
174
+
175
+ self.use_sdpa = use_sdpa
176
+
177
+ def forward_fsmn(self,
178
+ inputs: torch.Tensor,
179
+ mask: Optional[torch.Tensor] = None):
180
+ b, t, _, _ = inputs.size()
181
+ inputs = inputs.view(b, t, -1)
182
+ if mask is not None and mask.size(2) > 0: # time2 > 0
183
+ inputs = inputs * mask
184
+ x = inputs.transpose(1, 2)
185
+ x = self.pad_fn(x)
186
+ x = self.fsmn_block(x)
187
+ x = x.transpose(1, 2)
188
+ x += inputs
189
+ return x * mask
190
+
191
+ def qkv_attention(self,
192
+ q: torch.Tensor,
193
+ k: torch.Tensor,
194
+ v: torch.Tensor,
195
+ mask: Optional[torch.Tensor] = None,
196
+ mask_pad: Optional[torch.Tensor] = None,
197
+ freqs_cis: Optional[torch.Tensor] = None):
198
+ _, _, D = q.shape
199
+ scale = (D // self.n_head)**-0.25
200
+ q = q.view(*q.shape[:2], self.n_head, -1)
201
+ k = k.view(*k.shape[:2], self.n_head, -1)
202
+ v = v.view(*v.shape[:2], self.n_head, -1)
203
+
204
+ if freqs_cis is not None:
205
+ q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
206
+
207
+ fsm_memory = self.forward_fsmn(v, mask_pad)
208
+
209
+ q = q.permute(0, 2, 1, 3) * scale
210
+ v = v.permute(0, 2, 1, 3)
211
+
212
+ if not self.use_sdpa:
213
+ k = k.permute(0, 2, 3, 1) * scale
214
+ qk = q @ k # (B, n_head, T, T)
215
+ if mask is not None:
216
+ qk = qk + mask
217
+ qk = qk.float()
218
+ w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
219
+ return (w @ v).permute(
220
+ 0, 2, 1, 3).flatten(start_dim=2), qk.detach(), fsm_memory
221
+ else:
222
+ k = k.permute(0, 2, 1, 3) * scale
223
+ assert mask is not None
224
+ output = torch.nn.functional.scaled_dot_product_attention(
225
+ q,
226
+ k,
227
+ v,
228
+ attn_mask=mask,
229
+ dropout_p=0.,
230
+ scale=1.,
231
+ )
232
+ output = (output.transpose(1,
233
+ 2).contiguous().view(q.size(0), -1, D)
234
+ ) # (batch, time1, d_model)
235
+ return output, None, fsm_memory
236
+
237
+ def forward(self,
238
+ x: torch.Tensor,
239
+ mask: Optional[torch.Tensor] = None,
240
+ mask_pad: Optional[torch.Tensor] = None,
241
+ freqs_cis: Optional[torch.Tensor] = None):
242
+
243
+ q = self.query(x)
244
+ k = self.key(x)
245
+ v = self.value(x)
246
+
247
+ wv, qk, fsm_memory = self.qkv_attention(q, k, v, mask, mask_pad,
248
+ freqs_cis)
249
+ return self.out(wv) + fsm_memory, qk
250
+
251
+
252
+ class ResidualAttentionBlock(torch.nn.Module):
253
+
254
+ def __init__(
255
+ self,
256
+ n_state: int,
257
+ n_head: int,
258
+ kernel_size: int = 31,
259
+ use_sdpa: bool = False,
260
+ ):
261
+ super().__init__()
262
+
263
+ self.attn = FSMNMultiHeadAttention(n_state,
264
+ n_head,
265
+ kernel_size,
266
+ use_sdpa=use_sdpa)
267
+ self.attn_ln = LayerNorm(n_state, eps=1e-6)
268
+
269
+ n_mlp = n_state * 4
270
+
271
+ self.mlp = torch.nn.Sequential(Linear(n_state, n_mlp), torch.nn.GELU(),
272
+ Linear(n_mlp, n_state))
273
+ self.mlp_ln = LayerNorm(n_state)
274
+
275
+ def forward(
276
+ self,
277
+ x: torch.Tensor,
278
+ mask: Optional[torch.Tensor] = None,
279
+ mask_pad: Optional[torch.Tensor] = None,
280
+ freqs_cis: Optional[torch.Tensor] = None,
281
+ ):
282
+ x = x + self.attn(
283
+ self.attn_ln(x), mask=mask, mask_pad=mask_pad,
284
+ freqs_cis=freqs_cis)[0]
285
+
286
+ x = x + self.mlp(self.mlp_ln(x))
287
+ return x
288
+
289
+
290
+ class AudioEncoderV2(torch.nn.Module):
291
+
292
+ def __init__(
293
+ self,
294
+ n_mels: int,
295
+ n_state: int,
296
+ n_head: int,
297
+ n_layer: int,
298
+ stride: int,
299
+ use_sdpa: bool,
300
+ ):
301
+ super().__init__()
302
+ self.stride = stride
303
+
304
+ self.conv1 = Conv1d(n_mels,
305
+ n_state,
306
+ kernel_size=3,
307
+ stride=stride,
308
+ padding=1)
309
+ self.conv2 = Conv1d(n_state,
310
+ n_state,
311
+ kernel_size=3,
312
+ stride=2,
313
+ padding=1)
314
+ self.freqs_cis = precompute_freqs_cis(64, 1024 * 2)
315
+ self.blocks = torch.nn.ModuleList([
316
+ ResidualAttentionBlock(n_state, n_head, use_sdpa=use_sdpa)
317
+ for _ in range(n_layer)
318
+ ])
319
+
320
+ def forward(self, x: torch.Tensor,
321
+ x_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
322
+ """
323
+ x : torch.Tensor, shape = (batch_size, n_mels, T)
324
+ the mel spectrogram of the audio
325
+ x_len: torch.Tensor, shape = (batch_size,)
326
+ length of each audio in x
327
+ """
328
+ mask = make_non_pad_mask(x_len).unsqueeze(1)
329
+ x = torch.nn.functional.gelu(self.conv1(x * mask))
330
+ x_len = (x_len + 2 - 1 * (3 - 1) - 1) // self.stride + 1
331
+ mask = make_non_pad_mask(x_len).unsqueeze(1)
332
+ x = torch.nn.functional.gelu(self.conv2(x * mask))
333
+ x_len = (x_len + 2 - 1 * (3 - 1) - 1) // 2 + 1
334
+ mask = make_non_pad_mask(x_len).unsqueeze(1)
335
+ x = x.permute(0, 2, 1) # (B, T // 2, n_state)
336
+ freqs_cis = self.freqs_cis.to(x.device)
337
+ mask_pad = mask.transpose(1, 2)
338
+ mask = mask_to_bias(mask, x.dtype)
339
+
340
+ tmp = torch.view_as_real(freqs_cis)
341
+ cos, sin = tmp[:, :, 0], tmp[:, :, 1]
342
+
343
+ cos = torch.cat((cos, cos), dim=-1)
344
+ sin = torch.cat((sin, sin), dim=-1)
345
+ cos = cos.unsqueeze(0).unsqueeze(2)
346
+ sin = sin.unsqueeze(0).unsqueeze(2)
347
+
348
+ for block in self.blocks:
349
+ x = block(x, mask.unsqueeze(1), mask_pad, freqs_cis[:x.size(1)])
350
+
351
+ return x, x_len
352
+
353
+
354
+ class S3TokenizerV2(torch.nn.Module):
355
+ """S3 tokenizer v2 implementation (inference-only).
356
+ Args:
357
+ config (ModelConfig): Config
358
+ """
359
+
360
+ def __init__(self, name: str, config: ModelConfig = ModelConfig()):
361
+ super().__init__()
362
+ self.name = name # Store model name for token_rate determination
363
+ if 'v1' not in name:
364
+ assert 'v2' in name
365
+ # TODO(Mddct): make it configureable
366
+ config.n_codebook_size = 3**8
367
+ self.config = config
368
+ self.encoder = AudioEncoderV2(
369
+ self.config.n_mels,
370
+ self.config.n_audio_state,
371
+ self.config.n_audio_head,
372
+ self.config.n_audio_layer,
373
+ 2,
374
+ self.config.use_sdpa,
375
+ )
376
+ self.quantizer = FSQVectorQuantization(
377
+ self.config.n_audio_state,
378
+ self.config.n_codebook_size,
379
+ )
380
+
381
+ def forward(self, mel: torch.Tensor,
382
+ mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
383
+ return self.quantize(mel, mel_len)
384
+
385
+ @torch.inference_mode()
386
+ def quantize(self, mel: torch.Tensor,
387
+ mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
388
+ """
389
+ Quantize mel spectrogram to tokens, with automatic long audio handling.
390
+
391
+ Args:
392
+ mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
393
+ mel_len: mel length tensor, shape (batch_size,)
394
+
395
+ Returns:
396
+ code: quantized tokens, shape (batch_size, T')
397
+ code_len: token length, shape (batch_size,)
398
+ """
399
+ # Check if any audio in the batch exceeds 30 seconds
400
+ # Assuming 16kHz sample rate and hop_length=160, 30s = 30*16000/160 = 3000 frames
401
+ max_frames = 3000
402
+
403
+ # Check which samples are long audio
404
+ long_audio_mask = mel_len > max_frames
405
+
406
+ if long_audio_mask.any():
407
+ # Has long audio - need special processing
408
+ return self._quantize_mixed_batch(mel, mel_len, long_audio_mask,
409
+ max_frames)
410
+ else:
411
+ # All short audio - use original method
412
+ hidden, code_len = self.encoder(mel, mel_len)
413
+ code = self.quantizer.encode(hidden)
414
+ return code, code_len
415
+
416
+ @torch.inference_mode()
417
+ def _quantize_mixed_batch(
418
+ self, mel: torch.Tensor, mel_len: torch.Tensor,
419
+ long_audio_mask: torch.Tensor,
420
+ max_frames: int) -> Tuple[torch.Tensor, torch.Tensor]:
421
+ """
422
+ Handle mixed batch with both short and long audio using unified batch processing.
423
+
424
+ Args:
425
+ mel: mel spectrogram tensor, shape (batch_size, n_mels, T)
426
+ mel_len: mel length tensor, shape (batch_size,)
427
+ long_audio_mask: boolean mask for long audio, shape (batch_size,)
428
+ max_frames: maximum frames for short audio
429
+
430
+ Returns:
431
+ code: quantized tokens, shape (batch_size, T')
432
+ code_len: token length, shape (batch_size,)
433
+ """
434
+ batch_size = mel.size(0)
435
+
436
+ # Parameters for sliding window
437
+ sample_rate = 16000
438
+ hop_length = 160 # Default hop length for mel spectrogram
439
+ window_size = 30 # seconds
440
+ overlap = 4 # seconds
441
+
442
+ # Calculate frame-based parameters
443
+ frames_per_window = window_size * sample_rate // hop_length # 3000 frames
444
+ frames_per_overlap = overlap * sample_rate // hop_length # 400 frames
445
+ frames_per_stride = frames_per_window - frames_per_overlap # 2600 frames
446
+
447
+ # Collect all segments to process (including short and long audio segments)
448
+ all_segments = []
449
+ all_segments_len = []
450
+ segment_info = [
451
+ ] # Record which audio each segment belongs to and whether it's long audio
452
+
453
+ # Process all audio in the batch
454
+ for batch_idx in range(batch_size):
455
+ audio_mel = mel[batch_idx]
456
+ audio_mel_len = mel_len[batch_idx]
457
+ is_long_audio = long_audio_mask[batch_idx].item()
458
+
459
+ if not is_long_audio:
460
+ # Short audio: process directly as a single segment
461
+ segment = audio_mel[:, :audio_mel_len]
462
+ seg_len = audio_mel_len.item()
463
+
464
+ # Pad to max_frames if necessary
465
+ if seg_len < frames_per_window:
466
+ pad_size = frames_per_window - seg_len
467
+ segment = torch.nn.functional.pad(segment, (0, pad_size))
468
+
469
+ all_segments.append(segment)
470
+ all_segments_len.append(
471
+ torch.tensor(seg_len, device=mel.device))
472
+ segment_info.append({
473
+ 'batch_idx': batch_idx,
474
+ 'is_long_audio': False,
475
+ 'segment_idx': 0,
476
+ 'total_segments': 1
477
+ })
478
+ else:
479
+ # Long audio: split into multiple segments
480
+ start = 0
481
+ segment_idx = 0
482
+ while start < audio_mel_len:
483
+ end = min(start + frames_per_window, audio_mel_len)
484
+ segment = audio_mel[:, start:end]
485
+
486
+ seg_len = segment.size(1)
487
+ # Pad if necessary
488
+ if seg_len < frames_per_window:
489
+ pad_size = frames_per_window - seg_len
490
+ segment = torch.nn.functional.pad(
491
+ segment, (0, pad_size))
492
+
493
+ all_segments.append(segment)
494
+ all_segments_len.append(
495
+ torch.tensor(seg_len, device=mel.device))
496
+ segment_info.append({
497
+ 'batch_idx': batch_idx,
498
+ 'is_long_audio': True,
499
+ 'segment_idx': segment_idx,
500
+ 'total_segments': None # Will be filled later
501
+ })
502
+
503
+ segment_idx += 1
504
+ start += frames_per_stride
505
+
506
+ # Update total_segments info
507
+ total_segments = segment_idx
508
+ for info in segment_info:
509
+ if info['batch_idx'] == batch_idx and info['is_long_audio']:
510
+ info['total_segments'] = total_segments
511
+
512
+ if not all_segments:
513
+ # Fallback if no segments
514
+ return torch.zeros(batch_size,
515
+ 0,
516
+ dtype=torch.long,
517
+ device=mel.device), torch.zeros(
518
+ batch_size,
519
+ dtype=torch.long,
520
+ device=mel.device)
521
+
522
+ # Unified batch processing for all segments
523
+ unified_batch_mel = torch.stack(all_segments)
524
+ unified_batch_lens = torch.stack(all_segments_len)
525
+
526
+ # Process all segments at once
527
+ hidden, code_len = self.encoder(unified_batch_mel, unified_batch_lens)
528
+ codes = self.quantizer.encode(hidden)
529
+
530
+ # Reorganize results based on segment_info
531
+ results = {} # batch_idx -> (code_tensor, code_len)
532
+
533
+ for seg_idx, info in enumerate(segment_info):
534
+ batch_idx = info['batch_idx']
535
+ is_long_audio = info['is_long_audio']
536
+ segment_idx = info['segment_idx']
537
+
538
+ # Get codes for current segment
539
+ segment_code = codes[
540
+ seg_idx, :code_len[seg_idx].item()].cpu().numpy().tolist()
541
+
542
+ if not is_long_audio:
543
+ # Short audio: use directly
544
+ code_tensor = torch.tensor(segment_code,
545
+ dtype=torch.long,
546
+ device=mel.device)
547
+ results[batch_idx] = (code_tensor, len(segment_code))
548
+ else:
549
+ # Long audio: collect all segments
550
+ if batch_idx not in results:
551
+ results[batch_idx] = []
552
+ results[batch_idx].append(segment_code)
553
+
554
+ # Process long audio segment merging
555
+ for batch_idx in range(batch_size):
556
+ if long_audio_mask[batch_idx].item():
557
+ # Merge long audio segments
558
+ audio_codes = results[batch_idx]
559
+
560
+ # V2 models use 25Hz token rate
561
+ token_rate = 25
562
+
563
+ merged_codes = merge_tokenized_segments(audio_codes,
564
+ overlap=overlap,
565
+ token_rate=token_rate)
566
+
567
+ # Convert to tensor
568
+ merged_codes_tensor = torch.tensor(merged_codes,
569
+ dtype=torch.long,
570
+ device=mel.device)
571
+ results[batch_idx] = (merged_codes_tensor, len(merged_codes))
572
+
573
+ # Construct final output
574
+ max_code_len = max(code_info[1] for code_info in results.values())
575
+
576
+ output_codes = torch.zeros(batch_size,
577
+ max_code_len,
578
+ dtype=torch.long,
579
+ device=mel.device)
580
+ output_codes_len = torch.zeros(batch_size,
581
+ dtype=torch.long,
582
+ device=mel.device)
583
+
584
+ for batch_idx, (code_tensor, code_len) in results.items():
585
+ output_codes[batch_idx, :code_len] = code_tensor
586
+ output_codes_len[batch_idx] = code_len
587
+
588
+ return output_codes, output_codes_len
589
+
590
+ @property
591
+ def device(self):
592
+ return next(self.parameters()).device
593
+
594
+ def init_from_onnx(self, onnx_path: str):
595
+ ckpt = onnx2torch(onnx_path, None, False)
596
+ self.load_state_dict(ckpt, strict=True)
597
+
598
+ def init_from_pt(self, ckpt_path: str):
599
+ ckpt = torch.load(ckpt_path, map_location="cpu", mmap=True)
600
+ self.load_state_dict(ckpt, strict=True)
601
+
602
+ def freeze(self):
603
+ for _, param in self.named_parameters():
604
+ param.requires_grad = False
speech/tools/S3Tokenizer/s3tokenizer/utils.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 OpenAI. (authors: Whisper Team)
2
+ # 2024 Tsinghua Univ. (authors: Xingchen Song)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Modified from https://github.com/openai/whisper/blob/main/whisper/audio.py
16
+ Add rename_weights() & onnx2torch() & make_non_pad_mask() & mask_to_bias()
17
+ Copy merge_tokenized_segments() from https://github.com/Mddct/s3tokenizer-long/blob/main/example.py
18
+ """
19
+
20
+ import os
21
+ from functools import lru_cache
22
+ from typing import List, Optional, Union
23
+
24
+ import numpy as np
25
+ import onnx
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torchaudio
29
+ from torch.nn.utils.rnn import pad_sequence
30
+
31
+
32
+ def _rename_weights(weights_dict: dict):
33
+ """
34
+ Rename onnx weights to pytorch format.
35
+
36
+ Parameters
37
+ ----------
38
+ weight_dict: dict
39
+ The dict containing weights in onnx format
40
+
41
+ Returns
42
+ -------
43
+ A new weight dict containing the weights in pytorch format.
44
+ """
45
+ new_weight_dict = {}
46
+ for k in weights_dict.keys():
47
+ if "quantizer" in k: # vq or fsq
48
+ if k == "/quantizer/rq/model/layers.0/_codebook/Pow_1":
49
+ new_weight_dict["quantizer._codebook.embed"] = weights_dict[k]
50
+ elif 'project_down' in k: # v2
51
+ new_weight_dict[k] = weights_dict[k]
52
+ elif "positional_embedding" in k: # positional emb
53
+ new_weight_dict[k] = weights_dict[k]
54
+ elif "conv" in k: # 1/2 or 1/4 subsample
55
+ new_weight_dict[k] = weights_dict[k]
56
+ else: # transformer blocks
57
+ assert "blocks" in k
58
+ new_k = (k[1:].replace('/', '.').replace(
59
+ 'MatMul', 'weight').replace('Add_1', 'bias').replace(
60
+ 'Mul', 'weight').replace('Add', 'bias').replace(
61
+ 'mlp.mlp', 'mlp')).replace('fsmn_block.Conv',
62
+ 'fsmn_block.weight')
63
+
64
+ new_weight_dict[f"encoder.{new_k}"] = weights_dict[k]
65
+ return new_weight_dict
66
+
67
+
68
+ def onnx2torch(onnx_path: str, torch_path: str = None, verbose: bool = False):
69
+ """
70
+ Open an onnx file and convert to pytorch format.
71
+
72
+ Parameters
73
+ ----------
74
+ onnx_path: str
75
+ The onnx file to open, typically `speech_tokenizer_v1.onnx`
76
+
77
+ torch_path: str
78
+ The path to save the torch-formated checkpoint.
79
+
80
+ verbose: bool
81
+ Logging info or not.
82
+
83
+ Returns
84
+ -------
85
+ A checkpoint dict containing the weights and their names, if torch_path is
86
+ None. Otherwise save checkpoint dict to the desired path.
87
+ """
88
+ onnx_model = onnx.load(onnx_path)
89
+ weights_dict = {}
90
+ initializer_map = {
91
+ initializer.name: initializer
92
+ for initializer in onnx_model.graph.initializer
93
+ }
94
+ for node in onnx_model.graph.node:
95
+ for input_name in node.input:
96
+ if input_name in initializer_map:
97
+ ln_bias_name, ln_weight_name = None, None # for v2 ln
98
+ initializer = initializer_map[input_name]
99
+ if input_name in [
100
+ "onnx::Conv_1519",
101
+ "encoders.conv1.weight",
102
+ "onnx::Conv_2216",
103
+ ]: # v1_50hz, v1_25hz, v2_25hz
104
+ weight_name = "encoder.conv1.weight"
105
+ elif input_name in [
106
+ "onnx::Conv_1520",
107
+ "encoders.conv1.bias",
108
+ "onnx::Conv_2217",
109
+ ]: # v1_50hz, v1_25hz, v2_25hz
110
+ weight_name = "encoder.conv1.bias"
111
+ elif input_name in [
112
+ "onnx::Conv_1521",
113
+ "encoders.conv2.weight",
114
+ "onnx::Conv_2218",
115
+ ]:
116
+ weight_name = "encoder.conv2.weight"
117
+ elif input_name in [
118
+ "onnx::Conv_1522",
119
+ "encoders.conv2.bias",
120
+ "onnx::Conv_2219",
121
+ ]:
122
+ weight_name = "encoder.conv2.bias"
123
+ elif input_name == "encoders.positional_embedding":
124
+ weight_name = "encoder.positional_embedding"
125
+ elif input_name == 'quantizer.project_in.bias':
126
+ weight_name = "quantizer._codebook.project_down.bias"
127
+ elif input_name == 'onnx::MatMul_2536':
128
+ weight_name = "quantizer._codebook.project_down.weight"
129
+ else:
130
+ if node.op_type == 'LayerNormalization': # in input_name:
131
+ ln_name = node.name.replace('/LayerNormalization', '')
132
+ ln_weight_name = ln_name + '.weight'
133
+ ln_bias_name = ln_name + '.bias'
134
+ else:
135
+ weight_name = node.name
136
+ if ln_weight_name is not None and ln_bias_name is not None:
137
+ ln_inputs = node.input
138
+ scale_name = ln_inputs[1]
139
+ bias_name = ln_inputs[2]
140
+ scale = onnx.numpy_helper.to_array(
141
+ initializer_map[scale_name]).copy(
142
+ ) if scale_name in initializer_map else None
143
+ bias = onnx.numpy_helper.to_array(
144
+ initializer_map[bias_name]).copy(
145
+ ) if bias_name in initializer_map else None
146
+ scale.flags.writeable = True
147
+ bias.flags.writeable = True
148
+ weight_tensor = torch.from_numpy(scale)
149
+ bias_tensor = torch.from_numpy(bias)
150
+
151
+ weights_dict[ln_bias_name] = bias_tensor
152
+ weights_dict[ln_weight_name] = weight_tensor
153
+ else:
154
+ weight_array = onnx.numpy_helper.to_array(
155
+ initializer).copy()
156
+ weight_array.flags.writeable = True
157
+ weight_tensor = torch.from_numpy(weight_array)
158
+ if len(weight_tensor.shape) > 2 or weight_name in [
159
+ "encoder.positional_embedding"
160
+ ]:
161
+ weights_dict[weight_name] = weight_tensor
162
+ else:
163
+ weights_dict[weight_name] = weight_tensor.t()
164
+
165
+ new_weights_dict = _rename_weights(weights_dict)
166
+ if verbose:
167
+ for k, v in new_weights_dict.items():
168
+ print(f"{k} : {v.shape} {v.dtype}")
169
+ print(f"PyTorch weights saved to {torch_path}")
170
+ del weights_dict, onnx_model
171
+ if torch_path:
172
+ torch.save(new_weights_dict, torch_path)
173
+ else:
174
+ return new_weights_dict
175
+
176
+
177
+ def load_audio(file: str, sr: int = 16000):
178
+ """
179
+ Open an audio file and read as mono waveform, resampling as necessary
180
+
181
+ Parameters
182
+ ----------
183
+ file: str
184
+ The audio file to open
185
+
186
+ sr: int
187
+ The sample rate to resample the audio if necessary
188
+
189
+ Returns
190
+ -------
191
+ A torch.Tensor containing the audio waveform, in float32 dtype.
192
+ """
193
+ audio, sample_rate = torchaudio.load(file)
194
+ if sample_rate != sr:
195
+ audio = torchaudio.transforms.Resample(sample_rate, sr)(audio)
196
+ audio = audio[0] # get the first channel
197
+ return audio
198
+
199
+
200
+ @lru_cache(maxsize=None)
201
+ def _mel_filters(device, n_mels: int) -> torch.Tensor:
202
+ """
203
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
204
+ Allows decoupling librosa dependency; saved using:
205
+
206
+ np.savez_compressed(
207
+ "mel_filters.npz",
208
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
209
+ mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
210
+ )
211
+ """
212
+ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
213
+
214
+ filters_path = os.path.join(os.path.dirname(__file__), "assets",
215
+ "mel_filters.npz")
216
+ with np.load(filters_path, allow_pickle=False) as f:
217
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
218
+
219
+
220
+ def log_mel_spectrogram(
221
+ audio: Union[str, np.ndarray, torch.Tensor],
222
+ n_mels: int = 128,
223
+ padding: int = 0,
224
+ device: Optional[Union[str, torch.device]] = None,
225
+ ):
226
+ """
227
+ Compute the log-Mel spectrogram of
228
+
229
+ Parameters
230
+ ----------
231
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
232
+ The path to audio or either a NumPy array or Tensor containing the
233
+ audio waveform in 16 kHz
234
+
235
+ n_mels: int
236
+ The number of Mel-frequency filters, only 80 is supported
237
+
238
+ padding: int
239
+ Number of zero samples to pad to the right
240
+
241
+ device: Optional[Union[str, torch.device]]
242
+ If given, the audio tensor is moved to this device before STFT
243
+
244
+ Returns
245
+ -------
246
+ torch.Tensor, shape = (128, n_frames)
247
+ A Tensor that contains the Mel spectrogram
248
+ """
249
+ if not torch.is_tensor(audio):
250
+ if isinstance(audio, str):
251
+ audio = load_audio(audio)
252
+
253
+ if device is not None:
254
+ audio = audio.to(device)
255
+ if padding > 0:
256
+ audio = F.pad(audio, (0, padding))
257
+ window = torch.hann_window(400).to(audio.device)
258
+ stft = torch.stft(audio, 400, 160, window=window, return_complex=True)
259
+ magnitudes = stft[..., :-1].abs()**2
260
+
261
+ filters = _mel_filters(audio.device, n_mels)
262
+ mel_spec = filters @ magnitudes
263
+
264
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
265
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
266
+ log_spec = (log_spec + 4.0) / 4.0
267
+ return log_spec
268
+
269
+
270
+ def make_non_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
271
+ """Make mask tensor containing indices of non-padded part.
272
+
273
+ The sequences in a batch may have different lengths. To enable
274
+ batch computing, padding is need to make all sequence in same
275
+ size. To avoid the padding part pass value to context dependent
276
+ block such as attention or convolution , this padding part is
277
+ masked.
278
+
279
+ 1 for non-padded part and 0 for padded part.
280
+
281
+ Parameters
282
+ ----------
283
+ lengths (torch.Tensor): Batch of lengths (B,).
284
+
285
+ Returns:
286
+ -------
287
+ torch.Tensor: Mask tensor containing indices of padded part (B, max_T).
288
+
289
+ Examples:
290
+ >>> import torch
291
+ >>> import s3tokenizer
292
+ >>> lengths = torch.tensor([5, 3, 2])
293
+ >>> masks = s3tokenizer.make_non_pad_mask(lengths)
294
+ masks = [[1, 1, 1, 1, 1],
295
+ [1, 1, 1, 0, 0],
296
+ [1, 1, 0, 0, 0]]
297
+ """
298
+ batch_size = lengths.size(0)
299
+ max_len = max_len if max_len > 0 else lengths.max().item()
300
+ seq_range = torch.arange(0,
301
+ max_len,
302
+ dtype=torch.int64,
303
+ device=lengths.device)
304
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
305
+ seq_length_expand = lengths.unsqueeze(-1)
306
+ mask = seq_range_expand >= seq_length_expand
307
+ return ~mask
308
+
309
+
310
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
311
+ """Convert bool-tensor to float-tensor for flash attention.
312
+
313
+ Parameters
314
+ ----------
315
+ lengths (torch.Tensor): Batch of lengths (B, ?).
316
+
317
+ Returns:
318
+ -------
319
+ torch.Tensor: Mask tensor containing indices of padded part (B, ?).
320
+
321
+ Examples:
322
+ >>> import torch
323
+ >>> import s3tokenizer
324
+ >>> lengths = torch.tensor([5, 3, 2])
325
+ >>> masks = s3tokenizer.make_non_pad_mask(lengths)
326
+ masks = [[1, 1, 1, 1, 1],
327
+ [1, 1, 1, 0, 0],
328
+ [1, 1, 0, 0, 0]]
329
+ >>> new_masks = s3tokenizer.mask_to_bias(masks, torch.float32)
330
+ new_masks =
331
+ [[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
332
+ [-0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10],
333
+ [-0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10, -1.0000e+10]]
334
+ """
335
+ assert mask.dtype == torch.bool
336
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
337
+ mask = mask.to(dtype)
338
+
339
+ # attention mask bias
340
+ # NOTE(Mddct): torch.finfo jit issues
341
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
342
+ mask = (1.0 - mask) * -1.0e+10
343
+ return mask
344
+
345
+
346
+ def padding(data: List[torch.Tensor]):
347
+ """ Padding the data into batch data
348
+
349
+ Parameters
350
+ ----------
351
+ data: List[Tensor], shape of Tensor (128, T)
352
+
353
+ Returns:
354
+ -------
355
+ feats [B, 128, T_max], feats lengths [B]
356
+ """
357
+ sample = data
358
+ assert isinstance(sample, list)
359
+ feats_lengths = torch.tensor([s.size(1) for s in sample],
360
+ dtype=torch.int32)
361
+ feats = [s.t() for s in sample]
362
+ padded_feats = pad_sequence(feats, batch_first=True, padding_value=0)
363
+
364
+ return padded_feats.transpose(1, 2), feats_lengths
365
+
366
+
367
+ def merge_tokenized_segments(tokenized_segments, overlap, token_rate):
368
+ """
369
+ Merges tokenized outputs by keeping the middle and dropping half of the overlapped tokens.
370
+
371
+ Args:
372
+ - tokenized_segments (List[List[int]]): List of tokenized sequences.
373
+ - overlap (int): Overlapping duration in seconds (default: 4s).
374
+ - token_rate (int): Number of tokens per second.
375
+
376
+ Returns:
377
+ - List[int]: A single merged token sequence.
378
+ """
379
+ merged_tokens = []
380
+ overlap_tokens = (
381
+ overlap //
382
+ 2) * token_rate # Tokens corresponding to half of the overlap duration
383
+
384
+ for i, tokens in enumerate(tokenized_segments):
385
+ l = 0 if i == 0 else overlap_tokens
386
+ r = -overlap_tokens if i != len(tokenized_segments) - 1 else len(tokens)
387
+ # Keep only the middle part (drop overlap / 2 from both sides)
388
+ merged_tokens.extend(tokens[l:r])
389
+
390
+ return merged_tokens
speech/tools/S3Tokenizer/setup.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from setuptools import find_packages, setup
4
+
5
+
6
+ def parse_requirements(filename):
7
+ """Load requirements from a pip requirements file."""
8
+ with open(filename, 'r') as file:
9
+ lines = (line.strip() for line in file)
10
+ return [line for line in lines if line and not line.startswith('#')]
11
+
12
+
13
+ setup(
14
+ name="s3tokenizer",
15
+ version="0.2.0",
16
+ description=\
17
+ "Reverse Engineering of Supervised Semantic Speech Tokenizer (S3Tokenizer) proposed in CosyVoice", # noqa
18
+ long_description=open("README.md", encoding="utf-8").read(),
19
+ long_description_content_type="text/markdown",
20
+ python_requires=">=3.8",
21
+ author="xingchensong",
22
+ url="https://github.com/xingchensong/S3Tokenizer",
23
+ license="Apache2.0",
24
+ packages=find_packages(),
25
+ install_requires=parse_requirements(
26
+ Path(__file__).with_name("requirements.txt")),
27
+ entry_points={
28
+ "console_scripts": ["s3tokenizer=s3tokenizer.cli:main"],
29
+ },
30
+ include_package_data=True,
31
+ extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]},
32
+ classifiers=[
33
+ "Programming Language :: Python :: 3",
34
+ "Operating System :: OS Independent",
35
+ "Topic :: Scientific/Engineering",
36
+ ],
37
+ )
speech/tools/S3Tokenizer/test/test_batch_efficiency.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Batch processing efficiency test
4
+ Test the efficiency improvement of new batch processing functionality for mixed long and short audio
5
+ """
6
+
7
+ import time
8
+ import torch
9
+ import pytest
10
+ import s3tokenizer
11
+
12
+
13
+ def create_test_audio(duration_seconds=20, sample_rate=16000):
14
+ """Create test audio"""
15
+ length = int(duration_seconds * sample_rate)
16
+ # Create meaningful audio signal (sine wave mixture)
17
+ t = torch.linspace(0, duration_seconds, length)
18
+ audio = 0.5 * torch.sin(2 * torch.pi * 440 * t) # 440Hz fundamental
19
+ audio += 0.3 * torch.sin(2 * torch.pi * 880 * t) # 880Hz second harmonic
20
+ audio += 0.1 * torch.randn(length) # Add some noise
21
+ return audio
22
+
23
+
24
+ @pytest.fixture
25
+ def test_audios():
26
+ """Create test audio dataset"""
27
+ return [
28
+ create_test_audio(10), # Short audio
29
+ create_test_audio(20), # Medium audio
30
+ create_test_audio(40), # Long audio
31
+ create_test_audio(60), # Long audio
32
+ create_test_audio(15), # Short audio
33
+ create_test_audio(35), # Long audio
34
+ create_test_audio(25), # Medium audio
35
+ create_test_audio(50), # Long audio
36
+ ]
37
+
38
+
39
+ @pytest.fixture
40
+ def long_audios():
41
+ """Create long audio dataset"""
42
+ return [
43
+ create_test_audio(45.5),
44
+ create_test_audio(60),
45
+ create_test_audio(91.2),
46
+ create_test_audio(120),
47
+ ]
48
+
49
+
50
+ @pytest.mark.parametrize("model_name", [
51
+ "speech_tokenizer_v1_25hz", "speech_tokenizer_v1",
52
+ "speech_tokenizer_v2_25hz"
53
+ ])
54
+ def test_batch_efficiency(test_audios, model_name):
55
+ """Test batch processing efficiency for different models"""
56
+ print(f"\n=== Batch Processing Efficiency Test for {model_name} ===")
57
+
58
+ # Load model
59
+ model = s3tokenizer.load_model(model_name)
60
+ model.eval()
61
+
62
+ # Method 1: Individual processing
63
+ print(f"\n--- Method 1: Individual Processing ({model_name}) ---")
64
+ start_time = time.time()
65
+ individual_results = []
66
+
67
+ for i, audio in enumerate(test_audios):
68
+ mel = s3tokenizer.log_mel_spectrogram(audio)
69
+ mels = mel.unsqueeze(0)
70
+ mels_lens = torch.tensor([mel.size(1)])
71
+
72
+ with torch.no_grad():
73
+ codes, codes_lens = model.quantize(mels, mels_lens)
74
+
75
+ final_codes = codes[0, :codes_lens[0].item()].tolist()
76
+ individual_results.append(final_codes)
77
+
78
+ duration = audio.shape[0] / 16000
79
+ processing_type = "Long audio" if duration > 30 else "Short audio"
80
+ print(
81
+ f"Audio {i+1}: {duration:.1f}s, {len(final_codes)} tokens, {processing_type}"
82
+ )
83
+
84
+ individual_time = time.time() - start_time
85
+ print(f"Individual processing total time: {individual_time:.2f}s")
86
+
87
+ # Method 2: Batch processing
88
+ print(f"\n--- Method 2: Batch Processing ({model_name}) ---")
89
+ start_time = time.time()
90
+
91
+ # Prepare batch input
92
+ mels = []
93
+ for audio in test_audios:
94
+ mel = s3tokenizer.log_mel_spectrogram(audio)
95
+ mels.append(mel)
96
+
97
+ # Use padding to handle different lengths of mel
98
+ mels, mels_lens = s3tokenizer.padding(mels)
99
+
100
+ # Batch processing
101
+ with torch.no_grad():
102
+ codes, codes_lens = model.quantize(mels, mels_lens)
103
+
104
+ # Process results
105
+ batch_results = []
106
+ for i in range(len(test_audios)):
107
+ final_codes = codes[i, :codes_lens[i].item()].tolist()
108
+ batch_results.append(final_codes)
109
+
110
+ duration = test_audios[i].shape[0] / 16000
111
+ processing_type = "Long audio" if duration > 30 else "Short audio"
112
+ print(
113
+ f"Audio {i+1}: {duration:.1f}s, {len(final_codes)} tokens, {processing_type}"
114
+ )
115
+
116
+ batch_time = time.time() - start_time
117
+ print(f"Batch processing total time: {batch_time:.2f}s")
118
+
119
+ # Verify result consistency
120
+ print(f"\n--- Result Verification for {model_name} ---")
121
+ all_ok = True
122
+ for i in range(len(test_audios)):
123
+ individual_tokens = individual_results[i]
124
+ batch_tokens = batch_results[i]
125
+
126
+ # Calculate miss rate
127
+ if len(individual_tokens) != len(batch_tokens):
128
+ print(
129
+ f"❌ Audio {i+1} length mismatch: individual={len(individual_tokens)}, batch={len(batch_tokens)}"
130
+ )
131
+ all_ok = False
132
+ else:
133
+ mismatches = sum(1 for a, b in zip(individual_tokens, batch_tokens)
134
+ if a != b)
135
+ miss_rate = mismatches / len(individual_tokens) * 100 if len(
136
+ individual_tokens) > 0 else 0
137
+
138
+ if miss_rate < 0.2: # Less than 0.2% is considered OK
139
+ print(f"✅ Audio {i+1} miss rate: {miss_rate:.4f}% (OK)")
140
+ else:
141
+ print(f"❌ Audio {i+1} miss rate: {miss_rate:.4f}% (Too high)")
142
+ all_ok = False
143
+
144
+ # Efficiency improvement
145
+ speedup = individual_time / batch_time
146
+ print(f"\n--- Efficiency Improvement for {model_name} ---")
147
+ print(f"Batch processing speedup: {speedup:.2f}x")
148
+ if speedup > 1:
149
+ print("✅ Batch processing indeed improves efficiency!")
150
+ else:
151
+ print("⚠️ Batch processing doesn't significantly improve efficiency")
152
+
153
+ # Assertions for pytest
154
+ assert all_ok, f"Results don't match for model {model_name}"
155
+ assert len(individual_results) == len(
156
+ batch_results), "Number of results don't match"
157
+ assert all(
158
+ len(individual_results[i]) == len(batch_results[i])
159
+ for i in range(len(test_audios))), "Token counts don't match"
160
+
161
+ # Performance assertion - batch should be at least as fast as individual (allowing for some variance)
162
+ # assert batch_time <= individual_time * 1.1, f"Batch processing should not be significantly slower than individual processing for {model_name}"
163
+
164
+
165
+ @pytest.mark.parametrize("model_name", [
166
+ "speech_tokenizer_v1_25hz", "speech_tokenizer_v1",
167
+ "speech_tokenizer_v2_25hz"
168
+ ])
169
+ def test_pure_long_audio_batch(long_audios, model_name):
170
+ """Test pure long audio batch processing for different models"""
171
+ print(f"\n=== Pure Long Audio Batch Processing Test for {model_name} ===")
172
+
173
+ model = s3tokenizer.load_model(model_name)
174
+ model.eval()
175
+
176
+ # Prepare batch input
177
+ mels = []
178
+ for audio in long_audios:
179
+ mel = s3tokenizer.log_mel_spectrogram(audio)
180
+ mels.append(mel)
181
+
182
+ mels, mels_lens = s3tokenizer.padding(mels)
183
+
184
+ # Batch process long audio
185
+ start_time = time.time()
186
+ with torch.no_grad():
187
+ codes, codes_lens = model.quantize(mels, mels_lens)
188
+ processing_time = time.time() - start_time
189
+
190
+ print(
191
+ f"Batch processing {len(long_audios)} long audios took: {processing_time:.2f}s"
192
+ )
193
+
194
+ results = []
195
+ for i in range(len(long_audios)):
196
+ duration = long_audios[i].shape[0] / 16000
197
+ tokens_count = codes_lens[i].item()
198
+ results.append((duration, tokens_count))
199
+ print(f"Long audio {i+1}: {duration:.1f}s → {tokens_count} tokens")
200
+
201
+ print(
202
+ f"✅ Pure long audio batch processing test completed for {model_name}")
203
+
204
+ # Assertions for pytest
205
+ assert codes is not None, f"Codes should not be None for model {model_name}"
206
+ assert codes_lens is not None, f"Codes lengths should not be None for model {model_name}"
207
+ assert len(results) == len(
208
+ long_audios), "Number of results should match number of input audios"
209
+ assert all(
210
+ tokens_count > 0
211
+ for _, tokens_count in results), "All audio should produce tokens"
212
+ assert processing_time > 0, "Processing time should be positive"
213
+
214
+
215
+ @pytest.mark.parametrize("model_name", [
216
+ "speech_tokenizer_v1_25hz", "speech_tokenizer_v1",
217
+ "speech_tokenizer_v2_25hz"
218
+ ])
219
+ def test_model_loading(model_name):
220
+ """Test that all models can be loaded successfully"""
221
+ print(f"\n=== Model Loading Test for {model_name} ===")
222
+
223
+ model = s3tokenizer.load_model(model_name)
224
+ assert model is not None, f"Model {model_name} should load successfully"
225
+
226
+ # Test model can be set to eval mode
227
+ model.eval()
228
+ print(f"✅ Model {model_name} loaded and set to eval mode successfully")
229
+
230
+
231
+ @pytest.mark.parametrize("model_name", [
232
+ "speech_tokenizer_v1_25hz", "speech_tokenizer_v1",
233
+ "speech_tokenizer_v2_25hz"
234
+ ])
235
+ def test_single_audio_processing(model_name):
236
+ """Test single audio processing for different models"""
237
+ print(f"\n=== Single Audio Processing Test for {model_name} ===")
238
+
239
+ # Create a single test audio
240
+ audio = create_test_audio(30) # 30 second audio
241
+
242
+ model = s3tokenizer.load_model(model_name)
243
+ model.eval()
244
+
245
+ # Process the audio
246
+ mel = s3tokenizer.log_mel_spectrogram(audio)
247
+ mels = mel.unsqueeze(0)
248
+ mels_lens = torch.tensor([mel.size(1)])
249
+
250
+ with torch.no_grad():
251
+ codes, codes_lens = model.quantize(mels, mels_lens)
252
+
253
+ final_codes = codes[0, :codes_lens[0].item()].tolist()
254
+
255
+ # Assertions
256
+ assert codes is not None, f"Codes should not be None for model {model_name}"
257
+ assert codes_lens is not None, f"Codes lengths should not be None for model {model_name}"
258
+ assert len(
259
+ final_codes) > 0, f"Should produce tokens for model {model_name}"
260
+ assert codes_lens[0].item() == len(
261
+ final_codes
262
+ ), f"Codes length should match actual codes for model {model_name}"
263
+
264
+ duration = audio.shape[0] / 16000
265
+ print(
266
+ f"✅ Single audio processing test completed for {model_name}: {duration:.1f}s → {len(final_codes)} tokens"
267
+ )
268
+
269
+
270
+ if __name__ == "__main__":
271
+ # Run tests with pytest
272
+ pytest.main([__file__, "-v"])
speech/tools/S3Tokenizer/test/test_onnx.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright [2024-09-27] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>
4
+
5
+ import os
6
+ import time
7
+ from typing import Dict, Any
8
+
9
+ import numpy as np
10
+ import onnxruntime
11
+ import pytest
12
+ import s3tokenizer
13
+ import torch
14
+
15
+
16
+ def create_test_audio(duration_seconds: float = 20,
17
+ sample_rate: int = 16000) -> torch.Tensor:
18
+ """Create synthetic test audio"""
19
+ length = int(duration_seconds * sample_rate)
20
+ # Create sinusoidal mixed audio
21
+ t = torch.linspace(0, duration_seconds, length)
22
+ audio = 0.5 * torch.sin(2 * torch.pi * 440 * t) # 440Hz fundamental
23
+ audio += 0.3 * torch.sin(2 * torch.pi * 880 * t) # 880Hz second harmonic
24
+ audio += 0.1 * torch.randn(length) # Add noise
25
+ return audio
26
+
27
+
28
+ @pytest.fixture
29
+ def test_audio_suite():
30
+ """Create a suite of test audios with different lengths"""
31
+ return {
32
+ "short_audio_1": create_test_audio(5.0), # 5 seconds
33
+ "short_audio_2": create_test_audio(15.0), # 15 seconds
34
+ "medium_audio": create_test_audio(25.0), # 25 seconds
35
+ "medium_audio_2": create_test_audio(30.0), # 30 seconds
36
+ "long_audio": create_test_audio(
37
+ 35.0), # 35 seconds - for torch and onnx, 2 segments with padding
38
+ "long_audio_2": create_test_audio(
39
+ 56.0
40
+ ), # 56 seconds - for torch and onnx, exactly 2 segments without padding
41
+ "very_long_audio": create_test_audio(
42
+ 60.0), # 60 seconds - for torch and onnx, 3 segments with padding
43
+ }
44
+
45
+
46
+ def onnx_inference_short_audio(model_name: str, mel: torch.Tensor,
47
+ mel_len: torch.Tensor) -> torch.Tensor:
48
+ """
49
+ ONNX inference for short audio (<=30s)
50
+ """
51
+ # Load ONNX model
52
+ default = os.path.join(os.path.expanduser("~"), ".cache")
53
+ download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default),
54
+ "s3tokenizer")
55
+
56
+ option = onnxruntime.SessionOptions()
57
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
58
+ option.intra_op_num_threads = 1
59
+ providers = ["CPUExecutionProvider"]
60
+
61
+ ort_session = onnxruntime.InferenceSession(
62
+ f"{download_root}/{model_name}.onnx",
63
+ sess_options=option,
64
+ providers=providers)
65
+
66
+ # Direct inference for short audio
67
+ onnx_output = ort_session.run(
68
+ None, {
69
+ ort_session.get_inputs()[0].name:
70
+ mel[:, :mel_len.item()].unsqueeze(0).detach().cpu().numpy(),
71
+ ort_session.get_inputs()[1].name:
72
+ np.array([mel_len.item()], dtype=np.int32)
73
+ })[0]
74
+
75
+ # Convert to numpy array to fix linter issues
76
+ onnx_output = np.array(onnx_output)
77
+
78
+ # Handle different output formats
79
+ if onnx_output.ndim == 2:
80
+ onnx_output = onnx_output[0, :]
81
+ elif onnx_output.ndim == 3:
82
+ onnx_output = onnx_output[0, 0, :]
83
+
84
+ return torch.tensor(onnx_output, dtype=torch.long)
85
+
86
+
87
+ def onnx_inference_long_audio(model_name: str, mel: torch.Tensor,
88
+ mel_len: torch.Tensor) -> torch.Tensor:
89
+ """
90
+ ONNX inference for long audio (>30s) using sliding window approach
91
+ Based on _quantize_mixed_batch logic
92
+
93
+ Note: This may fail due to ONNX model limitations with dynamic lengths
94
+ """
95
+ # Load ONNX model
96
+ default = os.path.join(os.path.expanduser("~"), ".cache")
97
+ download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default),
98
+ "s3tokenizer")
99
+
100
+ option = onnxruntime.SessionOptions()
101
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
102
+ option.intra_op_num_threads = 1
103
+ providers = ["CPUExecutionProvider"]
104
+
105
+ ort_session = onnxruntime.InferenceSession(
106
+ f"{download_root}/{model_name}.onnx",
107
+ sess_options=option,
108
+ providers=providers)
109
+
110
+ # Parameters for sliding window (same as _quantize_mixed_batch)
111
+ sample_rate = 16000
112
+ hop_length = 160
113
+ window_size = 30 # seconds
114
+ overlap = 4 # seconds
115
+
116
+ # Calculate frame-based parameters
117
+ frames_per_window = window_size * sample_rate // hop_length # 3000 frames
118
+ frames_per_overlap = overlap * sample_rate // hop_length # 400 frames
119
+ frames_per_stride = frames_per_window - frames_per_overlap # 2600 frames
120
+
121
+ # Split into segments
122
+ segments = []
123
+ segments_len = []
124
+ start = 0
125
+
126
+ while start < mel_len.item():
127
+ end = min(start + frames_per_window, mel_len.item())
128
+ segment = mel[:, start:end]
129
+
130
+ if segment.size(1) < frames_per_window:
131
+ break
132
+
133
+ seg_len = segment.size(1)
134
+ segments.append(segment)
135
+ segments_len.append(seg_len)
136
+
137
+ start += frames_per_stride
138
+
139
+ if not segments:
140
+ raise ValueError("No valid segments for ONNX processing")
141
+
142
+ # Process each segment with ONNX
143
+ segment_results = []
144
+ for i, (segment, seg_len) in enumerate(zip(segments, segments_len)):
145
+ try:
146
+ onnx_output = ort_session.run(
147
+ None, {
148
+ ort_session.get_inputs()[0].name:
149
+ segment.unsqueeze(0).detach().cpu().numpy(),
150
+ ort_session.get_inputs()[1].name:
151
+ np.array([seg_len], dtype=np.int32)
152
+ })[0]
153
+
154
+ # Convert to numpy array to fix linter issues
155
+ onnx_output = np.array(onnx_output)
156
+
157
+ # Handle different output formats
158
+ if onnx_output.ndim == 2:
159
+ segment_codes = onnx_output[0, :].tolist()
160
+ elif onnx_output.ndim == 3:
161
+ segment_codes = onnx_output[0, 0, :].tolist()
162
+ else:
163
+ segment_codes = onnx_output.tolist()
164
+
165
+ segment_results.append(segment_codes)
166
+
167
+ except Exception as e:
168
+ print(f" ONNX error on segment {i+1}: {str(e)[:100]}...")
169
+ raise Exception(
170
+ f"ONNX inference failed on segment {i+1}: {str(e)}")
171
+
172
+ if not segment_results:
173
+ raise ValueError("All ONNX segments failed to process")
174
+
175
+ # Merge segments using the same logic as _quantize_mixed_batch
176
+ # Determine token rate based on model name
177
+ if model_name == "speech_tokenizer_v1":
178
+ token_rate = 50
179
+ else:
180
+ token_rate = 25
181
+
182
+ merged_codes = s3tokenizer.merge_tokenized_segments(
183
+ segment_results, overlap=overlap, token_rate=token_rate
184
+ )[:-overlap * token_rate] # NOTE(xcsong): drop the last overlap part.
185
+ return torch.tensor(merged_codes, dtype=torch.long)
186
+
187
+
188
+ def onnx_inference_with_long_audio_support(
189
+ model_name: str, mel: torch.Tensor,
190
+ mel_len: torch.Tensor) -> torch.Tensor:
191
+ """
192
+ ONNX inference with automatic long audio support
193
+ """
194
+ max_frames = 3000 # 30s * 16000 / 160 = 3000 frames
195
+
196
+ if mel_len.item() <= max_frames:
197
+ # Short audio - use direct inference
198
+ return onnx_inference_short_audio(model_name, mel, mel_len)
199
+ else:
200
+ # Long audio - use sliding window approach
201
+ return onnx_inference_long_audio(model_name, mel, mel_len)
202
+
203
+
204
+ def compare_torch_vs_onnx_single(model_name: str, audio: torch.Tensor,
205
+ audio_name: str) -> Dict[str, Any]:
206
+ """Test single audio with both torch and onnx versions"""
207
+ duration = audio.shape[0] / 16000
208
+
209
+ # Load torch model
210
+ tokenizer = s3tokenizer.load_model(model_name)
211
+ tokenizer.eval()
212
+
213
+ # Prepare input
214
+ mel = s3tokenizer.log_mel_spectrogram(audio)
215
+ mels = mel.unsqueeze(0)
216
+ mels_lens = torch.tensor([mel.size(1)])
217
+
218
+ # Test torch version
219
+ start_time = time.time()
220
+ with torch.no_grad():
221
+ torch_codes, torch_codes_lens = tokenizer.quantize(mels, mels_lens)
222
+ torch_time = time.time() - start_time
223
+
224
+ torch_result = torch_codes[0, :torch_codes_lens[0].item()]
225
+
226
+ # Test onnx version with long audio support
227
+ try:
228
+ start_time = time.time()
229
+ onnx_result = onnx_inference_with_long_audio_support(
230
+ model_name, mel, mels_lens[0])
231
+ onnx_time = time.time() - start_time
232
+
233
+ # Compare results
234
+ min_len = min(len(torch_result), len(onnx_result))
235
+ torch_truncated = torch_result[:min_len]
236
+ onnx_truncated = onnx_result[:min_len]
237
+
238
+ are_equal = torch.equal(torch_truncated, onnx_truncated)
239
+ miss_rate = 0.0
240
+
241
+ if not are_equal:
242
+ miss_num = torch.sum(~(torch_truncated == onnx_truncated))
243
+ miss_rate = miss_num.item() * 100.0 / min_len
244
+
245
+ return {
246
+ "audio_name": audio_name,
247
+ "model_name": model_name,
248
+ "duration": duration,
249
+ "torch_tokens": torch_truncated,
250
+ "onnx_tokens": onnx_truncated,
251
+ "torch_time": torch_time,
252
+ "onnx_time": onnx_time,
253
+ "results_match": are_equal,
254
+ "miss_rate": miss_rate
255
+ }
256
+
257
+ except Exception as e:
258
+ return {
259
+ "audio_name": audio_name,
260
+ "model_name": model_name,
261
+ "duration": duration,
262
+ "torch_tokens": torch_result,
263
+ "onnx_tokens": [],
264
+ "torch_time": torch_time,
265
+ "onnx_time": 0.0,
266
+ "results_match": False,
267
+ "miss_rate": 100.0,
268
+ "error": str(e)
269
+ }
270
+
271
+
272
+ @pytest.mark.parametrize("model_name", [
273
+ "speech_tokenizer_v1", "speech_tokenizer_v1_25hz",
274
+ "speech_tokenizer_v2_25hz"
275
+ ])
276
+ def test_torch_vs_onnx_short_audio(model_name, test_audio_suite):
277
+ """Test torch vs onnx for short audio (<=30s)"""
278
+ print(f"\n=== Testing {model_name} on Short Audio ===")
279
+
280
+ short_audios = {
281
+ k: v
282
+ for k, v in test_audio_suite.items() if v.shape[0] / 16000 <= 30
283
+ }
284
+
285
+ results = []
286
+ for audio_name, audio in short_audios.items():
287
+ result = compare_torch_vs_onnx_single(model_name, audio, audio_name)
288
+ results.append(result)
289
+
290
+ duration = result["duration"]
291
+ torch_tokens = result["torch_tokens"]
292
+ onnx_tokens = result["onnx_tokens"]
293
+ match_status = "✅" if result["results_match"] else "❌"
294
+
295
+ print(
296
+ f"{match_status} {audio_name}: {duration:.1f}s → torch:{len(torch_tokens)}, onnx:{len(onnx_tokens)}"
297
+ )
298
+
299
+ if not result["results_match"] and "error" not in result:
300
+ print(f" Miss rate: {result['miss_rate']:.2f}%")
301
+ print(
302
+ f" torch_tokens:\n{torch_tokens}\nonnx_tokens:\n{onnx_tokens}"
303
+ )
304
+
305
+ # Assertions
306
+ successful_tests = [r for r in results if "error" not in r]
307
+ assert len(successful_tests) == len(
308
+ short_audios
309
+ ), f"successful tests ({len(successful_tests)}) for {model_name} should be equal to number of short audios ({len(short_audios)})" # noqa
310
+
311
+ # For short audio, we expect reasonable match rate
312
+ for r in results:
313
+ assert r[
314
+ 'miss_rate'] < 0.5, f"Miss rate too high for {model_name}: {r['miss_rate']:.2f}%"
315
+
316
+ print(f"\n{model_name} Short Audio Summary:")
317
+ print(f" Successful tests: {len(successful_tests)}/{len(results)}")
318
+
319
+
320
+ @pytest.mark.parametrize("model_name", [
321
+ "speech_tokenizer_v1", "speech_tokenizer_v1_25hz",
322
+ "speech_tokenizer_v2_25hz"
323
+ ])
324
+ def test_torch_vs_onnx_long_audio(model_name, test_audio_suite):
325
+ """Test torch vs onnx for long audio (>30s) with ONNX sliding window implementation"""
326
+ print(
327
+ f"\n=== Testing {model_name} on Long Audio (ONNX Sliding Window) ===")
328
+
329
+ long_audios = {
330
+ k: v
331
+ for k, v in test_audio_suite.items() if v.shape[0] / 16000 > 30
332
+ }
333
+
334
+ results = []
335
+ for audio_name, audio in long_audios.items():
336
+ result = compare_torch_vs_onnx_single(model_name, audio, audio_name)
337
+ results.append(result)
338
+
339
+ duration = result["duration"]
340
+ torch_tokens = result["torch_tokens"]
341
+ onnx_tokens = result["onnx_tokens"]
342
+ match_status = "✅" if result["results_match"] else "❌"
343
+
344
+ print(
345
+ f"{match_status} {audio_name}: {duration:.1f}s → torch:{len(torch_tokens)}, onnx:{len(onnx_tokens)}"
346
+ )
347
+
348
+ if not result["results_match"] and "error" not in result:
349
+ print(f" Miss rate: {result['miss_rate']:.2f}%")
350
+ print(
351
+ f" torch_tokens:\n{torch_tokens}\nonnx_tokens:\n{onnx_tokens}"
352
+ )
353
+ elif "error" in result:
354
+ print(f" Error: {result['error'][:100]}...")
355
+
356
+ # For long audio with ONNX, we document the current limitations
357
+ successful_tests = [r for r in results if "error" not in r]
358
+ assert len(successful_tests) == len(
359
+ long_audios
360
+ ), f"successful tests ({len(successful_tests)}) for {model_name} should be equal to number of long audios ({len(long_audios)})" # noqa
361
+
362
+ print(f"\n{model_name} Long Audio Results:")
363
+ print(f" Total tests: {len(results)}")
364
+ print(f" Successful ONNX tests: {len(successful_tests)}")
365
+
366
+ for r in results:
367
+ # NOTE(xcsong): 0.5% is a reasonable miss rate for long audio, since we drop the last overlap part.
368
+ assert r[
369
+ 'miss_rate'] < 0.5, f"Miss rate too high for {model_name}: {r['miss_rate']}%"
370
+
371
+ # The main requirement is that Torch always works
372
+ print(" ✅ Torch processing works reliably for all long audio")
373
+
374
+
375
+ if __name__ == "__main__":
376
+ # Run tests with pytest
377
+ pytest.main([__file__, "-v"])