ArnoLiu commited on
Commit
d155c6a
·
verified ·
1 Parent(s): 0690e2b

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +27 -0
  2. .github/workflows/test.yml +20 -0
  3. .gitignore +201 -0
  4. .gitmodules +0 -0
  5. LICENSE +175 -0
  6. README.md +298 -3
  7. boson_multimodal/__init__.py +0 -0
  8. boson_multimodal/audio_processing/LICENSE +51 -0
  9. boson_multimodal/audio_processing/descriptaudiocodec/__init__.py +0 -0
  10. boson_multimodal/audio_processing/descriptaudiocodec/dac/model/base.py +286 -0
  11. boson_multimodal/audio_processing/descriptaudiocodec/dac/model/dac.py +365 -0
  12. boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/layers.py +33 -0
  13. boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/quantize.py +251 -0
  14. boson_multimodal/audio_processing/higgs_audio_tokenizer.py +329 -0
  15. boson_multimodal/audio_processing/quantization/__init__.py +8 -0
  16. boson_multimodal/audio_processing/quantization/ac.py +292 -0
  17. boson_multimodal/audio_processing/quantization/core_vq.py +360 -0
  18. boson_multimodal/audio_processing/quantization/core_vq_lsx_version.py +425 -0
  19. boson_multimodal/audio_processing/quantization/ddp_utils.py +197 -0
  20. boson_multimodal/audio_processing/quantization/distrib.py +123 -0
  21. boson_multimodal/audio_processing/quantization/vq.py +116 -0
  22. boson_multimodal/audio_processing/semantic_module.py +282 -0
  23. boson_multimodal/constants.py +3 -0
  24. boson_multimodal/data_collator/__init__.py +0 -0
  25. boson_multimodal/data_collator/higgs_audio_collator.py +509 -0
  26. boson_multimodal/data_types.py +38 -0
  27. boson_multimodal/dataset/__init__.py +0 -0
  28. boson_multimodal/dataset/chatml_dataset.py +533 -0
  29. boson_multimodal/model/__init__.py +0 -0
  30. boson_multimodal/model/higgs_audio/__init__.py +9 -0
  31. boson_multimodal/model/higgs_audio/audio_head.py +129 -0
  32. boson_multimodal/model/higgs_audio/common.py +27 -0
  33. boson_multimodal/model/higgs_audio/configuration_higgs_audio.py +235 -0
  34. boson_multimodal/model/higgs_audio/cuda_graph_runner.py +129 -0
  35. boson_multimodal/model/higgs_audio/custom_modules.py +155 -0
  36. boson_multimodal/model/higgs_audio/modeling_higgs_audio.py +0 -0
  37. boson_multimodal/model/higgs_audio/utils.py +756 -0
  38. boson_multimodal/serve/serve_engine.py +491 -0
  39. boson_multimodal/serve/utils.py +246 -0
  40. cmd.sh +35 -0
  41. examples/README.md +166 -0
  42. examples/generation.py +768 -0
  43. examples/interactive_generation.py +800 -0
  44. examples/scene_prompts/quiet_indoor.txt +1 -0
  45. examples/scene_prompts/reading_blog.txt +1 -0
  46. examples/serve_engine/README.md +25 -0
  47. examples/serve_engine/input_samples.py +87 -0
  48. examples/serve_engine/run_hf_example.py +48 -0
  49. examples/serve_engine/voice_examples/old_man.wav +3 -0
  50. examples/transcript/multi_speaker/en_argument.txt +4 -0
.gitattributes CHANGED
@@ -33,3 +33,30 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/serve_engine/voice_examples/old_man.wav filter=lfs diff=lfs merge=lfs -text
37
+ examples/voice_prompts/MYHHHHH.flac filter=lfs diff=lfs merge=lfs -text
38
+ examples/voice_prompts/MYHHHHH.wav filter=lfs diff=lfs merge=lfs -text
39
+ examples/voice_prompts/belinda.wav filter=lfs diff=lfs merge=lfs -text
40
+ examples/voice_prompts/bigbang_amy.wav filter=lfs diff=lfs merge=lfs -text
41
+ examples/voice_prompts/bigbang_sheldon.wav filter=lfs diff=lfs merge=lfs -text
42
+ examples/voice_prompts/broom_salesman.wav filter=lfs diff=lfs merge=lfs -text
43
+ examples/voice_prompts/en_man.wav filter=lfs diff=lfs merge=lfs -text
44
+ examples/voice_prompts/en_woman.wav filter=lfs diff=lfs merge=lfs -text
45
+ examples/voice_prompts/fiftyshades_anna.wav filter=lfs diff=lfs merge=lfs -text
46
+ examples/voice_prompts/mabaoguo.wav filter=lfs diff=lfs merge=lfs -text
47
+ examples/voice_prompts/xiaohei.wav filter=lfs diff=lfs merge=lfs -text
48
+ examples/voice_prompts/zh_man_sichuan.wav filter=lfs diff=lfs merge=lfs -text
49
+ figures/emergent-tts-emotions-win-rate.png filter=lfs diff=lfs merge=lfs -text
50
+ figures/higgs_audio_tokenizer_architecture.png filter=lfs diff=lfs merge=lfs -text
51
+ figures/higgs_audio_v2_architecture_combined.png filter=lfs diff=lfs merge=lfs -text
52
+ generation.wav filter=lfs diff=lfs merge=lfs -text
53
+ higgs-audio-v2-generation-3B-base/emergent-tts-emotions-win-rate.png filter=lfs diff=lfs merge=lfs -text
54
+ higgs-audio-v2-generation-3B-base/higgs_audio_tokenizer_architecture.png filter=lfs diff=lfs merge=lfs -text
55
+ higgs-audio-v2-generation-3B-base/higgs_audio_v2_architecture_combined.png filter=lfs diff=lfs merge=lfs -text
56
+ higgs-audio-v2-generation-3B-base/open_source_repo_demo.mp4 filter=lfs diff=lfs merge=lfs -text
57
+ higgs-audio-v2-generation-3B-base/tokenizer.json filter=lfs diff=lfs merge=lfs -text
58
+ higgs-audio-v2-tokenizer/higgs_audio_tokenizer_architecture.png filter=lfs diff=lfs merge=lfs -text
59
+ my_outputs/generation_001.wav filter=lfs diff=lfs merge=lfs -text
60
+ my_outputs/generation_002.wav filter=lfs diff=lfs merge=lfs -text
61
+ my_outputs/generation_003.wav filter=lfs diff=lfs merge=lfs -text
62
+ my_outputs/xiaohegeneration_002.wav filter=lfs diff=lfs merge=lfs -text
.github/workflows/test.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Unit Test
2
+ on:
3
+ push:
4
+ branches: [ main ]
5
+ pull_request:
6
+ branches: [ main ]
7
+
8
+ jobs:
9
+ lint:
10
+ name: Lint
11
+ runs-on: ubuntu-22.04
12
+ steps:
13
+ - name: Checkout code
14
+ uses: actions/checkout@v4
15
+
16
+ - name: Check Code Formatting with Ruff
17
+ run: |
18
+ echo "python version: $(python --version)"
19
+ pip install ruff==0.12.2 # Ensure ruff is installed
20
+ ruff format --check .
.gitignore ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Temporary files generated in training
2
+ dpo_samples*
3
+ scoring_results
4
+ results/
5
+ hf_slurm_logs/
6
+ slurm_results/
7
+ enroot_images/
8
+ slurm*.out
9
+ cache_*
10
+ mlruns/
11
+ local_download_dir/
12
+ audioverse/data
13
+ # the folder pattern is sft_{year}.
14
+ sft_20*
15
+ data/
16
+ audioverse/cache
17
+ # vim ipython plugin generated files
18
+ .jukit
19
+
20
+ # node
21
+ node_modules
22
+ package.json
23
+ package-lock.json
24
+
25
+ # Byte-compiled / optimized / DLL files
26
+ __pycache__/
27
+ *.py[cod]
28
+ *$py.class
29
+
30
+ # C extensions
31
+ *.so
32
+
33
+ # Distribution / packaging
34
+ .Python
35
+ build/
36
+ develop-eggs/
37
+ dist/
38
+ downloads/
39
+ eggs/
40
+ .eggs/
41
+ lib/
42
+ lib64/
43
+ parts/
44
+ sdist/
45
+ var/
46
+ wheels/
47
+ share/python-wheels/
48
+ *.egg-info/
49
+ .installed.cfg
50
+ *.egg
51
+ MANIFEST
52
+
53
+ # PyInstaller
54
+ # Usually these files are written by a python script from a template
55
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
56
+ *.manifest
57
+ *.spec
58
+
59
+ # Installer logs
60
+ pip-log.txt
61
+ pip-delete-this-directory.txt
62
+
63
+ # Unit test / coverage reports
64
+ !tests/*
65
+ htmlcov/
66
+ .tox/
67
+ .nox/
68
+ .coverage
69
+ .coverage.*
70
+ .cache
71
+ nosetests.xml
72
+ coverage.xml
73
+ *.cover
74
+ *.py,cover
75
+ .hypothesis/
76
+ .pytest_cache/
77
+ cover/
78
+
79
+ # Translations
80
+ *.mo
81
+ *.pot
82
+
83
+ # Django stuff:
84
+ *.log
85
+ local_settings.py
86
+ db.sqlite3
87
+ db.sqlite3-journal
88
+
89
+ # Flask stuff:
90
+ instance/
91
+ .webassets-cache
92
+
93
+ # Scrapy stuff:
94
+ .scrapy
95
+
96
+ # Sphinx documentation
97
+ docs/_build/
98
+
99
+ # PyBuilder
100
+ .pybuilder/
101
+ target/
102
+
103
+ # Jupyter Notebook
104
+ .ipynb_checkpoints
105
+
106
+ # IPython
107
+ profile_default/
108
+ ipython_config.py
109
+
110
+ # pyenv
111
+ # For a library or package, you might want to ignore these files since the code is
112
+ # intended to run in multiple environments; otherwise, check them in:
113
+ # .python-version
114
+
115
+ # pipenv
116
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
117
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
118
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
119
+ # install all needed dependencies.
120
+ #Pipfile.lock
121
+
122
+ # poetry
123
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
124
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
125
+ # commonly ignored for libraries.
126
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
127
+ #poetry.lock
128
+
129
+ # pdm
130
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
131
+ #pdm.lock
132
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
133
+ # in version control.
134
+ # https://pdm.fming.dev/#use-with-ide
135
+ .pdm.toml
136
+
137
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
138
+ __pypackages__/
139
+
140
+ # Celery stuff
141
+ celerybeat-schedule
142
+ celerybeat.pid
143
+
144
+ # SageMath parsed files
145
+ *.sage.py
146
+
147
+ # Environments
148
+ /.conda_env*
149
+ /.env*
150
+ /.higgs_audio_env*
151
+ /.venv*
152
+ /conda_env*
153
+ /env*
154
+ /ENV*
155
+ /higgs_audio_env*
156
+ /venv*
157
+
158
+ # Spyder project settings
159
+ .spyderproject
160
+ .spyproject
161
+
162
+ # Rope project settings
163
+ .ropeproject
164
+
165
+ # mkdocs documentation
166
+ /site
167
+
168
+ # mypy
169
+ .mypy_cache/
170
+ .dmypy.json
171
+ dmypy.json
172
+
173
+ # Pyre type checker
174
+ .pyre/
175
+
176
+ # pytype static type analyzer
177
+ .pytype/
178
+
179
+ # Cython debug symbols
180
+ cython_debug/
181
+
182
+ # PyCharm
183
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
184
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
185
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
186
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
187
+ #.idea/
188
+
189
+ *.jsonl
190
+ download
191
+ .DS_Store
192
+ *entry.py
193
+
194
+ # Pytorch
195
+ torch_compile_debug/
196
+
197
+ # Out Dir
198
+ result/
199
+
200
+ # Ruff
201
+ .ruff_cache/
.gitmodules ADDED
File without changes
LICENSE ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
README.md CHANGED
@@ -1,3 +1,298 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">Higgs Audio V2: Redefining Expressiveness in Audio Generation</h1>
2
+
3
+ <div align="center" style="display: flex; justify-content: center; margin-top: 10px;">
4
+ <a href="https://boson.ai/blog/higgs-audio-v2"><img src='https://img.shields.io/badge/🚀-Launch Blogpost-228B22' style="margin-right: 5px;"></a>
5
+ <a href="https://boson.ai/demo/tts"><img src="https://img.shields.io/badge/🕹️-Boson%20AI%20Playground-9C276A" style="margin-right: 5px;"></a>
6
+ <a href="https://huggingface.co/spaces/smola/higgs_audio_v2"><img src="https://img.shields.io/badge/🎮-HF%20Space%20Playground-8A2BE2" style="margin-right: 5px;"></a>
7
+ <a href="https://huggingface.co/bosonai/higgs-audio-v2-generation-3B-base"><img src="https://img.shields.io/badge/🤗-Checkpoints (3.6B LLM + 2.2B audio adapter)-ED5A22.svg" style="margin-right: 5px;"></a>
8
+ </div>
9
+
10
+
11
+ We are open-sourcing Higgs Audio v2, a powerful audio foundation model pretrained on over 10 million hours of audio data and a diverse set of text data. Despite having no post-training or fine-tuning, Higgs Audio v2 excels in expressive audio generation, thanks to its deep language and acoustic understanding.
12
+
13
+ On [EmergentTTS-Eval](https://github.com/boson-ai/emergenttts-eval-public), it achieves win rates of **75.7%** and **55.7%** over "gpt-4o-mini-tts" on the "Emotions" and "Questions" categories, respectively. It also obtains state-of-the-art performance on traditional TTS benchmarks like Seed-TTS Eval and Emotional Speech Dataset (ESD). Moreover, the model demonstrates capabilities rarely seen in previous systems, including generating natural multi-speaker dialogues in multiple languages, automatic prosody adaptation during narration, melodic humming with the cloned voice, and simultaneous generation of speech and background music.
14
+
15
+ <p align="center">
16
+ <img src="figures/emergent-tts-emotions-win-rate.png" width=900>
17
+ </p>
18
+
19
+ Here's the demo video that shows some of its emergent capabilities (remember to unmute):
20
+
21
+ <video src="https://github.com/user-attachments/assets/0fd73fad-097f-48a9-9f3f-bc2a63b3818d" type="video/mp4" width="80%" controls>
22
+ </video>
23
+
24
+ Here's another demo video that show-cases the model's multilingual capability and how it enabled live translation (remember to unmute):
25
+
26
+ <video src="https://github.com/user-attachments/assets/2b9b01ff-67fc-4bd9-9714-7c7df09e38d6" type="video/mp4" width="80%" controls>
27
+ </video>
28
+
29
+ ## Installation
30
+
31
+ We recommend to use NVIDIA Deep Learning Container to manage the CUDA environment. Following are two docker images that we have verified:
32
+ - nvcr.io/nvidia/pytorch:25.02-py3
33
+ - nvcr.io/nvidia/pytorch:25.01-py3
34
+
35
+ Here's an example command for launching a docker container environment. Please also check the [official NVIDIA documentations](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch).
36
+
37
+ ```bash
38
+ docker run --gpus all --ipc=host --net=host --ulimit memlock=-1 --ulimit stack=67108864 -it --rm nvcr.io/nvidia/pytorch:25.02-py3 bash
39
+ ```
40
+
41
+ ### Option 1: Direct installation
42
+
43
+
44
+ ```bash
45
+ git clone https://github.com/boson-ai/higgs-audio.git
46
+ cd higgs-audio
47
+
48
+ pip install -r requirements.txt
49
+ pip install -e .
50
+ ```
51
+
52
+ ### Option 2: Using venv
53
+
54
+ ```bash
55
+ git clone https://github.com/boson-ai/higgs-audio.git
56
+ cd higgs-audio
57
+
58
+ python3 -m venv higgs_audio_env
59
+ source higgs_audio_env/bin/activate
60
+ pip install -r requirements.txt
61
+ pip install -e .
62
+ ```
63
+
64
+
65
+ ### Option 3: Using conda
66
+ ```bash
67
+ git clone https://github.com/boson-ai/higgs-audio.git
68
+ cd higgs-audio
69
+
70
+ conda create -y --prefix ./conda_env --override-channels --strict-channel-priority --channel "conda-forge" "python==3.10.*"
71
+ conda activate ./conda_env
72
+ pip install -r requirements.txt
73
+ pip install -e .
74
+
75
+ # Uninstalling environment:
76
+ conda deactivate
77
+ conda remove -y --prefix ./conda_env --all
78
+ ```
79
+
80
+ ### Option 4: Using uv
81
+ ```bash
82
+ git clone https://github.com/boson-ai/higgs-audio.git
83
+ cd higgs-audio
84
+
85
+ uv venv --python 3.10
86
+ source .venv/bin/activate
87
+ uv pip install -r requirements.txt
88
+ uv pip install -e .
89
+ ```
90
+
91
+ ### Option 5: Using vllm
92
+
93
+ For advanced usage with higher throughput, we also built OpenAI compatible API server backed by vLLM engine for you to use.
94
+ Please refer to [examples/vllm](./examples/vllm) for more details.
95
+
96
+
97
+ ## Usage
98
+
99
+ > [!TIP]
100
+ > For optimal performance, run the generation examples on a machine equipped with GPU with at least 24GB memory!
101
+
102
+ ### Get Started
103
+
104
+ Here's a basic python snippet to help you get started.
105
+
106
+ ```python
107
+ from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse
108
+ from boson_multimodal.data_types import ChatMLSample, Message, AudioContent
109
+
110
+ import torch
111
+ import torchaudio
112
+ import time
113
+ import click
114
+
115
+ MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
116
+ AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"
117
+
118
+ system_prompt = (
119
+ "Generate audio following instruction.\n\n<|scene_desc_start|>\nAudio is recorded from a quiet room.\n<|scene_desc_end|>"
120
+ )
121
+
122
+ messages = [
123
+ Message(
124
+ role="system",
125
+ content=system_prompt,
126
+ ),
127
+ Message(
128
+ role="user",
129
+ content="The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years.",
130
+ ),
131
+ ]
132
+ device = "cuda" if torch.cuda.is_available() else "cpu"
133
+
134
+ serve_engine = HiggsAudioServeEngine(MODEL_PATH, AUDIO_TOKENIZER_PATH, device=device)
135
+
136
+ output: HiggsAudioResponse = serve_engine.generate(
137
+ chat_ml_sample=ChatMLSample(messages=messages),
138
+ max_new_tokens=1024,
139
+ temperature=0.3,
140
+ top_p=0.95,
141
+ top_k=50,
142
+ stop_strings=["<|end_of_text|>", "<|eot_id|>"],
143
+ )
144
+ torchaudio.save(f"output.wav", torch.from_numpy(output.audio)[None, :], output.sampling_rate)
145
+ ```
146
+
147
+ We also provide a list of examples under [examples](./examples). In the following we highlight a few examples to help you use Higgs Audio v2.
148
+
149
+ ### Zero-Shot Voice Cloning
150
+ Generate audio that sounds similar as the provided [reference audio](./examples/voice_prompts/belinda.wav).
151
+
152
+ ```bash
153
+ python3 examples/generation.py \
154
+ --transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \
155
+ --ref_audio belinda \
156
+ --temperature 0.3 \
157
+ --out_path generation.wav
158
+ ```
159
+
160
+ The generation script will automatically use `cuda:0` if it founds cuda is available. To change the device id, specify `--device_id`:
161
+
162
+ ```bash
163
+ python3 examples/generation.py \
164
+ --transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \
165
+ --ref_audio belinda \
166
+ --temperature 0.3 \
167
+ --device_id 0 \
168
+ --out_path generation.wav
169
+ ```
170
+
171
+ You can also try other voices. Check more example voices in [examples/voice_prompts](./examples/voice_prompts). You can also add your own voice to the folder.
172
+
173
+ ```bash
174
+ python3 examples/generation.py \
175
+ --transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \
176
+ --ref_audio broom_salesman \
177
+ --temperature 0.3 \
178
+ --out_path generation.wav
179
+ ```
180
+
181
+ ### Single-speaker Generation with Smart Voice
182
+ If you do not specify reference voice, the model will decide the voice based on the transcript it sees.
183
+
184
+ ```bash
185
+ python3 examples/generation.py \
186
+ --transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \
187
+ --temperature 0.3 \
188
+ --out_path generation.wav
189
+ ```
190
+
191
+
192
+ ### Multi-speaker Dialog with Smart Voice
193
+ Generate multi-speaker dialog. The model will decide the voices based on the transcript it sees.
194
+
195
+ ```bash
196
+ python3 examples/generation.py \
197
+ --transcript examples/transcript/multi_speaker/en_argument.txt \
198
+ --seed 12345 \
199
+ --out_path generation.wav
200
+ ```
201
+
202
+ ### Multi-speaker Dialog with Voice Clone
203
+
204
+ Generate multi-speaker dialog with the voices you picked.
205
+
206
+ ```bash
207
+ python3 examples/generation.py \
208
+ --transcript examples/transcript/multi_speaker/en_argument.txt \
209
+ --ref_audio belinda,broom_salesman \
210
+ --ref_audio_in_system_message \
211
+ --chunk_method speaker \
212
+ --seed 12345 \
213
+ --out_path generation.wav
214
+ ```
215
+
216
+
217
+ ## Technical Details
218
+ <img src="figures/higgs_audio_v2_architecture_combined.png" width=900>
219
+
220
+
221
+ Higgs Audio v2 adopts the "generation variant" depicted in the architecture figure above. Its strong performance is driven by three key technical innovations:
222
+ - We developed an automated annotation pipeline that leverages multiple ASR models, sound event classification models, and our in-house audio understanding model. Using this pipeline, we cleaned and annotated 10 million hours audio data, which we refer to as **AudioVerse**. The in-house understanding model is finetuned on top of [Higgs Audio v1 Understanding](https://www.boson.ai/blog/higgs-audio), which adopts the "understanding variant" shown in the architecture figure.
223
+ - We trained a unified audio tokenizer from scratch that captures both semantic and acoustic features. We also open-sourced our evaluation set on [HuggingFace](https://huggingface.co/datasets/bosonai/AudioTokenBench). Learn more in the [tokenizer blog](./tech_blogs/TOKENIZER_BLOG.md).
224
+ - We proposed the DualFFN architecture, which enhances the LLM’s ability to model acoustics tokens with minimal computational overhead. See the [architecture blog](./tech_blogs/ARCHITECTURE_BLOG.md).
225
+
226
+ ## Evaluation
227
+
228
+ Here's the performance of Higgs Audio v2 on four benchmarks, [Seed-TTS Eval](https://github.com/BytedanceSpeech/seed-tts-eval), [Emotional Speech Dataset (ESD)](https://paperswithcode.com/dataset/esd), [EmergentTTS-Eval](https://arxiv.org/abs/2505.23009), and Multi-speaker Eval:
229
+
230
+ #### Seed-TTS Eval & ESD
231
+
232
+ We prompt Higgs Audio v2 with the reference text, reference audio, and target text for zero-shot TTS. We use the standard evaluation metrics from Seed-TTS Eval and ESD.
233
+
234
+ | | SeedTTS-Eval| | ESD | |
235
+ |------------------------------|--------|--------|---------|-------------------|
236
+ | | WER ↓ | SIM ↑ | WER ↓ | SIM (emo2vec) ��� |
237
+ | Cosyvoice2 | 2.28 | 65.49 | 2.71 | 80.48 |
238
+ | Qwen2.5-omni† | 2.33 | 64.10 | - | - |
239
+ | ElevenLabs Multilingual V2 | **1.43** | 50.00 | 1.66 | 65.87 |
240
+ | Higgs Audio v1 | 2.18 | 66.27 | **1.49** | 82.84 |
241
+ | Higgs Audio v2 (base) | 2.44 | **67.70** | 1.78 | **86.13** |
242
+
243
+
244
+ #### EmergentTTS-Eval ("Emotions" and "Questions")
245
+
246
+ Following the [EmergentTTS-Eval Paper](https://arxiv.org/abs/2505.23009), we report the win-rate over "gpt-4o-mini-tts" with the "alloy" voice. The judge model is Gemini 2.5 Pro.
247
+
248
+ | Model | Emotions (%) ↑ | Questions (%) ↑ |
249
+ |------------------------------------|--------------|----------------|
250
+ | Higgs Audio v2 (base) | **75.71%** | **55.71%** |
251
+ | [gpt-4o-audio-preview†](https://platform.openai.com/docs/models/gpt-4o-audio-preview) | 61.64% | 47.85% |
252
+ | [Hume.AI](https://www.hume.ai/research) | 61.60% | 43.21% |
253
+ | **BASELINE:** [gpt-4o-mini-tts](https://platform.openai.com/docs/models/gpt-4o-mini-tts) | 50.00% | 50.00% |
254
+ | [Qwen 2.5 Omni†](https://github.com/QwenLM/Qwen2.5-Omni) | 41.60% | 51.78% |
255
+ | [minimax/speech-02-hd](https://replicate.com/minimax/speech-02-hd) | 40.86% | 47.32% |
256
+ | [ElevenLabs Multilingual v2](https://elevenlabs.io/blog/eleven-multilingual-v2) | 30.35% | 39.46% |
257
+ | [DeepGram Aura-2](https://deepgram.com/learn/introducing-aura-2-enterprise-text-to-speech) | 29.28% | 48.21% |
258
+ | [Sesame csm-1B](https://github.com/SesameAILabs/csm) | 15.96% | 31.78% |
259
+
260
+ <sup><sub>'†' means using the strong-prompting method described in the paper.</sub></sup>
261
+
262
+
263
+ #### Multi-speaker Eval
264
+
265
+ We also designed a multi-speaker evaluation benchmark to evaluate the capability of Higgs Audio v2 for multi-speaker dialog generation. The benchmark contains three subsets
266
+
267
+ - `two-speaker-conversation`: 1000 synthetic dialogues involving two speakers. We fix two reference audio clips to evaluate the model's ability in double voice cloning for utterances ranging from 4 to 10 dialogues between two randomly chosen persona.
268
+ - `small talk (no ref)`: 250 synthetic dialogues curated in the same way as above, but are characterized by short utterances and a limited number of turns (4–6), we do not fix reference audios in this case and this set is designed to evaluate the model's ability to automatically assign appropriate voices to speakers.
269
+ - `small talk (ref)`: 250 synthetic dialogues similar to above, but contains even shorter utterances as this set is meant to include reference clips in it's context, similar to `two-speaker-conversation`.
270
+
271
+
272
+ We report the word-error-rate (WER) and the geometric mean between intra-speaker similarity and inter-speaker dis-similarity on these three subsets. Other than Higgs Audio v2, we also evaluated [MoonCast](https://github.com/jzq2000/MoonCast) and [nari-labs/Dia-1.6B-0626](https://huggingface.co/nari-labs/Dia-1.6B-0626), two of the most popular open-source models capable of multi-speaker dialog generation. Results are summarized in the following table. We are not able to run [nari-labs/Dia-1.6B-0626](https://huggingface.co/nari-labs/Dia-1.6B-0626) on our "two-speaker-conversation" subset due to its strict limitation on the length of the utterances and output audio.
273
+
274
+ | | two-speaker-conversation | |small talk | | small talk (no ref) | |
275
+ | ---------------------------------------------- | -------------- | ------------------ | ---------- | -------------- | ------------------- | -------------- |
276
+ | | WER ↓ | Mean Sim & Dis-sim ↑ | WER ↓ | Mean Sim & Dis-sim ↑ | WER ↓ | Mean Sim & Dis-sim ↑ |
277
+ | [MoonCast](https://github.com/jzq2000/MoonCast) | 38.77 | 46.02 | **8.33** | 63.68 | 24.65 | 53.94 |
278
+ | [nari-labs/Dia-1.6B-0626](https://huggingface.co/nari-labs/Dia-1.6B-0626) | \- | \- | 17.62 | 63.15 | 19.46 | **61.14** |
279
+ | Higgs Audio v2 (base) | **18.88** | **51.95** | 11.89 | **67.92** | **14.65** | 55.28 |
280
+
281
+
282
+ ## Citation
283
+
284
+ If you feel the repository is helpful, please kindly cite as:
285
+
286
+ ```
287
+ @misc{higgsaudio2025,
288
+ author = {{Boson AI}},
289
+ title = {{Higgs Audio V2: Redefining Expressiveness in Audio Generation}},
290
+ year = {2025},
291
+ howpublished = {\url{https://github.com/boson-ai/higgs-audio}},
292
+ note = {GitHub repository. Release blog available at \url{https://www.boson.ai/blog/higgs-audio-v2}},
293
+ }
294
+ ```
295
+
296
+ ## Third-Party Licenses
297
+
298
+ The `boson_multimodal/audio_processing/` directory contains code derived from third-party repositories, primarily from [xcodec](https://github.com/zhenye234/xcodec). Please see the [`LICENSE`](boson_multimodal/audio_processing/LICENSE) in that directory for complete attribution and licensing information.
boson_multimodal/__init__.py ADDED
File without changes
boson_multimodal/audio_processing/LICENSE ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Third-Party License Attribution for Audio Processing Module
2
+ ===========================================================
3
+
4
+ This directory contains code derived from multiple open-source projects.
5
+ The following sections detail the licenses and attributions for third-party code.
6
+
7
+ ## XCodec Repository
8
+ The code in this directory is derived from:
9
+ https://github.com/zhenye234/xcodec
10
+
11
+ ## Individual File Attributions
12
+
13
+ ### Quantization Module (quantization/)
14
+ - Several files contain code derived from Meta Platforms, Inc. and the vector-quantize-pytorch repository
15
+ - Individual files contain their own license headers where applicable
16
+ - The vector-quantize-pytorch portions are licensed under the MIT License
17
+
18
+ ## License Terms
19
+
20
+ ### MIT License (for applicable portions)
21
+ Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ of this software and associated documentation files (the "Software"), to deal
23
+ in the Software without restriction, including without limitation the rights
24
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ copies of the Software, and to permit persons to whom the Software is
26
+ furnished to do so, subject to the following conditions:
27
+
28
+ The above copyright notice and this permission notice shall be included in all
29
+ copies or substantial portions of the Software.
30
+
31
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ SOFTWARE.
38
+
39
+ ## Attribution Requirements
40
+ When using this code, please ensure proper attribution to:
41
+ 1. The original xcodec repository: https://github.com/zhenye234/xcodec
42
+ 2. Any other repositories mentioned in individual file headers
43
+ 3. This derivative work and its modifications
44
+
45
+ ## Disclaimer
46
+ This directory contains modified versions of the original code. Please refer to
47
+ the original repositories for the canonical implementations and their specific
48
+ license terms.
49
+
50
+ For any questions about licensing or attribution, please check the individual
51
+ file headers and the original source repositories.
boson_multimodal/audio_processing/descriptaudiocodec/__init__.py ADDED
File without changes
boson_multimodal/audio_processing/descriptaudiocodec/dac/model/base.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+ from audiotools import AudioSignal
10
+ from torch import nn
11
+
12
+ SUPPORTED_VERSIONS = ["1.0.0"]
13
+
14
+
15
+ @dataclass
16
+ class DACFile:
17
+ codes: torch.Tensor
18
+
19
+ # Metadata
20
+ chunk_length: int
21
+ original_length: int
22
+ input_db: float
23
+ channels: int
24
+ sample_rate: int
25
+ padding: bool
26
+ dac_version: str
27
+
28
+ def save(self, path):
29
+ artifacts = {
30
+ "codes": self.codes.numpy().astype(np.uint16),
31
+ "metadata": {
32
+ "input_db": self.input_db.numpy().astype(np.float32),
33
+ "original_length": self.original_length,
34
+ "sample_rate": self.sample_rate,
35
+ "chunk_length": self.chunk_length,
36
+ "channels": self.channels,
37
+ "padding": self.padding,
38
+ "dac_version": SUPPORTED_VERSIONS[-1],
39
+ },
40
+ }
41
+ path = Path(path).with_suffix(".dac")
42
+ with open(path, "wb") as f:
43
+ np.save(f, artifacts)
44
+ return path
45
+
46
+ @classmethod
47
+ def load(cls, path):
48
+ artifacts = np.load(path, allow_pickle=True)[()]
49
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
50
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
+ raise RuntimeError(f"Given file {path} can't be loaded with this version of descript-audio-codec.")
52
+ return cls(codes=codes, **artifacts["metadata"])
53
+
54
+
55
+ class CodecMixin:
56
+ @property
57
+ def padding(self):
58
+ if not hasattr(self, "_padding"):
59
+ self._padding = True
60
+ return self._padding
61
+
62
+ @padding.setter
63
+ def padding(self, value):
64
+ assert isinstance(value, bool)
65
+
66
+ layers = [l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))]
67
+
68
+ for layer in layers:
69
+ if value:
70
+ if hasattr(layer, "original_padding"):
71
+ layer.padding = layer.original_padding
72
+ else:
73
+ layer.original_padding = layer.padding
74
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
75
+
76
+ self._padding = value
77
+
78
+ def get_delay(self):
79
+ # Any number works here, delay is invariant to input length
80
+ l_out = self.get_output_length(0)
81
+ L = l_out
82
+
83
+ layers = []
84
+ for layer in self.modules():
85
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
86
+ layers.append(layer)
87
+
88
+ for layer in reversed(layers):
89
+ d = layer.dilation[0]
90
+ k = layer.kernel_size[0]
91
+ s = layer.stride[0]
92
+
93
+ if isinstance(layer, nn.ConvTranspose1d):
94
+ L = ((L - d * (k - 1) - 1) / s) + 1
95
+ elif isinstance(layer, nn.Conv1d):
96
+ L = (L - 1) * s + d * (k - 1) + 1
97
+
98
+ L = math.ceil(L)
99
+
100
+ l_in = L
101
+
102
+ return (l_in - l_out) // 2
103
+
104
+ def get_output_length(self, input_length):
105
+ L = input_length
106
+ # Calculate output length
107
+ for layer in self.modules():
108
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
109
+ d = layer.dilation[0]
110
+ k = layer.kernel_size[0]
111
+ s = layer.stride[0]
112
+
113
+ if isinstance(layer, nn.Conv1d):
114
+ L = ((L - d * (k - 1) - 1) / s) + 1
115
+ elif isinstance(layer, nn.ConvTranspose1d):
116
+ L = (L - 1) * s + d * (k - 1) + 1
117
+
118
+ L = math.floor(L)
119
+ return L
120
+
121
+ @torch.no_grad()
122
+ def compress(
123
+ self,
124
+ audio_path_or_signal: Union[str, Path, AudioSignal],
125
+ win_duration: float = 1.0,
126
+ verbose: bool = False,
127
+ normalize_db: float = -16,
128
+ n_quantizers: int = None,
129
+ ) -> DACFile:
130
+ """Processes an audio signal from a file or AudioSignal object into
131
+ discrete codes. This function processes the signal in short windows,
132
+ using constant GPU memory.
133
+
134
+ Parameters
135
+ ----------
136
+ audio_path_or_signal : Union[str, Path, AudioSignal]
137
+ audio signal to reconstruct
138
+ win_duration : float, optional
139
+ window duration in seconds, by default 5.0
140
+ verbose : bool, optional
141
+ by default False
142
+ normalize_db : float, optional
143
+ normalize db, by default -16
144
+
145
+ Returns
146
+ -------
147
+ DACFile
148
+ Object containing compressed codes and metadata
149
+ required for decompression
150
+ """
151
+ audio_signal = audio_path_or_signal
152
+ if isinstance(audio_signal, (str, Path)):
153
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
154
+
155
+ self.eval()
156
+ original_padding = self.padding
157
+ original_device = audio_signal.device
158
+
159
+ audio_signal = audio_signal.clone()
160
+ original_sr = audio_signal.sample_rate
161
+
162
+ resample_fn = audio_signal.resample
163
+ loudness_fn = audio_signal.loudness
164
+
165
+ # If audio is > 10 minutes long, use the ffmpeg versions
166
+ if audio_signal.signal_duration >= 10 * 60 * 60:
167
+ resample_fn = audio_signal.ffmpeg_resample
168
+ loudness_fn = audio_signal.ffmpeg_loudness
169
+
170
+ original_length = audio_signal.signal_length
171
+ resample_fn(self.sample_rate)
172
+ input_db = loudness_fn()
173
+
174
+ if normalize_db is not None:
175
+ audio_signal.normalize(normalize_db)
176
+ audio_signal.ensure_max_of_audio()
177
+
178
+ nb, nac, nt = audio_signal.audio_data.shape
179
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
180
+ win_duration = audio_signal.signal_duration if win_duration is None else win_duration
181
+
182
+ if audio_signal.signal_duration <= win_duration:
183
+ # Unchunked compression (used if signal length < win duration)
184
+ self.padding = True
185
+ n_samples = nt
186
+ hop = nt
187
+ else:
188
+ # Chunked inference
189
+ self.padding = False
190
+ # Zero-pad signal on either side by the delay
191
+ audio_signal.zero_pad(self.delay, self.delay)
192
+ n_samples = int(win_duration * self.sample_rate)
193
+ # Round n_samples to nearest hop length multiple
194
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
195
+ hop = self.get_output_length(n_samples)
196
+
197
+ codes = []
198
+ range_fn = range if not verbose else tqdm.trange
199
+
200
+ for i in range_fn(0, nt, hop):
201
+ x = audio_signal[..., i : i + n_samples]
202
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
203
+
204
+ audio_data = x.audio_data.to(self.device)
205
+ audio_data = self.preprocess(audio_data, self.sample_rate)
206
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
207
+ codes.append(c.to(original_device))
208
+ chunk_length = c.shape[-1]
209
+
210
+ codes = torch.cat(codes, dim=-1)
211
+
212
+ dac_file = DACFile(
213
+ codes=codes,
214
+ chunk_length=chunk_length,
215
+ original_length=original_length,
216
+ input_db=input_db,
217
+ channels=nac,
218
+ sample_rate=original_sr,
219
+ padding=self.padding,
220
+ dac_version=SUPPORTED_VERSIONS[-1],
221
+ )
222
+
223
+ if n_quantizers is not None:
224
+ codes = codes[:, :n_quantizers, :]
225
+
226
+ self.padding = original_padding
227
+ return dac_file
228
+
229
+ @torch.no_grad()
230
+ def decompress(
231
+ self,
232
+ obj: Union[str, Path, DACFile],
233
+ verbose: bool = False,
234
+ ) -> AudioSignal:
235
+ """Reconstruct audio from a given .dac file
236
+
237
+ Parameters
238
+ ----------
239
+ obj : Union[str, Path, DACFile]
240
+ .dac file location or corresponding DACFile object.
241
+ verbose : bool, optional
242
+ Prints progress if True, by default False
243
+
244
+ Returns
245
+ -------
246
+ AudioSignal
247
+ Object with the reconstructed audio
248
+ """
249
+ self.eval()
250
+ if isinstance(obj, (str, Path)):
251
+ obj = DACFile.load(obj)
252
+
253
+ original_padding = self.padding
254
+ self.padding = obj.padding
255
+
256
+ range_fn = range if not verbose else tqdm.trange
257
+ codes = obj.codes
258
+ original_device = codes.device
259
+ chunk_length = obj.chunk_length
260
+ recons = []
261
+
262
+ for i in range_fn(0, codes.shape[-1], chunk_length):
263
+ c = codes[..., i : i + chunk_length].to(self.device)
264
+ z = self.quantizer.from_codes(c)[0]
265
+ r = self.decode(z)
266
+ recons.append(r.to(original_device))
267
+
268
+ recons = torch.cat(recons, dim=-1)
269
+ recons = AudioSignal(recons, self.sample_rate)
270
+
271
+ resample_fn = recons.resample
272
+ loudness_fn = recons.loudness
273
+
274
+ # If audio is > 10 minutes long, use the ffmpeg versions
275
+ if recons.signal_duration >= 10 * 60 * 60:
276
+ resample_fn = recons.ffmpeg_resample
277
+ loudness_fn = recons.ffmpeg_loudness
278
+
279
+ recons.normalize(obj.input_db)
280
+ resample_fn(obj.sample_rate)
281
+ recons = recons[..., : obj.original_length]
282
+ loudness_fn()
283
+ recons.audio_data = recons.audio_data.reshape(-1, obj.channels, obj.original_length)
284
+
285
+ self.padding = original_padding
286
+ return recons
boson_multimodal/audio_processing/descriptaudiocodec/dac/model/dac.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from audiotools.ml import BaseModel
9
+ from torch import nn
10
+
11
+ from .base import CodecMixin
12
+ from dac.nn.layers import Snake1d
13
+ from dac.nn.layers import WNConv1d
14
+ from dac.nn.layers import WNConvTranspose1d
15
+ from dac.nn.quantize import ResidualVectorQuantize
16
+
17
+
18
+ def init_weights(m):
19
+ if isinstance(m, nn.Conv1d):
20
+ nn.init.trunc_normal_(m.weight, std=0.02)
21
+ nn.init.constant_(m.bias, 0)
22
+
23
+
24
+ class ResidualUnit(nn.Module):
25
+ def __init__(self, dim: int = 16, dilation: int = 1):
26
+ super().__init__()
27
+ pad = ((7 - 1) * dilation) // 2
28
+ self.block = nn.Sequential(
29
+ Snake1d(dim),
30
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
31
+ Snake1d(dim),
32
+ WNConv1d(dim, dim, kernel_size=1),
33
+ )
34
+
35
+ def forward(self, x):
36
+ y = self.block(x)
37
+ pad = (x.shape[-1] - y.shape[-1]) // 2
38
+ if pad > 0:
39
+ x = x[..., pad:-pad]
40
+ return x + y
41
+
42
+
43
+ class EncoderBlock(nn.Module):
44
+ def __init__(self, dim: int = 16, stride: int = 1):
45
+ super().__init__()
46
+ self.block = nn.Sequential(
47
+ ResidualUnit(dim // 2, dilation=1),
48
+ ResidualUnit(dim // 2, dilation=3),
49
+ ResidualUnit(dim // 2, dilation=9),
50
+ Snake1d(dim // 2),
51
+ WNConv1d(
52
+ dim // 2,
53
+ dim,
54
+ kernel_size=2 * stride,
55
+ stride=stride,
56
+ padding=math.ceil(stride / 2),
57
+ ),
58
+ )
59
+
60
+ def forward(self, x):
61
+ return self.block(x)
62
+
63
+
64
+ class Encoder(nn.Module):
65
+ def __init__(
66
+ self,
67
+ d_model: int = 64,
68
+ strides: list = [2, 4, 8, 8],
69
+ d_latent: int = 256,
70
+ ):
71
+ super().__init__()
72
+ # Create first convolution
73
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
74
+
75
+ # Create EncoderBlocks that double channels as they downsample by `stride`
76
+ for stride in strides:
77
+ d_model *= 2
78
+ self.block += [EncoderBlock(d_model, stride=stride)]
79
+
80
+ # Create last convolution
81
+ self.block += [
82
+ Snake1d(d_model),
83
+ WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
84
+ ]
85
+
86
+ # Wrap black into nn.Sequential
87
+ self.block = nn.Sequential(*self.block)
88
+ self.enc_dim = d_model
89
+
90
+ def forward(self, x):
91
+ return self.block(x)
92
+
93
+
94
+ class DecoderBlock(nn.Module):
95
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, out_pad=0):
96
+ super().__init__()
97
+ self.block = nn.Sequential(
98
+ Snake1d(input_dim),
99
+ WNConvTranspose1d(
100
+ input_dim,
101
+ output_dim,
102
+ kernel_size=2 * stride,
103
+ stride=stride,
104
+ padding=math.ceil(stride / 2),
105
+ output_padding=stride % 2, # out_pad,
106
+ ),
107
+ ResidualUnit(output_dim, dilation=1),
108
+ ResidualUnit(output_dim, dilation=3),
109
+ ResidualUnit(output_dim, dilation=9),
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.block(x)
114
+
115
+
116
+ class Decoder(nn.Module):
117
+ def __init__(
118
+ self,
119
+ input_channel,
120
+ channels,
121
+ rates,
122
+ d_out: int = 1,
123
+ ):
124
+ super().__init__()
125
+
126
+ # Add first conv layer
127
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
128
+
129
+ # Add upsampling + MRF blocks
130
+ for i, stride in enumerate(rates):
131
+ input_dim = channels // 2**i
132
+ output_dim = channels // 2 ** (i + 1)
133
+ if i == 1:
134
+ out_pad = 1
135
+ else:
136
+ out_pad = 0
137
+ layers += [DecoderBlock(input_dim, output_dim, stride, out_pad)]
138
+
139
+ # Add final conv layer
140
+ layers += [
141
+ Snake1d(output_dim),
142
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
143
+ # nn.Tanh(),
144
+ ]
145
+
146
+ self.model = nn.Sequential(*layers)
147
+
148
+ def forward(self, x):
149
+ return self.model(x)
150
+
151
+
152
+ class DAC(BaseModel, CodecMixin):
153
+ def __init__(
154
+ self,
155
+ encoder_dim: int = 64,
156
+ encoder_rates: List[int] = [2, 4, 8, 8],
157
+ latent_dim: int = None,
158
+ decoder_dim: int = 1536,
159
+ decoder_rates: List[int] = [8, 8, 4, 2],
160
+ n_codebooks: int = 9,
161
+ codebook_size: int = 1024,
162
+ codebook_dim: Union[int, list] = 8,
163
+ quantizer_dropout: bool = False,
164
+ sample_rate: int = 44100,
165
+ ):
166
+ super().__init__()
167
+
168
+ self.encoder_dim = encoder_dim
169
+ self.encoder_rates = encoder_rates
170
+ self.decoder_dim = decoder_dim
171
+ self.decoder_rates = decoder_rates
172
+ self.sample_rate = sample_rate
173
+
174
+ if latent_dim is None:
175
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
176
+
177
+ self.latent_dim = latent_dim
178
+
179
+ self.hop_length = np.prod(encoder_rates)
180
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
181
+
182
+ self.n_codebooks = n_codebooks
183
+ self.codebook_size = codebook_size
184
+ self.codebook_dim = codebook_dim
185
+ self.quantizer = ResidualVectorQuantize(
186
+ input_dim=latent_dim,
187
+ n_codebooks=n_codebooks,
188
+ codebook_size=codebook_size,
189
+ codebook_dim=codebook_dim,
190
+ quantizer_dropout=quantizer_dropout,
191
+ )
192
+
193
+ self.decoder = Decoder(
194
+ latent_dim,
195
+ decoder_dim,
196
+ decoder_rates,
197
+ )
198
+ self.sample_rate = sample_rate
199
+ self.apply(init_weights)
200
+
201
+ self.delay = self.get_delay()
202
+
203
+ def preprocess(self, audio_data, sample_rate):
204
+ if sample_rate is None:
205
+ sample_rate = self.sample_rate
206
+ assert sample_rate == self.sample_rate
207
+
208
+ length = audio_data.shape[-1]
209
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
210
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
211
+
212
+ return audio_data
213
+
214
+ def encode(
215
+ self,
216
+ audio_data: torch.Tensor,
217
+ n_quantizers: int = None,
218
+ ):
219
+ """Encode given audio data and return quantized latent codes
220
+
221
+ Parameters
222
+ ----------
223
+ audio_data : Tensor[B x 1 x T]
224
+ Audio data to encode
225
+ n_quantizers : int, optional
226
+ Number of quantizers to use, by default None
227
+ If None, all quantizers are used.
228
+
229
+ Returns
230
+ -------
231
+ dict
232
+ A dictionary with the following keys:
233
+ "z" : Tensor[B x D x T]
234
+ Quantized continuous representation of input
235
+ "codes" : Tensor[B x N x T]
236
+ Codebook indices for each codebook
237
+ (quantized discrete representation of input)
238
+ "latents" : Tensor[B x N*D x T]
239
+ Projected latents (continuous representation of input before quantization)
240
+ "vq/commitment_loss" : Tensor[1]
241
+ Commitment loss to train encoder to predict vectors closer to codebook
242
+ entries
243
+ "vq/codebook_loss" : Tensor[1]
244
+ Codebook loss to update the codebook
245
+ "length" : int
246
+ Number of samples in input audio
247
+ """
248
+ z = self.encoder(audio_data)
249
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
250
+ return z, codes, latents, commitment_loss, codebook_loss
251
+
252
+ def decode(self, z: torch.Tensor):
253
+ """Decode given latent codes and return audio data
254
+
255
+ Parameters
256
+ ----------
257
+ z : Tensor[B x D x T]
258
+ Quantized continuous representation of input
259
+ length : int, optional
260
+ Number of samples in output audio, by default None
261
+
262
+ Returns
263
+ -------
264
+ dict
265
+ A dictionary with the following keys:
266
+ "audio" : Tensor[B x 1 x length]
267
+ Decoded audio data.
268
+ """
269
+ return self.decoder(z)
270
+
271
+ def forward(
272
+ self,
273
+ audio_data: torch.Tensor,
274
+ sample_rate: int = None,
275
+ n_quantizers: int = None,
276
+ ):
277
+ """Model forward pass
278
+
279
+ Parameters
280
+ ----------
281
+ audio_data : Tensor[B x 1 x T]
282
+ Audio data to encode
283
+ sample_rate : int, optional
284
+ Sample rate of audio data in Hz, by default None
285
+ If None, defaults to `self.sample_rate`
286
+ n_quantizers : int, optional
287
+ Number of quantizers to use, by default None.
288
+ If None, all quantizers are used.
289
+
290
+ Returns
291
+ -------
292
+ dict
293
+ A dictionary with the following keys:
294
+ "z" : Tensor[B x D x T]
295
+ Quantized continuous representation of input
296
+ "codes" : Tensor[B x N x T]
297
+ Codebook indices for each codebook
298
+ (quantized discrete representation of input)
299
+ "latents" : Tensor[B x N*D x T]
300
+ Projected latents (continuous representation of input before quantization)
301
+ "vq/commitment_loss" : Tensor[1]
302
+ Commitment loss to train encoder to predict vectors closer to codebook
303
+ entries
304
+ "vq/codebook_loss" : Tensor[1]
305
+ Codebook loss to update the codebook
306
+ "length" : int
307
+ Number of samples in input audio
308
+ "audio" : Tensor[B x 1 x length]
309
+ Decoded audio data.
310
+ """
311
+ length = audio_data.shape[-1]
312
+ audio_data = self.preprocess(audio_data, sample_rate)
313
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
314
+
315
+ x = self.decode(z)
316
+ return {
317
+ "audio": x[..., :length],
318
+ "z": z,
319
+ "codes": codes,
320
+ "latents": latents,
321
+ "vq/commitment_loss": commitment_loss,
322
+ "vq/codebook_loss": codebook_loss,
323
+ }
324
+
325
+
326
+ if __name__ == "__main__":
327
+ import numpy as np
328
+ from functools import partial
329
+
330
+ model = DAC().to("cpu")
331
+
332
+ for n, m in model.named_modules():
333
+ o = m.extra_repr()
334
+ p = sum([np.prod(p.size()) for p in m.parameters()])
335
+ fn = lambda o, p: o + f" {p / 1e6:<.3f}M params."
336
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
337
+ print(model)
338
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
339
+
340
+ length = 88200 * 2
341
+ x = torch.randn(1, 1, length).to(model.device)
342
+ x.requires_grad_(True)
343
+ x.retain_grad()
344
+
345
+ # Make a forward pass
346
+ out = model(x)["audio"]
347
+ print("Input shape:", x.shape)
348
+ print("Output shape:", out.shape)
349
+
350
+ # Create gradient variable
351
+ grad = torch.zeros_like(out)
352
+ grad[:, :, grad.shape[-1] // 2] = 1
353
+
354
+ # Make a backward pass
355
+ out.backward(grad)
356
+
357
+ # Check non-zero values
358
+ gradmap = x.grad.squeeze(0)
359
+ gradmap = (gradmap != 0).sum(0) # sum across features
360
+ rf = (gradmap != 0).sum()
361
+
362
+ print(f"Receptive field: {rf.item()}")
363
+
364
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
365
+ model.decompress(model.compress(x, verbose=True), verbose=True)
boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/layers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ # Scripting this brings model speed up 1.4x
18
+ @torch.jit.script
19
+ def snake(x, alpha):
20
+ shape = x.shape
21
+ x = x.reshape(shape[0], shape[1], -1)
22
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
+ x = x.reshape(shape)
24
+ return x
25
+
26
+
27
+ class Snake1d(nn.Module):
28
+ def __init__(self, channels):
29
+ super().__init__()
30
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
+
32
+ def forward(self, x):
33
+ return snake(x, self.alpha)
boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/quantize.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from dac.nn.layers import WNConv1d
11
+
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """
15
+ Implementation of VQ similar to Karpathy's repo:
16
+ https://github.com/karpathy/deep-vector-quantization
17
+ Additionally uses following tricks from Improved VQGAN
18
+ (https://arxiv.org/pdf/2110.04627.pdf):
19
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
+ for improved codebook usage
21
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
+ improves training stability
23
+ """
24
+
25
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
26
+ super().__init__()
27
+ self.codebook_size = codebook_size
28
+ self.codebook_dim = codebook_dim
29
+
30
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
+
34
+ def forward(self, z):
35
+ """Quantized the input tensor using a fixed codebook and returns
36
+ the corresponding codebook vectors
37
+
38
+ Parameters
39
+ ----------
40
+ z : Tensor[B x D x T]
41
+
42
+ Returns
43
+ -------
44
+ Tensor[B x D x T]
45
+ Quantized continuous representation of input
46
+ Tensor[1]
47
+ Commitment loss to train encoder to predict vectors closer to codebook
48
+ entries
49
+ Tensor[1]
50
+ Codebook loss to update the codebook
51
+ Tensor[B x T]
52
+ Codebook indices (quantized discrete representation of input)
53
+ Tensor[B x D x T]
54
+ Projected latents (continuous representation of input before quantization)
55
+ """
56
+
57
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
58
+ z_e = self.in_proj(z) # z_e : (B x D x T)
59
+ z_q, indices = self.decode_latents(z_e)
60
+
61
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
62
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
63
+
64
+ z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass
65
+
66
+ z_q = self.out_proj(z_q)
67
+
68
+ return z_q, commitment_loss, codebook_loss, indices, z_e
69
+
70
+ def embed_code(self, embed_id):
71
+ return F.embedding(embed_id, self.codebook.weight)
72
+
73
+ def decode_code(self, embed_id):
74
+ return self.embed_code(embed_id).transpose(1, 2)
75
+
76
+ def decode_latents(self, latents):
77
+ encodings = rearrange(latents, "b d t -> (b t) d")
78
+ codebook = self.codebook.weight # codebook: (N x D)
79
+
80
+ # L2 normalize encodings and codebook (ViT-VQGAN)
81
+ encodings = F.normalize(encodings)
82
+ codebook = F.normalize(codebook)
83
+
84
+ # Compute euclidean distance with codebook
85
+ dist = (
86
+ encodings.pow(2).sum(1, keepdim=True)
87
+ - 2 * encodings @ codebook.t()
88
+ + codebook.pow(2).sum(1, keepdim=True).t()
89
+ )
90
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
91
+ z_q = self.decode_code(indices)
92
+ return z_q, indices
93
+
94
+
95
+ class ResidualVectorQuantize(nn.Module):
96
+ """
97
+ Introduced in SoundStream: An end2end neural audio codec
98
+ https://arxiv.org/abs/2107.03312
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ input_dim: int = 512,
104
+ n_codebooks: int = 9,
105
+ codebook_size: int = 1024,
106
+ codebook_dim: Union[int, list] = 8,
107
+ quantizer_dropout: float = 0.0,
108
+ ):
109
+ super().__init__()
110
+ if isinstance(codebook_dim, int):
111
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
112
+
113
+ self.n_codebooks = n_codebooks
114
+ self.codebook_dim = codebook_dim
115
+ self.codebook_size = codebook_size
116
+
117
+ self.quantizers = nn.ModuleList(
118
+ [VectorQuantize(input_dim, codebook_size, codebook_dim[i]) for i in range(n_codebooks)]
119
+ )
120
+ self.quantizer_dropout = quantizer_dropout
121
+
122
+ def forward(self, z, n_quantizers: int = None):
123
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
124
+ the corresponding codebook vectors
125
+ Parameters
126
+ ----------
127
+ z : Tensor[B x D x T]
128
+ n_quantizers : int, optional
129
+ No. of quantizers to use
130
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
131
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
132
+ when in training mode, and a random number of quantizers is used.
133
+ Returns
134
+ -------
135
+ dict
136
+ A dictionary with the following keys:
137
+
138
+ "z" : Tensor[B x D x T]
139
+ Quantized continuous representation of input
140
+ "codes" : Tensor[B x N x T]
141
+ Codebook indices for each codebook
142
+ (quantized discrete representation of input)
143
+ "latents" : Tensor[B x N*D x T]
144
+ Projected latents (continuous representation of input before quantization)
145
+ "vq/commitment_loss" : Tensor[1]
146
+ Commitment loss to train encoder to predict vectors closer to codebook
147
+ entries
148
+ "vq/codebook_loss" : Tensor[1]
149
+ Codebook loss to update the codebook
150
+ """
151
+ z_q = 0
152
+ residual = z
153
+ commitment_loss = 0
154
+ codebook_loss = 0
155
+
156
+ codebook_indices = []
157
+ latents = []
158
+
159
+ if n_quantizers is None:
160
+ n_quantizers = self.n_codebooks
161
+ if self.training:
162
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
163
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
164
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
165
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
166
+ n_quantizers = n_quantizers.to(z.device)
167
+
168
+ for i, quantizer in enumerate(self.quantizers):
169
+ if self.training is False and i >= n_quantizers:
170
+ break
171
+
172
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual)
173
+
174
+ # Create mask to apply quantizer dropout
175
+ mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
176
+ z_q = z_q + z_q_i * mask[:, None, None]
177
+ residual = residual - z_q_i
178
+
179
+ # Sum losses
180
+ commitment_loss += (commitment_loss_i * mask).mean()
181
+ codebook_loss += (codebook_loss_i * mask).mean()
182
+
183
+ codebook_indices.append(indices_i)
184
+ latents.append(z_e_i)
185
+
186
+ codes = torch.stack(codebook_indices, dim=1)
187
+ latents = torch.cat(latents, dim=1)
188
+
189
+ return z_q, codes, latents, commitment_loss, codebook_loss
190
+
191
+ def from_codes(self, codes: torch.Tensor):
192
+ """Given the quantized codes, reconstruct the continuous representation
193
+ Parameters
194
+ ----------
195
+ codes : Tensor[B x N x T]
196
+ Quantized discrete representation of input
197
+ Returns
198
+ -------
199
+ Tensor[B x D x T]
200
+ Quantized continuous representation of input
201
+ """
202
+ z_q = 0.0
203
+ z_p = []
204
+ n_codebooks = codes.shape[1]
205
+ for i in range(n_codebooks):
206
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
207
+ z_p.append(z_p_i)
208
+
209
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
210
+ z_q = z_q + z_q_i
211
+ return z_q, torch.cat(z_p, dim=1), codes
212
+
213
+ def from_latents(self, latents: torch.Tensor):
214
+ """Given the unquantized latents, reconstruct the
215
+ continuous representation after quantization.
216
+
217
+ Parameters
218
+ ----------
219
+ latents : Tensor[B x N x T]
220
+ Continuous representation of input after projection
221
+
222
+ Returns
223
+ -------
224
+ Tensor[B x D x T]
225
+ Quantized representation of full-projected space
226
+ Tensor[B x D x T]
227
+ Quantized representation of latent space
228
+ """
229
+ z_q = 0
230
+ z_p = []
231
+ codes = []
232
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
233
+
234
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
235
+ for i in range(n_codebooks):
236
+ j, k = dims[i], dims[i + 1]
237
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
238
+ z_p.append(z_p_i)
239
+ codes.append(codes_i)
240
+
241
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
242
+ z_q = z_q + z_q_i
243
+
244
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
245
+
246
+
247
+ if __name__ == "__main__":
248
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
249
+ x = torch.randn(16, 512, 80)
250
+ y = rvq(x)
251
+ print(y["latents"].shape)
boson_multimodal/audio_processing/higgs_audio_tokenizer.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on code from: https://github.com/zhenye234/xcodec
2
+ # Licensed under MIT License
3
+ # Modifications by BosonAI
4
+
5
+ import math
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Optional, Union, Sequence
11
+ import numpy as np
12
+ from transformers import AutoModel
13
+ import torchaudio
14
+ import json
15
+ import librosa
16
+ from huggingface_hub import snapshot_download
17
+
18
+ from vector_quantize_pytorch import ResidualFSQ
19
+ from .descriptaudiocodec.dac.model import dac as dac2
20
+ from .quantization.vq import ResidualVectorQuantizer
21
+ from .semantic_module import Encoder, Decoder
22
+
23
+
24
+ class EncodedResult:
25
+ def __init__(self, audio_codes):
26
+ self.audio_codes = audio_codes
27
+
28
+
29
+ class HiggsAudioFeatureExtractor(nn.Module):
30
+ def __init__(self, sampling_rate=16000):
31
+ super().__init__()
32
+ self.sampling_rate = sampling_rate
33
+
34
+ def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"):
35
+ # Convert from librosa to torch
36
+ audio_signal = torch.tensor(raw_audio)
37
+ audio_signal = audio_signal.unsqueeze(0)
38
+ if len(audio_signal.shape) < 3:
39
+ audio_signal = audio_signal.unsqueeze(0)
40
+ return {"input_values": audio_signal}
41
+
42
+
43
+ class HiggsAudioTokenizer(nn.Module):
44
+ def __init__(
45
+ self,
46
+ n_filters: int = 32,
47
+ D: int = 128,
48
+ target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6],
49
+ ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320
50
+ sample_rate: int = 16000,
51
+ bins: int = 1024,
52
+ n_q: int = 8,
53
+ codebook_dim: int = None,
54
+ normalize: bool = False,
55
+ causal: bool = False,
56
+ semantic_techer: str = "hubert_base_general",
57
+ last_layer_semantic: bool = True,
58
+ merge_mode: str = "concat",
59
+ downsample_mode: str = "step_down",
60
+ semantic_mode: str = "classic",
61
+ vq_scale: int = 1,
62
+ semantic_sample_rate: int = None,
63
+ device: str = "cuda",
64
+ ):
65
+ super().__init__()
66
+ self.hop_length = np.prod(ratios)
67
+ self.semantic_techer = semantic_techer
68
+
69
+ self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz
70
+
71
+ self.target_bandwidths = target_bandwidths
72
+ self.n_q = n_q
73
+ self.sample_rate = sample_rate
74
+ self.encoder = dac2.Encoder(64, ratios, D)
75
+
76
+ self.decoder_2 = dac2.Decoder(D, 1024, ratios)
77
+ self.last_layer_semantic = last_layer_semantic
78
+ self.device = device
79
+ if semantic_techer == "hubert_base":
80
+ self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960")
81
+ self.semantic_sample_rate = 16000
82
+ self.semantic_dim = 768
83
+ self.encoder_semantic_dim = 768
84
+
85
+ elif semantic_techer == "wavlm_base_plus":
86
+ self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus")
87
+ self.semantic_sample_rate = 16000
88
+ self.semantic_dim = 768
89
+ self.encoder_semantic_dim = 768
90
+
91
+ elif semantic_techer == "hubert_base_general":
92
+ self.semantic_model = AutoModel.from_pretrained("bosonai/hubert_base", trust_remote_code=True)
93
+ self.semantic_sample_rate = 16000
94
+ self.semantic_dim = 768
95
+ self.encoder_semantic_dim = 768
96
+
97
+ # Overwrite semantic model sr to ensure semantic_downsample_factor is an integer
98
+ if semantic_sample_rate is not None:
99
+ self.semantic_sample_rate = semantic_sample_rate
100
+
101
+ self.semantic_model.eval()
102
+
103
+ # make the semantic model parameters do not need gradient
104
+ for param in self.semantic_model.parameters():
105
+ param.requires_grad = False
106
+
107
+ self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320)
108
+
109
+ self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale)
110
+ self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim)
111
+ self.decoder_semantic = Decoder(
112
+ code_dim=self.encoder_semantic_dim, output_channels=self.semantic_dim, decode_channels=self.semantic_dim
113
+ )
114
+
115
+ # out_D=D+768
116
+ if isinstance(bins, int): # RVQ
117
+ self.quantizer = ResidualVectorQuantizer(
118
+ dimension=self.quantizer_dim, codebook_dim=codebook_dim, n_q=n_q, bins=bins
119
+ )
120
+ self.quantizer_type = "RVQ"
121
+ else: # RFSQ
122
+ self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q)
123
+ self.quantizer_type = "RFSQ"
124
+
125
+ self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim)
126
+ self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim)
127
+ self.fc_post2 = nn.Linear(self.quantizer_dim, D)
128
+
129
+ self.downsample_mode = downsample_mode
130
+ if downsample_mode == "avg":
131
+ self.semantic_pooling = nn.AvgPool1d(
132
+ kernel_size=self.semantic_downsample_factor, stride=self.semantic_downsample_factor
133
+ )
134
+
135
+ self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate)
136
+
137
+ @property
138
+ def tps(self):
139
+ return self.frame_rate
140
+
141
+ @property
142
+ def sampling_rate(self):
143
+ return self.sample_rate
144
+
145
+ @property
146
+ def num_codebooks(self):
147
+ return self.n_q
148
+
149
+ @property
150
+ def codebook_size(self):
151
+ return self.quantizer_dim
152
+
153
+ def get_last_layer(self):
154
+ return self.decoder.layers[-1].weight
155
+
156
+ def calculate_rec_loss(self, rec, target):
157
+ target = target / target.norm(dim=-1, keepdim=True)
158
+ rec = rec / rec.norm(dim=-1, keepdim=True)
159
+ rec_loss = (1 - (target * rec).sum(-1)).mean()
160
+
161
+ return rec_loss
162
+
163
+ @torch.no_grad()
164
+ def get_regress_target(self, x):
165
+ x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate)
166
+
167
+ if (
168
+ self.semantic_techer == "hubert_base"
169
+ or self.semantic_techer == "hubert_base_general"
170
+ or self.semantic_techer == "wavlm_base_plus"
171
+ ):
172
+ x = x[:, 0, :]
173
+ x = F.pad(x, (160, 160))
174
+ target = self.semantic_model(x, output_hidden_states=True).hidden_states
175
+ target = torch.stack(target, dim=1) # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2)
176
+
177
+ # average for all layers
178
+ target = target.mean(1)
179
+ # target = target[9]
180
+ # if self.hop_length > 320:
181
+ # target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
182
+
183
+ elif self.semantic_techer == "w2v_bert2":
184
+ target = self.semantic_model(x)
185
+
186
+ elif self.semantic_techer.startswith("whisper"):
187
+ if self.last_layer_semantic:
188
+ target = self.semantic_model(x, avg_layers=False)
189
+ else:
190
+ target = self.semantic_model(x, avg_layers=True)
191
+
192
+ elif self.semantic_techer.startswith("mert_music"):
193
+ if self.last_layer_semantic:
194
+ target = self.semantic_model(x, avg_layers=False)
195
+ else:
196
+ target = self.semantic_model(x, avg_layers=True)
197
+
198
+ elif self.semantic_techer.startswith("qwen_audio_omni"):
199
+ target = self.semantic_model(x)
200
+
201
+ if self.downsample_mode == "step_down":
202
+ if self.semantic_downsample_factor > 1:
203
+ target = target[:, :: self.semantic_downsample_factor, :]
204
+
205
+ elif self.downsample_mode == "avg":
206
+ target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
207
+ return target
208
+
209
+ def forward(self, x: torch.Tensor, bw: int):
210
+ e_semantic_input = self.get_regress_target(x).detach()
211
+
212
+ e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
213
+ e_acoustic = self.encoder(x)
214
+
215
+ e = torch.cat([e_acoustic, e_semantic], dim=1)
216
+
217
+ e = self.fc_prior(e.transpose(1, 2))
218
+
219
+ if self.quantizer_type == "RVQ":
220
+ e = e.transpose(1, 2)
221
+ quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
222
+ quantized = quantized.transpose(1, 2)
223
+ else:
224
+ quantized, codes = self.quantizer(e)
225
+ commit_loss = torch.tensor(0.0)
226
+
227
+ quantized_semantic = self.fc_post1(quantized).transpose(1, 2)
228
+ quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
229
+
230
+ o = self.decoder_2(quantized_acoustic)
231
+
232
+ o_semantic = self.decoder_semantic(quantized_semantic)
233
+ semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(), o_semantic)
234
+
235
+ return o, commit_loss, semantic_recon_loss, None
236
+
237
+ def encode(self, audio_path_or_wv, sr=None, loudness_normalize=False, loudness_threshold=-23.0):
238
+ if isinstance(audio_path_or_wv, str):
239
+ wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None)
240
+ else:
241
+ wv = audio_path_or_wv
242
+ assert sr is not None
243
+ if loudness_normalize:
244
+ import pyloudnorm as pyln
245
+
246
+ meter = pyln.Meter(sr)
247
+ l = meter.integrated_loudness(wv)
248
+ wv = pyln.normalize.loudness(wv, l, loudness_threshold)
249
+ if sr != self.sampling_rate:
250
+ wv = librosa.resample(wv, orig_sr=sr, target_sr=self.sampling_rate)
251
+ if self.audio_tokenizer_feature_extractor is not None:
252
+ inputs = self.audio_tokenizer_feature_extractor(
253
+ raw_audio=wv, sampling_rate=self.audio_tokenizer_feature_extractor.sampling_rate, return_tensors="pt"
254
+ )
255
+ input_values = inputs["input_values"].to(self.device)
256
+ else:
257
+ input_values = torch.from_numpy(wv).float().unsqueeze(0)
258
+ with torch.no_grad():
259
+ encoder_outputs = self._xcodec_encode(input_values)
260
+ vq_code = encoder_outputs.audio_codes[0]
261
+ return vq_code
262
+
263
+ def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor:
264
+ bw = target_bw
265
+
266
+ e_semantic_input = self.get_regress_target(x).detach()
267
+
268
+ e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
269
+ e_acoustic = self.encoder(x)
270
+
271
+ if e_acoustic.shape[2] != e_semantic.shape[2]:
272
+ pad_size = 160 * self.semantic_downsample_factor
273
+ e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0))
274
+
275
+ if e_acoustic.shape[2] != e_semantic.shape[2]:
276
+ if e_acoustic.shape[2] > e_semantic.shape[2]:
277
+ e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]]
278
+ else:
279
+ e_semantic = e_semantic[:, :, : e_acoustic.shape[2]]
280
+
281
+ e = torch.cat([e_acoustic, e_semantic], dim=1)
282
+
283
+ e = self.fc_prior(e.transpose(1, 2))
284
+
285
+ if self.quantizer_type == "RVQ":
286
+ e = e.transpose(1, 2)
287
+ quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
288
+ codes = codes.permute(1, 0, 2)
289
+ else:
290
+ quantized, codes = self.quantizer(e)
291
+ codes = codes.permute(0, 2, 1)
292
+
293
+ # return codes
294
+ return EncodedResult(codes)
295
+
296
+ def decode(self, vq_code: torch.Tensor) -> torch.Tensor:
297
+ vq_code = vq_code.to(self.device)
298
+
299
+ if self.quantizer_type == "RVQ":
300
+ vq_code = vq_code.permute(1, 0, 2)
301
+ quantized = self.quantizer.decode(vq_code)
302
+ quantized = quantized.transpose(1, 2)
303
+ else:
304
+ vq_code = vq_code.permute(0, 2, 1)
305
+ quantized = self.quantizer.get_output_from_indices(vq_code)
306
+ quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
307
+
308
+ o = self.decoder_2(quantized_acoustic)
309
+ return o.detach().cpu().numpy()
310
+
311
+
312
+ def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"):
313
+ is_local = os.path.exists(tokenizer_name_or_path)
314
+ if not is_local:
315
+ tokenizer_path = snapshot_download(tokenizer_name_or_path)
316
+ else:
317
+ tokenizer_path = tokenizer_name_or_path
318
+ config_path = os.path.join(tokenizer_path, "config.json")
319
+ model_path = os.path.join(tokenizer_path, "model.pth")
320
+ config = json.load(open(config_path))
321
+ model = HiggsAudioTokenizer(
322
+ **config,
323
+ device=device,
324
+ )
325
+ parameter_dict = torch.load(model_path, map_location=device)
326
+ model.load_state_dict(parameter_dict, strict=False)
327
+ model.to(device)
328
+ model.eval()
329
+ return model
boson_multimodal/audio_processing/quantization/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from .vq import QuantizedResult, ResidualVectorQuantizer
boson_multimodal/audio_processing/quantization/ac.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Arithmetic coder."""
8
+
9
+ import io
10
+ import math
11
+ import random
12
+ import typing as tp
13
+ import torch
14
+
15
+ from ..binary import BitPacker, BitUnpacker
16
+
17
+
18
+ def build_stable_quantized_cdf(
19
+ pdf: torch.Tensor, total_range_bits: int, roundoff: float = 1e-8, min_range: int = 2, check: bool = True
20
+ ) -> torch.Tensor:
21
+ """Turn the given PDF into a quantized CDF that splits
22
+ [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
23
+ to the PDF.
24
+
25
+ Args:
26
+ pdf (torch.Tensor): probability distribution, shape should be `[N]`.
27
+ total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
28
+ during the coding process is `[0, 2 ** total_range_bits - 1]`.
29
+ roundoff (float): will round the pdf up to that level to remove difference coming
30
+ from e.g. evaluating the Language Model on different architectures.
31
+ min_range (int): minimum range width. Should always be at least 2 for numerical
32
+ stability. Use this to avoid pathological behavior is a value
33
+ that is expected to be rare actually happens in real life.
34
+ check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
35
+ """
36
+ pdf = pdf.detach()
37
+ if roundoff:
38
+ pdf = (pdf / roundoff).floor() * roundoff
39
+ # interpolate with uniform distribution to achieve desired minimum probability.
40
+ total_range = 2**total_range_bits
41
+ cardinality = len(pdf)
42
+ alpha = min_range * cardinality / total_range
43
+ assert alpha <= 1, "you must reduce min_range"
44
+ ranges = (((1 - alpha) * total_range) * pdf).floor().long()
45
+ ranges += min_range
46
+ quantized_cdf = torch.cumsum(ranges, dim=-1)
47
+ if min_range < 2:
48
+ raise ValueError("min_range must be at least 2.")
49
+ if check:
50
+ assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
51
+ if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
52
+ raise ValueError("You must increase your total_range_bits.")
53
+ return quantized_cdf
54
+
55
+
56
+ class ArithmeticCoder:
57
+ """ArithmeticCoder,
58
+ Let us take a distribution `p` over `N` symbols, and assume we have a stream
59
+ of random variables `s_t` sampled from `p`. Let us assume that we have a budget
60
+ of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
61
+ corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
62
+ sequence `(s_t)` by doing the following:
63
+
64
+ 1) Initialize the current range to` [0 ** 2 B - 1]`.
65
+ 2) For each time step t, split the current range into contiguous chunks,
66
+ one for each possible outcome, with size roughly proportional to `p`.
67
+ For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
68
+ would be `{[0, 2], [3, 3]}`.
69
+ 3) Select the chunk corresponding to `s_t`, and replace the current range with this.
70
+ 4) When done encoding all the values, just select any value remaining in the range.
71
+
72
+ You will notice that this procedure can fail: for instance if at any point in time
73
+ the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
74
+ possible outcome. Intuitively, the more likely a value is, the less the range width
75
+ will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
76
+ coding scheme, likely outcomes would take less bits, and more of them can be coded
77
+ with a fixed budget.
78
+
79
+ In practice, we do not know `B` ahead of time, but we have a way to inject new bits
80
+ when the current range decreases below a given limit (given by `total_range_bits`), without
81
+ having to redo all the computations. If we encode mostly likely values, we will seldom
82
+ need to inject new bits, but a single rare value can deplete our stock of entropy!
83
+
84
+ In this explanation, we assumed that the distribution `p` was constant. In fact, the present
85
+ code works for any sequence `(p_t)` possibly different for each timestep.
86
+ We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
87
+ the KL between the true distribution and `p_t`, the most efficient the coding will be.
88
+
89
+ Args:
90
+ fo (IO[bytes]): file-like object to which the bytes will be written to.
91
+ total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
92
+ Any time the current range width fall under this limit, new bits will
93
+ be injected to rescale the initial range.
94
+ """
95
+
96
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
97
+ assert total_range_bits <= 30
98
+ self.total_range_bits = total_range_bits
99
+ self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
100
+ self.low: int = 0
101
+ self.high: int = 0
102
+ self.max_bit: int = -1
103
+ self._dbg: tp.List[tp.Any] = []
104
+ self._dbg2: tp.List[tp.Any] = []
105
+
106
+ @property
107
+ def delta(self) -> int:
108
+ """Return the current range width."""
109
+ return self.high - self.low + 1
110
+
111
+ def _flush_common_prefix(self):
112
+ # If self.low and self.high start with the sames bits,
113
+ # those won't change anymore as we always just increase the range
114
+ # by powers of 2, and we can flush them out to the bit stream.
115
+ assert self.high >= self.low, (self.low, self.high)
116
+ assert self.high < 2 ** (self.max_bit + 1)
117
+ while self.max_bit >= 0:
118
+ b1 = self.low >> self.max_bit
119
+ b2 = self.high >> self.max_bit
120
+ if b1 == b2:
121
+ self.low -= b1 << self.max_bit
122
+ self.high -= b1 << self.max_bit
123
+ assert self.high >= self.low, (self.high, self.low, self.max_bit)
124
+ assert self.low >= 0
125
+ self.max_bit -= 1
126
+ self.packer.push(b1)
127
+ else:
128
+ break
129
+
130
+ def push(self, symbol: int, quantized_cdf: torch.Tensor):
131
+ """Push the given symbol on the stream, flushing out bits
132
+ if possible.
133
+
134
+ Args:
135
+ symbol (int): symbol to encode with the AC.
136
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
137
+ to build this from your pdf estimate.
138
+ """
139
+ while self.delta < 2**self.total_range_bits:
140
+ self.low *= 2
141
+ self.high = self.high * 2 + 1
142
+ self.max_bit += 1
143
+
144
+ range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
145
+ range_high = quantized_cdf[symbol].item() - 1
146
+ effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
147
+ effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
148
+ assert self.low <= self.high
149
+ self.high = self.low + effective_high
150
+ self.low = self.low + effective_low
151
+ assert self.low <= self.high, (effective_low, effective_high, range_low, range_high)
152
+ self._dbg.append((self.low, self.high))
153
+ self._dbg2.append((self.low, self.high))
154
+ outs = self._flush_common_prefix()
155
+ assert self.low <= self.high
156
+ assert self.max_bit >= -1
157
+ assert self.max_bit <= 61, self.max_bit
158
+ return outs
159
+
160
+ def flush(self):
161
+ """Flush the remaining information to the stream."""
162
+ while self.max_bit >= 0:
163
+ b1 = (self.low >> self.max_bit) & 1
164
+ self.packer.push(b1)
165
+ self.max_bit -= 1
166
+ self.packer.flush()
167
+
168
+
169
+ class ArithmeticDecoder:
170
+ """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
171
+
172
+ Note that this must be called with **exactly** the same parameters and sequence
173
+ of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
174
+
175
+ If the AC encoder current range is [L, H], with `L` and `H` having the some common
176
+ prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
177
+ For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
178
+ `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
179
+ for a specific sequence of symbols and a binary-search allows us to decode those symbols.
180
+ At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
181
+ and we will need to read new bits from the stream and repeat the process.
182
+
183
+ """
184
+
185
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
186
+ self.total_range_bits = total_range_bits
187
+ self.low: int = 0
188
+ self.high: int = 0
189
+ self.current: int = 0
190
+ self.max_bit: int = -1
191
+ self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
192
+ # Following is for debugging
193
+ self._dbg: tp.List[tp.Any] = []
194
+ self._dbg2: tp.List[tp.Any] = []
195
+ self._last: tp.Any = None
196
+
197
+ @property
198
+ def delta(self) -> int:
199
+ return self.high - self.low + 1
200
+
201
+ def _flush_common_prefix(self):
202
+ # Given the current range [L, H], if both have a common prefix,
203
+ # we know we can remove it from our representation to avoid handling large numbers.
204
+ while self.max_bit >= 0:
205
+ b1 = self.low >> self.max_bit
206
+ b2 = self.high >> self.max_bit
207
+ if b1 == b2:
208
+ self.low -= b1 << self.max_bit
209
+ self.high -= b1 << self.max_bit
210
+ self.current -= b1 << self.max_bit
211
+ assert self.high >= self.low
212
+ assert self.low >= 0
213
+ self.max_bit -= 1
214
+ else:
215
+ break
216
+
217
+ def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
218
+ """Pull a symbol, reading as many bits from the stream as required.
219
+ This returns `None` when the stream has been exhausted.
220
+
221
+ Args:
222
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
223
+ to build this from your pdf estimate. This must be **exatly**
224
+ the same cdf as the one used at encoding time.
225
+ """
226
+ while self.delta < 2**self.total_range_bits:
227
+ bit = self.unpacker.pull()
228
+ if bit is None:
229
+ return None
230
+ self.low *= 2
231
+ self.high = self.high * 2 + 1
232
+ self.current = self.current * 2 + bit
233
+ self.max_bit += 1
234
+
235
+ def bin_search(low_idx: int, high_idx: int):
236
+ # Binary search is not just for coding interviews :)
237
+ if high_idx < low_idx:
238
+ raise RuntimeError("Binary search failed")
239
+ mid = (low_idx + high_idx) // 2
240
+ range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
241
+ range_high = quantized_cdf[mid].item() - 1
242
+ effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
243
+ effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
244
+ low = effective_low + self.low
245
+ high = effective_high + self.low
246
+ if self.current >= low:
247
+ if self.current <= high:
248
+ return (mid, low, high, self.current)
249
+ else:
250
+ return bin_search(mid + 1, high_idx)
251
+ else:
252
+ return bin_search(low_idx, mid - 1)
253
+
254
+ self._last = (self.low, self.high, self.current, self.max_bit)
255
+ sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
256
+ self._dbg.append((self.low, self.high, self.current))
257
+ self._flush_common_prefix()
258
+ self._dbg2.append((self.low, self.high, self.current))
259
+
260
+ return sym
261
+
262
+
263
+ def test():
264
+ torch.manual_seed(1234)
265
+ random.seed(1234)
266
+ for _ in range(4):
267
+ pdfs = []
268
+ cardinality = random.randrange(4000)
269
+ steps = random.randrange(100, 500)
270
+ fo = io.BytesIO()
271
+ encoder = ArithmeticCoder(fo)
272
+ symbols = []
273
+ for step in range(steps):
274
+ pdf = torch.softmax(torch.randn(cardinality), dim=0)
275
+ pdfs.append(pdf)
276
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
277
+ symbol = torch.multinomial(pdf, 1).item()
278
+ symbols.append(symbol)
279
+ encoder.push(symbol, q_cdf)
280
+ encoder.flush()
281
+
282
+ fo.seek(0)
283
+ decoder = ArithmeticDecoder(fo)
284
+ for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
285
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
286
+ decoded_symbol = decoder.pull(q_cdf)
287
+ assert decoded_symbol == symbol, idx
288
+ assert decoder.pull(torch.zeros(1)) is None
289
+
290
+
291
+ if __name__ == "__main__":
292
+ test()
boson_multimodal/audio_processing/quantization/core_vq.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # This implementation is inspired from
8
+ # https://github.com/lucidrains/vector-quantize-pytorch
9
+ # which is released under MIT License. Hereafter, the original license:
10
+ # MIT License
11
+ #
12
+ # Copyright (c) 2020 Phil Wang
13
+ #
14
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ # of this software and associated documentation files (the "Software"), to deal
16
+ # in the Software without restriction, including without limitation the rights
17
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ # copies of the Software, and to permit persons to whom the Software is
19
+ # furnished to do so, subject to the following conditions:
20
+ #
21
+ # The above copyright notice and this permission notice shall be included in all
22
+ # copies or substantial portions of the Software.
23
+ #
24
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ # SOFTWARE.
31
+
32
+ """Core vector quantization implementation."""
33
+
34
+ import typing as tp
35
+
36
+ from einops import rearrange, repeat
37
+ import torch
38
+ from torch import nn
39
+ import torch.nn.functional as F
40
+
41
+ from xcodec.quantization.distrib import broadcast_tensors, rank
42
+
43
+
44
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
45
+ return val if val is not None else d
46
+
47
+
48
+ def ema_inplace(moving_avg, new, decay: float):
49
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
50
+
51
+
52
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
53
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
54
+
55
+
56
+ def uniform_init(*shape: int):
57
+ t = torch.empty(shape)
58
+ nn.init.kaiming_uniform_(t)
59
+ return t
60
+
61
+
62
+ def sample_vectors(samples, num: int):
63
+ num_samples, device = samples.shape[0], samples.device
64
+
65
+ if num_samples >= num:
66
+ indices = torch.randperm(num_samples, device=device)[:num]
67
+ else:
68
+ indices = torch.randint(0, num_samples, (num,), device=device)
69
+
70
+ return samples[indices]
71
+
72
+
73
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
74
+ dim, dtype = samples.shape[-1], samples.dtype
75
+
76
+ means = sample_vectors(samples, num_clusters)
77
+
78
+ for _ in range(num_iters):
79
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
80
+ dists = -(diffs**2).sum(dim=-1)
81
+
82
+ buckets = dists.max(dim=-1).indices
83
+ bins = torch.bincount(buckets, minlength=num_clusters)
84
+ zero_mask = bins == 0
85
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
86
+
87
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
88
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
89
+ new_means = new_means / bins_min_clamped[..., None]
90
+
91
+ means = torch.where(zero_mask[..., None], means, new_means)
92
+
93
+ return means, bins
94
+
95
+
96
+ class EuclideanCodebook(nn.Module):
97
+ """Codebook with Euclidean distance.
98
+ Args:
99
+ dim (int): Dimension.
100
+ codebook_size (int): Codebook size.
101
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
102
+ If set to true, run the k-means algorithm on the first training batch and use
103
+ the learned centroids as initialization.
104
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
105
+ decay (float): Decay for exponential moving average over the codebooks.
106
+ epsilon (float): Epsilon value for numerical stability.
107
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
108
+ that have an exponential moving average cluster size less than the specified threshold with
109
+ randomly selected vector from the current batch.
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ dim: int,
115
+ codebook_size: int,
116
+ kmeans_init: int = False,
117
+ kmeans_iters: int = 10,
118
+ decay: float = 0.99,
119
+ epsilon: float = 1e-5,
120
+ threshold_ema_dead_code: int = 2,
121
+ ):
122
+ super().__init__()
123
+ self.decay = decay
124
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
125
+ embed = init_fn(codebook_size, dim)
126
+
127
+ self.codebook_size = codebook_size
128
+
129
+ self.kmeans_iters = kmeans_iters
130
+ self.epsilon = epsilon
131
+ self.threshold_ema_dead_code = threshold_ema_dead_code
132
+
133
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
134
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
135
+ self.register_buffer("embed", embed)
136
+ self.register_buffer("embed_avg", embed.clone())
137
+
138
+ @torch.jit.ignore
139
+ def init_embed_(self, data):
140
+ if self.inited:
141
+ return
142
+
143
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
144
+ self.embed.data.copy_(embed)
145
+ self.embed_avg.data.copy_(embed.clone())
146
+ self.cluster_size.data.copy_(cluster_size)
147
+ self.inited.data.copy_(torch.Tensor([True]))
148
+ # Make sure all buffers across workers are in sync after initialization
149
+ broadcast_tensors(self.buffers())
150
+
151
+ def replace_(self, samples, mask):
152
+ modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
153
+ self.embed.data.copy_(modified_codebook)
154
+
155
+ def expire_codes_(self, batch_samples):
156
+ if self.threshold_ema_dead_code == 0:
157
+ return
158
+
159
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
160
+ if not torch.any(expired_codes):
161
+ return
162
+
163
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
164
+ self.replace_(batch_samples, mask=expired_codes)
165
+ broadcast_tensors(self.buffers())
166
+
167
+ def preprocess(self, x):
168
+ x = rearrange(x, "... d -> (...) d")
169
+ return x
170
+
171
+ def quantize(self, x):
172
+ embed = self.embed.t()
173
+ dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
174
+ embed_ind = dist.max(dim=-1).indices
175
+ return embed_ind
176
+
177
+ def postprocess_emb(self, embed_ind, shape):
178
+ return embed_ind.view(*shape[:-1])
179
+
180
+ def dequantize(self, embed_ind):
181
+ quantize = F.embedding(embed_ind, self.embed) # get embedding based on index
182
+ return quantize
183
+
184
+ def encode(self, x):
185
+ shape = x.shape
186
+ # pre-process
187
+ x = self.preprocess(x)
188
+ # quantize
189
+ embed_ind = self.quantize(x) # get index based on Euclidean distance
190
+ # post-process
191
+ embed_ind = self.postprocess_emb(embed_ind, shape)
192
+ return embed_ind
193
+
194
+ def decode(self, embed_ind):
195
+ quantize = self.dequantize(embed_ind)
196
+ return quantize
197
+
198
+ def forward(self, x):
199
+ shape, dtype = x.shape, x.dtype
200
+ x = self.preprocess(x)
201
+
202
+ self.init_embed_(x)
203
+
204
+ embed_ind = self.quantize(x)
205
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
206
+ embed_ind = self.postprocess_emb(embed_ind, shape)
207
+ quantize = self.dequantize(embed_ind)
208
+
209
+ if self.training:
210
+ # We do the expiry of code at that point as buffers are in sync
211
+ # and all the workers will take the same decision.
212
+ self.expire_codes_(x)
213
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
214
+ embed_sum = x.t() @ embed_onehot
215
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
216
+ cluster_size = (
217
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
218
+ )
219
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
220
+ self.embed.data.copy_(embed_normalized)
221
+
222
+ return quantize, embed_ind
223
+
224
+
225
+ class VectorQuantization(nn.Module):
226
+ """Vector quantization implementation.
227
+ Currently supports only euclidean distance.
228
+ Args:
229
+ dim (int): Dimension
230
+ codebook_size (int): Codebook size
231
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
232
+ decay (float): Decay for exponential moving average over the codebooks.
233
+ epsilon (float): Epsilon value for numerical stability.
234
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
235
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
236
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
237
+ that have an exponential moving average cluster size less than the specified threshold with
238
+ randomly selected vector from the current batch.
239
+ commitment_weight (float): Weight for commitment loss.
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ dim: int,
245
+ codebook_size: int,
246
+ codebook_dim: tp.Optional[int] = None,
247
+ decay: float = 0.99,
248
+ epsilon: float = 1e-5,
249
+ kmeans_init: bool = True,
250
+ kmeans_iters: int = 50,
251
+ threshold_ema_dead_code: int = 2,
252
+ commitment_weight: float = 1.0,
253
+ ):
254
+ super().__init__()
255
+ _codebook_dim: int = default(codebook_dim, dim)
256
+
257
+ requires_projection = _codebook_dim != dim
258
+ self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
259
+ self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
260
+
261
+ self.epsilon = epsilon
262
+ self.commitment_weight = commitment_weight
263
+
264
+ self._codebook = EuclideanCodebook(
265
+ dim=_codebook_dim,
266
+ codebook_size=codebook_size,
267
+ kmeans_init=kmeans_init,
268
+ kmeans_iters=kmeans_iters,
269
+ decay=decay,
270
+ epsilon=epsilon,
271
+ threshold_ema_dead_code=threshold_ema_dead_code,
272
+ )
273
+ self.codebook_size = codebook_size
274
+
275
+ @property
276
+ def codebook(self):
277
+ return self._codebook.embed
278
+
279
+ def encode(self, x):
280
+ x = rearrange(x, "b d n -> b n d")
281
+ x = self.project_in(x)
282
+ embed_in = self._codebook.encode(x)
283
+ return embed_in
284
+
285
+ def decode(self, embed_ind):
286
+ quantize = self._codebook.decode(embed_ind)
287
+ quantize = self.project_out(quantize)
288
+ quantize = rearrange(quantize, "b n d -> b d n")
289
+ return quantize
290
+
291
+ def forward(self, x):
292
+ device = x.device
293
+ x = rearrange(x, "b d n -> b n d")
294
+ x = self.project_in(x)
295
+
296
+ quantize, embed_ind = self._codebook(x)
297
+
298
+ if self.training:
299
+ quantize = x + (quantize - x).detach()
300
+
301
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
302
+
303
+ if self.training:
304
+ if self.commitment_weight > 0:
305
+ commit_loss = F.mse_loss(quantize.detach(), x)
306
+ loss = loss + commit_loss * self.commitment_weight
307
+
308
+ quantize = self.project_out(quantize)
309
+ quantize = rearrange(quantize, "b n d -> b d n")
310
+ return quantize, embed_ind, loss
311
+
312
+
313
+ class ResidualVectorQuantization(nn.Module):
314
+ """Residual vector quantization implementation.
315
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
316
+ """
317
+
318
+ def __init__(self, *, num_quantizers, **kwargs):
319
+ super().__init__()
320
+ self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
321
+
322
+ def forward(self, x, n_q: tp.Optional[int] = None):
323
+ quantized_out = 0.0
324
+ residual = x
325
+
326
+ all_losses = []
327
+ all_indices = []
328
+
329
+ n_q = n_q or len(self.layers)
330
+
331
+ for layer in self.layers[:n_q]:
332
+ quantized, indices, loss = layer(residual)
333
+ residual = residual - quantized
334
+ quantized_out = quantized_out + quantized
335
+
336
+ all_indices.append(indices)
337
+ all_losses.append(loss)
338
+
339
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
340
+ return quantized_out, out_indices, out_losses
341
+
342
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
343
+ residual = x
344
+ all_indices = []
345
+ n_q = n_q or len(self.layers)
346
+ for layer in self.layers[:n_q]:
347
+ indices = layer.encode(residual)
348
+ quantized = layer.decode(indices)
349
+ residual = residual - quantized
350
+ all_indices.append(indices)
351
+ out_indices = torch.stack(all_indices)
352
+ return out_indices
353
+
354
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
355
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
356
+ for i, indices in enumerate(q_indices):
357
+ layer = self.layers[i]
358
+ quantized = layer.decode(indices)
359
+ quantized_out = quantized_out + quantized
360
+ return quantized_out
boson_multimodal/audio_processing/quantization/core_vq_lsx_version.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c)
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # This implementation is inspired from
6
+ # https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py and
7
+ # https://github.com/clementchadebec/benchmark_VAE/blob/dfa0dcf6c79172df5d27769c09c860c42008baaa/src/pythae/models/vq_vae/vq_vae_utils.py#L81
8
+ #
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ # All rights reserved.
11
+ #
12
+ # This source code is licensed under the license found in the
13
+ # LICENSE file in the root directory of this source tree.
14
+ #
15
+ # This implementation is inspired from
16
+ # https://github.com/lucidrains/vector-quantize-pytorch
17
+ # which is released under MIT License. Hereafter, the original license:
18
+ # MIT License
19
+ #
20
+ # Copyright (c) 2020 Phil Wang
21
+ #
22
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
23
+ # of this software and associated documentation files (the "Software"), to deal
24
+ # in the Software without restriction, including without limitation the rights
25
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
26
+ # copies of the Software, and to permit persons to whom the Software is
27
+ # furnished to do so, subject to the following conditions:
28
+ #
29
+ # The above copyright notice and this permission notice shall be included in all
30
+ # copies or substantial portions of the Software.
31
+ #
32
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38
+ # SOFTWARE.
39
+
40
+ """Core vector quantization implementation."""
41
+
42
+ import typing as tp
43
+
44
+ from einops import rearrange
45
+ import torch
46
+ from torch import nn
47
+ import torch.nn.functional as F
48
+ import torch.distributed as dist
49
+
50
+ from .distrib import broadcast_tensors, is_distributed
51
+ from .ddp_utils import SyncFunction
52
+
53
+
54
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
55
+ return val if val is not None else d
56
+
57
+
58
+ def ema_inplace(moving_avg, new, decay: float):
59
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
60
+
61
+
62
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
63
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
64
+
65
+
66
+ def uniform_init(*shape: int):
67
+ t = torch.empty(shape)
68
+ nn.init.kaiming_uniform_(t)
69
+ return t
70
+
71
+
72
+ def sample_vectors(samples, num: int):
73
+ num_samples, device = samples.shape[0], samples.device
74
+
75
+ if num_samples >= num:
76
+ indices = torch.randperm(num_samples, device=device)[:num]
77
+ else:
78
+ indices = torch.randint(0, num_samples, (num,), device=device)
79
+
80
+ return samples[indices]
81
+
82
+
83
+ def kmeans(samples, num_clusters: int, num_iters: int = 10, frames_to_use: int = 10_000, batch_size: int = 64):
84
+ """
85
+ Memory-efficient K-means clustering.
86
+ Args:
87
+ samples (tensor): shape [N, D]
88
+ num_clusters (int): number of centroids.
89
+ num_iters (int): number of iterations.
90
+ frames_to_use (int): subsample size from total samples.
91
+ batch_size (int): batch size used in distance computation.
92
+ Returns:
93
+ means: [num_clusters, D]
94
+ bins: [num_clusters] (number of points per cluster)
95
+ """
96
+ N, D = samples.shape
97
+ dtype, device = samples.dtype, samples.device
98
+
99
+ if frames_to_use < N:
100
+ indices = torch.randperm(N, device=device)[:frames_to_use]
101
+ samples = samples[indices]
102
+
103
+ means = sample_vectors(samples, num_clusters)
104
+
105
+ for _ in range(num_iters):
106
+ # Store cluster assignments
107
+ all_assignments = []
108
+
109
+ for i in range(0, samples.shape[0], batch_size):
110
+ batch = samples[i : i + batch_size] # [B, D]
111
+ dists = torch.cdist(batch, means, p=2) # [B, C]
112
+ assignments = dists.argmin(dim=1) # [B]
113
+ all_assignments.append(assignments)
114
+
115
+ buckets = torch.cat(all_assignments, dim=0) # [N]
116
+ bins = torch.bincount(buckets, minlength=num_clusters)
117
+ zero_mask = bins == 0
118
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
119
+
120
+ # Compute new means
121
+ new_means = torch.zeros_like(means)
122
+ for i in range(num_clusters):
123
+ mask = buckets == i
124
+ if mask.any():
125
+ new_means[i] = samples[mask].mean(dim=0)
126
+
127
+ means = torch.where(zero_mask[:, None], means, new_means)
128
+
129
+ return means, bins
130
+
131
+
132
+ class EuclideanCodebook(nn.Module):
133
+ """Codebook with Euclidean distance.
134
+ Args:
135
+ dim (int): Dimension.
136
+ codebook_size (int): Codebook size.
137
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
138
+ If set to true, run the k-means algorithm on the first training batch and use
139
+ the learned centroids as initialization.
140
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
141
+ decay (float): Decay for exponential moving average over the codebooks.
142
+ epsilon (float): Epsilon value for numerical stability.
143
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
144
+ that have an exponential moving average cluster size less than the specified threshold with
145
+ randomly selected vector from the current batch.
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ dim: int,
151
+ codebook_size: int,
152
+ kmeans_init: int = False,
153
+ kmeans_iters: int = 10,
154
+ decay: float = 0.99,
155
+ epsilon: float = 1e-5,
156
+ threshold_ema_dead_code: int = 2,
157
+ ):
158
+ super().__init__()
159
+ self.decay = decay
160
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
161
+ embed = init_fn(codebook_size, dim)
162
+
163
+ self.codebook_size = codebook_size
164
+
165
+ self.kmeans_iters = kmeans_iters
166
+ self.epsilon = epsilon
167
+ self.threshold_ema_dead_code = threshold_ema_dead_code
168
+
169
+ # Flag variable to indicate whether the codebook is initialized
170
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
171
+ # Runing EMA cluster size/count: N_i^t in eq. (6) in vqvae paper
172
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
173
+ # Codebook
174
+ self.register_buffer("embed", embed)
175
+ # EMA codebook: eq. (7) in vqvae paper
176
+ self.register_buffer("embed_avg", embed.clone())
177
+
178
+ @torch.jit.ignore
179
+ def init_embed_(self, data):
180
+ """Initialize codebook.
181
+ Args:
182
+ data (tensor): [B * T, D].
183
+ """
184
+ if self.inited:
185
+ return
186
+
187
+ ## NOTE (snippet added by Songxiang Liu): gather data from all gpus
188
+ if dist.is_available() and dist.is_initialized():
189
+ # [B * T * world_size, D]
190
+ data = SyncFunction.apply(data)
191
+
192
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
193
+ self.embed.data.copy_(embed)
194
+ self.embed_avg.data.copy_(embed.clone())
195
+ self.cluster_size.data.copy_(cluster_size)
196
+ self.inited.data.copy_(torch.Tensor([True]))
197
+ # Make sure all buffers across workers are in sync after initialization
198
+ broadcast_tensors(self.buffers())
199
+
200
+ def replace_(self, samples, mask):
201
+ modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
202
+ self.embed.data.copy_(modified_codebook)
203
+
204
+ def expire_codes_(self, batch_samples):
205
+ if self.threshold_ema_dead_code == 0:
206
+ return
207
+
208
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
209
+ if not torch.any(expired_codes):
210
+ return
211
+
212
+ ## NOTE (snippet added by Songxiang Liu): gather data from all gpus
213
+ if is_distributed():
214
+ # [B * T * world_size, D]
215
+ batch_samples = SyncFunction.apply(batch_samples)
216
+
217
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
218
+ self.replace_(batch_samples, mask=expired_codes)
219
+ broadcast_tensors(self.buffers())
220
+
221
+ def preprocess(self, x):
222
+ x = rearrange(x, "... d -> (...) d")
223
+ return x
224
+
225
+ def quantize(self, x):
226
+ embed = self.embed.t()
227
+ dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
228
+ embed_ind = dist.max(dim=-1).indices
229
+ return embed_ind
230
+
231
+ def postprocess_emb(self, embed_ind, shape):
232
+ return embed_ind.view(*shape[:-1])
233
+
234
+ def dequantize(self, embed_ind):
235
+ quantize = F.embedding(embed_ind, self.embed)
236
+ return quantize
237
+
238
+ def encode(self, x):
239
+ shape = x.shape
240
+ # pre-process
241
+ x = self.preprocess(x) # [B, T, D] -> [B*T, D]
242
+ # quantize
243
+ embed_ind = self.quantize(x)
244
+ # post-process
245
+ embed_ind = self.postprocess_emb(embed_ind, shape)
246
+ return embed_ind
247
+
248
+ def decode(self, embed_ind):
249
+ quantize = self.dequantize(embed_ind)
250
+ return quantize
251
+
252
+ def forward(self, x):
253
+ # shape: [B, T, D]
254
+ shape, dtype = x.shape, x.dtype
255
+ x = self.preprocess(x) # [B, T, D] -> [B*T, D]
256
+
257
+ # Initialize codebook
258
+ self.init_embed_(x)
259
+
260
+ embed_ind = self.quantize(x) # [B*T,]
261
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) # [B*T, cb-size]
262
+ embed_ind = self.postprocess_emb(embed_ind, shape) # [B, T]
263
+ quantize = self.dequantize(embed_ind) # [B, T, D]
264
+
265
+ if self.training:
266
+ ### Update codebook by EMA
267
+ embed_onehot_sum = embed_onehot.sum(0) # [cb-size,]
268
+ embed_sum = x.t() @ embed_onehot # [D, cb-size]
269
+ if is_distributed():
270
+ dist.all_reduce(embed_onehot_sum)
271
+ dist.all_reduce(embed_sum)
272
+ # Update ema cluster count N_i^t, eq. (6) in vqvae paper
273
+ self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
274
+ # Update ema embed: eq. (7) in vqvae paper
275
+ self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay)
276
+ # apply laplace smoothing
277
+ n = self.cluster_size.sum()
278
+ cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n
279
+ # Update ema embed: eq. (8) in vqvae paper
280
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
281
+ self.embed.data.copy_(embed_normalized)
282
+
283
+ # We do the expiry of code at that point as buffers are in sync
284
+ # and all the workers will take the same decision.
285
+ self.expire_codes_(x)
286
+
287
+ return quantize, embed_ind
288
+
289
+
290
+ class VectorQuantization(nn.Module):
291
+ """Vector quantization implementation.
292
+ Currently supports only euclidean distance.
293
+ Args:
294
+ dim (int): Dimension
295
+ codebook_size (int): Codebook size
296
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
297
+ decay (float): Decay for exponential moving average over the codebooks.
298
+ epsilon (float): Epsilon value for numerical stability.
299
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
300
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
301
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
302
+ that have an exponential moving average cluster size less than the specified threshold with
303
+ randomly selected vector from the current batch.
304
+ commitment_weight (float): Weight for commitment loss.
305
+ """
306
+
307
+ def __init__(
308
+ self,
309
+ dim: int,
310
+ codebook_size: int,
311
+ codebook_dim: tp.Optional[int] = None,
312
+ decay: float = 0.99,
313
+ epsilon: float = 1e-5,
314
+ kmeans_init: bool = True,
315
+ kmeans_iters: int = 50,
316
+ threshold_ema_dead_code: int = 2,
317
+ commitment_weight: float = 1.0,
318
+ ):
319
+ super().__init__()
320
+ _codebook_dim: int = default(codebook_dim, dim)
321
+
322
+ requires_projection = _codebook_dim != dim
323
+ self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
324
+ self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
325
+
326
+ self.epsilon = epsilon
327
+ self.commitment_weight = commitment_weight
328
+
329
+ self._codebook = EuclideanCodebook(
330
+ dim=_codebook_dim,
331
+ codebook_size=codebook_size,
332
+ kmeans_init=kmeans_init,
333
+ kmeans_iters=kmeans_iters,
334
+ decay=decay,
335
+ epsilon=epsilon,
336
+ threshold_ema_dead_code=threshold_ema_dead_code,
337
+ )
338
+ self.codebook_size = codebook_size
339
+
340
+ @property
341
+ def codebook(self):
342
+ return self._codebook.embed
343
+
344
+ def encode(self, x):
345
+ x = rearrange(x, "b d n -> b n d")
346
+ x = self.project_in(x)
347
+ embed_in = self._codebook.encode(x)
348
+ return embed_in
349
+
350
+ def decode(self, embed_ind):
351
+ quantize = self._codebook.decode(embed_ind)
352
+ quantize = self.project_out(quantize)
353
+ quantize = rearrange(quantize, "b n d -> b d n")
354
+ return quantize
355
+
356
+ def forward(self, x):
357
+ device = x.device
358
+ x = x.transpose(1, 2).contiguous() # [b d n] -> [b n d]
359
+ x = self.project_in(x)
360
+
361
+ quantize, embed_ind = self._codebook(x)
362
+
363
+ if self.training:
364
+ quantize = x + (quantize - x).detach()
365
+
366
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
367
+
368
+ if self.training:
369
+ if self.commitment_weight > 0:
370
+ commit_loss = F.mse_loss(quantize.detach(), x)
371
+ loss = loss + commit_loss * self.commitment_weight
372
+
373
+ quantize = self.project_out(quantize)
374
+ quantize = quantize.transpose(1, 2).contiguous() # [b n d] -> [b d n]
375
+ return quantize, embed_ind, loss
376
+
377
+
378
+ class ResidualVectorQuantization(nn.Module):
379
+ """Residual vector quantization implementation.
380
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
381
+ """
382
+
383
+ def __init__(self, *, num_quantizers, **kwargs):
384
+ super().__init__()
385
+ self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
386
+
387
+ def forward(self, x, n_q: tp.Optional[int] = None):
388
+ quantized_out = 0.0
389
+ residual = x
390
+
391
+ all_losses = []
392
+ all_indices = []
393
+
394
+ n_q = n_q or len(self.layers)
395
+
396
+ for layer in self.layers[:n_q]:
397
+ quantized, indices, loss = layer(residual)
398
+ residual = residual - quantized
399
+ quantized_out = quantized_out + quantized
400
+
401
+ all_indices.append(indices)
402
+ all_losses.append(loss)
403
+
404
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
405
+ return quantized_out, out_indices, out_losses
406
+
407
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
408
+ residual = x
409
+ all_indices = []
410
+ n_q = n_q or len(self.layers)
411
+ for layer in self.layers[:n_q]:
412
+ indices = layer.encode(residual)
413
+ quantized = layer.decode(indices)
414
+ residual = residual - quantized
415
+ all_indices.append(indices)
416
+ out_indices = torch.stack(all_indices)
417
+ return out_indices
418
+
419
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
420
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
421
+ for i, indices in enumerate(q_indices):
422
+ layer = self.layers[i]
423
+ quantized = layer.decode(indices)
424
+ quantized_out = quantized_out + quantized
425
+ return quantized_out
boson_multimodal/audio_processing/quantization/ddp_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ import subprocess
4
+ from datetime import datetime
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch.nn.parallel import DistributedDataParallel
10
+ from torch.nn.parallel.distributed import _find_tensors
11
+ import torch.optim
12
+ import torch.utils.data
13
+ from packaging import version
14
+ from omegaconf import OmegaConf
15
+
16
+
17
+ def set_random_seed(seed):
18
+ random.seed(seed)
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed_all(seed)
22
+
23
+
24
+ def is_logging_process():
25
+ return not dist.is_initialized() or dist.get_rank() == 0
26
+
27
+
28
+ def get_logger(cfg, name=None):
29
+ # log_file_path is used when unit testing
30
+ if is_logging_process():
31
+ logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_config, resolve=True))
32
+ return logging.getLogger(name)
33
+
34
+
35
+ # from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
36
+ class SyncFunction(torch.autograd.Function):
37
+ @staticmethod
38
+ # @torch.no_grad()
39
+ def forward(ctx, tensor):
40
+ ctx.batch_size = tensor.shape[0]
41
+
42
+ gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
43
+
44
+ torch.distributed.all_gather(gathered_tensor, tensor)
45
+ gathered_tensor = torch.cat(gathered_tensor, 0)
46
+
47
+ return gathered_tensor
48
+
49
+ @staticmethod
50
+ def backward(ctx, grad_output):
51
+ grad_input = grad_output.clone()
52
+ torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
53
+
54
+ idx_from = torch.distributed.get_rank() * ctx.batch_size
55
+ idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
56
+ return grad_input[idx_from:idx_to]
57
+
58
+
59
+ def get_timestamp():
60
+ return datetime.now().strftime("%y%m%d-%H%M%S")
61
+
62
+
63
+ def get_commit_hash():
64
+ message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
65
+ return message.strip().decode("utf-8")
66
+
67
+
68
+ class DDP(DistributedDataParallel):
69
+ """
70
+ Override the forward call in lightning so it goes to training and validation step respectively
71
+ """
72
+
73
+ def forward(self, *inputs, **kwargs): # pragma: no cover
74
+ if version.parse(torch.__version__[:6]) < version.parse("1.11"):
75
+ self._sync_params()
76
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
77
+ assert len(self.device_ids) == 1
78
+ if self.module.training:
79
+ output = self.module.training_step(*inputs[0], **kwargs[0])
80
+ elif self.module.testing:
81
+ output = self.module.test_step(*inputs[0], **kwargs[0])
82
+ else:
83
+ output = self.module.validation_step(*inputs[0], **kwargs[0])
84
+ if torch.is_grad_enabled():
85
+ # We'll return the output object verbatim since it is a freeform
86
+ # object. We need to find any tensors in this object, though,
87
+ # because we need to figure out which parameters were used during
88
+ # this forward pass, to ensure we short circuit reduction for any
89
+ # unused parameters. Only if `find_unused_parameters` is set.
90
+ if self.find_unused_parameters:
91
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
92
+ else:
93
+ self.reducer.prepare_for_backward([])
94
+ else:
95
+ from torch.nn.parallel.distributed import (
96
+ logging,
97
+ Join,
98
+ _DDPSink,
99
+ _tree_flatten_with_rref,
100
+ _tree_unflatten_with_rref,
101
+ )
102
+
103
+ with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
104
+ if torch.is_grad_enabled() and self.require_backward_grad_sync:
105
+ self.logger.set_runtime_stats_and_log()
106
+ self.num_iterations += 1
107
+ self.reducer.prepare_for_forward()
108
+
109
+ # Notify the join context that this process has not joined, if
110
+ # needed
111
+ work = Join.notify_join_context(self)
112
+ if work:
113
+ self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size)
114
+
115
+ # Calling _rebuild_buckets before forward compuation,
116
+ # It may allocate new buckets before deallocating old buckets
117
+ # inside _rebuild_buckets. To save peak memory usage,
118
+ # call _rebuild_buckets before the peak memory usage increases
119
+ # during forward computation.
120
+ # This should be called only once during whole training period.
121
+ if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
122
+ logging.info("Reducer buckets have been rebuilt in this iteration.")
123
+ self._has_rebuilt_buckets = True
124
+
125
+ # sync params according to location (before/after forward) user
126
+ # specified as part of hook, if hook was specified.
127
+ buffer_hook_registered = hasattr(self, "buffer_hook")
128
+ if self._check_sync_bufs_pre_fwd():
129
+ self._sync_buffers()
130
+
131
+ if self._join_config.enable:
132
+ # Notify joined ranks whether they should sync in backwards pass or not.
133
+ self._check_global_requires_backward_grad_sync(is_joined_rank=False)
134
+
135
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
136
+ if self.module.training:
137
+ output = self.module.training_step(*inputs[0], **kwargs[0])
138
+ elif self.module.testing:
139
+ output = self.module.test_step(*inputs[0], **kwargs[0])
140
+ else:
141
+ output = self.module.validation_step(*inputs[0], **kwargs[0])
142
+
143
+ # sync params according to location (before/after forward) user
144
+ # specified as part of hook, if hook was specified.
145
+ if self._check_sync_bufs_post_fwd():
146
+ self._sync_buffers()
147
+
148
+ if torch.is_grad_enabled() and self.require_backward_grad_sync:
149
+ self.require_forward_param_sync = True
150
+ # We'll return the output object verbatim since it is a freeform
151
+ # object. We need to find any tensors in this object, though,
152
+ # because we need to figure out which parameters were used during
153
+ # this forward pass, to ensure we short circuit reduction for any
154
+ # unused parameters. Only if `find_unused_parameters` is set.
155
+ if self.find_unused_parameters and not self.static_graph:
156
+ # Do not need to populate this for static graph.
157
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
158
+ else:
159
+ self.reducer.prepare_for_backward([])
160
+ else:
161
+ self.require_forward_param_sync = False
162
+
163
+ # TODO: DDPSink is currently enabled for unused parameter detection and
164
+ # static graph training for first iteration.
165
+ if (self.find_unused_parameters and not self.static_graph) or (
166
+ self.static_graph and self.num_iterations == 1
167
+ ):
168
+ state_dict = {
169
+ "static_graph": self.static_graph,
170
+ "num_iterations": self.num_iterations,
171
+ }
172
+
173
+ output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output)
174
+ output_placeholders = [None for _ in range(len(output_tensor_list))]
175
+ # Do not touch tensors that have no grad_fn, which can cause issues
176
+ # such as https://github.com/pytorch/pytorch/issues/60733
177
+ for i, output in enumerate(output_tensor_list):
178
+ if torch.is_tensor(output) and output.grad_fn is None:
179
+ output_placeholders[i] = output
180
+
181
+ # When find_unused_parameters=True, makes tensors which require grad
182
+ # run through the DDPSink backward pass. When not all outputs are
183
+ # used in loss, this makes those corresponding tensors receive
184
+ # undefined gradient which the reducer then handles to ensure
185
+ # param.grad field is not touched and we don't error out.
186
+ passthrough_tensor_list = _DDPSink.apply(
187
+ self.reducer,
188
+ state_dict,
189
+ *output_tensor_list,
190
+ )
191
+ for i in range(len(output_placeholders)):
192
+ if output_placeholders[i] is None:
193
+ output_placeholders[i] = passthrough_tensor_list[i]
194
+
195
+ # Reconstruct output data structure.
196
+ output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref)
197
+ return output
boson_multimodal/audio_processing/quantization/distrib.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Torch distributed utilities."""
8
+
9
+ import typing as tp
10
+
11
+ import torch
12
+
13
+
14
+ def rank():
15
+ if torch.distributed.is_initialized():
16
+ return torch.distributed.get_rank()
17
+ else:
18
+ return 0
19
+
20
+
21
+ def world_size():
22
+ if torch.distributed.is_initialized():
23
+ return torch.distributed.get_world_size()
24
+ else:
25
+ return 1
26
+
27
+
28
+ def is_distributed():
29
+ return world_size() > 1
30
+
31
+
32
+ def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
33
+ if is_distributed():
34
+ return torch.distributed.all_reduce(tensor, op)
35
+
36
+
37
+ def _is_complex_or_float(tensor):
38
+ return torch.is_floating_point(tensor) or torch.is_complex(tensor)
39
+
40
+
41
+ def _check_number_of_params(params: tp.List[torch.Tensor]):
42
+ # utility function to check that the number of params in all workers is the same,
43
+ # and thus avoid a deadlock with distributed all reduce.
44
+ if not is_distributed() or not params:
45
+ return
46
+ # print('params[0].device ', params[0].device)
47
+ tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
48
+ all_reduce(tensor)
49
+ if tensor.item() != len(params) * world_size():
50
+ # If not all the workers have the same number, for at least one of them,
51
+ # this inequality will be verified.
52
+ raise RuntimeError(
53
+ f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one."
54
+ )
55
+
56
+
57
+ def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
58
+ """Broadcast the tensors from the given parameters to all workers.
59
+ This can be used to ensure that all workers have the same model to start with.
60
+ """
61
+ if not is_distributed():
62
+ return
63
+ tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
64
+ _check_number_of_params(tensors)
65
+ handles = []
66
+ for tensor in tensors:
67
+ handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
68
+ handles.append(handle)
69
+ for handle in handles:
70
+ handle.wait()
71
+
72
+
73
+ def sync_buffer(buffers, average=True):
74
+ """
75
+ Sync grad for buffers. If average is False, broadcast instead of averaging.
76
+ """
77
+ if not is_distributed():
78
+ return
79
+ handles = []
80
+ for buffer in buffers:
81
+ if torch.is_floating_point(buffer.data):
82
+ if average:
83
+ handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
84
+ else:
85
+ handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
86
+ handles.append((buffer, handle))
87
+ for buffer, handle in handles:
88
+ handle.wait()
89
+ if average:
90
+ buffer.data /= world_size
91
+
92
+
93
+ def sync_grad(params):
94
+ """
95
+ Simpler alternative to DistributedDataParallel, that doesn't rely
96
+ on any black magic. For simple models it can also be as fast.
97
+ Just call this on your model parameters after the call to backward!
98
+ """
99
+ if not is_distributed():
100
+ return
101
+ handles = []
102
+ for p in params:
103
+ if p.grad is not None:
104
+ handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
105
+ handles.append((p, handle))
106
+ for p, handle in handles:
107
+ handle.wait()
108
+ p.grad.data /= world_size()
109
+
110
+
111
+ def average_metrics(metrics: tp.Dict[str, float], count=1.0):
112
+ """Average a dictionary of metrics across all workers, using the optional
113
+ `count` as unormalized weight.
114
+ """
115
+ if not is_distributed():
116
+ return metrics
117
+ keys, values = zip(*metrics.items())
118
+ device = "cuda" if torch.cuda.is_available() else "cpu"
119
+ tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
120
+ tensor *= count
121
+ all_reduce(tensor)
122
+ averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
123
+ return dict(zip(keys, averaged))
boson_multimodal/audio_processing/quantization/vq.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Residual vector quantizer implementation."""
8
+
9
+ from dataclasses import dataclass, field
10
+ import math
11
+ import typing as tp
12
+
13
+ import torch
14
+ from torch import nn
15
+
16
+ # from .core_vq import ResidualVectorQuantization
17
+ from .core_vq_lsx_version import ResidualVectorQuantization
18
+
19
+
20
+ @dataclass
21
+ class QuantizedResult:
22
+ quantized: torch.Tensor
23
+ codes: torch.Tensor
24
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
25
+ penalty: tp.Optional[torch.Tensor] = None
26
+ metrics: dict = field(default_factory=dict)
27
+
28
+
29
+ class ResidualVectorQuantizer(nn.Module):
30
+ """Residual Vector Quantizer.
31
+ Args:
32
+ dimension (int): Dimension of the codebooks.
33
+ n_q (int): Number of residual vector quantizers used.
34
+ bins (int): Codebook size.
35
+ decay (float): Decay for exponential moving average over the codebooks.
36
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
37
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
38
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
39
+ that have an exponential moving average cluster size less than the specified threshold with
40
+ randomly selected vector from the current batch.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dimension: int = 256,
46
+ codebook_dim: int = None,
47
+ n_q: int = 8,
48
+ bins: int = 1024,
49
+ decay: float = 0.99,
50
+ kmeans_init: bool = True,
51
+ kmeans_iters: int = 50,
52
+ threshold_ema_dead_code: int = 2,
53
+ ):
54
+ super().__init__()
55
+ self.n_q = n_q
56
+ self.dimension = dimension
57
+ self.codebook_dim = codebook_dim
58
+ self.bins = bins
59
+ self.decay = decay
60
+ self.kmeans_init = kmeans_init
61
+ self.kmeans_iters = kmeans_iters
62
+ self.threshold_ema_dead_code = threshold_ema_dead_code
63
+ self.vq = ResidualVectorQuantization(
64
+ dim=self.dimension,
65
+ codebook_dim=self.codebook_dim,
66
+ codebook_size=self.bins,
67
+ num_quantizers=self.n_q,
68
+ decay=self.decay,
69
+ kmeans_init=self.kmeans_init,
70
+ kmeans_iters=self.kmeans_iters,
71
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
72
+ )
73
+
74
+ def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None): # -> QuantizedResult:
75
+ """Residual vector quantization on the given input tensor.
76
+ Args:
77
+ x (torch.Tensor): Input tensor.
78
+ sample_rate (int): Sample rate of the input tensor.
79
+ bandwidth (float): Target bandwidth.
80
+ Returns:
81
+ QuantizedResult:
82
+ The quantized (or approximately quantized) representation with
83
+ the associated bandwidth and any penalty term for the loss.
84
+ """
85
+ bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
86
+ n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
87
+ quantized, codes, commit_loss = self.vq(x, n_q=n_q)
88
+ bw = torch.tensor(n_q * bw_per_q).to(x)
89
+ return quantized, codes, bw, torch.mean(commit_loss)
90
+ # return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
91
+
92
+ def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int:
93
+ """Return n_q based on specified target bandwidth."""
94
+ bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
95
+ n_q = self.n_q
96
+ if bandwidth and bandwidth > 0.0:
97
+ n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
98
+ return n_q
99
+
100
+ def get_bandwidth_per_quantizer(self, sample_rate: int):
101
+ """Return bandwidth per quantizer for a given input sample rate."""
102
+ return math.log2(self.bins) * sample_rate / 1000
103
+
104
+ def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
105
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
106
+ The RVQ encode method sets the appropriate number of quantizer to use
107
+ and returns indices for each quantizer.
108
+ """
109
+ n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
110
+ codes = self.vq.encode(x, n_q=n_q)
111
+ return codes
112
+
113
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
114
+ """Decode the given codes to the quantized representation."""
115
+ quantized = self.vq.decode(codes)
116
+ return quantized
boson_multimodal/audio_processing/semantic_module.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on code from: https://github.com/zhenye234/xcodec
2
+ # Licensed under MIT License
3
+ # Modifications by BosonAI
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class Conv1d1x1(nn.Conv1d):
10
+ """1x1 Conv1d."""
11
+
12
+ def __init__(self, in_channels, out_channels, bias=True):
13
+ super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias)
14
+
15
+
16
+ class Conv1d(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_channels: int,
20
+ out_channels: int,
21
+ kernel_size: int,
22
+ stride: int = 1,
23
+ padding: int = -1,
24
+ dilation: int = 1,
25
+ groups: int = 1,
26
+ bias: bool = True,
27
+ ):
28
+ super().__init__()
29
+ self.in_channels = in_channels
30
+ self.out_channels = out_channels
31
+ self.kernel_size = kernel_size
32
+ if padding < 0:
33
+ padding = (kernel_size - 1) // 2 * dilation
34
+ self.dilation = dilation
35
+ self.conv = nn.Conv1d(
36
+ in_channels=in_channels,
37
+ out_channels=out_channels,
38
+ kernel_size=kernel_size,
39
+ stride=stride,
40
+ padding=padding,
41
+ dilation=dilation,
42
+ groups=groups,
43
+ bias=bias,
44
+ )
45
+
46
+ def forward(self, x):
47
+ """
48
+ Args:
49
+ x (Tensor): Float tensor variable with the shape (B, C, T).
50
+ Returns:
51
+ Tensor: Float tensor variable with the shape (B, C, T).
52
+ """
53
+ x = self.conv(x)
54
+ return x
55
+
56
+
57
+ class ResidualUnit(nn.Module):
58
+ def __init__(
59
+ self,
60
+ in_channels: int,
61
+ out_channels: int,
62
+ kernel_size=3,
63
+ dilation=1,
64
+ bias=False,
65
+ nonlinear_activation="ELU",
66
+ nonlinear_activation_params={},
67
+ ):
68
+ super().__init__()
69
+ self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
70
+ self.conv1 = Conv1d(
71
+ in_channels=in_channels,
72
+ out_channels=out_channels,
73
+ kernel_size=kernel_size,
74
+ stride=1,
75
+ dilation=dilation,
76
+ bias=bias,
77
+ )
78
+ self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
79
+
80
+ def forward(self, x):
81
+ y = self.conv1(self.activation(x))
82
+ y = self.conv2(self.activation(y))
83
+ return x + y
84
+
85
+
86
+ class ConvTranspose1d(nn.Module):
87
+ def __init__(
88
+ self,
89
+ in_channels: int,
90
+ out_channels: int,
91
+ kernel_size: int,
92
+ stride: int,
93
+ padding=-1,
94
+ output_padding=-1,
95
+ groups=1,
96
+ bias=True,
97
+ ):
98
+ super().__init__()
99
+ if padding < 0:
100
+ padding = (stride + 1) // 2
101
+ if output_padding < 0:
102
+ output_padding = 1 if stride % 2 else 0
103
+ self.deconv = nn.ConvTranspose1d(
104
+ in_channels=in_channels,
105
+ out_channels=out_channels,
106
+ kernel_size=kernel_size,
107
+ stride=stride,
108
+ padding=padding,
109
+ output_padding=output_padding,
110
+ groups=groups,
111
+ bias=bias,
112
+ )
113
+
114
+ def forward(self, x):
115
+ """
116
+ Args:
117
+ x (Tensor): Float tensor variable with the shape (B, C, T).
118
+ Returns:
119
+ Tensor: Float tensor variable with the shape (B, C', T').
120
+ """
121
+ x = self.deconv(x)
122
+ return x
123
+
124
+
125
+ class EncoderBlock(nn.Module):
126
+ def __init__(
127
+ self, in_channels: int, out_channels: int, stride: int, dilations=(1, 1), unit_kernel_size=3, bias=True
128
+ ):
129
+ super().__init__()
130
+ self.res_units = torch.nn.ModuleList()
131
+ for dilation in dilations:
132
+ self.res_units += [ResidualUnit(in_channels, in_channels, kernel_size=unit_kernel_size, dilation=dilation)]
133
+ self.num_res = len(self.res_units)
134
+
135
+ self.conv = Conv1d(
136
+ in_channels=in_channels,
137
+ out_channels=out_channels,
138
+ kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2
139
+ stride=stride,
140
+ bias=bias,
141
+ )
142
+
143
+ def forward(self, x):
144
+ for idx in range(self.num_res):
145
+ x = self.res_units[idx](x)
146
+ x = self.conv(x)
147
+ return x
148
+
149
+
150
+ class Encoder(nn.Module):
151
+ def __init__(
152
+ self,
153
+ input_channels: int,
154
+ encode_channels: int,
155
+ channel_ratios=(1, 1),
156
+ strides=(1, 1),
157
+ kernel_size=3,
158
+ bias=True,
159
+ block_dilations=(1, 1),
160
+ unit_kernel_size=3,
161
+ ):
162
+ super().__init__()
163
+ assert len(channel_ratios) == len(strides)
164
+
165
+ self.conv = Conv1d(
166
+ in_channels=input_channels, out_channels=encode_channels, kernel_size=kernel_size, stride=1, bias=False
167
+ )
168
+ self.conv_blocks = torch.nn.ModuleList()
169
+ in_channels = encode_channels
170
+ for idx, stride in enumerate(strides):
171
+ out_channels = int(encode_channels * channel_ratios[idx]) # could be float
172
+ self.conv_blocks += [
173
+ EncoderBlock(
174
+ in_channels,
175
+ out_channels,
176
+ stride,
177
+ dilations=block_dilations,
178
+ unit_kernel_size=unit_kernel_size,
179
+ bias=bias,
180
+ )
181
+ ]
182
+ in_channels = out_channels
183
+ self.num_blocks = len(self.conv_blocks)
184
+ self.out_channels = out_channels
185
+
186
+ def forward(self, x):
187
+ x = self.conv(x)
188
+ for i in range(self.num_blocks):
189
+ x = self.conv_blocks[i](x)
190
+ return x
191
+
192
+
193
+ class DecoderBlock(nn.Module):
194
+ """Decoder block (no up-sampling)"""
195
+
196
+ def __init__(
197
+ self, in_channels: int, out_channels: int, stride: int, dilations=(1, 1), unit_kernel_size=3, bias=True
198
+ ):
199
+ super().__init__()
200
+
201
+ if stride == 1:
202
+ self.conv = Conv1d(
203
+ in_channels=in_channels,
204
+ out_channels=out_channels,
205
+ kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape
206
+ stride=stride,
207
+ bias=bias,
208
+ )
209
+ else:
210
+ self.conv = ConvTranspose1d(
211
+ in_channels=in_channels,
212
+ out_channels=out_channels,
213
+ kernel_size=(2 * stride),
214
+ stride=stride,
215
+ bias=bias,
216
+ )
217
+
218
+ self.res_units = torch.nn.ModuleList()
219
+ for idx, dilation in enumerate(dilations):
220
+ self.res_units += [
221
+ ResidualUnit(out_channels, out_channels, kernel_size=unit_kernel_size, dilation=dilation)
222
+ ]
223
+ self.num_res = len(self.res_units)
224
+
225
+ def forward(self, x):
226
+ x = self.conv(x)
227
+ for idx in range(self.num_res):
228
+ x = self.res_units[idx](x)
229
+ return x
230
+
231
+
232
+ class Decoder(nn.Module):
233
+ def __init__(
234
+ self,
235
+ code_dim: int,
236
+ output_channels: int,
237
+ decode_channels: int,
238
+ channel_ratios=(1, 1),
239
+ strides=(1, 1),
240
+ kernel_size=3,
241
+ bias=True,
242
+ block_dilations=(1, 1),
243
+ unit_kernel_size=3,
244
+ ):
245
+ super().__init__()
246
+ assert len(channel_ratios) == len(strides)
247
+
248
+ self.conv1 = Conv1d(
249
+ in_channels=code_dim,
250
+ out_channels=int(decode_channels * channel_ratios[0]),
251
+ kernel_size=kernel_size,
252
+ stride=1,
253
+ bias=False,
254
+ )
255
+
256
+ self.conv_blocks = torch.nn.ModuleList()
257
+ for idx, stride in enumerate(strides):
258
+ in_channels = int(decode_channels * channel_ratios[idx])
259
+ if idx < (len(channel_ratios) - 1):
260
+ out_channels = int(decode_channels * channel_ratios[idx + 1])
261
+ else:
262
+ out_channels = decode_channels
263
+ self.conv_blocks += [
264
+ DecoderBlock(
265
+ in_channels,
266
+ out_channels,
267
+ stride,
268
+ dilations=block_dilations,
269
+ unit_kernel_size=unit_kernel_size,
270
+ bias=bias,
271
+ )
272
+ ]
273
+ self.num_blocks = len(self.conv_blocks)
274
+
275
+ self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
276
+
277
+ def forward(self, z):
278
+ x = self.conv1(z)
279
+ for i in range(self.num_blocks):
280
+ x = self.conv_blocks[i](x)
281
+ x = self.conv2(x)
282
+ return x
boson_multimodal/constants.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ AUDIO_IN_TOKEN = "<|AUDIO|>"
2
+ AUDIO_OUT_TOKEN = "<|AUDIO_OUT|>"
3
+ EOS_TOKEN = "<|end_of_text|>"
boson_multimodal/data_collator/__init__.py ADDED
File without changes
boson_multimodal/data_collator/higgs_audio_collator.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import List, Tuple
6
+
7
+ from dataclasses import dataclass
8
+ from typing import List, Optional
9
+ from transformers.models.whisper.processing_whisper import WhisperProcessor
10
+
11
+ from ..dataset.chatml_dataset import ChatMLDatasetSample
12
+ from ..model.higgs_audio.utils import build_delay_pattern_mask
13
+
14
+
15
+ def _ceil_to_nearest(n, round_to):
16
+ return (n + round_to - 1) // round_to * round_to
17
+
18
+
19
+ def _ceil_to_next_power_of_two(self, x):
20
+ return 1 if x == 0 else 2 ** (x - 1).bit_length()
21
+
22
+
23
+ @dataclass
24
+ class HiggsAudioBatchInput:
25
+ input_ids: torch.LongTensor # shape (bsz, seq_len).
26
+ attention_mask: torch.Tensor # shape (bsz, seq_len).
27
+ audio_features: Optional[torch.Tensor] # shape (num_audio_in, feature_dim, max_mel_seq_len).
28
+ audio_feature_attention_mask: Optional[torch.Tensor] # shape (num_audio_in, max_mel_seq_len).
29
+ audio_out_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
30
+ audio_out_ids_start: Optional[torch.LongTensor] # shape (num_audio_out,)
31
+ # The audio_out_ids_start_group_loc has the same length as audio_out_ids_start. It is used to recover group location in a batch for an audio segment
32
+ # Currently, we concatenante audio segments along dim 0 to handle variadic audio segment length. However, in the alignment stage, we need the location information
33
+ # For example,
34
+ # audio_out_ids_start = [0, 2, 4, 8]; and the first two audio segments come from the same sample in a batch, and other two come from different samples.
35
+ # This is a batch of 3 samples, then we will have the group location as:
36
+ # audio_out_ids_start_group_loc = [0, 0, 1, 2]
37
+ audio_out_ids_start_group_loc: Optional[
38
+ torch.LongTensor
39
+ ] # shape (num_audio_out,), specify which a sample's group location in the batch
40
+ audio_in_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_in_total_length)
41
+ audio_in_ids_start: Optional[torch.LongTensor] # shape (num_audio_in,)
42
+ label_ids: Optional[torch.LongTensor] # shape (bsz, seq_len)
43
+ label_audio_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
44
+ reward: Optional[float] = None
45
+
46
+
47
+ class HiggsAudioSampleCollator:
48
+ """Sample collator for Higgs-Audio model.
49
+
50
+ Args:
51
+ whisper_processor (WhisperProcessor): The whisper processor.
52
+ audio_in_token_id (int): The token id for audio-in.
53
+ audio_out_token_id (int): The token id for audio-out.
54
+ pad_token_id (int): The token id for padding.
55
+ audio_stream_bos_id (int): The token id for audio-stream beginning of sentence.
56
+ audio_stream_eos_id (int): The token id for audio-stream end of sentence.
57
+ round_to (int): The round-to value.
58
+ pad_left (bool): Whether to pad left.
59
+ return_audio_in_tokens (bool): Whether to return audio-in tokens.
60
+ use_delay_pattern (bool): Whether to use delay pattern.
61
+ disable_audio_codes_transform (bool): Whether to add bos and eos tokens to audio codes.
62
+ chunk_size_seconds (int): The chunk size in seconds.
63
+ add_new_bos_eos_for_long_chunk (bool): Whether to add new bos and eos tokens for long chunks.
64
+ mask_audio_out_token_label (bool): Whether to always mask the label associated with <|AUDIO_OUT|> token. Since we will always have `<|AUDIO_OUT|>` after `<|audio_bos|>`, we can safely mask <|AUDIO_OUT|>.
65
+
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ whisper_processor: WhisperProcessor,
71
+ audio_in_token_id,
72
+ audio_out_token_id,
73
+ pad_token_id,
74
+ audio_stream_bos_id,
75
+ audio_stream_eos_id,
76
+ round_to=8,
77
+ pad_left=False,
78
+ encode_whisper_embed=True,
79
+ return_audio_in_tokens=True,
80
+ audio_num_codebooks=None,
81
+ use_delay_pattern=False,
82
+ disable_audio_codes_transform=False,
83
+ chunk_size_seconds=30, # Maximum duration for each chunk
84
+ add_new_bos_eos_for_long_chunk=True,
85
+ mask_audio_out_token_label=True,
86
+ ):
87
+ self.whisper_processor = whisper_processor
88
+ self.round_to = round_to
89
+ self.pad_left = pad_left
90
+ self.audio_in_token_id = audio_in_token_id
91
+ self.audio_out_token_id = audio_out_token_id
92
+ self.audio_stream_bos_id = audio_stream_bos_id
93
+ self.audio_stream_eos_id = audio_stream_eos_id
94
+ self.pad_token_id = pad_token_id
95
+ self.encode_whisper_embed = encode_whisper_embed
96
+ self.return_audio_in_tokens = return_audio_in_tokens
97
+ self.audio_num_codebooks = audio_num_codebooks
98
+ self.use_delay_pattern = use_delay_pattern
99
+ if encode_whisper_embed:
100
+ self.chunk_size_seconds = chunk_size_seconds
101
+ self.chunk_size_samples = int(chunk_size_seconds * whisper_processor.feature_extractor.sampling_rate)
102
+ else:
103
+ self.chunk_size_seconds = None
104
+ self.chunk_size_samples = None
105
+ self.disable_audio_codes_transform = disable_audio_codes_transform
106
+ self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk
107
+ self.mask_audio_out_token_label = mask_audio_out_token_label
108
+
109
+ def _process_and_duplicate_audio_tokens(
110
+ self, input_ids: torch.Tensor, audio_idx: int, wv: torch.Tensor, sr: int, labels: Optional[torch.Tensor] = None
111
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
112
+ """Process long audio and duplicate corresponding audio tokens.
113
+
114
+ Args:
115
+ input_ids: Input token ids
116
+ audio_idx: Index of the audio token in the sequence
117
+ wv: Audio waveform
118
+ sr: Sample rate
119
+ labels: Optional label ids to be duplicated alongside input ids
120
+
121
+ Returns:
122
+ Tuple of:
123
+ - New input ids with duplicated audio tokens
124
+ - New label ids (if labels were provided) or None
125
+ - Number of chunks created
126
+ """
127
+ # Calculate number of chunks needed
128
+ total_samples = len(wv)
129
+ num_chunks = math.ceil(total_samples / self.chunk_size_samples)
130
+
131
+ if num_chunks <= 1:
132
+ return input_ids, labels, 1
133
+
134
+ # Get the three tokens: <|audio_bos|><|AUDIO|><|audio_eos|>
135
+ audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2]
136
+ # Duplicate sequence for each chunk
137
+ duplicated_sequence = audio_token_seq.repeat(num_chunks)
138
+
139
+ # Create new input_ids with duplicated tokens
140
+ new_input_ids = torch.cat([input_ids[: audio_idx - 1], duplicated_sequence, input_ids[audio_idx + 2 :]])
141
+
142
+ # If labels are provided, duplicate them as well
143
+ new_labels = None
144
+ if labels is not None:
145
+ label_seq = labels[audio_idx - 1 : audio_idx + 2]
146
+ duplicated_labels = label_seq.repeat(num_chunks)
147
+ new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]])
148
+
149
+ return new_input_ids, new_labels, num_chunks
150
+
151
+ def __call__(self, batch: List[ChatMLDatasetSample]):
152
+ """Collate the input data with support for long audio processing."""
153
+
154
+ label_ids = None
155
+ label_audio_ids = None
156
+ if all([ele.label_ids is None for ele in batch]):
157
+ return_labels = False
158
+ else:
159
+ return_labels = True
160
+
161
+ if self.encode_whisper_embed:
162
+ # Process each sample in the batch to handle long audio
163
+ # TODO(?) The implementation here can be optimized.
164
+ processed_batch = []
165
+ for i in range(len(batch)):
166
+ sample = batch[i]
167
+ audio_in_mask = sample.input_ids == self.audio_in_token_id
168
+ audio_in_indices = torch.where(audio_in_mask)[0]
169
+ audio_out_mask = sample.input_ids == self.audio_out_token_id
170
+
171
+ # Process each audio token and duplicate if needed
172
+ modified_input_ids = sample.input_ids
173
+ modified_labels = sample.label_ids if return_labels else None
174
+ modified_waveforms_concat = []
175
+ modified_waveforms_start = []
176
+ modified_sample_rate = []
177
+ offset = 0 # Track position changes from duplicating tokens
178
+ curr_wv_offset = 0
179
+
180
+ # Process input audio tokens
181
+ for idx, audio_idx in enumerate(audio_in_indices):
182
+ # Get the audio for this token
183
+ wv, sr = sample.get_wv(idx) # Use idx since we want the original audio index
184
+ if sr != self.whisper_processor.feature_extractor.sampling_rate:
185
+ resampled_wv = librosa.resample(
186
+ wv.cpu().numpy(),
187
+ orig_sr=sr,
188
+ target_sr=self.whisper_processor.feature_extractor.sampling_rate,
189
+ )
190
+ else:
191
+ resampled_wv = wv.cpu().numpy()
192
+ wv = torch.tensor(resampled_wv, device=wv.device)
193
+ sr = self.whisper_processor.feature_extractor.sampling_rate
194
+
195
+ # Process and duplicate tokens if necessary
196
+ token_pos = audio_idx + offset
197
+ modified_input_ids, modified_labels, num_chunks = self._process_and_duplicate_audio_tokens(
198
+ modified_input_ids, token_pos, wv, sr, modified_labels
199
+ )
200
+
201
+ # Update audio data
202
+ for chunk_idx in range(num_chunks):
203
+ chunk_start = chunk_idx * self.chunk_size_samples
204
+ chunk_end = min((chunk_idx + 1) * self.chunk_size_samples, len(wv))
205
+ chunk_wv = wv[chunk_start:chunk_end]
206
+ modified_waveforms_concat.append(chunk_wv)
207
+ modified_waveforms_start.append(curr_wv_offset)
208
+ curr_wv_offset += len(chunk_wv)
209
+ modified_sample_rate.append(sr)
210
+
211
+ # Update offset for next iteration
212
+ offset += (num_chunks - 1) * 3 # Each new chunk adds 3 more tokens
213
+
214
+ # Create new sample with modified tokens and audio data
215
+ processed_sample = ChatMLDatasetSample(
216
+ input_ids=modified_input_ids,
217
+ label_ids=modified_labels if return_labels else sample.label_ids,
218
+ audio_ids_concat=sample.audio_ids_concat,
219
+ audio_ids_start=sample.audio_ids_start,
220
+ audio_waveforms_concat=torch.cat(modified_waveforms_concat)
221
+ if modified_waveforms_concat
222
+ else sample.audio_waveforms_concat,
223
+ audio_waveforms_start=torch.tensor(modified_waveforms_start, dtype=torch.long)
224
+ if modified_waveforms_start
225
+ else sample.audio_waveforms_start,
226
+ audio_sample_rate=torch.tensor(modified_sample_rate)
227
+ if modified_sample_rate
228
+ else sample.audio_sample_rate,
229
+ audio_speaker_indices=torch.tensor([]),
230
+ # FIXME(sxjscience): The logic here is not correct for audio_label_ids_concat.
231
+ audio_label_ids_concat=sample.audio_label_ids_concat,
232
+ )
233
+ # audio_in_chunk_len = len(torch.where(modified_input_ids == self.audio_in_token_id)[0])
234
+ # assert audio_in_chunk_len == processed_sample.num_audios(), f"Mismatch: audio_in_chunk_len={audio_in_chunk_len}, processed_sample.num_audios()={processed_sample.num_audios()}"
235
+ processed_batch.append(processed_sample)
236
+ else:
237
+ processed_batch = batch
238
+
239
+ # Get the max sequence length based on processed batch
240
+ max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to)
241
+
242
+ # Get the ids for audio-in and audio-out for each batch
243
+ audio_in_wv_l = []
244
+ audio_in_ids_l = []
245
+ audio_out_ids_l = []
246
+ audio_out_ids_group_loc_l = []
247
+ audio_in_label_ids_l = None
248
+ audio_out_label_ids_l = None
249
+ reward_l = []
250
+
251
+ if return_labels:
252
+ audio_out_no_train_flag = [] # Whether the audio-out data should be trained on or not.
253
+
254
+ # Process the audio inputs and outputs
255
+ for i in range(len(processed_batch)):
256
+ audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id
257
+ audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id
258
+ audio_ids = torch.ones_like(processed_batch[i].input_ids)
259
+ audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1
260
+ audio_in_ids = audio_ids[audio_in_mask]
261
+ audio_out_ids = audio_ids[audio_out_mask]
262
+
263
+ if return_labels:
264
+ audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0)
265
+ if self.mask_audio_out_token_label:
266
+ processed_batch[i].label_ids[audio_out_mask] = -100
267
+
268
+ # Process audio inputs
269
+ if self.return_audio_in_tokens:
270
+ audio_in_ids_l.extend(
271
+ [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids]
272
+ )
273
+ if processed_batch[i].audio_label_ids_concat is not None:
274
+ if audio_in_label_ids_l is None:
275
+ audio_in_label_ids_l = []
276
+ audio_in_label_ids_l.extend(
277
+ [
278
+ processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
279
+ for idx in audio_in_ids
280
+ ]
281
+ )
282
+
283
+ audio_out_ids_l.extend(
284
+ [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids]
285
+ )
286
+ audio_out_ids_group_loc_l.append(i)
287
+ if processed_batch[i].reward is not None:
288
+ reward_l.append(processed_batch[i].reward)
289
+
290
+ if processed_batch[i].audio_label_ids_concat is not None:
291
+ if audio_out_label_ids_l is None:
292
+ audio_out_label_ids_l = []
293
+ audio_out_label_ids_l.extend(
294
+ [
295
+ processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
296
+ for idx in audio_out_ids
297
+ ]
298
+ )
299
+
300
+ if self.encode_whisper_embed:
301
+ for idx in audio_in_ids:
302
+ wv, sr = processed_batch[i].get_wv(idx)
303
+ resampled_wv = wv.cpu().numpy()
304
+ # Split long audio into chunks
305
+ total_samples = len(resampled_wv)
306
+ for chunk_start in range(0, total_samples, self.chunk_size_samples):
307
+ chunk_end = min(chunk_start + self.chunk_size_samples, total_samples)
308
+ chunk = resampled_wv[chunk_start:chunk_end]
309
+ audio_in_wv_l.append(chunk)
310
+ # assert len(audio_in_wv_l) == processed_batch[i].num_audios(), \
311
+ # f"Assertion failed: Mismatch in number of audios. " \
312
+ # f"Expected {processed_batch[i].num_audios()}, but got {len(audio_in_wv_l)} at index {i}."
313
+
314
+ if return_labels:
315
+ audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
316
+
317
+ # Process all audio features
318
+ if len(audio_in_wv_l) > 0:
319
+ feature_ret = self.whisper_processor.feature_extractor(
320
+ audio_in_wv_l,
321
+ sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
322
+ return_attention_mask=True,
323
+ padding="max_length",
324
+ )
325
+ audio_features = torch.from_numpy(feature_ret["input_features"])
326
+ audio_feature_attention_mask = torch.from_numpy(feature_ret["attention_mask"])
327
+ else:
328
+ if self.encode_whisper_embed:
329
+ audio_features = torch.zeros(
330
+ (
331
+ 0,
332
+ self.whisper_processor.feature_extractor.feature_size,
333
+ self.whisper_processor.feature_extractor.nb_max_frames,
334
+ ),
335
+ dtype=torch.float32,
336
+ )
337
+ audio_feature_attention_mask = torch.zeros(
338
+ (0, self.whisper_processor.feature_extractor.nb_max_frames), dtype=torch.int32
339
+ )
340
+ else:
341
+ audio_features = None
342
+ audio_feature_attention_mask = None
343
+
344
+ # Process audio input tokens
345
+ if len(audio_in_ids_l) > 0:
346
+ # Append audio-stream-bos and eos tokens
347
+ new_audio_in_ids_l = []
348
+ for ele in audio_in_ids_l:
349
+ if self.disable_audio_codes_transform:
350
+ # Do not add audio-stream-bos or eos tokens.
351
+ # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
352
+ audio_codes = ele
353
+ else:
354
+ audio_codes = torch.cat(
355
+ [
356
+ torch.full((ele.shape[0], 1), self.audio_stream_bos_id, dtype=torch.long),
357
+ ele,
358
+ torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long),
359
+ ],
360
+ dim=1,
361
+ )
362
+ if self.use_delay_pattern:
363
+ audio_codes = build_delay_pattern_mask(
364
+ audio_codes.unsqueeze(0),
365
+ bos_token_id=self.audio_stream_bos_id,
366
+ pad_token_id=self.audio_stream_eos_id,
367
+ )[0].squeeze(0)
368
+ new_audio_in_ids_l.append(audio_codes)
369
+ audio_in_ids = torch.cat(new_audio_in_ids_l, dim=1).long()
370
+ audio_in_ids_start = torch.cumsum(
371
+ torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_in_ids_l[:-1]]), dim=0
372
+ )
373
+ else:
374
+ audio_in_ids = torch.zeros((0, 0), dtype=torch.long)
375
+ audio_in_ids_start = torch.zeros(0, dtype=torch.long)
376
+
377
+ # Process audio output tokens
378
+ audio_out_ids_start_group_loc = None
379
+ if len(audio_out_ids_l) > 0:
380
+ new_audio_out_ids_l = []
381
+ label_audio_ids_l = []
382
+ for idx, ele in enumerate(audio_out_ids_l):
383
+ if self.disable_audio_codes_transform:
384
+ # Do not add audio-stream-bos or eos tokens.
385
+ # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
386
+ audio_codes = ele
387
+ if return_labels:
388
+ label_audio_ids = audio_out_label_ids_l[idx]
389
+ else:
390
+ audio_codes = torch.cat(
391
+ [
392
+ torch.full((ele.shape[0], 1), self.audio_stream_bos_id, dtype=torch.long),
393
+ ele,
394
+ torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long),
395
+ ],
396
+ dim=1,
397
+ )
398
+ if return_labels:
399
+ label_audio_ids = torch.cat(
400
+ [
401
+ torch.full((ele.shape[0], 1), -100, dtype=torch.long),
402
+ ele,
403
+ torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long),
404
+ ],
405
+ dim=1,
406
+ )
407
+ if self.use_delay_pattern:
408
+ audio_codes = build_delay_pattern_mask(
409
+ audio_codes.unsqueeze(0),
410
+ bos_token_id=self.audio_stream_bos_id,
411
+ pad_token_id=self.audio_stream_eos_id,
412
+ )[0].squeeze(0)
413
+ if return_labels:
414
+ label_audio_ids = build_delay_pattern_mask(
415
+ label_audio_ids.unsqueeze(0),
416
+ bos_token_id=-100,
417
+ pad_token_id=-100,
418
+ )[0].squeeze(0)
419
+ new_audio_out_ids_l.append(audio_codes)
420
+
421
+ if return_labels:
422
+ if audio_out_no_train_flag[idx]:
423
+ label_audio_ids[:] = -100
424
+ label_audio_ids_l.append(label_audio_ids)
425
+
426
+ audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long()
427
+ if return_labels:
428
+ label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long()
429
+ audio_out_ids_start = torch.cumsum(
430
+ torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]), dim=0
431
+ )
432
+ audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long)
433
+ else:
434
+ audio_out_ids = torch.zeros((0, 0), dtype=torch.long)
435
+ audio_out_ids_start = torch.zeros(0, dtype=torch.long)
436
+ if return_labels:
437
+ label_audio_ids = torch.zeros((0, 0), dtype=torch.long)
438
+
439
+ reward = torch.tensor(reward_l, dtype=torch.float32)
440
+
441
+ # Handle padding for input ids and attention mask
442
+ if self.pad_left:
443
+ input_ids = torch.stack(
444
+ [
445
+ F.pad(ele.input_ids, (max_seq_length - len(ele.input_ids), 0), value=self.pad_token_id)
446
+ for ele in processed_batch
447
+ ]
448
+ )
449
+ if return_labels:
450
+ label_ids = torch.stack(
451
+ [
452
+ F.pad(ele.label_ids, (max_seq_length - len(ele.label_ids), 0), value=-100)
453
+ for ele in processed_batch
454
+ ]
455
+ )
456
+ attention_mask = torch.stack(
457
+ [
458
+ F.pad(torch.ones_like(ele.input_ids), (max_seq_length - len(ele.input_ids), 0), value=0)
459
+ for ele in processed_batch
460
+ ]
461
+ )
462
+ else:
463
+ input_ids = torch.stack(
464
+ [
465
+ F.pad(ele.input_ids, (0, max_seq_length - len(ele.input_ids)), value=self.pad_token_id)
466
+ for ele in processed_batch
467
+ ]
468
+ )
469
+ if return_labels:
470
+ label_ids = torch.stack(
471
+ [
472
+ F.pad(ele.label_ids, (0, max_seq_length - len(ele.label_ids)), value=-100)
473
+ for ele in processed_batch
474
+ ]
475
+ )
476
+ attention_mask = torch.stack(
477
+ [
478
+ F.pad(torch.ones_like(ele.input_ids), (0, max_seq_length - len(ele.input_ids)), value=0)
479
+ for ele in processed_batch
480
+ ]
481
+ )
482
+
483
+ if not self.return_audio_in_tokens:
484
+ audio_in_ids = None
485
+ audio_in_ids_start = None
486
+
487
+ # Apply audio_num_codebooks limit if specified
488
+ if self.audio_num_codebooks is not None:
489
+ if audio_in_ids is not None:
490
+ audio_in_ids = audio_in_ids[: self.audio_num_codebooks]
491
+ if audio_out_ids is not None:
492
+ audio_out_ids = audio_out_ids[: self.audio_num_codebooks]
493
+ if label_audio_ids is not None:
494
+ label_audio_ids = label_audio_ids[: self.audio_num_codebooks]
495
+
496
+ return HiggsAudioBatchInput(
497
+ input_ids=input_ids,
498
+ attention_mask=attention_mask,
499
+ audio_features=audio_features,
500
+ audio_feature_attention_mask=audio_feature_attention_mask,
501
+ audio_out_ids=audio_out_ids,
502
+ audio_out_ids_start=audio_out_ids_start,
503
+ audio_out_ids_start_group_loc=audio_out_ids_start_group_loc,
504
+ audio_in_ids=audio_in_ids,
505
+ audio_in_ids_start=audio_in_ids_start,
506
+ label_ids=label_ids,
507
+ label_audio_ids=label_audio_ids,
508
+ reward=reward,
509
+ )
boson_multimodal/data_types.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Basic data types for multimodal ChatML format."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Optional, Union
5
+
6
+
7
+ @dataclass
8
+ class AudioContent:
9
+ audio_url: str
10
+ # Base64 encoded audio bytes
11
+ raw_audio: Optional[str] = None
12
+ offset: Optional[float] = None
13
+ duration: Optional[float] = None
14
+ row_id: Optional[int] = None
15
+ type: str = "audio"
16
+
17
+
18
+ @dataclass
19
+ class TextContent:
20
+ text: str
21
+ type: str = "text"
22
+
23
+
24
+ @dataclass
25
+ class Message:
26
+ role: str
27
+ content: Union[str, AudioContent, TextContent, List[Union[str, AudioContent, TextContent]]]
28
+ recipient: Optional[str] = None
29
+
30
+
31
+ @dataclass
32
+ class ChatMLSample:
33
+ """Dataclass to hold multimodal ChatML data."""
34
+
35
+ messages: List[Message]
36
+ start_index: Optional[int] = None # We will mask the messages[:start_index] when finetuning the LLM.
37
+ misc: Optional[Dict] = None
38
+ speaker: Optional[str] = None
boson_multimodal/dataset/__init__.py ADDED
File without changes
boson_multimodal/dataset/chatml_dataset.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dacite
2
+ import pandas as pd
3
+ import torch
4
+ import json
5
+
6
+ import numpy as np
7
+ import multiprocessing as mp
8
+
9
+ from dataclasses import dataclass, fields
10
+ from abc import ABC, abstractmethod
11
+ from typing import Union, List, Dict, Optional
12
+
13
+ from ..data_types import ChatMLSample, TextContent, AudioContent
14
+ from ..constants import AUDIO_IN_TOKEN, AUDIO_OUT_TOKEN
15
+
16
+ from loguru import logger
17
+
18
+ # Whisper processor, 30 sec -> 3000 features
19
+ # Then we divide 4 in the audio towker, we decrease 3000 features to 750, which gives 25 Hz
20
+ WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC = 25
21
+
22
+
23
+ @dataclass
24
+ class ChatMLDatasetSample:
25
+ input_ids: torch.LongTensor # Shape (seq_len,): The input text tokens.
26
+ label_ids: torch.LongTensor # Shape (seq_len,): The label ids.
27
+ audio_ids_concat: torch.LongTensor # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
28
+ # Here `audio_seq_len` is the length of the concatenated audio tokens.`
29
+ audio_ids_start: (
30
+ torch.LongTensor
31
+ ) # Shape (num_audios,): The start index of each audio token in the concatenated audio tokens.
32
+ audio_waveforms_concat: (
33
+ torch.Tensor
34
+ ) # Shape (total_wv_length,): The concatenated audio waveforms for audio-in features.
35
+ audio_waveforms_start: (
36
+ torch.LongTensor
37
+ ) # Shape (num_audios,): The start index of each audio waveform in the concatenated audio waveforms.
38
+ audio_sample_rate: torch.Tensor # Shape (num_audios,): The sampling rate of the audio waveforms.
39
+ audio_speaker_indices: (
40
+ torch.LongTensor
41
+ ) # Shape (num_audios,) -1 means unknown speaker: The speaker indices for each audio.
42
+ audio_label_ids_concat: Optional[torch.LongTensor] = (
43
+ None # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
44
+ )
45
+ # Here `audio_seq_len` is the length of the concatenated audio tokens.`
46
+ reward: Optional[float] = None
47
+
48
+ def num_audios(self):
49
+ return max(len(self.audio_waveforms_start), len(self.audio_ids_start))
50
+
51
+ def get_audio_codes(self, idx):
52
+ code_start = self.audio_ids_start[idx]
53
+ if idx < len(self.audio_ids_start) - 1:
54
+ code_end = self.audio_ids_start[idx + 1]
55
+ else:
56
+ code_end = self.audio_ids_concat.shape[-1]
57
+
58
+ return self.audio_ids_concat[:, code_start:code_end]
59
+
60
+ def get_audio_codes_labels(self, idx):
61
+ if self.audio_label_ids_concat is None:
62
+ return None
63
+ code_start = self.audio_ids_start[idx]
64
+ if idx < len(self.audio_ids_start) - 1:
65
+ code_end = self.audio_ids_start[idx + 1]
66
+ else:
67
+ code_end = self.audio_ids_concat.shape[-1]
68
+
69
+ return self.audio_label_ids_concat[:, code_start:code_end]
70
+
71
+ def get_wv(self, idx):
72
+ wv_start = self.audio_waveforms_start[idx]
73
+ sr = self.audio_sample_rate[idx]
74
+ if idx < len(self.audio_waveforms_start) - 1:
75
+ wv_end = self.audio_waveforms_start[idx + 1]
76
+ else:
77
+ wv_end = self.audio_waveforms_concat.shape[-1]
78
+ return self.audio_waveforms_concat[wv_start:wv_end], sr
79
+
80
+ def cal_num_tokens(
81
+ self,
82
+ encode_whisper_embed: bool = True,
83
+ encode_audio_in_tokens: bool = False,
84
+ encode_audio_out_tokens: bool = True,
85
+ audio_in_token_id: int = 128015,
86
+ audio_out_token_id: int = 128016,
87
+ ) -> int:
88
+ # we firstly exclude <|AUDIO|> and <|AUDIO_OUT|> because we do late merging and replace those position with actual audio features and audio token ids
89
+ # It's assumed that we always have audio_ids when audio_waveforms are there (but not vice-versa)
90
+ num_tokens = len(self.input_ids) - len(self.audio_ids_start)
91
+
92
+ if encode_whisper_embed and len(self.audio_waveforms_concat) > 0:
93
+ audio_lengths = torch.diff(self.audio_waveforms_start)
94
+ if len(audio_lengths):
95
+ # Sum before calling .item()
96
+ num_tokens += (
97
+ (
98
+ np.ceil(WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC * audio_lengths / self.audio_sample_rate[:-1])
99
+ ).sum()
100
+ ).item()
101
+ # add the last audio's token estimation
102
+ num_tokens += (
103
+ np.ceil(
104
+ WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC
105
+ * (self.audio_waveforms_concat.shape[0] - self.audio_waveforms_start[-1])
106
+ / self.audio_sample_rate[-1]
107
+ )
108
+ ).item()
109
+
110
+ if self.audio_ids_concat.size(1) > 0:
111
+ audio_io_ids = self.input_ids[
112
+ (self.input_ids == audio_in_token_id) | (self.input_ids == audio_out_token_id)
113
+ ]
114
+ audio_io_id_lengths = torch.concat(
115
+ [
116
+ torch.diff(self.audio_ids_start),
117
+ torch.tensor([self.audio_ids_concat.shape[-1] - self.audio_ids_start[-1]]),
118
+ ]
119
+ )
120
+ if encode_audio_in_tokens:
121
+ num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_in_token_id]).item()
122
+
123
+ if encode_audio_out_tokens:
124
+ num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_out_token_id]).item()
125
+
126
+ return int(num_tokens)
127
+
128
+ @classmethod
129
+ def merge(
130
+ cls,
131
+ samples: List["ChatMLDatasetSample"],
132
+ eos_token_id: int,
133
+ ignore_index: int,
134
+ padding_size: Optional[int] = None,
135
+ ) -> "ChatMLDatasetSample":
136
+ """Merges a list of ChatMLDatasetSample instances, inserting eos_token_id and ignore_index between them, and adjusting offsets for audio_ids_start and audio_waveforms_start.
137
+
138
+ Args:
139
+ samples (List[ChatMLDatasetSample]): List of samples to merge.
140
+ eos_token_id (int): Tokens to be inserted into input_ids between samples.
141
+ ignore_index (int): Default label for padding.
142
+ padding_size (Optional[int]): If provided, pad the sequence to with this length.
143
+
144
+ Returns:
145
+ ChatMLDatasetSample: Merged and potentially padded sample.
146
+ """
147
+ if not samples:
148
+ logger.fatal("The samples list is empty and cannot be merged.")
149
+ raise ValueError("The samples list is empty and cannot be merged.")
150
+
151
+ # Initialize empty lists for concatenation
152
+ input_ids_list = []
153
+ label_ids_list = []
154
+ audio_ids_concat_list = []
155
+ audio_ids_start_list = []
156
+ audio_waveforms_concat_list = []
157
+ audio_waveforms_start_list = []
158
+ audio_sample_rate_list = []
159
+ audio_speaker_indices_list = []
160
+
161
+ # Track offsets
162
+ audio_ids_offset = 0
163
+ audio_waveforms_offset = 0
164
+
165
+ for sample in samples:
166
+ # Add input_ids and label_ids with padding
167
+ if input_ids_list:
168
+ input_ids_list.append(torch.tensor([eos_token_id], dtype=torch.long))
169
+ label_ids_list.append(torch.tensor([ignore_index], dtype=torch.long))
170
+ input_ids_list.append(sample.input_ids)
171
+ label_ids_list.append(sample.label_ids)
172
+
173
+ # Add audio_ids_concat and handle empty audio ids
174
+ if sample.audio_ids_concat.size(1) > 0:
175
+ audio_ids_concat_list.append(sample.audio_ids_concat)
176
+
177
+ # Offset and add audio_ids_start
178
+ audio_ids_start_list.append(sample.audio_ids_start + audio_ids_offset)
179
+ audio_ids_offset += sample.audio_ids_concat.size(
180
+ 1
181
+ ) # (num_codebooks, seq_len): Update offset by audio_seq_len
182
+
183
+ # Add audio_waveforms_concat
184
+ if sample.audio_waveforms_concat.size(0) > 0:
185
+ # Check dimensions of the audio waveform to ensure consistency
186
+ if (
187
+ audio_waveforms_concat_list
188
+ and sample.audio_waveforms_concat.dim() != audio_waveforms_concat_list[0].dim()
189
+ ):
190
+ logger.warning(
191
+ f"Skipping audio waveform with inconsistent dimensions: expected {audio_waveforms_concat_list[0].dim()}D, got {sample.audio_waveforms_concat.dim()}D"
192
+ )
193
+ continue
194
+
195
+ audio_waveforms_concat_list.append(sample.audio_waveforms_concat)
196
+ audio_waveforms_start_list.append(sample.audio_waveforms_start + audio_waveforms_offset)
197
+ audio_waveforms_offset += sample.audio_waveforms_concat.size(0)
198
+
199
+ # Add audio_sample_rate and audio_speaker_indices
200
+ audio_sample_rate_list.append(sample.audio_sample_rate)
201
+
202
+ audio_speaker_indices_list.append(sample.audio_speaker_indices)
203
+
204
+ # Concatenate all tensors
205
+ input_ids = torch.cat(input_ids_list, dim=0)
206
+ label_ids = torch.cat(label_ids_list, dim=0)
207
+
208
+ # Apply padding if padding_size is specified
209
+ if padding_size is not None and padding_size > 0:
210
+ input_ids = torch.cat([input_ids, torch.full((padding_size,), eos_token_id, dtype=torch.long)], dim=0)
211
+ label_ids = torch.cat([label_ids, torch.full((padding_size,), ignore_index, dtype=torch.long)], dim=0)
212
+
213
+ # Safely concatenate audio tensors with proper error handling
214
+ try:
215
+ audio_ids_concat = torch.cat(audio_ids_concat_list, dim=1) if audio_ids_concat_list else torch.tensor([[]])
216
+ audio_ids_start = torch.cat(audio_ids_start_list, dim=0) if audio_ids_start_list else torch.tensor([])
217
+
218
+ # Check for dimensional consistency in audio waveforms
219
+ if audio_waveforms_concat_list:
220
+ dims = [t.dim() for t in audio_waveforms_concat_list]
221
+ if not all(d == dims[0] for d in dims):
222
+ # If dimensions don't match, log warning and filter out the problematic tensors
223
+ logger.warning(
224
+ f"Inconsistent dimensions in audio waveforms: {dims}. Filtering to keep only consistent ones."
225
+ )
226
+ expected_dim = max(set(dims), key=dims.count) # Most common dimension
227
+ audio_waveforms_concat_list = [t for t in audio_waveforms_concat_list if t.dim() == expected_dim]
228
+
229
+ # Recalculate audio_waveforms_start with the filtered list
230
+ if audio_waveforms_concat_list:
231
+ audio_waveforms_offset = 0
232
+ audio_waveforms_start_list = []
233
+ for waveform in audio_waveforms_concat_list:
234
+ audio_waveforms_start_list.append(torch.tensor([audio_waveforms_offset]))
235
+ audio_waveforms_offset += waveform.size(0)
236
+
237
+ audio_waveforms_concat = (
238
+ torch.cat(audio_waveforms_concat_list, dim=0) if audio_waveforms_concat_list else torch.tensor([])
239
+ )
240
+ audio_waveforms_start = (
241
+ torch.cat(audio_waveforms_start_list, dim=0) if audio_waveforms_start_list else torch.tensor([])
242
+ )
243
+ audio_sample_rate = (
244
+ torch.cat(audio_sample_rate_list, dim=0) if audio_sample_rate_list else torch.tensor([])
245
+ )
246
+ audio_speaker_indices = (
247
+ torch.cat(audio_speaker_indices_list, dim=0) if audio_speaker_indices_list else torch.tensor([])
248
+ )
249
+
250
+ except RuntimeError as e:
251
+ logger.error(f"Error during tensor concatenation: {str(e)}")
252
+ logger.warning("Falling back to empty audio tensors")
253
+ # Fall back to empty tensors
254
+ audio_ids_concat = torch.tensor([[]])
255
+ audio_ids_start = torch.tensor([])
256
+ audio_waveforms_concat = torch.tensor([])
257
+ audio_waveforms_start = torch.tensor([])
258
+ audio_sample_rate = torch.tensor([])
259
+ audio_speaker_indices = torch.tensor([])
260
+
261
+ # Create the merged sample
262
+ merged_sample = cls(
263
+ input_ids=input_ids,
264
+ label_ids=label_ids,
265
+ audio_ids_concat=audio_ids_concat,
266
+ audio_ids_start=audio_ids_start,
267
+ audio_waveforms_concat=audio_waveforms_concat,
268
+ audio_waveforms_start=audio_waveforms_start,
269
+ audio_sample_rate=audio_sample_rate,
270
+ audio_speaker_indices=audio_speaker_indices,
271
+ )
272
+
273
+ return merged_sample
274
+
275
+
276
+ @dataclass
277
+ class RankedChatMLDatasetSampleTuple:
278
+ samples: List[ChatMLDatasetSample]
279
+ scores: List[float]
280
+
281
+ def max_score_sample(self) -> ChatMLDatasetSample:
282
+ idx = self.scores.index(max(self.scores))
283
+ self.samples[idx].reward = self.scores[idx]
284
+ return self.samples[idx]
285
+
286
+ def min_score_sample(self) -> ChatMLDatasetSample:
287
+ idx = self.scores.index(min(self.scores))
288
+ self.samples[idx].reward = self.scores[idx]
289
+ return self.samples[idx]
290
+
291
+
292
+ @dataclass
293
+ class ChatMLDatasetStorageSample:
294
+ input_tokens: torch.LongTensor
295
+ label_tokens: torch.LongTensor
296
+ audio_bytes_cache_dir_index: int
297
+ audio_codes_cache_dir_index: int
298
+ audio_bytes_indices: torch.LongTensor
299
+ audio_codes_indices: torch.LongTensor
300
+ speaker_indices: torch.LongTensor
301
+ file_index: int
302
+ original_sample_index: int
303
+
304
+
305
+ # TODO(sxjscience): We need to revist the logic about parsing speaker ids.
306
+ # Currently, we assume that the speaker id is stored at the "misc" field in ChatMLSample.
307
+ def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
308
+ """Preprocess the ChatML sample to get the tokens for the text part.
309
+
310
+ Args:
311
+ sample (ChatMLSample): The ChatML sample to preprocess.
312
+ tokenizer: The tokenizer to use for encoding the text.
313
+
314
+ """
315
+
316
+ try:
317
+ if not isinstance(sample, ChatMLSample):
318
+ # Handle all fields that could be NaN
319
+ if "speaker" in sample and pd.isna(sample["speaker"]):
320
+ sample["speaker"] = None
321
+ if "start_index" in sample and pd.isna(sample["start_index"]):
322
+ sample["start_index"] = None
323
+ if "content" in sample and pd.isna(sample["content"]):
324
+ sample["content"] = ""
325
+
326
+ # Convert any other potential NaN values in nested structures
327
+ def convert_nan_to_none(obj):
328
+ import numpy as np
329
+
330
+ if isinstance(obj, (pd.Series, np.ndarray)):
331
+ return obj.tolist()
332
+ elif pd.api.types.is_scalar(obj) and pd.isna(obj):
333
+ return None
334
+ elif isinstance(obj, dict):
335
+ return {k: convert_nan_to_none(v) for k, v in obj.items()}
336
+ elif isinstance(obj, (list, tuple)): # Fixed: Handle both list and tuple
337
+ return [convert_nan_to_none(item) for item in obj]
338
+ return obj
339
+
340
+ # Clean the sample data
341
+ clean_sample = convert_nan_to_none(sample)
342
+
343
+ val_keys = []
344
+ for field in fields(ChatMLSample):
345
+ if field.name in clean_sample:
346
+ val_keys.append(field.name)
347
+ clean_sample = {k: clean_sample[k] for k in val_keys}
348
+
349
+ try:
350
+ sample = dacite.from_dict(
351
+ data_class=ChatMLSample, data=clean_sample, config=dacite.Config(strict=True, check_types=True)
352
+ )
353
+ except Exception as e:
354
+ print(f"Failed to convert to ChatMLSample: {e}")
355
+ print(f"Clean sample: {json.dumps(clean_sample, indent=2)}")
356
+ return None, None, None, None
357
+
358
+ input_tokens = []
359
+ label_tokens = []
360
+ audio_contents = []
361
+ speaker_id = None
362
+ if sample.speaker is not None:
363
+ speaker_id = sample.speaker
364
+ elif sample.misc is not None:
365
+ if "speaker" in sample.misc:
366
+ speaker_id = sample.misc["speaker"]
367
+
368
+ total_m = len(sample.messages)
369
+ for turn_id, message in enumerate(sample.messages):
370
+ role = message.role
371
+ recipient = message.recipient
372
+ content = message.content
373
+ content_l = []
374
+
375
+ if isinstance(content, str):
376
+ content_l.append(TextContent(text=content))
377
+ elif isinstance(content, TextContent):
378
+ content_l.append(content)
379
+ elif isinstance(content, AudioContent):
380
+ content_l.append(content)
381
+ elif isinstance(content, list):
382
+ for ele in content:
383
+ if isinstance(ele, str):
384
+ content_l.append(TextContent(text=ele))
385
+ else:
386
+ content_l.append(ele)
387
+ if turn_id == 0:
388
+ prefix = f"<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n"
389
+ else:
390
+ prefix = f"<|start_header_id|>{role}<|end_header_id|>\n\n"
391
+ eot_postfix = "<|eot_id|>"
392
+ eom_postfix = "<|eom_id|>"
393
+
394
+ prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
395
+ input_tokens.extend(prefix_tokens)
396
+ label_tokens.extend([-100 for _ in prefix_tokens])
397
+
398
+ if recipient:
399
+ assert role == "assistant", "Recipient is only available for assistant role."
400
+ recipient_tokens = tokenizer.encode(f"{recipient}<|recipient|>", add_special_tokens=False)
401
+ input_tokens.extend(recipient_tokens)
402
+ label_tokens.extend(recipient_tokens)
403
+
404
+ for content in content_l:
405
+ if content.type == "text":
406
+ text_tokens = tokenizer.encode(content.text, add_special_tokens=False)
407
+ input_tokens.extend(text_tokens)
408
+ if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
409
+ label_tokens.extend(text_tokens)
410
+ else:
411
+ label_tokens.extend([-100 for _ in text_tokens])
412
+
413
+ elif content.type == "audio":
414
+ # Generate the text-part of the audio tokens
415
+ audio_contents.append(content)
416
+ if role == "user" or role == "system":
417
+ # Add the text tokens
418
+ text_tokens = tokenizer.encode(
419
+ f"<|audio_bos|><|AUDIO|><|audio_eos|>",
420
+ add_special_tokens=False,
421
+ )
422
+ input_tokens.extend(text_tokens)
423
+ label_tokens.extend([-100 for _ in text_tokens])
424
+ elif role == "assistant":
425
+ # Add the text tokens for audio-out part.
426
+ text_tokens = tokenizer.encode(
427
+ f"<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>",
428
+ add_special_tokens=False,
429
+ )
430
+ input_tokens.extend(text_tokens)
431
+ if sample.start_index is None or turn_id >= sample.start_index:
432
+ label_tokens.extend(text_tokens)
433
+ else:
434
+ label_tokens.extend([-100 for _ in text_tokens])
435
+ next_id = turn_id + 1
436
+ if role == "assistant" and next_id != total_m and sample.messages[next_id].role == "assistant":
437
+ postfix_tokens = tokenizer.encode(eom_postfix, add_special_tokens=False)
438
+ input_tokens.extend(postfix_tokens)
439
+ else:
440
+ postfix_tokens = tokenizer.encode(eot_postfix, add_special_tokens=False)
441
+ input_tokens.extend(postfix_tokens)
442
+ if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
443
+ label_tokens.extend(postfix_tokens)
444
+ else:
445
+ label_tokens.extend([-100 for _ in postfix_tokens])
446
+
447
+ return input_tokens, label_tokens, audio_contents, speaker_id
448
+
449
+ except Exception as e:
450
+ print(f"Error in prepare_chatml_sample: {str(e)}")
451
+ print(f"Sample data: {json.dumps(sample, indent=2)}")
452
+ return None, None, None, None
453
+
454
+
455
+ def extract_generation_prompt_from_input_tokens(input_tokens, tokenizer):
456
+ """Extract the generation prompt and reference answer from the input tokens.
457
+
458
+ For example:
459
+
460
+ Input Text = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
461
+ What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
462
+ <|start_header_id|>assistant<|end_header_id|>\n\nAt first they went by quick, too quick to even get.<|eot_id|>'
463
+
464
+ -->
465
+
466
+ Prompt = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
467
+ What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
468
+ <|start_header_id|>assistant<|end_header_id|>\n\n',
469
+ Reference = 'At first they went by quick, too quick to even get.'
470
+
471
+ Args:
472
+ input_tokens: The input tokens.
473
+ audio_contents: The audio contents.
474
+ tokenizer: The tokenizer to use for decoding the text.
475
+
476
+ Returns:
477
+ prompt_tokens: The tokens for the prompt.
478
+ reference_answer: The reference answer.
479
+ num_audios_in_reference: The number of audios in the reference answer.
480
+
481
+ """
482
+ input_text = tokenizer.decode(input_tokens)
483
+ generation_prefix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
484
+ postfix = "<|eot_id|>"
485
+ assert generation_prefix in input_text
486
+ generation_prompt_end_loc = input_text.rfind(generation_prefix) + len(generation_prefix)
487
+ generation_prompt = input_text[:generation_prompt_end_loc]
488
+ reference_answer = input_text[generation_prompt_end_loc : input_text.find(postfix, generation_prompt_end_loc)]
489
+ num_audios_in_reference = reference_answer.count(AUDIO_IN_TOKEN) + reference_answer.count(AUDIO_OUT_TOKEN)
490
+ return tokenizer.encode(generation_prompt, add_special_tokens=False), reference_answer, num_audios_in_reference
491
+
492
+
493
+ def prepare_chatml_dataframe_single_process(df, tokenizer):
494
+ """Prepare the ChatML DataFrame."""
495
+ ret = []
496
+ for _, row in df.iterrows():
497
+ input_tokens, label_tokens, audio_contents, speaker_id = prepare_chatml_sample(row.to_dict(), tokenizer)
498
+ ret.append((input_tokens, label_tokens, audio_contents, speaker_id))
499
+ return ret
500
+
501
+
502
+ def prepare_chatml_dataframe(df, tokenizer, num_process=16):
503
+ if num_process is None:
504
+ return prepare_chatml_dataframe_single_process(df, tokenizer)
505
+ else:
506
+ num_process = max(min(len(df) // 1000, num_process), 1)
507
+ workloads = np.array_split(df, num_process)
508
+ with mp.Pool(num_process) as pool:
509
+ ret = pool.starmap(
510
+ prepare_chatml_dataframe_single_process, [(workload, tokenizer) for workload in workloads]
511
+ )
512
+ return sum(ret, [])
513
+
514
+
515
+ class DatasetInterface(ABC):
516
+ @abstractmethod
517
+ def __getitem__(self, idx) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
518
+ """Retrieve a dataset sample by index."""
519
+ raise NotImplementedError
520
+
521
+
522
+ class IterableDatasetInterface(ABC):
523
+ @abstractmethod
524
+ def __iter__(self) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
525
+ """Retrieve a sample by iterating through the dataset."""
526
+ raise NotImplementedError
527
+
528
+
529
+ @dataclass
530
+ class DatasetInfo:
531
+ dataset_type: str
532
+ group_type: Optional[str] = None
533
+ mask_text: Optional[bool] = None # Whether to mask the text tokens for pretraining samples.
boson_multimodal/model/__init__.py ADDED
File without changes
boson_multimodal/model/higgs_audio/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel
2
+
3
+ from .configuration_higgs_audio import HiggsAudioConfig, HiggsAudioEncoderConfig
4
+ from .modeling_higgs_audio import HiggsAudioModel
5
+
6
+
7
+ AutoConfig.register("higgs_audio_encoder", HiggsAudioEncoderConfig)
8
+ AutoConfig.register("higgs_audio", HiggsAudioConfig)
9
+ AutoModel.register(HiggsAudioConfig, HiggsAudioModel)
boson_multimodal/model/higgs_audio/audio_head.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Projector that maps hidden states from the LLM component to multimodal logits."""
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+
9
+ from .common import HiggsAudioPreTrainedModel
10
+ from .configuration_higgs_audio import HiggsAudioConfig
11
+
12
+
13
+ @dataclass
14
+ class HiggsAudioDecoderLayerOutput:
15
+ logits: torch.FloatTensor
16
+ audio_logits: torch.FloatTensor
17
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
18
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
19
+
20
+
21
+ class HiggsAudioDecoderProjector(HiggsAudioPreTrainedModel):
22
+ """Projection layers that map hidden states from the LLM component to audio / text logits.
23
+
24
+ We support two type of audio head:
25
+ - Basic Audio Head:
26
+ Directly map the hidden states to audio logits for all the codebooks.
27
+ """
28
+
29
+ def __init__(self, config: HiggsAudioConfig, layer_idx: Optional[int] = None):
30
+ super().__init__(config)
31
+ self.text_lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
32
+ self.audio_lm_head = nn.Linear(
33
+ config.text_config.hidden_size, config.audio_num_codebooks * (config.audio_codebook_size + 2), bias=False
34
+ )
35
+
36
+ # Initialize weights and apply final processing
37
+ self.post_init()
38
+
39
+ def forward(
40
+ self,
41
+ hidden_states,
42
+ audio_out_mask,
43
+ label_audio_ids=None,
44
+ attention_mask=None,
45
+ position_ids=None,
46
+ past_key_values=None,
47
+ use_cache=None,
48
+ output_attentions=None,
49
+ output_hidden_states=None,
50
+ output_audio_hidden_states=False,
51
+ cache_position=None,
52
+ ):
53
+ """
54
+ Args:
55
+ hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
56
+ Hidden states from the LLM component
57
+ audio_out_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
58
+ Mask for identifying the audio out tokens.
59
+ label_audio_ids (`torch.Tensor` of shape `(num_codebooks, num_audio_out_tokens)`):
60
+ Label tokens for the audio-out part. This is used for calculating the logits if RQ-Transformer is used.
61
+ attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
62
+ Mask to avoid performing attention on padding token indices
63
+ position_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
64
+ Position ids for the input tokens
65
+
66
+ Returns:
67
+ logits (`torch.Tensor` of shape `(batch_size, seq_len, vocab_size)`):
68
+ Logits for text tokens
69
+ audio_logits (`torch.Tensor` of shape `(num_audio_out_tokens, audio_num_codebooks * audio_codebook_size)`):
70
+ Logits for audio tokens. We ensure `num_text_tokens + num_audio_tokens == batch_size * seq_len`
71
+ """
72
+ logits = self.text_lm_head(hidden_states)
73
+
74
+ all_hidden_states = () if output_hidden_states else None
75
+ all_self_attns = () if output_attentions else None
76
+ next_decoder_cache = None
77
+
78
+ if self.config.audio_decoder_proj_num_layers > 0:
79
+ # create position embeddings to be shared across the decoder layers
80
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
81
+ for decoder_layer in self.transformer_layers:
82
+ if output_hidden_states:
83
+ all_hidden_states += (hidden_states,)
84
+
85
+ if self.gradient_checkpointing and self.training:
86
+ layer_outputs = self._gradient_checkpointing_func(
87
+ decoder_layer.__call__,
88
+ hidden_states,
89
+ attention_mask,
90
+ position_ids,
91
+ past_key_values,
92
+ output_attentions,
93
+ use_cache,
94
+ cache_position,
95
+ position_embeddings,
96
+ )
97
+ else:
98
+ layer_outputs = decoder_layer(
99
+ hidden_states,
100
+ attention_mask=attention_mask,
101
+ position_ids=position_ids,
102
+ past_key_value=past_key_values,
103
+ output_attentions=output_attentions,
104
+ use_cache=use_cache,
105
+ cache_position=cache_position,
106
+ position_embeddings=position_embeddings,
107
+ )
108
+ hidden_states = layer_outputs[0]
109
+ hidden_states = self.norm(hidden_states)
110
+
111
+ if output_hidden_states:
112
+ all_hidden_states += (hidden_states,)
113
+
114
+ if output_attentions:
115
+ all_self_attns += (layer_outputs[1],)
116
+
117
+ if use_cache:
118
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
119
+
120
+ next_cache = next_decoder_cache if use_cache else None
121
+
122
+ audio_logits = self.audio_lm_head(hidden_states[audio_out_mask])
123
+
124
+ if output_audio_hidden_states:
125
+ audio_hidden_states = hidden_states[audio_out_mask]
126
+ else:
127
+ audio_hidden_states = None
128
+
129
+ return logits, audio_logits, all_self_attns, all_hidden_states, audio_hidden_states, next_cache
boson_multimodal/model/higgs_audio/common.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from transformers.modeling_utils import PreTrainedModel
4
+
5
+ from .configuration_higgs_audio import HiggsAudioConfig
6
+
7
+
8
+ class HiggsAudioPreTrainedModel(PreTrainedModel):
9
+ config_class = HiggsAudioConfig
10
+ base_model_prefix = "model"
11
+ supports_gradient_checkpointing = True
12
+ _no_split_modules = []
13
+ _skip_keys_device_placement = "past_key_values"
14
+ _supports_flash_attn_2 = True
15
+ _supports_sdpa = True
16
+
17
+ def _init_weights(self, module):
18
+ std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_encoder_config.init_std
19
+
20
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
21
+ module.weight.data.normal_(mean=0.0, std=std)
22
+ if module.bias is not None:
23
+ module.bias.data.zero_()
24
+ elif isinstance(module, nn.Embedding):
25
+ module.weight.data.normal_(mean=0.0, std=std)
26
+ if module.padding_idx is not None:
27
+ module.weight.data[module.padding_idx].zero_()
boson_multimodal/model/higgs_audio/configuration_higgs_audio.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.models.auto import CONFIG_MAPPING
3
+
4
+
5
+ class HiggsAudioEncoderConfig(PretrainedConfig):
6
+ """Configuration of the Audio encoder in Higgs-Audio."""
7
+
8
+ model_type = "higgs_audio_encoder"
9
+
10
+ def __init__(
11
+ self,
12
+ num_mel_bins=128,
13
+ encoder_layers=32,
14
+ encoder_attention_heads=20,
15
+ encoder_ffn_dim=5120,
16
+ encoder_layerdrop=0.0,
17
+ d_model=1280,
18
+ dropout=0.0,
19
+ attention_dropout=0.0,
20
+ activation_function="gelu",
21
+ activation_dropout=0.0,
22
+ scale_embedding=False,
23
+ init_std=0.02,
24
+ max_source_positions=1500,
25
+ pad_token_id=128001,
26
+ **kwargs,
27
+ ):
28
+ super().__init__(**kwargs)
29
+
30
+ self.num_mel_bins = num_mel_bins
31
+ self.d_model = d_model
32
+ self.encoder_layers = encoder_layers
33
+ self.encoder_attention_heads = encoder_attention_heads
34
+ self.encoder_ffn_dim = encoder_ffn_dim
35
+ self.dropout = dropout
36
+ self.attention_dropout = attention_dropout
37
+ self.activation_function = activation_function
38
+ self.activation_dropout = activation_dropout
39
+ self.encoder_layerdrop = encoder_layerdrop
40
+ self.num_hidden_layers = encoder_layers
41
+ self.init_std = init_std
42
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
43
+ self.max_source_positions = max_source_positions
44
+ self.pad_token_id = pad_token_id
45
+
46
+
47
+ class HiggsAudioConfig(PretrainedConfig):
48
+ r"""
49
+ This is the configuration class for the HiggsAudioModel.
50
+
51
+ Args:
52
+ text_config (`Union[AutoConfig, dict]`):
53
+ The config object or dictionary of the text backbone.
54
+ audio_encoder_config (`Union[AutoConfig, dict]`):
55
+ The config object or dictionary of the whisper encoder.
56
+ The audio encoder will be bidirectional and will be only available for audio understanding.
57
+ audio_tokenizer_config
58
+ The config object or dictionary of the audio tokenizer.
59
+ audio_adapter_type
60
+ The type of audio adapter to use. We support two types of adapter:
61
+ - stack:
62
+ We stack additional Transformer layers after the main LLM backbone for audio generation.
63
+ - dual_ffn:
64
+ For selected part of the LLM backbone, we replace the text FFN with a dual FFN architecture
65
+ that contains an additional audio FFN. The audio FFN will be triggered when the location is marked for audio tokens.
66
+ - dual_ffn_fast_forward:
67
+ We pick a few layers in the LLM backbone to plug-in the audio FFN. For the remaining layers,
68
+ the audio hidden states will be directly fast-forward to the next layer.
69
+ This reduces the computational cost for audio generation.
70
+ audio_embed_avg (`bool`, *optional*, defaults to False):
71
+ Whether to average the audio embeddings before sending them to the text attention layer.
72
+ audio_ffn_hidden_size
73
+ The hidden size of the audio feedforward network in dual-path FFN
74
+ audio_ffn_intermediate_size
75
+ The intermediate size of the audio feedforward network in dual-path FFN
76
+ audio_dual_ffn_layers
77
+ The layers in the LLM backbone to plug-in the dual FFN layer (mixture of audio FFN and text FFN).
78
+ audio_decoder_proj_num_attention (`int`, *optional*, defaults to 0):
79
+ The number of attention heads in the audio decoder projection layer.
80
+ use_delay_pattern (`bool`, *optional*, defaults to False):
81
+ Whether to use delay pattern in the audio decoder.
82
+ skip_audio_tower (`bool`, *optional*, defaults to False):
83
+ Whether to skip the audio tower in the audio encoder.
84
+ use_audio_out_embed_projector (`bool`, *optional*, defaults to False):
85
+ Whether to use an embedding projector to map audio out embeddings.
86
+ use_audio_out_self_attention (`bool`, *optional*, defaults to False):
87
+ Whether to use self-attention to aggregate information from audio-tokens before sending to the text attention layer.
88
+ audio_num_codebooks (`int`, *optional*, defaults to 12):
89
+ The number of codebooks in RVQGAN.
90
+ audio_codebook_size (`int`, *optional*, defaults to 1024):
91
+ The size of each codebook in RVQGAN.
92
+ audio_stream_bos_id
93
+ The id of the bos in the audio stream
94
+ audio_stream_eos_id
95
+ The id of the eos in the audio stream
96
+ audio_bos_token (`str`, *optional*, defaults to "<|audio_bos|>"):
97
+ The special `<|audio_bos|>` token. In Higgs-Audio, it is mapped to 128011,
98
+ which is the index of `<|reserved_special_token_3|>` in Llama-3.1-8B-Instruct's tokenizer.
99
+ audio_eos_token (`str`, *optional*, defaults to "<|audio_eos|>"):
100
+ The special `<|audio_eos|>` token. We use 128012 as the default value,
101
+ which is the index of `<|reserved_special_token_4|>` in Llama-3.1-8B-Instruct's tokenizer.
102
+ audio_out_bos_token (`str`, *optional*, defaults to "<|audio_out_bos|>"):
103
+ The special `<|audio_out_bos|>` token. We use 128013 as the default value,
104
+ which is the index of `<|reserved_special_token_5|>` in Llama-3.1-8B-Instruct's tokenizer.
105
+ audio_token (`str`, *optional*, defaults to "<|AUDIO|>"):
106
+ The special `<|AUDIO|>` token. We use 128015 as the default value,
107
+ which is the index of `<|reserved_special_token_7|>` in Llama-3.1-8B-Instruct's tokenizer.
108
+ This token indicates that the location should be filled in with whisper features.
109
+ audio_out_token (`str`, *optional*, defaults to "<|AUDIO_OUT|>"):
110
+ The special `<|AUDIO_OUT|>` token. We use 128016 as the default value,
111
+ which is the index of `<|reserved_special_token_8|>` in Llama-3.1-8B-Instruct's tokenizer.
112
+ This token indicates that the location should be filled in with audio tokens extracted via audio tokenizer.
113
+ """
114
+
115
+ model_type = "higgs_audio"
116
+ is_composition = True
117
+
118
+ def __init__(
119
+ self,
120
+ text_config=None,
121
+ audio_encoder_config=None,
122
+ audio_tokenizer_config=None,
123
+ audio_adapter_type="stack",
124
+ audio_embed_avg=False,
125
+ audio_ffn_hidden_size=4096,
126
+ audio_ffn_intermediate_size=14336,
127
+ audio_dual_ffn_layers=None,
128
+ audio_decoder_proj_num_layers=0,
129
+ encode_whisper_embed=True,
130
+ encode_audio_in_tokens=False,
131
+ use_delay_pattern=False,
132
+ skip_audio_tower=False,
133
+ use_audio_out_embed_projector=False,
134
+ use_audio_out_self_attention=False,
135
+ use_rq_transformer=False,
136
+ rq_transformer_hidden_size=None,
137
+ rq_transformer_intermediate_size=None,
138
+ rq_transformer_num_attention_heads=None,
139
+ rq_transformer_num_key_value_heads=None,
140
+ rq_transformer_num_hidden_layers=3,
141
+ audio_num_codebooks=12,
142
+ audio_codebook_size=1024,
143
+ audio_stream_bos_id=1024,
144
+ audio_stream_eos_id=1025,
145
+ audio_bos_token="<|audio_bos|>",
146
+ audio_eos_token="<|audio_eos|>",
147
+ audio_out_bos_token="<|audio_out_bos|>",
148
+ audio_in_token="<|AUDIO|>",
149
+ audio_out_token="<|AUDIO_OUT|>",
150
+ audio_in_token_idx=128015,
151
+ audio_out_token_idx=128016,
152
+ pad_token_id=128001,
153
+ audio_out_bos_token_id=128013,
154
+ audio_eos_token_id=128012,
155
+ **kwargs,
156
+ ):
157
+ if isinstance(audio_encoder_config, dict):
158
+ audio_encoder_config["model_type"] = (
159
+ audio_encoder_config["model_type"] if "model_type" in audio_encoder_config else "higgs_audio_encoder"
160
+ )
161
+ audio_encoder_config = CONFIG_MAPPING[audio_encoder_config["model_type"]](**audio_encoder_config)
162
+ elif audio_encoder_config is None:
163
+ audio_encoder_config = HiggsAudioEncoderConfig()
164
+
165
+ if isinstance(text_config, dict):
166
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
167
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
168
+ elif text_config is None:
169
+ text_config = CONFIG_MAPPING["llama"]()
170
+
171
+ assert audio_adapter_type in [
172
+ "stack",
173
+ "dual_ffn",
174
+ "dual_ffn_fast_forward",
175
+ ], f"Invalid audio adapter type: {audio_adapter_type}"
176
+ if audio_adapter_type.startswith("dual_ffn"):
177
+ assert audio_dual_ffn_layers is not None, (
178
+ "audio_dual_ffn_layers must be specified when using dual_ffn adapter."
179
+ )
180
+ self.text_config = text_config
181
+ self.audio_encoder_config = audio_encoder_config
182
+ self.audio_tokenizer_config = audio_tokenizer_config
183
+ self.audio_adapter_type = audio_adapter_type
184
+ self.audio_embed_avg = audio_embed_avg
185
+ self.audio_ffn_hidden_size = audio_ffn_hidden_size
186
+ self.audio_ffn_intermediate_size = audio_ffn_intermediate_size
187
+ self.audio_dual_ffn_layers = audio_dual_ffn_layers
188
+ self.audio_decoder_proj_num_layers = audio_decoder_proj_num_layers
189
+ self.encode_whisper_embed = encode_whisper_embed
190
+ self.encode_audio_in_tokens = encode_audio_in_tokens
191
+ self.use_delay_pattern = use_delay_pattern
192
+ self.skip_audio_tower = skip_audio_tower
193
+ self.use_audio_out_embed_projector = use_audio_out_embed_projector
194
+ self.use_audio_out_self_attention = use_audio_out_self_attention
195
+
196
+ self.use_rq_transformer = use_rq_transformer
197
+
198
+ if self.use_rq_transformer:
199
+ assert not self.use_delay_pattern, "Delay pattern is not supported if you turned on RQ-Transformer!"
200
+ self.rq_transformer_hidden_size = rq_transformer_hidden_size
201
+ self.rq_transformer_intermediate_size = rq_transformer_intermediate_size
202
+ self.rq_transformer_num_attention_heads = rq_transformer_num_attention_heads
203
+ self.rq_transformer_num_key_value_heads = rq_transformer_num_key_value_heads
204
+ self.rq_transformer_num_hidden_layers = rq_transformer_num_hidden_layers
205
+
206
+ if use_rq_transformer:
207
+ # For RQ-Transformer, we set the hidden_size to the same as the text model's hidden size if it is not specified.
208
+ if self.rq_transformer_hidden_size is None:
209
+ self.rq_transformer_hidden_size = text_config.hidden_size
210
+ assert self.rq_transformer_hidden_size % 128 == 0
211
+ if self.rq_transformer_intermediate_size is None:
212
+ self.rq_transformer_intermediate_size = text_config.intermediate_size
213
+ if self.rq_transformer_num_attention_heads is None:
214
+ self.rq_transformer_num_attention_heads = self.rq_transformer_hidden_size // 128
215
+ if self.rq_transformer_num_key_value_heads is None:
216
+ self.rq_transformer_num_key_value_heads = self.rq_transformer_hidden_size // 128 // 4
217
+ assert self.rq_transformer_hidden_size % self.rq_transformer_num_attention_heads == 0
218
+ assert self.rq_transformer_hidden_size % self.rq_transformer_num_key_value_heads == 0
219
+
220
+ self.audio_num_codebooks = audio_num_codebooks
221
+ self.audio_codebook_size = audio_codebook_size
222
+ self.audio_bos_token = audio_bos_token
223
+ self.audio_eos_token = audio_eos_token
224
+ self.audio_out_bos_token = audio_out_bos_token
225
+ self.audio_in_token = audio_in_token
226
+ self.audio_out_token = audio_out_token
227
+ self.audio_in_token_idx = audio_in_token_idx
228
+ self.audio_out_token_idx = audio_out_token_idx
229
+ self.audio_stream_bos_id = audio_stream_bos_id
230
+ self.audio_stream_eos_id = audio_stream_eos_id
231
+ self.audio_out_bos_token_id = audio_out_bos_token_id
232
+ self.audio_eos_token_id = audio_eos_token_id
233
+
234
+ super().__init__(**kwargs)
235
+ self.pad_token_id = pad_token_id
boson_multimodal/model/higgs_audio/cuda_graph_runner.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, List, Dict, Tuple, Union
4
+ import gc
5
+
6
+ from transformers.cache_utils import Cache
7
+
8
+
9
+ _NUM_WARMUP_ITERS = 2
10
+
11
+
12
+ class CUDAGraphRunner(nn.Module):
13
+ def __init__(self, model):
14
+ super().__init__()
15
+ self.model = model
16
+
17
+ self.input_buffers: Dict[str, torch.Tensor] = {}
18
+ self.output_buffers: Dict[str, torch.Tensor] = {}
19
+
20
+ self._graph: Optional[torch.cuda.CUDAGraph] = None
21
+
22
+ @property
23
+ def graph(self):
24
+ assert self._graph is not None
25
+ return self._graph
26
+
27
+ def capture(
28
+ self,
29
+ hidden_states: torch.Tensor,
30
+ causal_mask: torch.Tensor,
31
+ position_ids: torch.Tensor,
32
+ audio_discrete_codes_mask: torch.Tensor,
33
+ cache_position: torch.Tensor,
34
+ past_key_values: Union[Cache, List[torch.FloatTensor]],
35
+ use_cache: bool,
36
+ audio_attention_mask: torch.Tensor,
37
+ fast_forward_attention_mask: torch.Tensor,
38
+ output_attentions: bool,
39
+ output_hidden_states: bool,
40
+ is_decoding_audio_token: Optional[bool] = None,
41
+ is_using_cuda_graph: Optional[bool] = False,
42
+ stream: torch.cuda.Stream = None,
43
+ memory_pool: Optional[Tuple[int, int]] = None,
44
+ ):
45
+ assert self._graph is None
46
+ # Run warmup iterations
47
+ for _ in range(_NUM_WARMUP_ITERS):
48
+ self.model(
49
+ hidden_states=hidden_states,
50
+ causal_mask=causal_mask,
51
+ position_ids=position_ids,
52
+ audio_discrete_codes_mask=audio_discrete_codes_mask,
53
+ cache_position=cache_position,
54
+ past_key_values=past_key_values,
55
+ use_cache=use_cache,
56
+ audio_attention_mask=audio_attention_mask,
57
+ fast_forward_attention_mask=fast_forward_attention_mask,
58
+ output_attentions=output_attentions,
59
+ output_hidden_states=output_hidden_states,
60
+ is_decoding_audio_token=is_decoding_audio_token,
61
+ is_using_cuda_graph=is_using_cuda_graph,
62
+ )
63
+
64
+ torch.cuda.synchronize()
65
+
66
+ # Capture the graph
67
+ self._graph = torch.cuda.CUDAGraph()
68
+ with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
69
+ out_hidden_states, all_hidden_states, all_self_attns = self.model(
70
+ hidden_states=hidden_states,
71
+ causal_mask=causal_mask,
72
+ position_ids=position_ids,
73
+ audio_discrete_codes_mask=audio_discrete_codes_mask,
74
+ cache_position=cache_position,
75
+ past_key_values=past_key_values,
76
+ use_cache=use_cache,
77
+ audio_attention_mask=audio_attention_mask,
78
+ fast_forward_attention_mask=fast_forward_attention_mask,
79
+ output_attentions=output_attentions,
80
+ output_hidden_states=output_hidden_states,
81
+ is_decoding_audio_token=is_decoding_audio_token,
82
+ is_using_cuda_graph=is_using_cuda_graph,
83
+ )
84
+ # hidden_states_out = torch.ops._C.weak_ref_tensor(outputs[0])
85
+ # del outputs
86
+ gc.collect()
87
+ torch.cuda.synchronize()
88
+
89
+ # Save input and output buffers
90
+ self.input_buffers = {
91
+ "hidden_states": hidden_states,
92
+ "causal_mask": causal_mask,
93
+ "position_ids": position_ids,
94
+ "audio_discrete_codes_mask": audio_discrete_codes_mask,
95
+ "cache_position": cache_position,
96
+ "past_key_values": past_key_values,
97
+ "audio_attention_mask": audio_attention_mask,
98
+ "fast_forward_attention_mask": fast_forward_attention_mask,
99
+ }
100
+ self.output_buffers = {
101
+ "hidden_states": out_hidden_states,
102
+ "all_hidden_states": all_hidden_states,
103
+ "all_self_attns": all_self_attns,
104
+ }
105
+
106
+ def forward(
107
+ self,
108
+ hidden_states: torch.Tensor,
109
+ causal_mask: torch.Tensor,
110
+ position_ids: torch.Tensor,
111
+ audio_discrete_codes_mask: torch.Tensor,
112
+ cache_position: torch.Tensor,
113
+ audio_attention_mask: torch.Tensor,
114
+ fast_forward_attention_mask: torch.Tensor,
115
+ **kwargs,
116
+ ) -> torch.Tensor:
117
+ # Copy input tensors to buffers
118
+ self.input_buffers["hidden_states"].copy_(hidden_states, non_blocking=True)
119
+ self.input_buffers["causal_mask"].copy_(causal_mask, non_blocking=True)
120
+ self.input_buffers["position_ids"].copy_(position_ids, non_blocking=True)
121
+ self.input_buffers["audio_discrete_codes_mask"].copy_(audio_discrete_codes_mask, non_blocking=True)
122
+ self.input_buffers["cache_position"].copy_(cache_position, non_blocking=True)
123
+ self.input_buffers["audio_attention_mask"].copy_(audio_attention_mask, non_blocking=True)
124
+ self.input_buffers["fast_forward_attention_mask"].copy_(fast_forward_attention_mask, non_blocking=True)
125
+
126
+ # Run the captured graph
127
+ self.graph.replay()
128
+
129
+ return self.output_buffers["hidden_states"], None, None
boson_multimodal/model/higgs_audio/custom_modules.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class PartiallyFrozenEmbedding(nn.Module):
6
+ """Split an existing `nn.Embedding` module that splits the embedding into:
7
+
8
+ - A frozen embedding for indices [0..freeze_until_idx].
9
+ - A trainable embedding for indices [freeze_until_idx+1..vocab_size-1].
10
+
11
+ This should work with both Zero-2 and Zero-3 seamlessly
12
+ """
13
+
14
+ def __init__(self, original_embedding: nn.Embedding, freeze_until_idx: int):
15
+ """
16
+ :param original_embedding: An instance of nn.Embedding (the original embedding layer).
17
+ :param freeze_until_idx: The index up to which the embedding is frozen (excluding). The freeze_until_idx is not frozen.
18
+ """
19
+ super().__init__()
20
+ self.freeze_until_idx = freeze_until_idx
21
+ self.original_vocab_size = original_embedding.num_embeddings
22
+ self.embedding_dim = original_embedding.embedding_dim
23
+
24
+ # Split the original embedding into frozen and trainable parts
25
+ self.embedding_frozen = nn.Embedding(
26
+ freeze_until_idx,
27
+ self.embedding_dim,
28
+ dtype=original_embedding.weight.dtype,
29
+ device=original_embedding.weight.device,
30
+ )
31
+ self.embedding_trainable = nn.Embedding(
32
+ self.original_vocab_size - freeze_until_idx,
33
+ self.embedding_dim,
34
+ dtype=original_embedding.weight.dtype,
35
+ device=original_embedding.weight.device,
36
+ )
37
+
38
+ # Copy weights from the original embedding into the frozen and trainable parts
39
+ with torch.no_grad():
40
+ self.embedding_frozen.weight.copy_(original_embedding.weight[:freeze_until_idx])
41
+ self.embedding_trainable.weight.copy_(original_embedding.weight[freeze_until_idx:])
42
+
43
+ # Freeze the frozen embedding
44
+ self.embedding_frozen.weight.requires_grad = False
45
+
46
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ Forward pass for the split embedding wrapper.
49
+ :param input_ids: Tensor of shape [batch_size, seq_len] with indices in [0..original_vocab_size-1].
50
+ """
51
+ # Masks to separate frozen and trainable indices
52
+ # (bsz, seq_len)
53
+ mask_frozen = input_ids < self.freeze_until_idx
54
+ mask_trainable = ~mask_frozen
55
+
56
+ # Output tensor for embedding results
57
+ batch_size, seq_len = input_ids.shape
58
+ embeddings = torch.zeros(
59
+ batch_size,
60
+ seq_len,
61
+ self.embedding_dim,
62
+ device=input_ids.device,
63
+ dtype=self.embedding_frozen.weight.dtype,
64
+ )
65
+
66
+ # Handle frozen embedding
67
+ if mask_frozen.any():
68
+ frozen_ids = input_ids[mask_frozen]
69
+ frozen_emb = self.embedding_frozen(frozen_ids)
70
+ embeddings[mask_frozen] = frozen_emb
71
+
72
+ # Handle trainable embedding
73
+ if mask_trainable.any():
74
+ # Adjust trainable IDs to the local index space of the trainable embedding
75
+ trainable_ids = input_ids[mask_trainable] - (self.freeze_until_idx)
76
+ trainable_emb = self.embedding_trainable(trainable_ids)
77
+ embeddings[mask_trainable] = trainable_emb
78
+
79
+ return embeddings
80
+
81
+ def to_unsplit(self) -> nn.Embedding:
82
+ unsplit_embedding = nn.Embedding(
83
+ self.original_vocab_size,
84
+ self.embedding_dim,
85
+ dtype=self.embedding_frozen.weight.dtype,
86
+ device=self.embedding_frozen.weight.device,
87
+ )
88
+
89
+ with torch.no_grad():
90
+ unsplit_embedding.weight[: self.freeze_until_idx].copy_(self.embedding_frozen.weight)
91
+ unsplit_embedding.weight[self.freeze_until_idx :].copy_(self.embedding_trainable.weight)
92
+
93
+ return unsplit_embedding
94
+
95
+
96
+ class PartiallyFrozenLinear(nn.Module):
97
+ """A wrapper around nn.Linear to partially freeze part of the weight matrix."""
98
+
99
+ def __init__(self, original_linear: nn.Linear, freeze_until_idx: int):
100
+ """
101
+ :param original_linear: The original nn.Linear layer.
102
+ :param freeze_until_idx: The index up to which the rows of the weight matrix are frozen.
103
+ """
104
+ super().__init__()
105
+ assert original_linear.bias is None, "Currently only support linear module without bias"
106
+
107
+ self.freeze_until_idx = freeze_until_idx
108
+ self.input_dim = original_linear.in_features
109
+ self.output_dim = original_linear.out_features
110
+
111
+ # Create frozen and trainable linear layers
112
+ self.linear_frozen = nn.Linear(
113
+ self.input_dim,
114
+ freeze_until_idx,
115
+ bias=False,
116
+ dtype=original_linear.weight.dtype,
117
+ device=original_linear.weight.device,
118
+ )
119
+ self.linear_trainable = nn.Linear(
120
+ self.input_dim,
121
+ self.output_dim - freeze_until_idx,
122
+ bias=False,
123
+ dtype=original_linear.weight.dtype,
124
+ device=original_linear.weight.device,
125
+ )
126
+
127
+ # Copy weights from the original linear layer
128
+ with torch.no_grad():
129
+ self.linear_frozen.weight.copy_(original_linear.weight[:freeze_until_idx])
130
+ self.linear_trainable.weight.copy_(original_linear.weight[freeze_until_idx:])
131
+
132
+ # Freeze the frozen linear layer
133
+ self.linear_frozen.weight.requires_grad = False
134
+
135
+ def forward(self, input_tensor):
136
+ # input_tensor: (bsz, seq_len, hidden_state_dim)
137
+ frozen_output = self.linear_frozen(input_tensor)
138
+ trainable_output = self.linear_trainable(input_tensor)
139
+ return torch.cat((frozen_output, trainable_output), dim=-1)
140
+
141
+ def to_unsplit(self) -> nn.Linear:
142
+ unsplit_linear = nn.Linear(
143
+ self.input_dim,
144
+ self.output_dim,
145
+ bias=False,
146
+ dtype=self.linear_frozen.weight.dtype,
147
+ device=self.linear_frozen.weight.device,
148
+ )
149
+
150
+ # Copy weights from the frozen and trainable layers into the unsplit linear layer
151
+ with torch.no_grad():
152
+ unsplit_linear.weight[: self.freeze_until_idx].copy_(self.linear_frozen.weight)
153
+ unsplit_linear.weight[self.freeze_until_idx :].copy_(self.linear_trainable.weight)
154
+
155
+ return unsplit_linear
boson_multimodal/model/higgs_audio/modeling_higgs_audio.py ADDED
The diff for this file is too large to render. See raw diff
 
boson_multimodal/model/higgs_audio/utils.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ from contextlib import contextmanager
3
+ from functools import wraps
4
+ import torch
5
+ from transformers.integrations import is_deepspeed_available
6
+
7
+ if is_deepspeed_available():
8
+ from deepspeed.utils import groups as deepspeed_groups
9
+ from deepspeed.sequence.layer import _SeqAllToAll
10
+ else:
11
+ deepspeed_groups = None
12
+ _SeqAllToAll = None
13
+
14
+
15
+ def _ceil_to_nearest(n, round_to):
16
+ return (n + round_to - 1) // round_to * round_to
17
+
18
+
19
+ def count_parameters(model, trainable_only=True):
20
+ if trainable_only:
21
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
22
+ else:
23
+ return sum(p.numel() for p in model.parameters())
24
+
25
+
26
+ def build_delay_pattern_mask(
27
+ input_ids: torch.LongTensor,
28
+ bos_token_id: int,
29
+ pad_token_id: int,
30
+ ):
31
+ """Implement the delay pattern proposed in "Simple and Controllable Music Generation", https://arxiv.org/pdf/2306.05284
32
+
33
+ In the delay pattern, each codebook is offset by the previous codebook by
34
+ one. We insert a special delay token at the start of the sequence if its delayed, and append pad token once the sequence finishes.
35
+
36
+ Take the example where there are 4 codebooks and audio sequence length=5. After shifting, the output should have length seq_len + num_codebooks - 1
37
+
38
+ - [ *, *, *, *, *, P, P, P]
39
+ - [ B, *, *, *, *, *, P, P]
40
+ - [ B, B, *, *, *, *, *, P]
41
+ - [ B, B, B, *, *, *, *, *]
42
+
43
+ where B indicates the delay token id, P is the special padding token id and `*` indicates that the original audio token.
44
+
45
+ Now let's consider the case where we have a sequence of audio tokens to condition on.
46
+ The audio tokens were originally in the following non-delayed form:
47
+
48
+ - [a, b]
49
+ - [c, d]
50
+ - [e, f]
51
+ - [g, h]
52
+
53
+ After conversion, we get the following delayed form:
54
+ - [a, b, -1, -1, -1]
55
+ - [B, c, d, -1, -1]
56
+ - [B, B, e, f, -1]
57
+ - [B, B, B, g, h]
58
+
59
+ Note that we have a special token `-1` that indicates it should be replaced by a new token we see in the generation phase.
60
+ In that case, we should override the `-1` tokens in auto-regressive generation.
61
+
62
+ Args:
63
+ input_ids (:obj:`torch.LongTensor`):
64
+ The input ids of the prompt. It will have shape (bsz, num_codebooks, seq_len).
65
+ bos_token_id (:obj:`int`):
66
+ The id of the special delay token
67
+ pad_token_id (:obj:`int`):
68
+ The id of the padding token. Should be the same as eos_token_id.
69
+
70
+ Returns:
71
+ input_ids (:obj:`torch.LongTensor`):
72
+ The transformed input ids with delay pattern applied. It will have shape (bsz, num_codebooks, seq_len + num_codebooks - 1).
73
+ input_ids_with_gen_mask (:obj:`torch.LongTensor`):
74
+ The transformed input ids with delay pattern applied. The -1 in the output indicates new tokens that should be generated.
75
+
76
+ """
77
+ bsz, num_codebooks, seq_len = input_ids.shape
78
+
79
+ new_seq_len = seq_len + num_codebooks - 1
80
+ input_ids_with_gen_mask = torch.ones((bsz, num_codebooks, new_seq_len), dtype=torch.long, device=input_ids.device)
81
+ bos_mask = torch.tril(input_ids_with_gen_mask, -1) > 0
82
+ eos_mask = torch.triu(input_ids_with_gen_mask, seq_len) > 0
83
+ input_ids_with_gen_mask[bos_mask] = bos_token_id
84
+ input_ids_with_gen_mask[(~bos_mask) & (~eos_mask)] = input_ids.reshape(-1)
85
+ input_ids = input_ids_with_gen_mask.clone()
86
+ input_ids[eos_mask] = pad_token_id
87
+ input_ids_with_gen_mask[eos_mask] = -1
88
+ return input_ids, input_ids_with_gen_mask
89
+
90
+
91
+ def revert_delay_pattern(data):
92
+ """Convert samples encoded with delay pattern back to the original form.
93
+
94
+ Args:
95
+ data (:obj:`torch.Tensor`):
96
+ The data with delay pattern applied. It will have shape (num_codebooks, seq_len + num_codebooks - 1).
97
+
98
+ Returns:
99
+ ret (:obj:`torch.Tensor`):
100
+ Recovered data with delay pattern removed. It will have shape (num_codebooks, seq_len).
101
+ """
102
+ assert len(data.shape) == 2
103
+ out_l = []
104
+ num_codebooks = data.shape[0]
105
+ for i in range(num_codebooks):
106
+ out_l.append(data[i : (i + 1), i : (data.shape[1] - num_codebooks + 1 + i)])
107
+ return torch.cat(out_l, dim=0)
108
+
109
+
110
+ def merge_input_ids_with_audio_features(
111
+ audio_features_embed,
112
+ audio_features_length,
113
+ audio_in_embed,
114
+ audio_in_ids_start,
115
+ audio_out_embed,
116
+ audio_out_ids_start,
117
+ audio_in_token_idx,
118
+ audio_out_token_idx,
119
+ inputs_embeds,
120
+ input_ids,
121
+ attention_mask,
122
+ label_ids,
123
+ pad_token_id,
124
+ ignore_index=-100,
125
+ round_to=8,
126
+ left_padding=True,
127
+ ):
128
+ """
129
+ Merge input_ids with audio features into final embeddings.
130
+
131
+ Args:
132
+ audio_features_embed (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`):
133
+ Encoded vectors of all audios in the batch (obtained from the semantic encoder)
134
+ audio_features_length (`torch.LongTensor` of shape `(num_audios,)`):
135
+ The length of audio embeddings of each audio as stacked in `audio_features_embed`
136
+ audio_in_embed (`torch.Tensor` of shape `(total_num_audio_in_tokens, embed_dim)`):
137
+ The embeddings of audio-in tokens
138
+ audio_in_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
139
+ The start index of the audio-in tokens for each audio
140
+ audio_out_embed (`torch.Tensor` of shape `(total_num_audio_out_tokens, embed_dim)`):
141
+ The embeddings of audio-out tokens
142
+ audio_out_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
143
+ The start index of the audio-out tokens for each audio
144
+ audio_in_token_idx
145
+ The index of the audio-in token in the vocabulary
146
+ audio_out_token_idx
147
+ The index of the audio-out token in the vocabulary
148
+ inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
149
+ Token embeddings before merging with audio embeddings
150
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
151
+ Input_ids of tokens, possibly filled with audio token
152
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
153
+ Mask to avoid performing attention on padding token indices.
154
+ label_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
155
+ labels need to be recalculated to support training (if provided)
156
+ pad_token_id (`int`):
157
+ The index of the pad token in the vocabulary
158
+ ignore_index
159
+ The index to ignore in the loss calculation
160
+ round_to
161
+ The number to round to for padding
162
+ left_padding
163
+ Whether to apply left padding
164
+
165
+ Returns:
166
+ final_embedding
167
+ The final embeddings after merging audio embeddings with text embeddings.
168
+ final_attention_mask
169
+ The final attention mask after merging audio embeddings with text embeddings.
170
+ final_labels
171
+ The labels for the text stream
172
+ position_ids
173
+ Positional ids for the merged data
174
+ final_input_ids
175
+ The final input_ids after merging audio embeddings with text embeddings.
176
+ final_audio_in_mask
177
+ Mask for audio-in embeddings
178
+ final_audio_in_discrete_codes_mask
179
+ Mask for audio-in discrete tokens
180
+ final_audio_out_mask
181
+ Mask for audio-out embeddings
182
+
183
+ Explanation:
184
+ each audio has variable length embeddings, with length specified by
185
+ - audio_features_length
186
+ - audio_in_ids_start
187
+ - audio_out_ids_start
188
+
189
+ Task:
190
+ - fill each <|AUDIO|> with audio embeddings (it can be the combination of embeddings extracted by WhisperEncoder and embeddings from audio codebooks)
191
+ - fill each <|AUDIO_OUT|> with the audio-out embeddings
192
+
193
+ Example:
194
+ <|AUDIO_OUT|>: X (5 tokens), Y (3 tokens)
195
+ <|AUDIO|>: Z (8 tokens)
196
+
197
+ X, Y are in the same sequence (in-context voice-clone). Z is in a different sequence (audio understanding).
198
+ if right padding
199
+ input_ids: [
200
+ a b c d e f X g h i j k Y l m
201
+ o p q r Z s t u v _ _ _ _ _ _
202
+ ]
203
+ input_ids should be: [
204
+ a b c d e f X X X X X g h i j k Y Y Y l m
205
+ o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
206
+ ]
207
+ labels should be: [
208
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
209
+ o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
210
+ ]
211
+ elif left padding
212
+ input_ids: [
213
+ a b c d e f X g h i j k Y l m
214
+ _ _ _ _ _ _ o p q r Z s t u v
215
+ ]
216
+ input_ids should be: [
217
+ a b c d e f X X X X X g h i j k Y Y Y l m
218
+ _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
219
+ ]
220
+ labels should be: [
221
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
222
+ _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
223
+ ]
224
+
225
+ """
226
+ if label_ids is None:
227
+ skip_labels = True
228
+ else:
229
+ skip_labels = False
230
+ if audio_features_embed is not None and audio_features_embed.shape[0] == 0:
231
+ audio_features_embed = None
232
+ if audio_in_embed is not None and audio_in_embed.shape[0] == 0:
233
+ audio_in_embed = None
234
+ if audio_out_embed is not None and audio_out_embed.shape[0] == 0:
235
+ audio_out_embed = None
236
+
237
+ batch_size, sequence_length, embed_dim = inputs_embeds.shape
238
+
239
+ target_device = inputs_embeds.device
240
+ if left_padding is None:
241
+ left_padding = torch.any(attention_mask[:, 0] == 0)
242
+
243
+ audio_in_token_mask = input_ids == audio_in_token_idx
244
+ audio_out_token_mask = input_ids == audio_out_token_idx
245
+ text_token_mask = (input_ids != audio_in_token_idx) & (input_ids != audio_out_token_idx)
246
+
247
+ # 1. Calculate the number of tokens for each placeholder (like [<|AUDIO|>, <|AUDIO_OUT|>]).
248
+ token_placeholder_num = torch.ones_like(input_ids)
249
+
250
+ if audio_features_embed is not None:
251
+ num_audios, max_audio_tokens, _ = audio_features_embed.shape
252
+ audio_in_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to(
253
+ audio_features_length.device
254
+ ) < audio_features_length.unsqueeze(1)
255
+ masked_audio_in_features = audio_features_embed[audio_in_features_mask].view(-1, embed_dim)
256
+ token_placeholder_num[audio_in_token_mask] = audio_features_length.long()
257
+
258
+ if audio_in_embed is not None:
259
+ audio_in_codes_length = torch.concat(
260
+ [
261
+ audio_in_ids_start[1:] - audio_in_ids_start[:-1],
262
+ torch.tensor(
263
+ [audio_in_embed.shape[0] - audio_in_ids_start[-1]],
264
+ device=audio_in_ids_start.device,
265
+ dtype=torch.long,
266
+ ),
267
+ ],
268
+ dim=0,
269
+ )
270
+ if audio_features_embed is not None:
271
+ token_placeholder_num[audio_in_token_mask] += audio_in_codes_length.long()
272
+ else:
273
+ token_placeholder_num[audio_in_token_mask] = audio_in_codes_length.long()
274
+
275
+ if audio_out_embed is not None:
276
+ audio_out_codes_length = torch.concat(
277
+ [
278
+ audio_out_ids_start[1:] - audio_out_ids_start[:-1],
279
+ torch.tensor(
280
+ [audio_out_embed.shape[0] - audio_out_ids_start[-1]],
281
+ device=audio_out_ids_start.device,
282
+ dtype=torch.long,
283
+ ),
284
+ ],
285
+ dim=0,
286
+ )
287
+ token_placeholder_num[audio_out_token_mask] = audio_out_codes_length.long()
288
+
289
+ new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
290
+ max_token_num = _ceil_to_nearest(token_placeholder_num.sum(-1).max(), round_to)
291
+ nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1]
292
+
293
+ if left_padding:
294
+ new_token_positions += nb_audio_pad[:, None] # offset for left padding
295
+
296
+ # 2. Create the full embedding, already padded to the maximum position
297
+ final_embedding = torch.zeros(
298
+ (batch_size, max_token_num, embed_dim), dtype=inputs_embeds.dtype, device=inputs_embeds.device
299
+ )
300
+ final_attention_mask = torch.zeros(
301
+ (batch_size, max_token_num), dtype=attention_mask.dtype, device=inputs_embeds.device
302
+ )
303
+ final_input_ids = torch.full(
304
+ (batch_size, max_token_num), pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device
305
+ )
306
+ if skip_labels:
307
+ final_labels = None
308
+ else:
309
+ final_labels = torch.full(
310
+ (batch_size, max_token_num), ignore_index, dtype=label_ids.dtype, device=inputs_embeds.device
311
+ )
312
+
313
+ final_audio_in_mask = torch.full((batch_size, max_token_num), False, dtype=torch.bool, device=inputs_embeds.device)
314
+ final_audio_in_discrete_codes_mask = torch.full(
315
+ (batch_size, max_token_num), False, dtype=torch.bool, device=inputs_embeds.device
316
+ )
317
+ final_audio_out_mask = torch.full(
318
+ (batch_size, max_token_num), False, dtype=torch.bool, device=inputs_embeds.device
319
+ )
320
+ # 3. Get the audio-in token positions and audio-out token positions
321
+ batch_id = torch.arange(batch_size, device=target_device).unsqueeze(1).expand(batch_size, sequence_length)
322
+ audio_in_batch_id = batch_id[audio_in_token_mask] # Shape (num_audio_in,)
323
+ audio_out_batch_id = batch_id[audio_out_token_mask] # Shape (num_audio_out,)
324
+ audio_features_token_ends = new_token_positions[audio_in_token_mask] # Shape (num_audio_in,)
325
+ audio_out_embed_ends = new_token_positions[audio_out_token_mask] # Shape (num_audio_out,)
326
+
327
+ if audio_in_embed is not None:
328
+ # Fill in the audio-in embeddings
329
+ seq_indices = (
330
+ torch.arange(max_token_num, device=target_device)
331
+ .unsqueeze(0)
332
+ .expand(audio_in_ids_start.shape[0], max_token_num)
333
+ )
334
+ audio_in_embed_token_starts = audio_features_token_ends - audio_in_codes_length + 1
335
+ batch_indices, col_indices = torch.where(
336
+ (seq_indices >= audio_in_embed_token_starts.unsqueeze(1))
337
+ & (seq_indices <= audio_features_token_ends.unsqueeze(1))
338
+ )
339
+ batch_indices = audio_in_batch_id[batch_indices]
340
+ final_embedding[batch_indices, col_indices] = audio_in_embed
341
+ final_input_ids[batch_indices, col_indices] = audio_in_token_idx
342
+ if not skip_labels:
343
+ final_labels[batch_indices, col_indices] = ignore_index
344
+ final_audio_in_mask[batch_indices, col_indices] = True
345
+ final_audio_in_discrete_codes_mask[batch_indices, col_indices] = True
346
+ audio_features_token_ends = audio_features_token_ends - audio_in_codes_length
347
+
348
+ if audio_features_embed is not None:
349
+ # Fill in the audio features
350
+ seq_indices = (
351
+ torch.arange(max_token_num, device=target_device)
352
+ .unsqueeze(0)
353
+ .expand(audio_features_embed.shape[0], max_token_num)
354
+ )
355
+ audio_features_token_starts = audio_features_token_ends - audio_features_length + 1
356
+ batch_indices, col_indices = torch.where(
357
+ (seq_indices >= audio_features_token_starts.unsqueeze(1))
358
+ & (seq_indices <= audio_features_token_ends.unsqueeze(1))
359
+ )
360
+ batch_indices = audio_in_batch_id[batch_indices]
361
+ final_embedding[batch_indices, col_indices] = masked_audio_in_features
362
+ final_input_ids[batch_indices, col_indices] = audio_in_token_idx
363
+ if not skip_labels:
364
+ final_labels[batch_indices, col_indices] = ignore_index
365
+ final_audio_in_mask[batch_indices, col_indices] = True
366
+
367
+ if audio_out_embed is not None:
368
+ # Fill in the audio-out embeddings
369
+ seq_indices = (
370
+ torch.arange(max_token_num, device=target_device)
371
+ .unsqueeze(0)
372
+ .expand(audio_out_ids_start.shape[0], max_token_num)
373
+ )
374
+ audio_out_embed_token_starts = audio_out_embed_ends - audio_out_codes_length + 1
375
+ batch_indices, col_indices = torch.where(
376
+ (seq_indices >= audio_out_embed_token_starts.unsqueeze(1))
377
+ & (seq_indices <= audio_out_embed_ends.unsqueeze(1))
378
+ )
379
+ batch_indices = audio_out_batch_id[batch_indices]
380
+ final_embedding[batch_indices, col_indices] = audio_out_embed
381
+ final_input_ids[batch_indices, col_indices] = audio_out_token_idx
382
+ if not skip_labels:
383
+ final_labels[batch_indices, col_indices] = ignore_index
384
+ final_audio_out_mask[batch_indices, col_indices] = True
385
+
386
+ # Fill in the original text embeddings and labels
387
+ batch_indices, non_audio_indices = torch.where(text_token_mask)
388
+ text_to_overwrite = new_token_positions[batch_indices, non_audio_indices]
389
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_audio_indices]
390
+ if not skip_labels:
391
+ final_labels[batch_indices, text_to_overwrite] = label_ids[batch_indices, non_audio_indices]
392
+ final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_audio_indices]
393
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_audio_indices]
394
+ final_attention_mask = final_attention_mask | final_audio_in_mask | final_audio_out_mask
395
+
396
+ # Trim the tensor if there are redundant padding tokens
397
+ if left_padding:
398
+ first_non_zero_loc = final_attention_mask.sum(0).nonzero()[0]
399
+ first_non_zero_loc = (first_non_zero_loc // round_to) * round_to
400
+ if first_non_zero_loc > 0:
401
+ final_attention_mask = final_attention_mask[:, first_non_zero_loc:]
402
+ final_embedding = final_embedding[:, first_non_zero_loc:]
403
+ if not skip_labels:
404
+ final_labels = final_labels[:, first_non_zero_loc:]
405
+ final_input_ids = final_input_ids[:, first_non_zero_loc:]
406
+ final_audio_in_mask = final_audio_in_mask[:, first_non_zero_loc:]
407
+ final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, first_non_zero_loc:]
408
+ final_audio_out_mask = final_audio_out_mask[:, first_non_zero_loc:]
409
+ else:
410
+ # We have done right padding, so we need to trim the mask
411
+ last_non_zero_loc = final_attention_mask.sum(0).nonzero()[-1] + 1
412
+ last_non_zero_loc = ((last_non_zero_loc + round_to - 1) // round_to) * round_to
413
+ if last_non_zero_loc < max_token_num:
414
+ final_attention_mask = final_attention_mask[:, :last_non_zero_loc]
415
+ final_embedding = final_embedding[:, :last_non_zero_loc]
416
+ if not skip_labels:
417
+ final_labels = final_labels[:, :last_non_zero_loc]
418
+ final_input_ids = final_input_ids[:, :last_non_zero_loc]
419
+ final_audio_in_mask = final_audio_in_mask[:, :last_non_zero_loc]
420
+ final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, :last_non_zero_loc]
421
+ final_audio_out_mask = final_audio_out_mask[:, :last_non_zero_loc]
422
+
423
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
424
+ return (
425
+ final_embedding,
426
+ final_attention_mask,
427
+ final_labels,
428
+ position_ids,
429
+ final_input_ids,
430
+ final_audio_in_mask,
431
+ final_audio_in_discrete_codes_mask,
432
+ final_audio_out_mask,
433
+ )
434
+
435
+
436
+ def is_deepspeed_ulysses_enabled():
437
+ if deepspeed_groups is None:
438
+ return False
439
+
440
+ """Check if sequence parallelism is enabled."""
441
+ return deepspeed_groups._get_sequence_parallel_world_size() > 1
442
+
443
+
444
+ def support_deepspeed_ulysses(module):
445
+ """A decorator around Pytorch module. It is needed for the module that needs access to sequence parallel info."""
446
+ module._sp_size = None
447
+ module._sp_rank = None
448
+ module._sp_group = None
449
+
450
+ @property
451
+ def sp_size(self):
452
+ if self._sp_size is None:
453
+ self._sp_size = 1
454
+ if is_deepspeed_ulysses_enabled():
455
+ self._sp_size = deepspeed_groups._get_sequence_parallel_group().size()
456
+ return self._sp_size
457
+
458
+ @property
459
+ def sp_rank(self):
460
+ if self._sp_rank is None:
461
+ self._sp_rank = 0
462
+ if is_deepspeed_ulysses_enabled():
463
+ self._sp_rank = deepspeed_groups._get_sequence_parallel_rank()
464
+ return self._sp_rank
465
+
466
+ @property
467
+ def sp_group(self):
468
+ if self._sp_group is None and is_deepspeed_ulysses_enabled():
469
+ self._sp_group = deepspeed_groups._get_sequence_parallel_group()
470
+ return self._sp_group
471
+
472
+ module.sp_size = sp_size
473
+ module.sp_rank = sp_rank
474
+ module.sp_group = sp_group
475
+
476
+ return module
477
+
478
+
479
+ def deepspeed_ulysses_attention(seq_dim=1, head_dim=2):
480
+ """Perform all-to-all before and after the attention function."""
481
+
482
+ def attention_decorator(attn_func=None):
483
+ def wrapped(*args, **kwargs):
484
+ if is_deepspeed_ulysses_enabled():
485
+ sp_group = deepspeed_groups._get_sequence_parallel_group()
486
+ scatter_idx = head_dim # Scatter on num_heads dimension
487
+ gather_idx = seq_dim # Gather on seq_len dimension
488
+ batch_dim_idx = 0
489
+ args = list(args)
490
+ args[0] = _SeqAllToAll.apply(sp_group, args[0], scatter_idx, gather_idx, batch_dim_idx)
491
+ args[1] = _SeqAllToAll.apply(sp_group, args[1], scatter_idx, gather_idx, batch_dim_idx)
492
+ args[2] = _SeqAllToAll.apply(sp_group, args[2], scatter_idx, gather_idx, batch_dim_idx)
493
+ args = tuple(args)
494
+
495
+ attn_output = attn_func(*args, **kwargs)
496
+
497
+ if is_deepspeed_ulysses_enabled():
498
+ scatter_idx = seq_dim # Scatter back on seq_len dimension
499
+ gather_idx = head_dim # Gather on num_heads dimension
500
+ batch_dim_idx = 0
501
+ attn_output = _SeqAllToAll.apply(sp_group, attn_output, scatter_idx, gather_idx, batch_dim_idx)
502
+
503
+ return attn_output
504
+
505
+ return wrapped
506
+
507
+ return attention_decorator
508
+
509
+
510
+ def deepspeed_ulysses_rope(state_seq_dim=2, trig_seq_dim=1):
511
+ """Slice the corresponding cos and sin chunks for rope."""
512
+
513
+ def rope_decorator(rope_func=None):
514
+ def wrapped(*args, **kwargs):
515
+ if is_deepspeed_ulysses_enabled():
516
+ sp_rank = deepspeed_groups._get_sequence_parallel_rank()
517
+ args = list(args)
518
+ seq_chunk_size = args[0].size(state_seq_dim)
519
+ args[2] = torch.narrow(args[2], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
520
+ args[3] = torch.narrow(args[3], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
521
+ args = tuple(args)
522
+
523
+ return rope_func(*args, **kwargs)
524
+
525
+ return wrapped
526
+
527
+ return rope_decorator
528
+
529
+
530
+ def _gather_tensors(input_, group=None):
531
+ """Gather tensors and concatenate them along a dimension."""
532
+ input_ = input_.contiguous()
533
+ world_size = torch.distributed.get_world_size(group)
534
+ if world_size == 1:
535
+ return input_
536
+ tensor_shapes = [
537
+ torch.empty(len(input_.size()), dtype=torch.int64, device=input_.device) for _ in range(world_size)
538
+ ]
539
+ input_size = torch.tensor(input_.size(), dtype=torch.int64, device=input_.device)
540
+ torch.distributed.all_gather(tensor_shapes, input_size, group=group)
541
+ gathered_buffers = [
542
+ torch.empty(tensor_shapes[i].tolist(), dtype=input_.dtype, device=input_.device) for i in range(world_size)
543
+ ]
544
+ torch.distributed.all_gather(gathered_buffers, input_, group=group)
545
+ return gathered_buffers
546
+
547
+
548
+ def _scatter_tensors(input_, group=None):
549
+ """Scatter tensors."""
550
+ world_size = torch.distributed.get_world_size(group)
551
+ if world_size == 1:
552
+ return input_
553
+ rank = torch.distributed.get_rank(group)
554
+ return input_[rank]
555
+
556
+
557
+ class _GatherTensors(torch.autograd.Function):
558
+ """All gather tensors among the ranks."""
559
+
560
+ @staticmethod
561
+ def symbolic(graph, input_, group):
562
+ return _gather_tensors(input_, group)
563
+
564
+ @staticmethod
565
+ def forward(ctx, input_, group):
566
+ ctx.group = group
567
+ return torch.nested.as_nested_tensor(_gather_tensors(input_, group), layout=torch.jagged)
568
+
569
+ @staticmethod
570
+ def backward(ctx, grad_output):
571
+ return _scatter_tensors(grad_output, ctx.group), None
572
+
573
+
574
+ def all_gather_tensors(input_, size=None, dim=0, group=None):
575
+ if torch.distributed.get_world_size(group) == 1:
576
+ # no sequence parallelism
577
+ return input_
578
+ gathered_tensors = _GatherTensors.apply(input_, group)
579
+
580
+ if size:
581
+ split_gathered_tensors = []
582
+ for s, gathered_tensor in zip(size, gathered_tensors):
583
+ split_gathered_tensor = torch.split(gathered_tensor, s.tolist())
584
+ split_gathered_tensors.append(split_gathered_tensor)
585
+
586
+ gathered_tensors = [y for x in zip(*split_gathered_tensors) for y in x]
587
+
588
+ return torch.cat(gathered_tensors, dim).contiguous()
589
+
590
+
591
+ def get_sequence_data_parallel_world_size():
592
+ return torch.distributed.get_world_size()
593
+
594
+
595
+ def get_sequence_data_parallel_rank():
596
+ return torch.distributed.get_rank()
597
+
598
+
599
+ def get_sequence_data_parallel_group():
600
+ return torch.distributed.group.WORLD
601
+
602
+
603
+ if is_deepspeed_available():
604
+ deepspeed_groups._get_sequence_data_parallel_world_size = get_sequence_data_parallel_world_size
605
+ deepspeed_groups._get_sequence_data_parallel_rank = get_sequence_data_parallel_rank
606
+ deepspeed_groups._get_sequence_data_parallel_group = get_sequence_data_parallel_group
607
+
608
+
609
+ def _gather_tokens(input_, dim=0, group=None):
610
+ """Gather tensors and concatenate them along a dimension"""
611
+ input_ = input_.contiguous()
612
+ world_size = torch.distributed.get_world_size(group)
613
+ if world_size == 1:
614
+ return input_
615
+
616
+ gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device)
617
+ torch.distributed.all_gather_into_tensor(gather_buffer, input_, group=group)
618
+ if dim == 0:
619
+ shape = list(input_.size())
620
+ shape[0] = shape[0] * world_size
621
+ output = gather_buffer.view(shape)
622
+ else:
623
+ tensor_list = [
624
+ gather_buffer.narrow(0, input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size)
625
+ ]
626
+ # Note: torch.cat already creates a contiguous tensor.
627
+ output = torch.cat(tensor_list, dim=dim).contiguous()
628
+
629
+ return output
630
+
631
+
632
+ def _drop_tokens(input_, dim=0, group=None):
633
+ """Divide a tensor among the sequence parallel ranks"""
634
+ world_size = torch.distributed.get_world_size(group)
635
+ if world_size == 1:
636
+ return input_
637
+ this_rank = torch.distributed.get_rank(group)
638
+ assert input_.shape[dim] % world_size == 0, (
639
+ f"input dimension {dim} ({input_.shape[dim]}) is not divisible by sequence parallel world size ({world_size})"
640
+ )
641
+ chunk_size = input_.shape[dim] // world_size
642
+
643
+ return torch.narrow(input_, dim, this_rank * chunk_size, chunk_size)
644
+
645
+
646
+ class _DropTokens(torch.autograd.Function):
647
+ "Divide tokens equally among the sequence parallel ranks"
648
+
649
+ @staticmethod
650
+ def symbolic(graph, input_, dim, group, grad_scale):
651
+ return _drop_tokens(input_, dim, group)
652
+
653
+ @staticmethod
654
+ def forward(ctx, input_, dim, group, grad_scale):
655
+ ctx.dim = dim
656
+ ctx.group = group
657
+ ctx.grad_scale = grad_scale
658
+ return _drop_tokens(input_, dim, group)
659
+
660
+ @staticmethod
661
+ def backward(ctx, grad_output):
662
+ grad_input = _gather_tokens(grad_output, ctx.dim, ctx.group)
663
+ if ctx.grad_scale != 1:
664
+ grad_input /= ctx.grad_scale
665
+ return grad_input, None, None, None
666
+
667
+
668
+ class _GatherTokens(torch.autograd.Function):
669
+ "Gather tokens among the sequence parallel ranks"
670
+
671
+ @staticmethod
672
+ def symbolic(graph, input_, dim, group, grad_scale):
673
+ return _gather_tokens(input_, dim, group)
674
+
675
+ @staticmethod
676
+ def forward(ctx, input_, dim, group, grad_scale):
677
+ ctx.dim = dim
678
+ ctx.group = group
679
+ ctx.grad_scale = grad_scale
680
+ return _gather_tokens(input_, dim, group)
681
+
682
+ @staticmethod
683
+ def backward(ctx, grad_output):
684
+ grad_input = _drop_tokens(grad_output, ctx.dim, ctx.group)
685
+ if ctx.grad_scale != 1:
686
+ grad_input *= ctx.grad_scale
687
+ return grad_input, None, None, None
688
+
689
+
690
+ def drop_tokens(input_, dim=0, group=None, grad_scale=1):
691
+ if torch.distributed.get_world_size(group) == 1:
692
+ # no sequence parallelism
693
+ return input_
694
+ return _DropTokens.apply(input_, dim, group, grad_scale)
695
+
696
+
697
+ def gather_tokens(input_, dim=0, group=None, grad_scale=1):
698
+ if torch.distributed.get_world_size(group) == 1:
699
+ # no sequence parallelism
700
+ return input_
701
+ return _GatherTokens.apply(input_, dim, group, grad_scale)
702
+
703
+
704
+ def sequence_chunking_per_rank(sp_size, sp_rank, *args, dim=1):
705
+ """
706
+ Slice the inputs to create chuncks per the sequence parallel rank. This is used for the context parallel training.
707
+
708
+ Args:
709
+ sp_size (`int`):
710
+ Sequence parallel size.
711
+ sp_rank (`int`):
712
+ Sequence parallel rank for the current process.
713
+ dim (`int`):
714
+ The dimension to slice
715
+ """
716
+ if sp_size == 1:
717
+ return args[0] if len(args) == 1 else args
718
+
719
+ seq_length = args[0].size(dim)
720
+ for arg in args[1:]:
721
+ assert arg.size(dim) == seq_length, (
722
+ f"arg={arg} ({arg.shape[dim]}) does not have the same size as args[0] ({seq_length}) in dimension {dim}"
723
+ )
724
+ assert seq_length % sp_size == 0, (
725
+ f"dimension {dim} ({args[0].shape[dim]}) is not divisible by sequence parallel world size ({sp_size})"
726
+ )
727
+
728
+ sub_seq_length = seq_length // sp_size
729
+ sub_seq_start = sp_rank * sub_seq_length
730
+
731
+ output = []
732
+ for ind in args:
733
+ ind = torch.narrow(ind, dim, sub_seq_start, sub_seq_length)
734
+ output.append(ind)
735
+
736
+ return tuple(output) if len(output) > 1 else output[0]
737
+
738
+
739
+ @contextmanager
740
+ def disable_deepspeed_ulysses():
741
+ """Disable deepspeed ulysses (sequence parallelism) if it is enabled"""
742
+ if is_deepspeed_ulysses_enabled():
743
+ _old_get_sequence_parallel_world_size = deepspeed_groups._get_sequence_parallel_world_size
744
+
745
+ def _get_sequence_parallel_world_size():
746
+ return 1
747
+
748
+ deepspeed_groups._get_sequence_parallel_world_size = _get_sequence_parallel_world_size
749
+ try:
750
+ yield
751
+ finally:
752
+ deepspeed_groups._get_sequence_parallel_world_size = _old_get_sequence_parallel_world_size
753
+ else:
754
+ context = contextlib.nullcontext
755
+ with context():
756
+ yield
boson_multimodal/serve/serve_engine.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import torch
4
+ import numpy as np
5
+ from io import BytesIO
6
+ from dataclasses import dataclass
7
+ from typing import List, Optional, Union
8
+ from copy import deepcopy
9
+ from transformers import AutoTokenizer, AutoProcessor
10
+ from transformers.cache_utils import StaticCache
11
+ from transformers.generation.streamers import BaseStreamer
12
+ from transformers.generation.stopping_criteria import StoppingCriteria
13
+ from dataclasses import asdict
14
+ from loguru import logger
15
+ import threading
16
+ import librosa
17
+
18
+
19
+ from ..dataset.chatml_dataset import ChatMLSample, ChatMLDatasetSample, prepare_chatml_sample
20
+ from ..model.higgs_audio import HiggsAudioModel
21
+ from ..model.higgs_audio.utils import revert_delay_pattern
22
+ from ..data_collator.higgs_audio_collator import HiggsAudioSampleCollator
23
+ from ..audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer
24
+
25
+
26
+ @dataclass
27
+ class HiggsAudioStreamerDelta:
28
+ """Represents a chunk of generated content, either text or audio tokens."""
29
+
30
+ text: Optional[str] = None
31
+ text_tokens: Optional[torch.Tensor] = None
32
+ audio_tokens: Optional[torch.Tensor] = None
33
+ finish_reason: Optional[str] = None
34
+
35
+
36
+ class AsyncHiggsAudioStreamer(BaseStreamer):
37
+ """
38
+ Async streamer that handles both text and audio token generation from Higgs-Audio model.
39
+ Stores chunks in a queue to be consumed by downstream applications.
40
+
41
+ Parameters:
42
+ tokenizer (`AutoTokenizer`):
43
+ The tokenizer used to decode text tokens.
44
+ skip_prompt (`bool`, *optional*, defaults to `False`):
45
+ Whether to skip the prompt tokens in generation.
46
+ timeout (`float`, *optional*):
47
+ The timeout for the queue. If `None`, the queue will block indefinitely.
48
+ decode_kwargs (`dict`, *optional*):
49
+ Additional keyword arguments to pass to the tokenizer's `decode` method.
50
+
51
+ Examples:
52
+ ```python
53
+ >>> from transformers import AutoTokenizer
54
+ >>> from threading import Thread
55
+ >>> import asyncio
56
+
57
+ >>> tokenizer = AutoTokenizer.from_pretrained("path/to/higgs/tokenizer")
58
+ >>> model = HiggsAudioModel.from_pretrained("path/to/higgs/model")
59
+ >>> inputs = tokenizer(["Generate some text and audio:"], return_tensors="pt")
60
+
61
+ >>> async def main():
62
+ ... streamer = AsyncHiggsAudioStreamer(tokenizer)
63
+ ... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
64
+ ... thread = Thread(target=model.generate, kwargs=generation_kwargs)
65
+ ... thread.start()
66
+ ...
67
+ ... async for delta in streamer:
68
+ ... if delta.text is not None:
69
+ ... print("Text:", delta.text)
70
+ ... if delta.audio_tokens is not None:
71
+ ... print("Audio tokens shape:", delta.audio_tokens.shape)
72
+ >>> asyncio.run(main())
73
+ ```
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ tokenizer: "AutoTokenizer",
79
+ skip_prompt: bool = False,
80
+ timeout: Optional[float] = None,
81
+ audio_num_codebooks: int = 1,
82
+ **decode_kwargs,
83
+ ):
84
+ self.tokenizer = tokenizer
85
+ self.skip_prompt = skip_prompt
86
+ self.timeout = timeout
87
+ self.decode_kwargs = decode_kwargs
88
+ self.audio_num_codebooks = audio_num_codebooks
89
+ # Queue to store generated chunks
90
+ self.queue = asyncio.Queue()
91
+ self.stop_signal = None
92
+
93
+ # Get running event loop
94
+ self.loop = asyncio.get_running_loop()
95
+ self.has_asyncio_timeout = hasattr(asyncio, "timeout")
96
+
97
+ # State tracking
98
+ self.next_tokens_are_prompt = True
99
+
100
+ def put(self, value: torch.Tensor):
101
+ """
102
+ Receives tokens and processes them as either text or audio tokens.
103
+ For text tokens, decodes and caches them until complete words are formed.
104
+ For audio tokens, directly queues them.
105
+ """
106
+ if value.shape[0] > 1 and not self.next_tokens_are_prompt:
107
+ # This is likely audio tokens (shape: [audio_num_codebooks])
108
+ assert value.shape[0] == self.audio_num_codebooks, "Number of codebooks mismatch"
109
+ delta = HiggsAudioStreamerDelta(audio_tokens=value)
110
+ if self.loop.is_running():
111
+ self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
112
+ return
113
+
114
+ # Skip prompt tokens if configured
115
+ if self.skip_prompt and self.next_tokens_are_prompt:
116
+ self.next_tokens_are_prompt = False
117
+ return
118
+
119
+ # Process as text tokens
120
+ if len(value.shape) > 1:
121
+ value = value[0]
122
+
123
+ text = self.tokenizer.decode(value, **self.decode_kwargs)
124
+ delta = HiggsAudioStreamerDelta(text=text, text_tokens=value)
125
+ if self.loop.is_running():
126
+ self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
127
+
128
+ def end(self):
129
+ """Flushes any remaining text tokens and signals the end of generation."""
130
+ self.next_tokens_are_prompt = True
131
+ if self.loop.is_running():
132
+ self.loop.call_soon_threadsafe(self.queue.put_nowait, self.stop_signal)
133
+
134
+ def __aiter__(self):
135
+ return self
136
+
137
+ async def __anext__(self):
138
+ try:
139
+ if self.has_asyncio_timeout:
140
+ async with asyncio.timeout(self.timeout):
141
+ value = await self.queue.get()
142
+ else:
143
+ value = await asyncio.wait_for(self.queue.get(), timeout=self.timeout)
144
+ except asyncio.TimeoutError:
145
+ raise TimeoutError()
146
+ else:
147
+ if value == self.stop_signal:
148
+ raise StopAsyncIteration()
149
+ else:
150
+ return value
151
+
152
+
153
+ class AsyncStoppingCriteria(StoppingCriteria):
154
+ """
155
+ Stopping criteria that checks for stop signal from a threading event.
156
+
157
+ Args:
158
+ stop_signal (threading.Event): Event that will receive stop signals
159
+ """
160
+
161
+ def __init__(self, stop_signal: threading.Event):
162
+ self.stop_signal = stop_signal
163
+
164
+ def __call__(self, input_ids, scores, **kwargs) -> bool:
165
+ if self.stop_signal.is_set():
166
+ logger.info(f"Stop signal received. Can be caused by client disconnection.")
167
+ return True
168
+ return False
169
+
170
+
171
+ @dataclass
172
+ class HiggsAudioResponse:
173
+ audio: Optional[np.ndarray] = None
174
+ generated_audio_tokens: Optional[np.ndarray] = None
175
+ sampling_rate: Optional[int] = None
176
+ generated_text: str = ""
177
+ generated_text_tokens: Optional[np.ndarray] = None
178
+ usage: Optional[dict] = None
179
+
180
+
181
+ class HiggsAudioServeEngine:
182
+ def __init__(
183
+ self,
184
+ model_name_or_path: str,
185
+ audio_tokenizer_name_or_path: str,
186
+ tokenizer_name_or_path: Optional[str] = None,
187
+ device: str = "cuda",
188
+ torch_dtype: Union[torch.dtype, str] = "auto",
189
+ kv_cache_lengths: List[int] = [1024, 4096, 8192], # Multiple KV cache sizes
190
+ ):
191
+ """
192
+ Initialize the HiggsAudioServeEngine, a serving wrapper for the HiggsAudioModel.
193
+ The model, tokenizer, and audio tokenizer will be downloaded from the Hugging Face Hub if they are not local.
194
+
195
+ Args:
196
+ model_name_or_path (str):
197
+ The name or path of the model to load.
198
+ audio_tokenizer_name_or_path (str):
199
+ The name or path of the audio tokenizer to load.
200
+ tokenizer_name_or_path (str):
201
+ The name or path of the tokenizer to load.
202
+ device (str):
203
+ The device to use for the model.
204
+ kv_cache_lengths (List[int]):
205
+ The lengths of the KV caches to use for the model. Used for cuda graph capture when device is cuda.
206
+ torch_dtype (Union[torch.dtype, str]):
207
+ The dtype to use for the model.
208
+ """
209
+ self.device = device
210
+ self.model_name_or_path = model_name_or_path
211
+ self.torch_dtype = torch_dtype
212
+
213
+ # Initialize model and tokenizer
214
+ self.model = HiggsAudioModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).to(device)
215
+ logger.info(f"Loaded model from {model_name_or_path}, dtype: {self.model.dtype}")
216
+
217
+ if tokenizer_name_or_path is None:
218
+ tokenizer_name_or_path = model_name_or_path
219
+ logger.info(f"Loading tokenizer from {tokenizer_name_or_path}")
220
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
221
+
222
+ logger.info(f"Initializing Higgs Audio Tokenizer")
223
+ self.audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer_name_or_path, device=device)
224
+
225
+ self.audio_num_codebooks = self.model.config.audio_num_codebooks
226
+ self.audio_codebook_size = self.model.config.audio_codebook_size
227
+ self.audio_tokenizer_tps = self.audio_tokenizer.tps
228
+ self.samples_per_token = int(self.audio_tokenizer.sampling_rate // self.audio_tokenizer_tps)
229
+ self.hamming_window_len = 2 * self.audio_num_codebooks * self.samples_per_token
230
+ # Set the audio special tokens
231
+ self.model.set_audio_special_tokens(self.tokenizer)
232
+
233
+ # Prepare KV caches for different lengths
234
+ cache_config = deepcopy(self.model.config.text_config)
235
+ cache_config.num_hidden_layers = self.model.config.text_config.num_hidden_layers
236
+ if self.model.config.audio_dual_ffn_layers:
237
+ cache_config.num_hidden_layers += len(self.model.config.audio_dual_ffn_layers)
238
+ # A list of KV caches for different lengths
239
+ self.kv_caches = {
240
+ length: StaticCache(
241
+ config=cache_config,
242
+ max_batch_size=1,
243
+ max_cache_len=length,
244
+ device=self.model.device,
245
+ dtype=self.model.dtype,
246
+ )
247
+ for length in sorted(kv_cache_lengths)
248
+ }
249
+
250
+ if self.model.config.encode_whisper_embed:
251
+ logger.info(f"Loading whisper processor")
252
+ whisper_processor = AutoProcessor.from_pretrained(
253
+ "openai/whisper-large-v3-turbo",
254
+ trust_remote=True,
255
+ device=self.device,
256
+ )
257
+ else:
258
+ whisper_processor = None
259
+
260
+ # Reuse collator to prepare inference samples
261
+ self.collator = HiggsAudioSampleCollator(
262
+ whisper_processor=whisper_processor,
263
+ encode_whisper_embed=self.model.config.encode_whisper_embed,
264
+ audio_in_token_id=self.model.config.audio_in_token_idx,
265
+ audio_out_token_id=self.model.config.audio_out_token_idx,
266
+ audio_stream_bos_id=self.model.config.audio_stream_bos_id,
267
+ audio_stream_eos_id=self.model.config.audio_stream_eos_id,
268
+ pad_token_id=self.model.config.pad_token_id,
269
+ return_audio_in_tokens=False,
270
+ use_delay_pattern=self.model.config.use_delay_pattern,
271
+ audio_num_codebooks=self.model.config.audio_num_codebooks,
272
+ round_to=1,
273
+ )
274
+
275
+ # Capture CUDA graphs for each KV cache length
276
+ if device == "cuda":
277
+ logger.info(f"Capturing CUDA graphs for each KV cache length")
278
+ self.model.capture_model(self.kv_caches.values())
279
+
280
+ def _prepare_inputs(self, chat_ml_sample: ChatMLSample, force_audio_gen: bool = False):
281
+ input_tokens, _, audio_contents, _ = prepare_chatml_sample(
282
+ chat_ml_sample,
283
+ self.tokenizer,
284
+ )
285
+
286
+ postfix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
287
+ if force_audio_gen:
288
+ postfix += "<|audio_out_bos|>"
289
+ postfix = self.tokenizer.encode(postfix, add_special_tokens=False)
290
+ input_tokens.extend(postfix)
291
+
292
+ # Configure the audio inputs
293
+ audio_ids_l = []
294
+ for audio_content in audio_contents:
295
+ if audio_content.audio_url not in ["placeholder", ""]:
296
+ raw_audio, _ = librosa.load(audio_content.audio_url, sr=self.audio_tokenizer.sampling_rate)
297
+ elif audio_content.raw_audio is not None:
298
+ raw_audio, _ = librosa.load(
299
+ BytesIO(base64.b64decode(audio_content.raw_audio)), sr=self.audio_tokenizer.sampling_rate
300
+ )
301
+ else:
302
+ raw_audio = None
303
+
304
+ if raw_audio is not None:
305
+ audio_ids = self.audio_tokenizer.encode(raw_audio, self.audio_tokenizer.sampling_rate)
306
+ audio_ids_l.append(audio_ids.squeeze(0).cpu())
307
+
308
+ if len(audio_ids_l) > 0:
309
+ audio_ids_start = torch.tensor(
310
+ np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
311
+ dtype=torch.long,
312
+ device=self.device,
313
+ )[0:-1]
314
+ audio_ids_concat = torch.cat(audio_ids_l, dim=1)
315
+ else:
316
+ audio_ids_start = None
317
+ audio_ids_concat = None
318
+
319
+ sample = ChatMLDatasetSample(
320
+ input_ids=torch.LongTensor(input_tokens),
321
+ label_ids=None,
322
+ audio_ids_concat=audio_ids_concat,
323
+ audio_ids_start=audio_ids_start,
324
+ audio_waveforms_concat=None,
325
+ audio_waveforms_start=None,
326
+ audio_sample_rate=None,
327
+ audio_speaker_indices=None,
328
+ )
329
+ data = self.collator([sample])
330
+ inputs = asdict(data)
331
+ for k, v in inputs.items():
332
+ if isinstance(v, torch.Tensor):
333
+ inputs[k] = v.to(self.model.device)
334
+
335
+ return inputs
336
+
337
+ def _prepare_kv_caches(self):
338
+ for kv_cache in self.kv_caches.values():
339
+ kv_cache.reset()
340
+
341
+ def generate(
342
+ self,
343
+ chat_ml_sample: ChatMLSample,
344
+ max_new_tokens: int,
345
+ temperature: float = 0.7,
346
+ top_k: Optional[int] = None,
347
+ top_p: float = 0.95,
348
+ stop_strings: Optional[List[str]] = None,
349
+ force_audio_gen: bool = False,
350
+ ras_win_len: Optional[int] = 7,
351
+ ras_win_max_num_repeat: int = 2,
352
+ seed: Optional[int] = None,
353
+ ):
354
+ """
355
+ Generate audio from a chatml sample.
356
+ Args:
357
+ chat_ml_sample: A chatml sample.
358
+ max_new_tokens: The maximum number of new tokens to generate.
359
+ temperature: The temperature to use for the generation.
360
+ top_p: The top p to use for the generation.
361
+ stop_strings: A list of strings to stop the generation.
362
+ force_audio_gen: Whether to force audio generation. This ensures the model generates audio tokens rather than text tokens.
363
+ ras_win_len: The length of the RAS window. We use 7 by default. You can disable it by setting it to None or <=0.
364
+ ras_win_max_num_repeat: The maximum number of times to repeat the RAS window.
365
+ Returns:
366
+ A dictionary with the following keys:
367
+ audio: The generated audio.
368
+ sampling_rate: The sampling rate of the generated audio.
369
+ """
370
+ # Default stop strings
371
+ if stop_strings is None:
372
+ stop_strings = ["<|end_of_text|>", "<|eot_id|>"]
373
+ if ras_win_len is not None and ras_win_len <= 0:
374
+ ras_win_len = None
375
+
376
+ with torch.no_grad():
377
+ inputs = self._prepare_inputs(chat_ml_sample, force_audio_gen=force_audio_gen)
378
+ prompt_token_ids = inputs["input_ids"][0].cpu().numpy()
379
+
380
+ self._prepare_kv_caches()
381
+
382
+ outputs = self.model.generate(
383
+ **inputs,
384
+ max_new_tokens=max_new_tokens,
385
+ use_cache=True,
386
+ stop_strings=stop_strings,
387
+ tokenizer=self.tokenizer,
388
+ do_sample=False if temperature == 0.0 else True,
389
+ temperature=temperature,
390
+ top_k=top_k,
391
+ top_p=top_p,
392
+ past_key_values_buckets=self.kv_caches,
393
+ ras_win_len=ras_win_len,
394
+ ras_win_max_num_repeat=ras_win_max_num_repeat,
395
+ seed=seed,
396
+ )
397
+
398
+ if len(outputs[1]) > 0:
399
+ wv_list = []
400
+ for output_audio in outputs[1]:
401
+ vq_code = revert_delay_pattern(output_audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
402
+ wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
403
+ wv_list.append(wv_numpy)
404
+ wv_numpy = np.concatenate(wv_list)
405
+ else:
406
+ wv_numpy = None
407
+
408
+ # We only support one request at a time now
409
+ generated_text_tokens = outputs[0][0].cpu().numpy()[len(prompt_token_ids) :]
410
+ generated_text = self.tokenizer.decode(generated_text_tokens)
411
+ generated_audio_tokens = outputs[1][0].cpu().numpy()
412
+ return HiggsAudioResponse(
413
+ audio=wv_numpy,
414
+ generated_audio_tokens=generated_audio_tokens,
415
+ sampling_rate=self.audio_tokenizer.sampling_rate,
416
+ generated_text=generated_text,
417
+ generated_text_tokens=generated_text_tokens,
418
+ usage={
419
+ "prompt_tokens": prompt_token_ids.shape[0],
420
+ "completion_tokens": generated_text_tokens.shape[0] + generated_audio_tokens.shape[1],
421
+ "total_tokens": (
422
+ prompt_token_ids.shape[0] + generated_text_tokens.shape[0] + generated_audio_tokens.shape[1]
423
+ ),
424
+ "cached_tokens": 0,
425
+ },
426
+ )
427
+
428
+ async def generate_delta_stream(
429
+ self,
430
+ chat_ml_sample: ChatMLSample,
431
+ max_new_tokens: int,
432
+ temperature: float = 0.7,
433
+ top_k: Optional[int] = None,
434
+ top_p: float = 0.95,
435
+ stop_strings: Optional[List[str]] = None,
436
+ force_audio_gen: bool = False,
437
+ ras_win_len: Optional[int] = 7,
438
+ ras_win_max_num_repeat: int = 2,
439
+ seed: Optional[int] = None,
440
+ ):
441
+ """
442
+ Generate audio from a chatml sample.
443
+ Args:
444
+ chat_ml_sample: A chatml sample.
445
+ max_new_tokens: The maximum number of new tokens to generate.
446
+ temperature: The temperature to use for the generation.
447
+ top_p: The top p to use for the generation.
448
+ stop_strings: A list of strings to stop the generation.
449
+ force_audio_gen: Whether to force audio generation. This ensures the model generates audio tokens rather than text tokens.
450
+ ras_win_len: The length of the RAS window. We use 7 by default. You can disable it by setting it to None or <=0.
451
+ ras_win_max_num_repeat: The maximum number of times to repeat the RAS window.
452
+ Returns:
453
+ Delta AsyncGenerator
454
+ """
455
+ # Default stop strings
456
+ if stop_strings is None:
457
+ stop_strings = ["<|end_of_text|>", "<|eot_id|>"]
458
+ if ras_win_len is not None and ras_win_len <= 0:
459
+ ras_win_len = None
460
+
461
+ with torch.no_grad():
462
+ inputs = self._prepare_inputs(chat_ml_sample, force_audio_gen=force_audio_gen)
463
+
464
+ self._prepare_kv_caches()
465
+
466
+ streamer = AsyncHiggsAudioStreamer(
467
+ self.tokenizer,
468
+ audio_num_codebooks=self.model.config.audio_num_codebooks,
469
+ skip_prompt=True,
470
+ )
471
+ generation_kwargs = dict(
472
+ **inputs,
473
+ max_new_tokens=max_new_tokens,
474
+ use_cache=True,
475
+ stop_strings=stop_strings,
476
+ tokenizer=self.tokenizer,
477
+ do_sample=False if temperature == 0.0 else True,
478
+ temperature=temperature,
479
+ top_k=top_k,
480
+ top_p=top_p,
481
+ past_key_values_buckets=self.kv_caches,
482
+ ras_win_len=ras_win_len,
483
+ ras_win_max_num_repeat=ras_win_max_num_repeat,
484
+ seed=seed,
485
+ streamer=streamer,
486
+ )
487
+ thread = threading.Thread(target=self.model.generate, kwargs=generation_kwargs)
488
+ thread.start()
489
+
490
+ async for delta in streamer:
491
+ yield delta
boson_multimodal/serve/utils.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import base64
3
+ import re
4
+ import regex
5
+ from typing import AsyncGenerator, Union
6
+ import io
7
+ from pydub import AudioSegment
8
+ import torch
9
+ import numpy as np
10
+ from functools import lru_cache
11
+
12
+ from ..audio_processing.higgs_audio_tokenizer import HiggsAudioTokenizer
13
+
14
+
15
+ def random_uuid() -> str:
16
+ return str(uuid.uuid4().hex)
17
+
18
+
19
+ async def async_generator_wrap(first_element, gen: AsyncGenerator):
20
+ """Wrap an async generator with the first element."""
21
+ yield first_element
22
+ async for item in gen:
23
+ yield item
24
+
25
+
26
+ @lru_cache(maxsize=50)
27
+ def encode_base64_content_from_file(file_path: str) -> str:
28
+ """Encode a content from a local file to base64 format."""
29
+ # Read the MP3 file as binary and encode it directly to Base64
30
+ with open(file_path, "rb") as audio_file:
31
+ audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8")
32
+ return audio_base64
33
+
34
+
35
+ def pcm16_to_target_format(
36
+ np_audio: np.ndarray,
37
+ sample_rate: int,
38
+ bit_depth: int,
39
+ channels: int,
40
+ format: str,
41
+ target_rate: int,
42
+ ):
43
+ wav_audio = AudioSegment(
44
+ np_audio.tobytes(),
45
+ frame_rate=sample_rate,
46
+ sample_width=bit_depth // 8,
47
+ channels=channels,
48
+ )
49
+ if target_rate is not None and target_rate != sample_rate:
50
+ wav_audio = wav_audio.set_frame_rate(target_rate)
51
+
52
+ # Convert WAV to MP3
53
+ target_io = io.BytesIO()
54
+ wav_audio.export(target_io, format=format)
55
+ target_io.seek(0)
56
+
57
+ return target_io
58
+
59
+
60
+ chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
61
+
62
+
63
+ def contains_chinese(text: str):
64
+ return bool(chinese_char_pattern.search(text))
65
+
66
+
67
+ # remove blank between chinese character
68
+ def replace_blank(text: str):
69
+ out_str = []
70
+ for i, c in enumerate(text):
71
+ if c == " ":
72
+ if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
73
+ out_str.append(c)
74
+ else:
75
+ out_str.append(c)
76
+ return "".join(out_str)
77
+
78
+
79
+ def replace_corner_mark(text: str):
80
+ text = text.replace("²", "平方")
81
+ text = text.replace("³", "立方")
82
+ return text
83
+
84
+
85
+ # remove meaningless symbol
86
+ def remove_bracket(text: str):
87
+ text = text.replace("(", "").replace(")", "")
88
+ text = text.replace("【", "").replace("】", "")
89
+ text = text.replace("`", "").replace("`", "")
90
+ text = text.replace("——", " ")
91
+ return text
92
+
93
+
94
+ # split paragrah logic:
95
+ # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
96
+ # 2. cal sentence len according to lang
97
+ # 3. split sentence according to puncatation
98
+ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
99
+ def calc_utt_length(_text: str):
100
+ if lang == "zh":
101
+ return len(_text)
102
+ else:
103
+ return len(tokenize(_text))
104
+
105
+ def should_merge(_text: str):
106
+ if lang == "zh":
107
+ return len(_text) < merge_len
108
+ else:
109
+ return len(tokenize(_text)) < merge_len
110
+
111
+ if lang == "zh":
112
+ pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
113
+ else:
114
+ pounc = [".", "?", "!", ";", ":"]
115
+ if comma_split:
116
+ pounc.extend([",", ","])
117
+
118
+ if text[-1] not in pounc:
119
+ if lang == "zh":
120
+ text += "。"
121
+ else:
122
+ text += "."
123
+
124
+ st = 0
125
+ utts = []
126
+ for i, c in enumerate(text):
127
+ if c in pounc:
128
+ if len(text[st:i]) > 0:
129
+ utts.append(text[st:i] + c)
130
+ if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
131
+ tmp = utts.pop(-1)
132
+ utts.append(tmp + text[i + 1])
133
+ st = i + 2
134
+ else:
135
+ st = i + 1
136
+
137
+ final_utts = []
138
+ cur_utt = ""
139
+ for utt in utts:
140
+ if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
141
+ final_utts.append(cur_utt)
142
+ cur_utt = ""
143
+ cur_utt = cur_utt + utt
144
+ if len(cur_utt) > 0:
145
+ if should_merge(cur_utt) and len(final_utts) != 0:
146
+ final_utts[-1] = final_utts[-1] + cur_utt
147
+ else:
148
+ final_utts.append(cur_utt)
149
+
150
+ return final_utts
151
+
152
+
153
+ def is_only_punctuation(text: str):
154
+ # Regular expression: Match strings that consist only of punctuation marks or are empty.
155
+ punctuation_pattern = r"^[\p{P}\p{S}]*$"
156
+ return bool(regex.fullmatch(punctuation_pattern, text))
157
+
158
+
159
+ # spell Arabic numerals
160
+ def spell_out_number(text: str, inflect_parser):
161
+ new_text = []
162
+ st = None
163
+ for i, c in enumerate(text):
164
+ if not c.isdigit():
165
+ if st is not None:
166
+ num_str = inflect_parser.number_to_words(text[st:i])
167
+ new_text.append(num_str)
168
+ st = None
169
+ new_text.append(c)
170
+ else:
171
+ if st is None:
172
+ st = i
173
+ if st is not None and st < len(text):
174
+ num_str = inflect_parser.number_to_words(text[st:])
175
+ new_text.append(num_str)
176
+ return "".join(new_text)
177
+
178
+
179
+ def remove_emoji(text: str):
180
+ # Pattern to match emojis and their modifiers
181
+ # - Standard emoji range
182
+ # - Zero-width joiners (U+200D)
183
+ # - Variation selectors (U+FE0F, U+FE0E)
184
+ # - Skin tone modifiers (U+1F3FB to U+1F3FF)
185
+ emoji_pattern = re.compile(
186
+ r"["
187
+ r"\U00010000-\U0010FFFF" # Standard emoji range
188
+ r"\u200D" # Zero-width joiner
189
+ r"\uFE0F\uFE0E" # Variation selectors
190
+ r"\U0001F3FB-\U0001F3FF" # Skin tone modifiers
191
+ r"]+",
192
+ flags=re.UNICODE,
193
+ )
194
+ return emoji_pattern.sub(r"", text)
195
+
196
+
197
+ def remove_repeated_punctuations(text, punctuations):
198
+ if len(punctuations) == 0:
199
+ return text
200
+ pattern = f"[{re.escape(''.join(punctuations))}]" # Create regex pattern for given punctuations
201
+ return re.sub(rf"({pattern})\1+", r"\1", text)
202
+
203
+
204
+ def full_to_half_width(text: str) -> str:
205
+ """Convert full-width punctuation to half-width in a given string."""
206
+ full_width = "!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
207
+ half_width = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
208
+ trans_table = str.maketrans(full_width, half_width)
209
+ return text.translate(trans_table)
210
+
211
+
212
+ def split_interleaved_delayed_audios(
213
+ audio_data: Union[list[list[int]], torch.Tensor],
214
+ audio_tokenizer: HiggsAudioTokenizer,
215
+ audio_stream_eos_id: int,
216
+ ) -> list[tuple[list[list[int]], torch.Tensor]]:
217
+ separator = [audio_stream_eos_id] * audio_tokenizer.num_codebooks
218
+
219
+ # Convert separator to numpy array if audio_data is numpy array
220
+ if isinstance(audio_data, torch.Tensor):
221
+ audio_data = audio_data.transpose(1, 0)
222
+ separator = torch.tensor(separator)
223
+ # Find the indices where the rows equal the separator
224
+ split_indices = torch.where(torch.all(audio_data == separator, dim=1))[0]
225
+ start = 0
226
+ groups = []
227
+ for idx in split_indices:
228
+ groups.append(audio_data[start:idx].transpose(1, 0))
229
+ start = idx + 1
230
+ if start < len(audio_data):
231
+ groups.append(audio_data[start:].transpose(1, 0))
232
+ else:
233
+ groups = []
234
+ current = []
235
+ for row in audio_data:
236
+ current.append(row)
237
+
238
+ if row == separator:
239
+ groups.append(current)
240
+ current = []
241
+
242
+ # Don't forget the last group if there's no trailing separator
243
+ if current:
244
+ groups.append(current)
245
+
246
+ return groups
cmd.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 需要clone的声音放在examples/voice_prompts下
2
+
3
+
4
+ 中国的二零二七年技术路线图强调脑机接口(BCI)是与美国技术竞争的关键领域。工业和信息化部和其他六个中国机构发布了脑机接口的政策蓝图,包括芯片,电极和集成产品的突破目标。值得注意的研究包括用于改善中风患者抓握能力的微创血管介入电极,帮助四肢瘫痪患者的植入式处理器,以及具有快速稳定神经解码的超柔性硬币大小电极,用于人机交互。
5
+
6
+ MYHHHHH
7
+ ./conda_env/bin/python3 examples/generation.py \
8
+ --transcript "$(cat input)" \
9
+ --ref_audio xiaohei \
10
+ --temperature 0.3 \
11
+ --out_path generation.wav \
12
+ --model_path higgs-audio-v2-generation-3B-base \
13
+ --audio_tokenizer higgs-audio-v2-tokenizer
14
+
15
+
16
+ ./conda_env/bin/python3 examples/generation.py \
17
+ --transcript examples/transcript/multi_speaker/en_argument.txt \
18
+ --ref_audio belinda,MYHHHHH \
19
+ --seed 12345 \
20
+ --out_path generation.wav
21
+
22
+
23
+
24
+
25
+ ./conda_env/bin/python3 examples/interactive_generation.py \
26
+ --model_path ./higgs-audio-v2-generation-3B-base \
27
+ --audio_tokenizer ./higgs-audio-v2-tokenizer \
28
+ --temperature 0.8 \
29
+ --ref_audio MYHHHHH \
30
+ --output_dir ./my_outputs
31
+
32
+
33
+
34
+
35
+ 我们还在 “示例” 部分提供了一系列示例。下面我们将重点介绍几个示例,以帮助您使用 Higgs Audio v2。
examples/README.md ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Examples
2
+
3
+ > [!NOTE]
4
+ > If you do not like the audio you get, you can generate multiple times with different seeds. In addition, you may need to apply text normalization to get the best performance, e.g. converting 70 °F to "seventy degrees Fahrenheit", and converting "1 2 3 4" to "one two three four". The model also performs better in longer sentences. Right now, the model has not been post-trained, we will release the post-trained model in the future.
5
+
6
+ ## Single-speaker Audio Generation
7
+
8
+ ### Voice clone
9
+
10
+ ```bash
11
+ python3 generation.py \
12
+ --transcript transcript/single_speaker/en_dl.txt \
13
+ --ref_audio broom_salesman \
14
+ --seed 12345 \
15
+ --out_path generation.wav
16
+ ```
17
+
18
+ The model will read the transcript with the same voice as in the [reference audio](./voice_prompts/broom_salesman.wav). The technique is also called shallow voice clone.
19
+
20
+ We have some example audio prompts stored in [voice_prompts](./voice_prompts/). Feel free to pick one in the folder and try out the model. Here's another example that uses the voice of `belinda`. You can also add new own favorite voice in the folder and clone the voice.
21
+
22
+ ```bash
23
+ python3 generation.py \
24
+ --transcript transcript/single_speaker/en_dl.txt \
25
+ --ref_audio belinda \
26
+ --seed 12345 \
27
+ --out_path generation.wav
28
+ ```
29
+
30
+ #### (Experimental) Cross-lingual voice clone
31
+
32
+ This example demonstrates voice cloning with a Chinese prompt, where the synthesized speech is in English.
33
+
34
+ ```bash
35
+ python3 generation.py \
36
+ --transcript transcript/single_speaker/en_dl.txt \
37
+ --scene_prompt empty \
38
+ --ref_audio zh_man_sichuan \
39
+ --temperature 0.3 \
40
+ --seed 12345 \
41
+ --out_path generation.wav
42
+ ```
43
+
44
+ ### Smart voice
45
+
46
+ The model supports reading the transcript with a random voice.
47
+
48
+ ```bash
49
+ python3 generation.py \
50
+ --transcript transcript/single_speaker/en_dl.txt \
51
+ --seed 12345 \
52
+ --out_path generation.wav
53
+ ```
54
+
55
+ It also works for other languages like Chinese.
56
+
57
+ ```bash
58
+ python3 generation.py \
59
+ --transcript transcript/single_speaker/zh_ai.txt \
60
+ --seed 12345 \
61
+ --out_path generation.wav
62
+ ```
63
+
64
+ ### Describe speaker characteristics with text
65
+
66
+ The model allows you to describe the speaker via text. See [voice_prompts/profile.yaml](voice_prompts/profile.yaml) for examples. You can run the following two examples that try to specify male / female British accent for the speakers. Also, try to remove the `--seed 12345` flag to see how the model is generating different voices.
67
+
68
+ ```bash
69
+ # Male British Accent
70
+ python3 generation.py \
71
+ --transcript transcript/single_speaker/en_dl.txt \
72
+ --ref_audio profile:male_en_british \
73
+ --seed 12345 \
74
+ --out_path generation.wav
75
+
76
+ # Female British Accent
77
+ python3 generation.py \
78
+ --transcript transcript/single_speaker/en_dl.txt \
79
+ --ref_audio profile:female_en_british \
80
+ --seed 12345 \
81
+ --out_path generation.wav
82
+ ```
83
+
84
+ ### Chunking for long-form audio generation
85
+
86
+ To generate long-form audios, you can chunk the text and render each chunk one by one while putting the previous generated audio and the reference audio in the prompt. Here's an example that generates the first five paragraphs of Higgs Audio v1 release blog. See [text](./transcript/single_speaker/en_higgs_audio_blog.md).
87
+
88
+ ```bash
89
+ python3 generation.py \
90
+ --scene_prompt scene_prompts/reading_blog.txt \
91
+ --transcript transcript/single_speaker/en_higgs_audio_blog.md \
92
+ --ref_audio en_man \
93
+ --chunk_method word \
94
+ --temperature 0.3 \
95
+ --generation_chunk_buffer_size 2 \
96
+ --seed 12345 \
97
+ --out_path generation.wav
98
+ ```
99
+
100
+ ### Experimental and Emergent Capabilities
101
+
102
+ As shown in our demo, the pretrained model is demonstrating emergent features. We prepared some samples to help you explore these experimental prompts. We will enhance the stability of these experimental prompts in the future version of HiggsAudio.
103
+
104
+ #### (Experimental) Hum a tune with the cloned voice
105
+ The model is able to hum a tune with the cloned voice.
106
+
107
+ ```bash
108
+ python3 generation.py \
109
+ --transcript transcript/single_speaker/experimental/en_humming.txt \
110
+ --ref_audio en_woman \
111
+ --ras_win_len 0 \
112
+ --seed 12345 \
113
+ --out_path generation.wav
114
+ ```
115
+
116
+ #### (Experimental) Read the sentence while adding background music (BGM)
117
+
118
+ ```bash
119
+ python3 generation.py \
120
+ --transcript transcript/single_speaker/experimental/en_bgm.txt \
121
+ --ref_audio en_woman \
122
+ --ras_win_len 0 \
123
+ --ref_audio_in_system_message \
124
+ --seed 123456 \
125
+ --out_path generation.wav
126
+ ```
127
+
128
+ ## Multi-speaker Audio Generation
129
+
130
+
131
+ ### Smart voice
132
+
133
+ To get started to explore HiggsAudio's capability in generating multi-speaker audios. Let's try to generate a multi-speaker dialog from transcript in the zero-shot fashion. See the transcript in [transcript/multi_speaker/en_argument.txt](transcript/multi_speaker/en_argument.txt). The speakers are annotated with `[SPEAKER0]` and `[SPEAKER1]`.
134
+
135
+ ```bash
136
+ python3 generation.py \
137
+ --transcript transcript/multi_speaker/en_argument.txt \
138
+ --seed 12345 \
139
+ --out_path generation.wav
140
+ ```
141
+
142
+ ### Multi-voice clone
143
+ You can also try to clone the voices from multiple people simultaneously and generate audio about the transcript. Here's an example that puts reference audios in the system message and prompt the model iteratively. You can hear "Belinda" arguing with "Broom Salesman".
144
+
145
+ ```bash
146
+ python3 generation.py \
147
+ --transcript transcript/multi_speaker/en_argument.txt \
148
+ --ref_audio belinda,broom_salesman \
149
+ --ref_audio_in_system_message \
150
+ --chunk_method speaker \
151
+ --seed 12345 \
152
+ --out_path generation.wav
153
+ ```
154
+
155
+ You can also let "Broom Salesman" talking to "Belinda", who recently trained HiggsAudio.
156
+
157
+ ```bash
158
+ python3 generation.py \
159
+ --transcript transcript/multi_speaker/en_higgs.txt \
160
+ --ref_audio broom_salesman,belinda \
161
+ --ref_audio_in_system_message \
162
+ --chunk_method speaker \
163
+ --chunk_max_num_turns 2 \
164
+ --seed 12345 \
165
+ --out_path generation.wav
166
+ ```
examples/generation.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Example script for generating audio using HiggsAudio."""
2
+
3
+ import click
4
+ import soundfile as sf
5
+ import langid
6
+ import jieba
7
+ import os
8
+ import re
9
+ import copy
10
+ import torchaudio
11
+ import tqdm
12
+ import yaml
13
+
14
+ from loguru import logger
15
+ from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse
16
+ from boson_multimodal.data_types import Message, ChatMLSample, AudioContent, TextContent
17
+
18
+ from boson_multimodal.model.higgs_audio import HiggsAudioConfig, HiggsAudioModel
19
+ from boson_multimodal.data_collator.higgs_audio_collator import HiggsAudioSampleCollator
20
+ from boson_multimodal.audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer
21
+ from boson_multimodal.dataset.chatml_dataset import (
22
+ ChatMLDatasetSample,
23
+ prepare_chatml_sample,
24
+ )
25
+ from boson_multimodal.model.higgs_audio.utils import revert_delay_pattern
26
+ from typing import List
27
+ from transformers import AutoConfig, AutoTokenizer
28
+ from transformers.cache_utils import StaticCache
29
+ from typing import Optional
30
+ from dataclasses import asdict
31
+ import torch
32
+
33
+ CURR_DIR = os.path.dirname(os.path.abspath(__file__))
34
+
35
+
36
+ AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>"
37
+
38
+
39
+ MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech.
40
+ If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.
41
+ If no speaker tag is present, select a suitable voice on your own."""
42
+
43
+
44
+ def normalize_chinese_punctuation(text):
45
+ """
46
+ Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
47
+ """
48
+ # Mapping of Chinese punctuation to English punctuation
49
+ chinese_to_english_punct = {
50
+ ",": ", ", # comma
51
+ "。": ".", # period
52
+ ":": ":", # colon
53
+ ";": ";", # semicolon
54
+ "?": "?", # question mark
55
+ "!": "!", # exclamation mark
56
+ "(": "(", # left parenthesis
57
+ ")": ")", # right parenthesis
58
+ "【": "[", # left square bracket
59
+ "】": "]", # right square bracket
60
+ "《": "<", # left angle quote
61
+ "》": ">", # right angle quote
62
+ "“": '"', # left double quotation
63
+ "”": '"', # right double quotation
64
+ "‘": "'", # left single quotation
65
+ "’": "'", # right single quotation
66
+ "、": ",", # enumeration comma
67
+ "—": "-", # em dash
68
+ "…": "...", # ellipsis
69
+ "·": ".", # middle dot
70
+ "「": '"', # left corner bracket
71
+ "」": '"', # right corner bracket
72
+ "『": '"', # left double corner bracket
73
+ "』": '"', # right double corner bracket
74
+ }
75
+
76
+ # Replace each Chinese punctuation with its English counterpart
77
+ for zh_punct, en_punct in chinese_to_english_punct.items():
78
+ text = text.replace(zh_punct, en_punct)
79
+
80
+ return text
81
+
82
+
83
+ def prepare_chunk_text(
84
+ text, chunk_method: Optional[str] = None, chunk_max_word_num: int = 100, chunk_max_num_turns: int = 1
85
+ ):
86
+ """Chunk the text into smaller pieces. We will later feed the chunks one by one to the model.
87
+
88
+ Parameters
89
+ ----------
90
+ text : str
91
+ The text to be chunked.
92
+ chunk_method : str, optional
93
+ The method to use for chunking. Options are "speaker", "word", or None. By default, we won't use any chunking and
94
+ will feed the whole text to the model.
95
+ replace_speaker_tag_with_special_tags : bool, optional
96
+ Whether to replace speaker tags with special tokens, by default False
97
+ If the flag is set to True, we will replace [SPEAKER0] with <|speaker_id_start|>SPEAKER0<|speaker_id_end|>
98
+ chunk_max_word_num : int, optional
99
+ The maximum number of words for each chunk when "word" chunking method is used, by default 100
100
+ chunk_max_num_turns : int, optional
101
+ The maximum number of turns for each chunk when "speaker" chunking method is used,
102
+
103
+ Returns
104
+ -------
105
+ List[str]
106
+ The list of text chunks.
107
+
108
+ """
109
+ if chunk_method is None:
110
+ return [text]
111
+ elif chunk_method == "speaker":
112
+ lines = text.split("\n")
113
+ speaker_chunks = []
114
+ speaker_utterance = ""
115
+ for line in lines:
116
+ line = line.strip()
117
+ if line.startswith("[SPEAKER") or line.startswith("<|speaker_id_start|>"):
118
+ if speaker_utterance:
119
+ speaker_chunks.append(speaker_utterance.strip())
120
+ speaker_utterance = line
121
+ else:
122
+ if speaker_utterance:
123
+ speaker_utterance += "\n" + line
124
+ else:
125
+ speaker_utterance = line
126
+ if speaker_utterance:
127
+ speaker_chunks.append(speaker_utterance.strip())
128
+ if chunk_max_num_turns > 1:
129
+ merged_chunks = []
130
+ for i in range(0, len(speaker_chunks), chunk_max_num_turns):
131
+ merged_chunk = "\n".join(speaker_chunks[i : i + chunk_max_num_turns])
132
+ merged_chunks.append(merged_chunk)
133
+ return merged_chunks
134
+ return speaker_chunks
135
+ elif chunk_method == "word":
136
+ # TODO: We may improve the logic in the future
137
+ # For long-form generation, we will first divide the text into multiple paragraphs by splitting with "\n\n"
138
+ # After that, we will chunk each paragraph based on word count
139
+ language = langid.classify(text)[0]
140
+ paragraphs = text.split("\n\n")
141
+ chunks = []
142
+ for idx, paragraph in enumerate(paragraphs):
143
+ if language == "zh":
144
+ # For Chinese, we will chunk based on character count
145
+ words = list(jieba.cut(paragraph, cut_all=False))
146
+ for i in range(0, len(words), chunk_max_word_num):
147
+ chunk = "".join(words[i : i + chunk_max_word_num])
148
+ chunks.append(chunk)
149
+ else:
150
+ words = paragraph.split(" ")
151
+ for i in range(0, len(words), chunk_max_word_num):
152
+ chunk = " ".join(words[i : i + chunk_max_word_num])
153
+ chunks.append(chunk)
154
+ chunks[-1] += "\n\n"
155
+ return chunks
156
+ else:
157
+ raise ValueError(f"Unknown chunk method: {chunk_method}")
158
+
159
+
160
+ def _build_system_message_with_audio_prompt(system_message):
161
+ contents = []
162
+
163
+ while AUDIO_PLACEHOLDER_TOKEN in system_message:
164
+ loc = system_message.find(AUDIO_PLACEHOLDER_TOKEN)
165
+ contents.append(TextContent(system_message[:loc]))
166
+ contents.append(AudioContent(audio_url=""))
167
+ system_message = system_message[loc + len(AUDIO_PLACEHOLDER_TOKEN) :]
168
+
169
+ if len(system_message) > 0:
170
+ contents.append(TextContent(system_message))
171
+ ret = Message(
172
+ role="system",
173
+ content=contents,
174
+ )
175
+ return ret
176
+
177
+
178
+ class HiggsAudioModelClient:
179
+ def __init__(
180
+ self,
181
+ model_path,
182
+ audio_tokenizer,
183
+ device=None,
184
+ device_id=None,
185
+ max_new_tokens=2048,
186
+ kv_cache_lengths: List[int] = [1024, 4096, 8192], # Multiple KV cache sizes,
187
+ use_static_kv_cache=False,
188
+ ):
189
+ # Use explicit device if provided, otherwise try CUDA/MPS/CPU
190
+ if device_id is not None:
191
+ device = f"cuda:{device_id}"
192
+ self._device = device
193
+ else:
194
+ if device is not None:
195
+ self._device = device
196
+ else: # We get to choose the device
197
+ # Prefer CUDA over MPS (Apple Silicon GPU) over CPU if available
198
+ if torch.cuda.is_available():
199
+ self._device = "cuda:0"
200
+ elif torch.backends.mps.is_available():
201
+ self._device = "mps"
202
+ else:
203
+ self._device = "cpu"
204
+
205
+ logger.info(f"Using device: {self._device}")
206
+ if isinstance(audio_tokenizer, str):
207
+ # For MPS, use CPU due to embedding operation limitations in quantization layers
208
+ audio_tokenizer_device = "cpu" if self._device == "mps" else self._device
209
+ self._audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer, device=audio_tokenizer_device)
210
+ else:
211
+ self._audio_tokenizer = audio_tokenizer
212
+
213
+ self._model = HiggsAudioModel.from_pretrained(
214
+ model_path,
215
+ device_map=self._device,
216
+ torch_dtype=torch.bfloat16,
217
+ )
218
+ self._model.eval()
219
+ self._kv_cache_lengths = kv_cache_lengths
220
+ self._use_static_kv_cache = use_static_kv_cache
221
+
222
+ self._tokenizer = AutoTokenizer.from_pretrained(model_path)
223
+ self._config = AutoConfig.from_pretrained(model_path)
224
+ self._max_new_tokens = max_new_tokens
225
+ self._collator = HiggsAudioSampleCollator(
226
+ whisper_processor=None,
227
+ audio_in_token_id=self._config.audio_in_token_idx,
228
+ audio_out_token_id=self._config.audio_out_token_idx,
229
+ audio_stream_bos_id=self._config.audio_stream_bos_id,
230
+ audio_stream_eos_id=self._config.audio_stream_eos_id,
231
+ encode_whisper_embed=self._config.encode_whisper_embed,
232
+ pad_token_id=self._config.pad_token_id,
233
+ return_audio_in_tokens=self._config.encode_audio_in_tokens,
234
+ use_delay_pattern=self._config.use_delay_pattern,
235
+ round_to=1,
236
+ audio_num_codebooks=self._config.audio_num_codebooks,
237
+ )
238
+ self.kv_caches = None
239
+ if use_static_kv_cache:
240
+ self._init_static_kv_cache()
241
+
242
+ def _init_static_kv_cache(self):
243
+ cache_config = copy.deepcopy(self._model.config.text_config)
244
+ cache_config.num_hidden_layers = self._model.config.text_config.num_hidden_layers
245
+ if self._model.config.audio_dual_ffn_layers:
246
+ cache_config.num_hidden_layers += len(self._model.config.audio_dual_ffn_layers)
247
+ # A list of KV caches for different lengths
248
+ self.kv_caches = {
249
+ length: StaticCache(
250
+ config=cache_config,
251
+ max_batch_size=1,
252
+ max_cache_len=length,
253
+ device=self._model.device,
254
+ dtype=self._model.dtype,
255
+ )
256
+ for length in sorted(self._kv_cache_lengths)
257
+ }
258
+ # Capture CUDA graphs for each KV cache length
259
+ if "cuda" in self._device:
260
+ logger.info(f"Capturing CUDA graphs for each KV cache length")
261
+ self._model.capture_model(self.kv_caches.values())
262
+
263
+ def _prepare_kv_caches(self):
264
+ for kv_cache in self.kv_caches.values():
265
+ kv_cache.reset()
266
+
267
+ @torch.inference_mode()
268
+ def generate(
269
+ self,
270
+ messages,
271
+ audio_ids,
272
+ chunked_text,
273
+ generation_chunk_buffer_size,
274
+ temperature=1.0,
275
+ top_k=50,
276
+ top_p=0.95,
277
+ ras_win_len=7,
278
+ ras_win_max_num_repeat=2,
279
+ seed=123,
280
+ *args,
281
+ **kwargs,
282
+ ):
283
+ if ras_win_len is not None and ras_win_len <= 0:
284
+ ras_win_len = None
285
+ sr = 24000
286
+ audio_out_ids_l = []
287
+ generated_audio_ids = []
288
+ generation_messages = []
289
+ for idx, chunk_text in tqdm.tqdm(
290
+ enumerate(chunked_text), desc="Generating audio chunks", total=len(chunked_text)
291
+ ):
292
+ generation_messages.append(
293
+ Message(
294
+ role="user",
295
+ content=chunk_text,
296
+ )
297
+ )
298
+ chatml_sample = ChatMLSample(messages=messages + generation_messages)
299
+ input_tokens, _, _, _ = prepare_chatml_sample(chatml_sample, self._tokenizer)
300
+ postfix = self._tokenizer.encode(
301
+ "<|start_header_id|>assistant<|end_header_id|>\n\n", add_special_tokens=False
302
+ )
303
+ input_tokens.extend(postfix)
304
+
305
+ logger.info(f"========= Chunk {idx} Input =========")
306
+ logger.info(self._tokenizer.decode(input_tokens))
307
+ context_audio_ids = audio_ids + generated_audio_ids
308
+
309
+ curr_sample = ChatMLDatasetSample(
310
+ input_ids=torch.LongTensor(input_tokens),
311
+ label_ids=None,
312
+ audio_ids_concat=torch.concat([ele.cpu() for ele in context_audio_ids], dim=1)
313
+ if context_audio_ids
314
+ else None,
315
+ audio_ids_start=torch.cumsum(
316
+ torch.tensor([0] + [ele.shape[1] for ele in context_audio_ids], dtype=torch.long), dim=0
317
+ )
318
+ if context_audio_ids
319
+ else None,
320
+ audio_waveforms_concat=None,
321
+ audio_waveforms_start=None,
322
+ audio_sample_rate=None,
323
+ audio_speaker_indices=None,
324
+ )
325
+
326
+ batch_data = self._collator([curr_sample])
327
+ batch = asdict(batch_data)
328
+ for k, v in batch.items():
329
+ if isinstance(v, torch.Tensor):
330
+ batch[k] = v.contiguous().to(self._device)
331
+
332
+ if self._use_static_kv_cache:
333
+ self._prepare_kv_caches()
334
+
335
+ # Generate audio
336
+ outputs = self._model.generate(
337
+ **batch,
338
+ max_new_tokens=self._max_new_tokens,
339
+ use_cache=True,
340
+ do_sample=True,
341
+ temperature=temperature,
342
+ top_k=top_k,
343
+ top_p=top_p,
344
+ past_key_values_buckets=self.kv_caches,
345
+ ras_win_len=ras_win_len,
346
+ ras_win_max_num_repeat=ras_win_max_num_repeat,
347
+ stop_strings=["<|end_of_text|>", "<|eot_id|>"],
348
+ tokenizer=self._tokenizer,
349
+ seed=seed,
350
+ )
351
+
352
+ step_audio_out_ids_l = []
353
+ for ele in outputs[1]:
354
+ audio_out_ids = ele
355
+ if self._config.use_delay_pattern:
356
+ audio_out_ids = revert_delay_pattern(audio_out_ids)
357
+ step_audio_out_ids_l.append(audio_out_ids.clip(0, self._audio_tokenizer.codebook_size - 1)[:, 1:-1])
358
+ audio_out_ids = torch.concat(step_audio_out_ids_l, dim=1)
359
+ audio_out_ids_l.append(audio_out_ids)
360
+ generated_audio_ids.append(audio_out_ids)
361
+
362
+ generation_messages.append(
363
+ Message(
364
+ role="assistant",
365
+ content=AudioContent(audio_url=""),
366
+ )
367
+ )
368
+ if generation_chunk_buffer_size is not None and len(generated_audio_ids) > generation_chunk_buffer_size:
369
+ generated_audio_ids = generated_audio_ids[-generation_chunk_buffer_size:]
370
+ generation_messages = generation_messages[(-2 * generation_chunk_buffer_size) :]
371
+
372
+ logger.info(f"========= Final Text output =========")
373
+ logger.info(self._tokenizer.decode(outputs[0][0]))
374
+ concat_audio_out_ids = torch.concat(audio_out_ids_l, dim=1)
375
+
376
+ # Fix MPS compatibility: detach and move to CPU before decoding
377
+ if concat_audio_out_ids.device.type in ["mps", "cuda"]:
378
+ concat_audio_out_ids_cpu = concat_audio_out_ids.detach().cpu()
379
+ else:
380
+ concat_audio_out_ids_cpu = concat_audio_out_ids
381
+
382
+ concat_wv = self._audio_tokenizer.decode(concat_audio_out_ids_cpu.unsqueeze(0))[0, 0]
383
+ text_result = self._tokenizer.decode(outputs[0][0])
384
+ return concat_wv, sr, text_result
385
+
386
+
387
+ def prepare_generation_context(scene_prompt, ref_audio, ref_audio_in_system_message, audio_tokenizer, speaker_tags):
388
+ """Prepare the context for generation.
389
+
390
+ The context contains the system message, user message, assistant message, and audio prompt if any.
391
+ """
392
+ system_message = None
393
+ messages = []
394
+ audio_ids = []
395
+ if ref_audio is not None:
396
+ num_speakers = len(ref_audio.split(","))
397
+ speaker_info_l = ref_audio.split(",")
398
+ voice_profile = None
399
+ if any([speaker_info.startswith("profile:") for speaker_info in ref_audio.split(",")]):
400
+ ref_audio_in_system_message = True
401
+ if ref_audio_in_system_message:
402
+ speaker_desc = []
403
+ for spk_id, character_name in enumerate(speaker_info_l):
404
+ if character_name.startswith("profile:"):
405
+ if voice_profile is None:
406
+ with open(f"{CURR_DIR}/voice_prompts/profile.yaml", "r", encoding="utf-8") as f:
407
+ voice_profile = yaml.safe_load(f)
408
+ character_desc = voice_profile["profiles"][character_name[len("profile:") :].strip()]
409
+ speaker_desc.append(f"SPEAKER{spk_id}: {character_desc}")
410
+ else:
411
+ speaker_desc.append(f"SPEAKER{spk_id}: {AUDIO_PLACEHOLDER_TOKEN}")
412
+ if scene_prompt:
413
+ system_message = (
414
+ "Generate audio following instruction."
415
+ "\n\n"
416
+ f"<|scene_desc_start|>\n{scene_prompt}\n\n" + "\n".join(speaker_desc) + "\n<|scene_desc_end|>"
417
+ )
418
+ else:
419
+ system_message = (
420
+ "Generate audio following instruction.\n\n"
421
+ + f"<|scene_desc_start|>\n"
422
+ + "\n".join(speaker_desc)
423
+ + "\n<|scene_desc_end|>"
424
+ )
425
+ system_message = _build_system_message_with_audio_prompt(system_message)
426
+ else:
427
+ if scene_prompt:
428
+ system_message = Message(
429
+ role="system",
430
+ content=f"Generate audio following instruction.\n\n<|scene_desc_start|>\n{scene_prompt}\n<|scene_desc_end|>",
431
+ )
432
+ voice_profile = None
433
+ for spk_id, character_name in enumerate(ref_audio.split(",")):
434
+ if not character_name.startswith("profile:"):
435
+ prompt_audio_path = os.path.join(f"{CURR_DIR}/voice_prompts", f"{character_name}.wav")
436
+ prompt_text_path = os.path.join(f"{CURR_DIR}/voice_prompts", f"{character_name}.txt")
437
+ assert os.path.exists(prompt_audio_path), (
438
+ f"Voice prompt audio file {prompt_audio_path} does not exist."
439
+ )
440
+ assert os.path.exists(prompt_text_path), f"Voice prompt text file {prompt_text_path} does not exist."
441
+ with open(prompt_text_path, "r", encoding="utf-8") as f:
442
+ prompt_text = f.read().strip()
443
+ audio_tokens = audio_tokenizer.encode(prompt_audio_path)
444
+ audio_ids.append(audio_tokens)
445
+
446
+ if not ref_audio_in_system_message:
447
+ messages.append(
448
+ Message(
449
+ role="user",
450
+ content=f"[SPEAKER{spk_id}] {prompt_text}" if num_speakers > 1 else prompt_text,
451
+ )
452
+ )
453
+ messages.append(
454
+ Message(
455
+ role="assistant",
456
+ content=AudioContent(
457
+ audio_url=prompt_audio_path,
458
+ ),
459
+ )
460
+ )
461
+ else:
462
+ if len(speaker_tags) > 1:
463
+ # By default, we just alternate between male and female voices
464
+ speaker_desc_l = []
465
+
466
+ for idx, tag in enumerate(speaker_tags):
467
+ if idx % 2 == 0:
468
+ speaker_desc = f"feminine"
469
+ else:
470
+ speaker_desc = f"masculine"
471
+ speaker_desc_l.append(f"{tag}: {speaker_desc}")
472
+
473
+ speaker_desc = "\n".join(speaker_desc_l)
474
+ scene_desc_l = []
475
+ if scene_prompt:
476
+ scene_desc_l.append(scene_prompt)
477
+ scene_desc_l.append(speaker_desc)
478
+ scene_desc = "\n\n".join(scene_desc_l)
479
+
480
+ system_message = Message(
481
+ role="system",
482
+ content=f"{MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE}\n\n<|scene_desc_start|>\n{scene_desc}\n<|scene_desc_end|>",
483
+ )
484
+ else:
485
+ system_message_l = ["Generate audio following instruction."]
486
+ if scene_prompt:
487
+ system_message_l.append(f"<|scene_desc_start|>\n{scene_prompt}\n<|scene_desc_end|>")
488
+ system_message = Message(
489
+ role="system",
490
+ content="\n\n".join(system_message_l),
491
+ )
492
+ if system_message:
493
+ messages.insert(0, system_message)
494
+ return messages, audio_ids
495
+
496
+
497
+ @click.command()
498
+ @click.option(
499
+ "--model_path",
500
+ type=str,
501
+ default="bosonai/higgs-audio-v2-generation-3B-base",
502
+ help="Output wav file path.",
503
+ )
504
+ @click.option(
505
+ "--audio_tokenizer",
506
+ type=str,
507
+ default="bosonai/higgs-audio-v2-tokenizer",
508
+ help="Audio tokenizer path, if not set, use the default one.",
509
+ )
510
+ @click.option(
511
+ "--max_new_tokens",
512
+ type=int,
513
+ default=2048,
514
+ help="The maximum number of new tokens to generate.",
515
+ )
516
+ @click.option(
517
+ "--transcript",
518
+ type=str,
519
+ default="transcript/single_speaker/en_dl.txt",
520
+ help="The prompt to use for generation. If not set, we will use a default prompt.",
521
+ )
522
+ @click.option(
523
+ "--scene_prompt",
524
+ type=str,
525
+ default=f"{CURR_DIR}/scene_prompts/quiet_indoor.txt",
526
+ help="The scene description prompt to use for generation. If not set, or set to `empty`, we will leave it to empty.",
527
+ )
528
+ @click.option(
529
+ "--temperature",
530
+ type=float,
531
+ default=1.0,
532
+ help="The value used to module the next token probabilities.",
533
+ )
534
+ @click.option(
535
+ "--top_k",
536
+ type=int,
537
+ default=50,
538
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
539
+ )
540
+ @click.option(
541
+ "--top_p",
542
+ type=float,
543
+ default=0.95,
544
+ help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
545
+ )
546
+ @click.option(
547
+ "--ras_win_len",
548
+ type=int,
549
+ default=7,
550
+ help="The window length for RAS sampling. If set to 0 or a negative value, we won't use RAS sampling.",
551
+ )
552
+ @click.option(
553
+ "--ras_win_max_num_repeat",
554
+ type=int,
555
+ default=2,
556
+ help="The maximum number of times to repeat the RAS window. Only used when --ras_win_len is set.",
557
+ )
558
+ @click.option(
559
+ "--ref_audio",
560
+ type=str,
561
+ default=None,
562
+ help="The voice prompt to use for generation. If not set, we will let the model randomly pick a voice. "
563
+ "For multi-speaker generation, you can specify the prompts as `belinda,chadwick` and we will use the voice of belinda as SPEAKER0 and the voice of chadwick as SPEAKER1.",
564
+ )
565
+ @click.option(
566
+ "--ref_audio_in_system_message",
567
+ is_flag=True,
568
+ default=False,
569
+ help="Whether to include the voice prompt description in the system message.",
570
+ show_default=True,
571
+ )
572
+ @click.option(
573
+ "--chunk_method",
574
+ default=None,
575
+ type=click.Choice([None, "speaker", "word"]),
576
+ help="The method to use for chunking the prompt text. Options are 'speaker', 'word', or None. By default, we won't use any chunking and will feed the whole text to the model.",
577
+ )
578
+ @click.option(
579
+ "--chunk_max_word_num",
580
+ default=200,
581
+ type=int,
582
+ help="The maximum number of words for each chunk when 'word' chunking method is used. Only used when --chunk_method is set to 'word'.",
583
+ )
584
+ @click.option(
585
+ "--chunk_max_num_turns",
586
+ default=1,
587
+ type=int,
588
+ help="The maximum number of turns for each chunk when 'speaker' chunking method is used. Only used when --chunk_method is set to 'speaker'.",
589
+ )
590
+ @click.option(
591
+ "--generation_chunk_buffer_size",
592
+ default=None,
593
+ type=int,
594
+ help="The maximal number of chunks to keep in the buffer. We will always keep the reference audios, and keep `max_chunk_buffer` chunks of generated audio.",
595
+ )
596
+ @click.option(
597
+ "--seed",
598
+ default=None,
599
+ type=int,
600
+ help="Random seed for generation.",
601
+ )
602
+ @click.option(
603
+ "--device_id",
604
+ type=int,
605
+ default=None,
606
+ help="The device to run the model on.",
607
+ )
608
+ @click.option(
609
+ "--out_path",
610
+ type=str,
611
+ default="generation.wav",
612
+ )
613
+ @click.option(
614
+ "--use_static_kv_cache",
615
+ type=int,
616
+ default=1,
617
+ help="Whether to use static KV cache for faster generation. Only works when using GPU.",
618
+ )
619
+ @click.option(
620
+ "--device",
621
+ type=click.Choice(["auto", "cuda", "mps", "none"]),
622
+ default="auto",
623
+ help="Device to use: 'auto' (pick best available), 'cuda', 'mps', or 'none' (CPU only).",
624
+ )
625
+ def main(
626
+ model_path,
627
+ audio_tokenizer,
628
+ max_new_tokens,
629
+ transcript,
630
+ scene_prompt,
631
+ temperature,
632
+ top_k,
633
+ top_p,
634
+ ras_win_len,
635
+ ras_win_max_num_repeat,
636
+ ref_audio,
637
+ ref_audio_in_system_message,
638
+ chunk_method,
639
+ chunk_max_word_num,
640
+ chunk_max_num_turns,
641
+ generation_chunk_buffer_size,
642
+ seed,
643
+ device_id,
644
+ out_path,
645
+ use_static_kv_cache,
646
+ device,
647
+ ):
648
+ # specifying a device_id implies CUDA
649
+ if device_id is None:
650
+ if device == "auto":
651
+ if torch.cuda.is_available():
652
+ device_id = 0
653
+ device = "cuda:0"
654
+ elif torch.backends.mps.is_available():
655
+ device_id = None # MPS doesn't use device IDs like CUDA
656
+ device = "mps"
657
+ else:
658
+ device_id = None
659
+ device = "cpu"
660
+ elif device == "cuda":
661
+ device_id = 0
662
+ device = "cuda:0"
663
+ elif device == "mps":
664
+ device_id = None
665
+ device = "mps"
666
+ else:
667
+ device_id = None
668
+ device = "cpu"
669
+ else:
670
+ device = f"cuda:{device_id}"
671
+ # For MPS, use CPU for audio tokenizer due to embedding operation limitations
672
+ audio_tokenizer_device = "cpu" if device == "mps" else device
673
+ audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer, device=audio_tokenizer_device)
674
+
675
+ # Disable static KV cache on MPS since it relies on CUDA graphs
676
+ if device == "mps" and use_static_kv_cache:
677
+ use_static_kv_cache = False
678
+ model_client = HiggsAudioModelClient(
679
+ model_path=model_path,
680
+ audio_tokenizer=audio_tokenizer,
681
+ device=device,
682
+ device_id=device_id,
683
+ max_new_tokens=max_new_tokens,
684
+ use_static_kv_cache=use_static_kv_cache,
685
+ )
686
+
687
+ pattern = re.compile(r"\[(SPEAKER\d+)\]")
688
+
689
+ if os.path.exists(transcript):
690
+ logger.info(f"Loading transcript from {transcript}")
691
+ with open(transcript, "r", encoding="utf-8") as f:
692
+ transcript = f.read().strip()
693
+
694
+ if scene_prompt is not None and scene_prompt != "empty" and os.path.exists(scene_prompt):
695
+ with open(scene_prompt, "r", encoding="utf-8") as f:
696
+ scene_prompt = f.read().strip()
697
+ else:
698
+ scene_prompt = None
699
+
700
+ speaker_tags = sorted(set(pattern.findall(transcript)))
701
+ # Perform some basic normalization
702
+ transcript = normalize_chinese_punctuation(transcript)
703
+ # Other normalizations (e.g., parentheses and other symbols. Will be improved in the future)
704
+ transcript = transcript.replace("(", " ")
705
+ transcript = transcript.replace(")", " ")
706
+ transcript = transcript.replace("°F", " degrees Fahrenheit")
707
+ transcript = transcript.replace("°C", " degrees Celsius")
708
+
709
+ for tag, replacement in [
710
+ ("[laugh]", "<SE>[Laughter]</SE>"),
711
+ ("[humming start]", "<SE>[Humming]</SE>"),
712
+ ("[humming end]", "<SE_e>[Humming]</SE_e>"),
713
+ ("[music start]", "<SE_s>[Music]</SE_s>"),
714
+ ("[music end]", "<SE_e>[Music]</SE_e>"),
715
+ ("[music]", "<SE>[Music]</SE>"),
716
+ ("[sing start]", "<SE_s>[Singing]</SE_s>"),
717
+ ("[sing end]", "<SE_e>[Singing]</SE_e>"),
718
+ ("[applause]", "<SE>[Applause]</SE>"),
719
+ ("[cheering]", "<SE>[Cheering]</SE>"),
720
+ ("[cough]", "<SE>[Cough]</SE>"),
721
+ ]:
722
+ transcript = transcript.replace(tag, replacement)
723
+ lines = transcript.split("\n")
724
+ transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()])
725
+ transcript = transcript.strip()
726
+
727
+ if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "</SE_e>", "</SE>"]]):
728
+ transcript += "."
729
+
730
+ messages, audio_ids = prepare_generation_context(
731
+ scene_prompt=scene_prompt,
732
+ ref_audio=ref_audio,
733
+ ref_audio_in_system_message=ref_audio_in_system_message,
734
+ audio_tokenizer=audio_tokenizer,
735
+ speaker_tags=speaker_tags,
736
+ )
737
+ chunked_text = prepare_chunk_text(
738
+ transcript,
739
+ chunk_method=chunk_method,
740
+ chunk_max_word_num=chunk_max_word_num,
741
+ chunk_max_num_turns=chunk_max_num_turns,
742
+ )
743
+
744
+ logger.info("Chunks used for generation:")
745
+ for idx, chunk_text in enumerate(chunked_text):
746
+ logger.info(f"Chunk {idx}:")
747
+ logger.info(chunk_text)
748
+ logger.info("-----")
749
+
750
+ concat_wv, sr, text_output = model_client.generate(
751
+ messages=messages,
752
+ audio_ids=audio_ids,
753
+ chunked_text=chunked_text,
754
+ generation_chunk_buffer_size=generation_chunk_buffer_size,
755
+ temperature=temperature,
756
+ top_k=top_k,
757
+ top_p=top_p,
758
+ ras_win_len=ras_win_len,
759
+ ras_win_max_num_repeat=ras_win_max_num_repeat,
760
+ seed=seed,
761
+ )
762
+
763
+ sf.write(out_path, concat_wv, sr)
764
+ logger.info(f"Wav file is saved to '{out_path}' with sample rate {sr}")
765
+
766
+
767
+ if __name__ == "__main__":
768
+ main()
examples/interactive_generation.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Interactive script for generating audio using HiggsAudio with single model load."""
3
+
4
+ import click
5
+ import soundfile as sf
6
+ import langid
7
+ import jieba
8
+ import os
9
+ import re
10
+ import copy
11
+ import torchaudio
12
+ import tqdm
13
+ import yaml
14
+
15
+ from loguru import logger
16
+ from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse
17
+ from boson_multimodal.data_types import Message, ChatMLSample, AudioContent, TextContent
18
+
19
+ from boson_multimodal.model.higgs_audio import HiggsAudioConfig, HiggsAudioModel
20
+ from boson_multimodal.data_collator.higgs_audio_collator import HiggsAudioSampleCollator
21
+ from boson_multimodal.audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer
22
+ from boson_multimodal.dataset.chatml_dataset import (
23
+ ChatMLDatasetSample,
24
+ prepare_chatml_sample,
25
+ )
26
+ from boson_multimodal.model.higgs_audio.utils import revert_delay_pattern
27
+ from typing import List
28
+ from transformers import AutoConfig, AutoTokenizer
29
+ from transformers.cache_utils import StaticCache
30
+ from typing import Optional
31
+ from dataclasses import asdict
32
+ import torch
33
+
34
+ CURR_DIR = os.path.dirname(os.path.abspath(__file__))
35
+
36
+ AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>"
37
+
38
+ MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech.
39
+ If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.
40
+ If no speaker tag is present, select a suitable voice on your own."""
41
+
42
+ def normalize_chinese_punctuation(text):
43
+ """
44
+ Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
45
+ """
46
+ chinese_to_english_punct = {
47
+ ",": ", ", # comma
48
+ "。": ".", # period
49
+ ":": ":", # colon
50
+ ";": ";", # semicolon
51
+ "?": "?", # question mark
52
+ "!": "!", # exclamation mark
53
+ "(": "(", # left parenthesis
54
+ ")": ")", # right parenthesis
55
+ "【": "[", # left square bracket
56
+ "】": "]", # right square bracket
57
+ "《": "<", # left angle quote
58
+ "》": ">", # right angle quote
59
+ "“": '"', # left double quotation
60
+ "”": '"', # right double quotation
61
+ "‘": "'", # left single quotation
62
+ "’": "'", # right single quotation
63
+ "、": ",", # enumeration comma
64
+ "—": "-", # em dash
65
+ "…": "...", # ellipsis
66
+ "·": ".", # middle dot
67
+ "「": '"', # left corner bracket
68
+ "」": '"', # right corner bracket
69
+ "『": '"', # left double corner bracket
70
+ "』": '"', # right double corner bracket
71
+ }
72
+
73
+ for zh_punct, en_punct in chinese_to_english_punct.items():
74
+ text = text.replace(zh_punct, en_punct)
75
+
76
+ return text
77
+
78
+ def prepare_chunk_text(
79
+ text, chunk_method: Optional[str] = None, chunk_max_word_num: int = 100, chunk_max_num_turns: int = 1
80
+ ):
81
+ """Chunk the text into smaller pieces. We will later feed the chunks one by one to the model."""
82
+ if chunk_method is None:
83
+ return [text]
84
+ elif chunk_method == "speaker":
85
+ lines = text.split("\n")
86
+ speaker_chunks = []
87
+ speaker_utterance = ""
88
+ for line in lines:
89
+ line = line.strip()
90
+ if line.startswith("[SPEAKER") or line.startswith("<|speaker_id_start|>"):
91
+ if speaker_utterance:
92
+ speaker_chunks.append(speaker_utterance.strip())
93
+ speaker_utterance = line
94
+ else:
95
+ if speaker_utterance:
96
+ speaker_utterance += "\n" + line
97
+ else:
98
+ speaker_utterance = line
99
+ if speaker_utterance:
100
+ speaker_chunks.append(speaker_utterance.strip())
101
+ if chunk_max_num_turns > 1:
102
+ merged_chunks = []
103
+ for i in range(0, len(speaker_chunks), chunk_max_num_turns):
104
+ merged_chunk = "\n".join(speaker_chunks[i : i + chunk_max_num_turns])
105
+ merged_chunks.append(merged_chunk)
106
+ return merged_chunks
107
+ return speaker_chunks
108
+ elif chunk_method == "word":
109
+ language = langid.classify(text)[0]
110
+ paragraphs = text.split("\n\n")
111
+ chunks = []
112
+ for idx, paragraph in enumerate(paragraphs):
113
+ if language == "zh":
114
+ words = list(jieba.cut(paragraph, cut_all=False))
115
+ for i in range(0, len(words), chunk_max_word_num):
116
+ chunk = "".join(words[i : i + chunk_max_word_num])
117
+ chunks.append(chunk)
118
+ else:
119
+ words = paragraph.split(" ")
120
+ for i in range(0, len(words), chunk_max_word_num):
121
+ chunk = " ".join(words[i : i + chunk_max_word_num])
122
+ chunks.append(chunk)
123
+ chunks[-1] += "\n\n"
124
+ return chunks
125
+ else:
126
+ raise ValueError(f"Unknown chunk method: {chunk_method}")
127
+
128
+ def _build_system_message_with_audio_prompt(system_message):
129
+ contents = []
130
+
131
+ while AUDIO_PLACEHOLDER_TOKEN in system_message:
132
+ loc = system_message.find(AUDIO_PLACEHOLDER_TOKEN)
133
+ contents.append(TextContent(system_message[:loc]))
134
+ contents.append(AudioContent(audio_url=""))
135
+ system_message = system_message[loc + len(AUDIO_PLACEHOLDER_TOKEN) :]
136
+
137
+ if len(system_message) > 0:
138
+ contents.append(TextContent(system_message))
139
+ ret = Message(
140
+ role="system",
141
+ content=contents,
142
+ )
143
+ return ret
144
+
145
+ class HiggsAudioModelClient:
146
+ def __init__(
147
+ self,
148
+ model_path,
149
+ audio_tokenizer,
150
+ device=None,
151
+ device_id=None,
152
+ max_new_tokens=2048,
153
+ kv_cache_lengths: List[int] = [1024, 4096, 8192],
154
+ use_static_kv_cache=False,
155
+ ):
156
+ if device_id is not None:
157
+ device = f"cuda:{device_id}"
158
+ self._device = device
159
+ else:
160
+ if device is not None:
161
+ self._device = device
162
+ else:
163
+ if torch.cuda.is_available():
164
+ self._device = "cuda:0"
165
+ elif torch.backends.mps.is_available():
166
+ self._device = "mps"
167
+ else:
168
+ self._device = "cpu"
169
+
170
+ logger.info(f"Using device: {self._device}")
171
+ if isinstance(audio_tokenizer, str):
172
+ audio_tokenizer_device = "cpu" if self._device == "mps" else self._device
173
+ self._audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer, device=audio_tokenizer_device)
174
+ else:
175
+ self._audio_tokenizer = audio_tokenizer
176
+
177
+ self._model = HiggsAudioModel.from_pretrained(
178
+ model_path,
179
+ device_map=self._device,
180
+ torch_dtype=torch.bfloat16,
181
+ )
182
+ self._model.eval()
183
+ self._kv_cache_lengths = kv_cache_lengths
184
+ self._use_static_kv_cache = use_static_kv_cache
185
+
186
+ self._tokenizer = AutoTokenizer.from_pretrained(model_path)
187
+ self._config = AutoConfig.from_pretrained(model_path)
188
+ self._max_new_tokens = max_new_tokens
189
+ self._collator = HiggsAudioSampleCollator(
190
+ whisper_processor=None,
191
+ audio_in_token_id=self._config.audio_in_token_idx,
192
+ audio_out_token_id=self._config.audio_out_token_idx,
193
+ audio_stream_bos_id=self._config.audio_stream_bos_id,
194
+ audio_stream_eos_id=self._config.audio_stream_eos_id,
195
+ encode_whisper_embed=self._config.encode_whisper_embed,
196
+ pad_token_id=self._config.pad_token_id,
197
+ return_audio_in_tokens=self._config.encode_audio_in_tokens,
198
+ use_delay_pattern=self._config.use_delay_pattern,
199
+ round_to=1,
200
+ audio_num_codebooks=self._config.audio_num_codebooks,
201
+ )
202
+ self.kv_caches = None
203
+ if use_static_kv_cache:
204
+ self._init_static_kv_cache()
205
+
206
+ def _init_static_kv_cache(self):
207
+ cache_config = copy.deepcopy(self._model.config.text_config)
208
+ cache_config.num_hidden_layers = self._model.config.text_config.num_hidden_layers
209
+ if self._model.config.audio_dual_ffn_layers:
210
+ cache_config.num_hidden_layers += len(self._model.config.audio_dual_ffn_layers)
211
+ self.kv_caches = {
212
+ length: StaticCache(
213
+ config=cache_config,
214
+ max_batch_size=1,
215
+ max_cache_len=length,
216
+ device=self._model.device,
217
+ dtype=self._model.dtype,
218
+ )
219
+ for length in sorted(self._kv_cache_lengths)
220
+ }
221
+ if "cuda" in self._device:
222
+ logger.info(f"Capturing CUDA graphs for each KV cache length")
223
+ self._model.capture_model(self.kv_caches.values())
224
+
225
+ def _prepare_kv_caches(self):
226
+ for kv_cache in self.kv_caches.values():
227
+ kv_cache.reset()
228
+
229
+ @torch.inference_mode()
230
+ def generate(
231
+ self,
232
+ messages,
233
+ audio_ids,
234
+ chunked_text,
235
+ generation_chunk_buffer_size,
236
+ temperature=1.0,
237
+ top_k=50,
238
+ top_p=0.95,
239
+ ras_win_len=7,
240
+ ras_win_max_num_repeat=2,
241
+ seed=123,
242
+ *args,
243
+ **kwargs,
244
+ ):
245
+ if ras_win_len is not None and ras_win_len <= 0:
246
+ ras_win_len = None
247
+ sr = 24000
248
+ audio_out_ids_l = []
249
+ generated_audio_ids = []
250
+ generation_messages = []
251
+ for idx, chunk_text in tqdm.tqdm(
252
+ enumerate(chunked_text), desc="Generating audio chunks", total=len(chunked_text)
253
+ ):
254
+ generation_messages.append(
255
+ Message(
256
+ role="user",
257
+ content=chunk_text,
258
+ )
259
+ )
260
+ chatml_sample = ChatMLSample(messages=messages + generation_messages)
261
+ input_tokens, _, _, _ = prepare_chatml_sample(chatml_sample, self._tokenizer)
262
+ postfix = self._tokenizer.encode(
263
+ "<|start_header_id|>assistant<|end_header_id|>\n\n", add_special_tokens=False
264
+ )
265
+ input_tokens.extend(postfix)
266
+
267
+ logger.info(f"========= Chunk {idx} Input =========")
268
+ logger.info(self._tokenizer.decode(input_tokens))
269
+ context_audio_ids = audio_ids + generated_audio_ids
270
+
271
+ curr_sample = ChatMLDatasetSample(
272
+ input_ids=torch.LongTensor(input_tokens),
273
+ label_ids=None,
274
+ audio_ids_concat=torch.concat([ele.cpu() for ele in context_audio_ids], dim=1)
275
+ if context_audio_ids
276
+ else None,
277
+ audio_ids_start=torch.cumsum(
278
+ torch.tensor([0] + [ele.shape[1] for ele in context_audio_ids], dtype=torch.long), dim=0
279
+ )
280
+ if context_audio_ids
281
+ else None,
282
+ audio_waveforms_concat=None,
283
+ audio_waveforms_start=None,
284
+ audio_sample_rate=None,
285
+ audio_speaker_indices=None,
286
+ )
287
+
288
+ batch_data = self._collator([curr_sample])
289
+ batch = asdict(batch_data)
290
+ for k, v in batch.items():
291
+ if isinstance(v, torch.Tensor):
292
+ batch[k] = v.contiguous().to(self._device)
293
+
294
+ if self._use_static_kv_cache:
295
+ self._prepare_kv_caches()
296
+
297
+ outputs = self._model.generate(
298
+ **batch,
299
+ max_new_tokens=self._max_new_tokens,
300
+ use_cache=True,
301
+ do_sample=True,
302
+ temperature=temperature,
303
+ top_k=top_k,
304
+ top_p=top_p,
305
+ past_key_values_buckets=self.kv_caches,
306
+ ras_win_len=ras_win_len,
307
+ ras_win_max_num_repeat=ras_win_max_num_repeat,
308
+ stop_strings=["<|end_of_text|>", "<|eot_id|>"],
309
+ tokenizer=self._tokenizer,
310
+ seed=seed,
311
+ )
312
+
313
+ step_audio_out_ids_l = []
314
+ for ele in outputs[1]:
315
+ audio_out_ids = ele
316
+ if self._config.use_delay_pattern:
317
+ audio_out_ids = revert_delay_pattern(audio_out_ids)
318
+ step_audio_out_ids_l.append(audio_out_ids.clip(0, self._audio_tokenizer.codebook_size - 1)[:, 1:-1])
319
+ audio_out_ids = torch.concat(step_audio_out_ids_l, dim=1)
320
+ audio_out_ids_l.append(audio_out_ids)
321
+ generated_audio_ids.append(audio_out_ids)
322
+
323
+ generation_messages.append(
324
+ Message(
325
+ role="assistant",
326
+ content=AudioContent(audio_url=""),
327
+ )
328
+ )
329
+ if generation_chunk_buffer_size is not None and len(generated_audio_ids) > generation_chunk_buffer_size:
330
+ generated_audio_ids = generated_audio_ids[-generation_chunk_buffer_size:]
331
+ generation_messages = generation_messages[(-2 * generation_chunk_buffer_size) :]
332
+
333
+ logger.info(f"========= Final Text output =========")
334
+ logger.info(self._tokenizer.decode(outputs[0][0]))
335
+ concat_audio_out_ids = torch.concat(audio_out_ids_l, dim=1)
336
+
337
+ if concat_audio_out_ids.device.type in ["mps", "cuda"]:
338
+ concat_audio_out_ids_cpu = concat_audio_out_ids.detach().cpu()
339
+ else:
340
+ concat_audio_out_ids_cpu = concat_audio_out_ids
341
+
342
+ concat_wv = self._audio_tokenizer.decode(concat_audio_out_ids_cpu.unsqueeze(0))[0, 0]
343
+ text_result = self._tokenizer.decode(outputs[0][0])
344
+ return concat_wv, sr, text_result
345
+
346
+ def prepare_generation_context(scene_prompt, ref_audio, ref_audio_in_system_message, audio_tokenizer, speaker_tags):
347
+ """Prepare the context for generation."""
348
+ system_message = None
349
+ messages = []
350
+ audio_ids = []
351
+ if ref_audio is not None:
352
+ num_speakers = len(ref_audio.split(","))
353
+ speaker_info_l = ref_audio.split(",")
354
+ voice_profile = None
355
+ if any([speaker_info.startswith("profile:") for speaker_info in ref_audio.split(",")]):
356
+ ref_audio_in_system_message = True
357
+ if ref_audio_in_system_message:
358
+ speaker_desc = []
359
+ for spk_id, character_name in enumerate(speaker_info_l):
360
+ if character_name.startswith("profile:"):
361
+ if voice_profile is None:
362
+ with open(f"{CURR_DIR}/voice_prompts/profile.yaml", "r", encoding="utf-8") as f:
363
+ voice_profile = yaml.safe_load(f)
364
+ character_desc = voice_profile["profiles"][character_name[len("profile:") :].strip()]
365
+ speaker_desc.append(f"SPEAKER{spk_id}: {character_desc}")
366
+ else:
367
+ speaker_desc.append(f"SPEAKER{spk_id}: {AUDIO_PLACEHOLDER_TOKEN}")
368
+ if scene_prompt:
369
+ system_message = (
370
+ "Generate audio following instruction."
371
+ "\n\n"
372
+ f"<|scene_desc_start|>\n{scene_prompt}\n\n" + "\n".join(speaker_desc) + "\n<|scene_desc_end|>"
373
+ )
374
+ else:
375
+ system_message = (
376
+ "Generate audio following instruction.\n\n"
377
+ + f"<|scene_desc_start|>\n"
378
+ + "\n".join(speaker_desc)
379
+ + "\n<|scene_desc_end|>"
380
+ )
381
+ system_message = _build_system_message_with_audio_prompt(system_message)
382
+ else:
383
+ if scene_prompt:
384
+ system_message = Message(
385
+ role="system",
386
+ content=f"Generate audio following instruction.\n\n<|scene_desc_start|>\n{scene_prompt}\n<|scene_desc_end|>",
387
+ )
388
+ voice_profile = None
389
+ for spk_id, character_name in enumerate(ref_audio.split(",")):
390
+ if not character_name.startswith("profile:"):
391
+ prompt_audio_path = os.path.join(f"{CURR_DIR}/voice_prompts", f"{character_name}.wav")
392
+ prompt_text_path = os.path.join(f"{CURR_DIR}/voice_prompts", f"{character_name}.txt")
393
+ assert os.path.exists(prompt_audio_path), (
394
+ f"Voice prompt audio file {prompt_audio_path} does not exist."
395
+ )
396
+ assert os.path.exists(prompt_text_path), f"Voice prompt text file {prompt_text_path} does not exist."
397
+ with open(prompt_text_path, "r", encoding="utf-8") as f:
398
+ prompt_text = f.read().strip()
399
+ audio_tokens = audio_tokenizer.encode(prompt_audio_path)
400
+ audio_ids.append(audio_tokens)
401
+
402
+ if not ref_audio_in_system_message:
403
+ messages.append(
404
+ Message(
405
+ role="user",
406
+ content=f"[SPEAKER{spk_id}] {prompt_text}" if num_speakers > 1 else prompt_text,
407
+ )
408
+ )
409
+ messages.append(
410
+ Message(
411
+ role="assistant",
412
+ content=AudioContent(
413
+ audio_url=prompt_audio_path,
414
+ ),
415
+ )
416
+ )
417
+ else:
418
+ if len(speaker_tags) > 1:
419
+ speaker_desc_l = []
420
+
421
+ for idx, tag in enumerate(speaker_tags):
422
+ if idx % 2 == 0:
423
+ speaker_desc = f"feminine"
424
+ else:
425
+ speaker_desc = f"masculine"
426
+ speaker_desc_l.append(f"{tag}: {speaker_desc}")
427
+
428
+ speaker_desc = "\n".join(speaker_desc_l)
429
+ scene_desc_l = []
430
+ if scene_prompt:
431
+ scene_desc_l.append(scene_prompt)
432
+ scene_desc_l.append(speaker_desc)
433
+ scene_desc = "\n\n".join(scene_desc_l)
434
+
435
+ system_message = Message(
436
+ role="system",
437
+ content=f"{MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE}\n\n<|scene_desc_start|>\n{scene_desc}\n<|scene_desc_end|>",
438
+ )
439
+ else:
440
+ system_message_l = ["Generate audio following instruction."]
441
+ if scene_prompt:
442
+ system_message_l.append(f"<|scene_desc_start|>\n{scene_prompt}\n<|scene_desc_end|>")
443
+ system_message = Message(
444
+ role="system",
445
+ content="\n\n".join(system_message_l),
446
+ )
447
+ if system_message:
448
+ messages.insert(0, system_message)
449
+ return messages, audio_ids
450
+
451
+ def interactive_generation_loop(
452
+ model_client,
453
+ audio_tokenizer,
454
+ scene_prompt,
455
+ ref_audio,
456
+ ref_audio_in_system_message,
457
+ chunk_method,
458
+ chunk_max_word_num,
459
+ chunk_max_num_turns,
460
+ generation_chunk_buffer_size,
461
+ temperature,
462
+ top_k,
463
+ top_p,
464
+ ras_win_len,
465
+ ras_win_max_num_repeat,
466
+ seed,
467
+ output_dir,
468
+ ):
469
+ """Main interactive loop for audio generation."""
470
+ logger.info("Starting interactive generation mode. Enter 'quit' or 'exit' to stop.")
471
+ logger.info("Enter your transcript and press Enter to generate audio.")
472
+
473
+ generation_count = 0
474
+
475
+ while True:
476
+ try:
477
+ # Get user input
478
+ print("\n" + "="*50)
479
+ print("Enter transcript (or 'quit'/'exit' to stop):")
480
+ user_input = input("> ").strip()
481
+
482
+ if not user_input:
483
+ continue
484
+
485
+ if user_input.lower() in ['quit', 'exit']:
486
+ logger.info("Exiting interactive generation mode.")
487
+ break
488
+
489
+ transcript = user_input
490
+
491
+ # Process transcript
492
+ pattern = re.compile(r"\[(SPEAKER\d+)\]")
493
+ speaker_tags = sorted(set(pattern.findall(transcript)))
494
+
495
+ # Normalize transcript
496
+ transcript = normalize_chinese_punctuation(transcript)
497
+ transcript = transcript.replace("(", " ")
498
+ transcript = transcript.replace(")", " ")
499
+ transcript = transcript.replace("°F", " degrees Fahrenheit")
500
+ transcript = transcript.replace("°C", " degrees Celsius")
501
+
502
+ for tag, replacement in [
503
+ ("[laugh]", "<SE>[Laughter]</SE>"),
504
+ ("[humming start]", "<SE>[Humming]</SE>"),
505
+ ("[humming end]", "<SE_e>[Humming]</SE_e>"),
506
+ ("[music start]", "<SE_s>[Music]</SE_s>"),
507
+ ("[music end]", "<SE_e>[Music]</SE_e>"),
508
+ ("[music]", "<SE>[Music]</SE>"),
509
+ ("[sing start]", "<SE_s>[Singing]</SE_s>"),
510
+ ("[sing end]", "<SE_e>[Singing]</SE_e>"),
511
+ ("[applause]", "<SE>[Applause]</SE>"),
512
+ ("[cheering]", "<SE>[Cheering]</SE>"),
513
+ ("[cough]", "<SE>[Cough]</SE>"),
514
+ ]:
515
+ transcript = transcript.replace(tag, replacement)
516
+
517
+ lines = transcript.split("\n")
518
+ transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()])
519
+ transcript = transcript.strip()
520
+
521
+ if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "</SE_e>", "</SE>"]]):
522
+ transcript += "."
523
+
524
+ # Prepare generation context
525
+ messages, audio_ids = prepare_generation_context(
526
+ scene_prompt=scene_prompt,
527
+ ref_audio=ref_audio,
528
+ ref_audio_in_system_message=ref_audio_in_system_message,
529
+ audio_tokenizer=audio_tokenizer,
530
+ speaker_tags=speaker_tags,
531
+ )
532
+
533
+ # Chunk text
534
+ chunked_text = prepare_chunk_text(
535
+ transcript,
536
+ chunk_method=chunk_method,
537
+ chunk_max_word_num=chunk_max_word_num,
538
+ chunk_max_num_turns=chunk_max_num_turns,
539
+ )
540
+
541
+ logger.info("Chunks used for generation:")
542
+ for idx, chunk_text in enumerate(chunked_text):
543
+ logger.info(f"Chunk {idx}:")
544
+ logger.info(chunk_text)
545
+ logger.info("-----")
546
+
547
+ # Generate audio
548
+ logger.info(f"Generating audio for input: {transcript[:50]}...")
549
+ concat_wv, sr, text_output = model_client.generate(
550
+ messages=messages,
551
+ audio_ids=audio_ids,
552
+ chunked_text=chunked_text,
553
+ generation_chunk_buffer_size=generation_chunk_buffer_size,
554
+ temperature=temperature,
555
+ top_k=top_k,
556
+ top_p=top_p,
557
+ ras_win_len=ras_win_len,
558
+ ras_win_max_num_repeat=ras_win_max_num_repeat,
559
+ seed=seed,
560
+ )
561
+
562
+ # Save audio file
563
+ generation_count += 1
564
+ output_filename = f"generation_{generation_count:03d}.wav"
565
+ output_path = os.path.join(output_dir, output_filename)
566
+ sf.write(output_path, concat_wv, sr)
567
+ logger.info(f"Audio saved to: {output_path}")
568
+ print(f"✓ Audio generated and saved to: {output_filename}")
569
+
570
+ except KeyboardInterrupt:
571
+ logger.info("\nInterrupted by user. Exiting...")
572
+ break
573
+ except Exception as e:
574
+ logger.error(f"Error during generation: {e}")
575
+ print(f"✗ Error: {e}")
576
+ continue
577
+
578
+ @click.command()
579
+ @click.option(
580
+ "--model_path",
581
+ type=str,
582
+ default="./higgs-audio-v2-generation-3B-base",
583
+ help="Path to the model directory.",
584
+ )
585
+ @click.option(
586
+ "--audio_tokenizer",
587
+ type=str,
588
+ default="./higgs-audio-v2-tokenizer",
589
+ help="Path to the audio tokenizer directory.",
590
+ )
591
+ @click.option(
592
+ "--max_new_tokens",
593
+ type=int,
594
+ default=2048,
595
+ help="The maximum number of new tokens to generate.",
596
+ )
597
+ @click.option(
598
+ "--scene_prompt",
599
+ type=str,
600
+ default=f"{CURR_DIR}/scene_prompts/quiet_indoor.txt",
601
+ help="The scene description prompt to use for generation. If not set, or set to `empty`, we will leave it to empty.",
602
+ )
603
+ @click.option(
604
+ "--temperature",
605
+ type=float,
606
+ default=1.0,
607
+ help="The value used to module the next token probabilities.",
608
+ )
609
+ @click.option(
610
+ "--top_k",
611
+ type=int,
612
+ default=50,
613
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering.",
614
+ )
615
+ @click.option(
616
+ "--top_p",
617
+ type=float,
618
+ default=0.95,
619
+ help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.",
620
+ )
621
+ @click.option(
622
+ "--ras_win_len",
623
+ type=int,
624
+ default=7,
625
+ help="The window length for RAS sampling. If set to 0 or a negative value, we won't use RAS sampling.",
626
+ )
627
+ @click.option(
628
+ "--ras_win_max_num_repeat",
629
+ type=int,
630
+ default=2,
631
+ help="The maximum number of times to repeat the RAS window. Only used when --ras_win_len is set.",
632
+ )
633
+ @click.option(
634
+ "--ref_audio",
635
+ type=str,
636
+ default=None,
637
+ help="The voice prompt to use for generation. If not set, we will let the model randomly pick a voice.",
638
+ )
639
+ @click.option(
640
+ "--ref_audio_in_system_message",
641
+ is_flag=True,
642
+ default=False,
643
+ help="Whether to include the voice prompt description in the system message.",
644
+ show_default=True,
645
+ )
646
+ @click.option(
647
+ "--chunk_method",
648
+ default=None,
649
+ type=click.Choice([None, "speaker", "word"]),
650
+ help="The method to use for chunking the prompt text.",
651
+ )
652
+ @click.option(
653
+ "--chunk_max_word_num",
654
+ default=200,
655
+ type=int,
656
+ help="The maximum number of words for each chunk when 'word' chunking method is used.",
657
+ )
658
+ @click.option(
659
+ "--chunk_max_num_turns",
660
+ default=1,
661
+ type=int,
662
+ help="The maximum number of turns for each chunk when 'speaker' chunking method is used.",
663
+ )
664
+ @click.option(
665
+ "--generation_chunk_buffer_size",
666
+ default=None,
667
+ type=int,
668
+ help="The maximal number of chunks to keep in the buffer.",
669
+ )
670
+ @click.option(
671
+ "--seed",
672
+ default=None,
673
+ type=int,
674
+ help="Random seed for generation.",
675
+ )
676
+ @click.option(
677
+ "--device_id",
678
+ type=int,
679
+ default=None,
680
+ help="The device to run the model on.",
681
+ )
682
+ @click.option(
683
+ "--output_dir",
684
+ type=str,
685
+ default="./interactive_outputs",
686
+ help="Directory to save generated audio files.",
687
+ )
688
+ @click.option(
689
+ "--use_static_kv_cache",
690
+ type=int,
691
+ default=1,
692
+ help="Whether to use static KV cache for faster generation. Only works when using GPU.",
693
+ )
694
+ @click.option(
695
+ "--device",
696
+ type=click.Choice(["auto", "cuda", "mps", "none"]),
697
+ default="auto",
698
+ help="Device to use: 'auto' (pick best available), 'cuda', 'mps', or 'none' (CPU only).",
699
+ )
700
+ def main(
701
+ model_path,
702
+ audio_tokenizer,
703
+ max_new_tokens,
704
+ scene_prompt,
705
+ temperature,
706
+ top_k,
707
+ top_p,
708
+ ras_win_len,
709
+ ras_win_max_num_repeat,
710
+ ref_audio,
711
+ ref_audio_in_system_message,
712
+ chunk_method,
713
+ chunk_max_word_num,
714
+ chunk_max_num_turns,
715
+ generation_chunk_buffer_size,
716
+ seed,
717
+ device_id,
718
+ output_dir,
719
+ use_static_kv_cache,
720
+ device,
721
+ ):
722
+ """Interactive audio generation - model loads once, generates multiple times."""
723
+
724
+ # Setup device
725
+ if device_id is None:
726
+ if device == "auto":
727
+ if torch.cuda.is_available():
728
+ device_id = 0
729
+ device = "cuda:0"
730
+ elif torch.backends.mps.is_available():
731
+ device_id = None
732
+ device = "mps"
733
+ else:
734
+ device_id = None
735
+ device = "cpu"
736
+ elif device == "cuda":
737
+ device_id = 0
738
+ device = "cuda:0"
739
+ elif device == "mps":
740
+ device_id = None
741
+ device = "mps"
742
+ else:
743
+ device_id = None
744
+ device = "cpu"
745
+ else:
746
+ device = f"cuda:{device_id}"
747
+
748
+ # For MPS, use CPU for audio tokenizer
749
+ audio_tokenizer_device = "cpu" if device == "mps" else device
750
+ audio_tokenizer_obj = load_higgs_audio_tokenizer(audio_tokenizer, device=audio_tokenizer_device)
751
+
752
+ # Disable static KV cache on MPS
753
+ if device == "mps" and use_static_kv_cache:
754
+ use_static_kv_cache = False
755
+
756
+ # Create output directory
757
+ os.makedirs(output_dir, exist_ok=True)
758
+ logger.info(f"Output directory: {output_dir}")
759
+
760
+ # Load scene prompt if file exists
761
+ if scene_prompt is not None and scene_prompt != "empty" and os.path.exists(scene_prompt):
762
+ with open(scene_prompt, "r", encoding="utf-8") as f:
763
+ scene_prompt = f.read().strip()
764
+ else:
765
+ scene_prompt = None
766
+
767
+ # Initialize model client (loads model once)
768
+ logger.info("Loading model... This may take a while.")
769
+ model_client = HiggsAudioModelClient(
770
+ model_path=model_path,
771
+ audio_tokenizer=audio_tokenizer_obj,
772
+ device=device,
773
+ device_id=device_id,
774
+ max_new_tokens=max_new_tokens,
775
+ use_static_kv_cache=use_static_kv_cache,
776
+ )
777
+ logger.info("Model loaded successfully!")
778
+
779
+ # Start interactive generation loop
780
+ interactive_generation_loop(
781
+ model_client=model_client,
782
+ audio_tokenizer=audio_tokenizer_obj,
783
+ scene_prompt=scene_prompt,
784
+ ref_audio=ref_audio,
785
+ ref_audio_in_system_message=ref_audio_in_system_message,
786
+ chunk_method=chunk_method,
787
+ chunk_max_word_num=chunk_max_word_num,
788
+ chunk_max_num_turns=chunk_max_num_turns,
789
+ generation_chunk_buffer_size=generation_chunk_buffer_size,
790
+ temperature=temperature,
791
+ top_k=top_k,
792
+ top_p=top_p,
793
+ ras_win_len=ras_win_len,
794
+ ras_win_max_num_repeat=ras_win_max_num_repeat,
795
+ seed=seed,
796
+ output_dir=output_dir,
797
+ )
798
+
799
+ if __name__ == "__main__":
800
+ main()
examples/scene_prompts/quiet_indoor.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Audio is recorded from a quiet room.
examples/scene_prompts/reading_blog.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ In this audio, the person is reading a blog post aloud. The content is informative and engaging, with the speaker using a clear, conversational tone to make the material feel more approachable. The pacing is moderate, allowing listeners to absorb the information, and the tone shifts slightly to emphasize key points. The speaker occasionally pauses for effect, ensuring each section flows smoothly, as they guide the listener through the post's main ideas.
examples/serve_engine/README.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Examples to use HiggsAudioServeEngine
2
+
3
+ The `run_hf_example.py` script provides three different examples for using the `HiggsAudioServeEngine`.
4
+ Each example will generate an audio file (`output_{example}.wav`) in the current directory.
5
+
6
+ ### Zero-Shot Voice Generation
7
+ Generate audio with specific voice characteristics (e.g., accents).
8
+
9
+ ```bash
10
+ python run_hf_example.py zero_shot
11
+ ```
12
+
13
+ ### Voice Cloning
14
+ Clone a voice from a reference audio sample.
15
+
16
+ ```bash
17
+ python run_hf_example.py voice_clone
18
+ ```
19
+
20
+ ### (Experimental) Interleaved Dialogue Generation
21
+ Higgs Audio v2 is also able to generate text. Here's an example that shows it is able to generate multi-speaker conversations with interleaved transcript and audio from scene descriptions.
22
+
23
+ ```bash
24
+ python run_hf_example.py interleaved_dialogue
25
+ ```
examples/serve_engine/input_samples.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ from boson_multimodal.data_types import ChatMLSample, Message, AudioContent
4
+
5
+
6
+ def encode_base64_content_from_file(file_path: str) -> str:
7
+ """Encode a content from a local file to base64 format."""
8
+ # Read the audio file as binary and encode it directly to Base64
9
+ with open(file_path, "rb") as audio_file:
10
+ audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8")
11
+ return audio_base64
12
+
13
+
14
+ def get_interleaved_dialogue_input_sample():
15
+ system_prompt = (
16
+ "Generate audio following instruction.\n\n"
17
+ "<|scene_desc_start|>\n"
18
+ "SPEAKER0: vocal fry;moderate pitch;monotone;masculine;young adult;slightly fast\n"
19
+ "SPEAKER1: masculine;moderate;moderate pitch;monotone;mature\n\n"
20
+ "In this scene, a group of adventurers is debating whether to investigate a potentially dangerous situation.\n"
21
+ "<|scene_desc_end|>"
22
+ )
23
+
24
+ messages = [
25
+ Message(
26
+ role="system",
27
+ content=system_prompt,
28
+ ),
29
+ Message(
30
+ role="user",
31
+ content="<|generation_instruction_start|>\nGenerate interleaved transcript and audio that lasts for around 20 seconds.\n<|generation_instruction_end|>",
32
+ ),
33
+ ]
34
+ chat_ml_sample = ChatMLSample(messages=messages)
35
+ return chat_ml_sample
36
+
37
+
38
+ def get_zero_shot_input_sample():
39
+ system_prompt = (
40
+ "Generate audio following instruction.\n\n<|scene_desc_start|>\nSPEAKER0: british accent\n<|scene_desc_end|>"
41
+ )
42
+
43
+ messages = [
44
+ Message(
45
+ role="system",
46
+ content=system_prompt,
47
+ ),
48
+ Message(
49
+ role="user",
50
+ content="Hey, everyone! Welcome back to Tech Talk Tuesdays.\n"
51
+ "It's your host, Alex, and today, we're diving into a topic that's become absolutely crucial in the tech world — deep learning.\n"
52
+ "And let's be honest, if you've been even remotely connected to tech, AI, or machine learning lately, you know that deep learning is everywhere.",
53
+ ),
54
+ ]
55
+ chat_ml_sample = ChatMLSample(messages=messages)
56
+ return chat_ml_sample
57
+
58
+
59
+ def get_voice_clone_input_sample():
60
+ reference_text = "I would imagine so. A wand with a dragon heartstring core is capable of dazzling magic."
61
+ reference_audio = encode_base64_content_from_file(
62
+ os.path.join(os.path.dirname(__file__), "voice_examples/old_man.wav")
63
+ )
64
+ messages = [
65
+ Message(
66
+ role="user",
67
+ content=reference_text,
68
+ ),
69
+ Message(
70
+ role="assistant",
71
+ content=AudioContent(raw_audio=reference_audio, audio_url="placeholder"),
72
+ ),
73
+ Message(
74
+ role="user",
75
+ content="Hey, everyone! Welcome back to Tech Talk Tuesdays.\n"
76
+ "It's your host, Alex, and today, we're diving into a topic that's become absolutely crucial in the tech world — deep learning.\n"
77
+ "And let's be honest, if you've been even remotely connected to tech, AI, or machine learning lately, you know that deep learning is everywhere.",
78
+ ),
79
+ ]
80
+ return ChatMLSample(messages=messages)
81
+
82
+
83
+ INPUT_SAMPLES = {
84
+ "interleaved_dialogue": get_interleaved_dialogue_input_sample,
85
+ "zero_shot": get_zero_shot_input_sample,
86
+ "voice_clone": get_voice_clone_input_sample,
87
+ }
examples/serve_engine/run_hf_example.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Example for using HiggsAudio for generating both the transcript and audio in an interleaved manner."""
2
+
3
+ from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse
4
+ import torch
5
+ import torchaudio
6
+ import time
7
+ from loguru import logger
8
+ import click
9
+
10
+ from input_samples import INPUT_SAMPLES
11
+
12
+ MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
13
+ AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"
14
+
15
+
16
+ @click.command()
17
+ @click.argument("example", type=click.Choice(list(INPUT_SAMPLES.keys())))
18
+ def main(example: str):
19
+ input_sample = INPUT_SAMPLES[example]()
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ logger.info(f"Using device: {device}")
22
+
23
+ serve_engine = HiggsAudioServeEngine(
24
+ MODEL_PATH,
25
+ AUDIO_TOKENIZER_PATH,
26
+ device=device,
27
+ )
28
+
29
+ logger.info("Starting generation...")
30
+ start_time = time.time()
31
+ output: HiggsAudioResponse = serve_engine.generate(
32
+ chat_ml_sample=input_sample,
33
+ max_new_tokens=1024,
34
+ temperature=1.0,
35
+ top_p=0.95,
36
+ top_k=50,
37
+ stop_strings=["<|end_of_text|>", "<|eot_id|>"],
38
+ )
39
+ elapsed_time = time.time() - start_time
40
+ logger.info(f"Generation time: {elapsed_time:.2f} seconds")
41
+
42
+ torchaudio.save(f"output_{example}.wav", torch.from_numpy(output.audio)[None, :], output.sampling_rate)
43
+ logger.info(f"Generated text:\n{output.generated_text}")
44
+ logger.info(f"Saved audio to output_{example}.wav")
45
+
46
+
47
+ if __name__ == "__main__":
48
+ main()
examples/serve_engine/voice_examples/old_man.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83bda9cd63be92366ef40dbe15c33e67b78766fb7069609f10dfc05cc626deba
3
+ size 1246508
examples/transcript/multi_speaker/en_argument.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [SPEAKER0] I can't believe you did that without even asking me first!
2
+ [SPEAKER1] Oh, come on! It wasn't a big deal, and I knew you would overreact like this.
3
+ [SPEAKER0] Overreact? You made a decision that affects both of us without even considering my opinion!
4
+ [SPEAKER1] Because I didn't have time to sit around waiting for you to make up your mind! Someone had to act.