diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..6f4011d986cd8895bbcc4cee0b071036ea7b285d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,24 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +boson_multimodal/model/higgs_audio/__pycache__/modeling_higgs_audio.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +examples/serve_engine/voice_examples/old_man.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/belinda.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/bigbang_amy.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/bigbang_sheldon.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/broom_salesman.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/chadwick.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/en_man.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/en_woman.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/fiftyshades_anna.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/mabaoguo.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/mabel.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/shrek_donkey_es.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/shrek_donkey.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/shrek_fiona.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/shrek_shrek.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/vex.wav filter=lfs diff=lfs merge=lfs -text +examples/voice_prompts/zh_man_sichuan.wav filter=lfs diff=lfs merge=lfs -text +figures/emergent-tts-emotions-win-rate.png filter=lfs diff=lfs merge=lfs -text +figures/higgs_audio_tokenizer_architecture.png filter=lfs diff=lfs merge=lfs -text +figures/higgs_audio_v2_architecture_combined.png filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..19dc35b2433851a0e8fd866a5d323b2ba18c12ed --- /dev/null +++ b/LICENSE @@ -0,0 +1,175 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..acf3eec3fd68aee7a918315624a58b03dec8f8b1 --- /dev/null +++ b/README.md @@ -0,0 +1,325 @@ +

Higgs Audio V2: Redefining Expressiveness in Audio Generation

+ +
+ + + + + +
+ + +We are open-sourcing Higgs Audio v2, a powerful audio foundation model pretrained on over 10 million hours of audio data and a diverse set of text data. Despite having no post-training or fine-tuning, Higgs Audio v2 excels in expressive audio generation, thanks to its deep language and acoustic understanding. + +On [EmergentTTS-Eval](https://github.com/boson-ai/emergenttts-eval-public), it achieves win rates of **75.7%** and **55.7%** over "gpt-4o-mini-tts" on the "Emotions" and "Questions" categories, respectively. It also obtains state-of-the-art performance on traditional TTS benchmarks like Seed-TTS Eval and Emotional Speech Dataset (ESD). Moreover, the model demonstrates capabilities rarely seen in previous systems, including generating natural multi-speaker dialogues in multiple languages, automatic prosody adaptation during narration, melodic humming with the cloned voice, and simultaneous generation of speech and background music. + +

+ +

+ +Here's the demo video that shows some of its emergent capabilities (remember to unmute): + + + +Here's another demo video that show-cases the model's multilingual capability and how it enabled live translation (remember to unmute): + + + +## Installation + +We recommend to use NVIDIA Deep Learning Container to manage the CUDA environment. Following are two docker images that we have verified: +- nvcr.io/nvidia/pytorch:25.02-py3 +- nvcr.io/nvidia/pytorch:25.01-py3 + +Here's an example command for launching a docker container environment. Please also check the [official NVIDIA documentations](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch). + +```bash +docker run --gpus all --ipc=host --net=host --ulimit memlock=-1 --ulimit stack=67108864 -it --rm nvcr.io/nvidia/pytorch:25.02-py3 bash +``` + +### Option 1: Direct installation + + +```bash +git clone https://github.com/boson-ai/higgs-audio.git +cd higgs-audio + +pip install -r requirements.txt +pip install -e . +``` + +### Option 2: Using venv + +```bash +git clone https://github.com/boson-ai/higgs-audio.git +cd higgs-audio + +python3 -m venv higgs_audio_env +source higgs_audio_env/bin/activate +pip install -r requirements.txt +pip install -e . +``` + + +### Option 3: Using conda +```bash +git clone https://github.com/boson-ai/higgs-audio.git +cd higgs-audio + +conda create -n higgs_audio_env python=3.10 +conda activate higgs_audio_env +pip install -r requirements.txt +pip install -e . +``` + +### Option 4: Using uv +```bash +git clone https://github.com/boson-ai/higgs-audio.git +cd higgs-audio + +uv venv --python 3.10 +source .venv/bin/activate +uv pip install -r requirements.txt +uv pip install -e . +``` + +### Option 5: Using vllm + +For advanced usage with higher throughput, we also built OpenAI compatible API server backed by vLLM engine for you to use. +Please refer to [examples/vllm](./examples/vllm) for more details. + + +## Usage + +> [!TIP] +> For optimal performance, run the generation examples on a machine equipped with GPU with at least 24GB memory! + +### Get Started + +Here's a basic python snippet to help you get started. + +```python +from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse +from boson_multimodal.data_types import ChatMLSample, Message, AudioContent + +import torch +import torchaudio +import time +import click + +MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base" +AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer" + +system_prompt = ( + "Generate audio following instruction.\n\n<|scene_desc_start|>\nAudio is recorded from a quiet room.\n<|scene_desc_end|>" +) + +messages = [ + Message( + role="system", + content=system_prompt, + ), + Message( + role="user", + content="The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years.", + ), +] +device = "cuda" if torch.cuda.is_available() else "cpu" + +serve_engine = HiggsAudioServeEngine(MODEL_PATH, AUDIO_TOKENIZER_PATH, device=device) + +output: HiggsAudioResponse = serve_engine.generate( + chat_ml_sample=ChatMLSample(messages=messages), + max_new_tokens=1024, + temperature=0.3, + top_p=0.95, + top_k=50, + stop_strings=["<|end_of_text|>", "<|eot_id|>"], +) +torchaudio.save(f"output.wav", torch.from_numpy(output.audio)[None, :], output.sampling_rate) +``` + +We also provide a list of examples under [examples](./examples). In the following we highlight a few examples to help you use Higgs Audio v2. + +### Zero-Shot Voice Cloning +Generate audio that sounds similar as the provided [reference audio](./examples/voice_prompts/belinda.wav). + +```bash +python3 examples/generation.py \ +--transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \ +--ref_audio belinda \ +--temperature 0.3 \ +--out_path generation.wav +``` + +The generation script will automatically use `cuda:0` if it founds cuda is available. To change the device id, specify `--device_id`: + +```bash +python3 examples/generation.py \ +--transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \ +--ref_audio belinda \ +--temperature 0.3 \ +--device_id 0 \ +--out_path generation.wav +``` + +You can also try other voices. Check more example voices in [examples/voice_prompts](./examples/voice_prompts). You can also add your own voice to the folder. + +```bash +python3 examples/generation.py \ +--transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \ +--ref_audio broom_salesman \ +--temperature 0.3 \ +--out_path generation.wav +``` + +### Voice Cloning via Cog (Replicate) + +You can also run Higgs Audio v2 using [Cog](https://cog.run), which packages the model for reproducible inference. This is useful for deploying on Replicate or other platforms. + +#### Prerequisites +- [Install Cog](https://cog.run/getting-started) +- GPU with at least 24GB VRAM (e.g., A100, RTX 4090) + +#### Basic Text-to-Speech +```bash +cog predict -i text="The sun rises in the east and sets in the west." +``` + +#### Voice Cloning with Reference Audio +To clone a voice, provide a reference audio file: + +```bash +cog predict -i text="The sun rises in the east and sets in the west." \ + -i ref_audio=@/path/to/reference_audio.wav +``` + +#### Customization Parameters +- `text` (str): Text to convert to speech +- `ref_audio` (Path, optional): Reference audio file for voice cloning (WAV, MP3, etc.) +- `scene_description` (str): Scene context for audio generation (default: "Audio is recorded from a quiet room.") +- `temperature` (float): Controls randomness, 0.1-1.0 (default: 0.3, lower = more deterministic) +- `top_p` (float): Nucleus sampling parameter, 0.1-1.0 (default: 0.95) +- `top_k` (int): Top-k sampling, 1-100 (default: 50) +- `max_new_tokens` (int): Maximum audio tokens to generate, 256-2048 (default: 1024) +- `system_message` (str): Custom system prompt (optional) + +#### Example: Generate with Custom Scene +```bash +cog predict -i text="Generate a whisper voice in a noisy cafe." \ + -i scene_description="Audio is recorded from a busy cafe with background chatter." \ + -i temperature=0.5 +``` + +#### Example: Clone Multiple Voices +```bash +cog predict -i text="Speaker one talks here." -i ref_audio=@voice1.wav +cog predict -i text="Speaker two talks here." -i ref_audio=@voice2.wav +``` + +### Single-speaker Generation with Smart Voice +If you do not specify reference voice, the model will decide the voice based on the transcript it sees. + +```bash +python3 examples/generation.py \ +--transcript "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." \ +--temperature 0.3 \ +--out_path generation.wav +``` + + +### Multi-speaker Dialog with Smart Voice +Generate multi-speaker dialog. The model will decide the voices based on the transcript it sees. + +```bash +python3 examples/generation.py \ +--transcript examples/transcript/multi_speaker/en_argument.txt \ +--seed 12345 \ +--out_path generation.wav +``` + +### Multi-speaker Dialog with Voice Clone + +Generate multi-speaker dialog with the voices you picked. + +```bash +python3 examples/generation.py \ +--transcript examples/transcript/multi_speaker/en_argument.txt \ +--ref_audio belinda,broom_salesman \ +--ref_audio_in_system_message \ +--chunk_method speaker \ +--seed 12345 \ +--out_path generation.wav +``` + + +## Technical Details + + + +Higgs Audio v2 adopts the "generation variant" depicted in the architecture figure above. Its strong performance is driven by three key technical innovations: +- We developed an automated annotation pipeline that leverages multiple ASR models, sound event classification models, and our in-house audio understanding model. Using this pipeline, we cleaned and annotated 10 million hours audio data, which we refer to as **AudioVerse**. The in-house understanding model is finetuned on top of [Higgs Audio v1 Understanding](https://www.boson.ai/blog/higgs-audio), which adopts the "understanding variant" shown in the architecture figure. +- We trained a unified audio tokenizer from scratch that captures both semantic and acoustic features. Learn more in the [tokenizer blog](./tech_blogs/TOKENIZER_BLOG.md). +- We proposed the DualFFN architecture, which enhances the LLM’s ability to model acoustics tokens with minimal computational overhead. See the [architecture blog](./tech_blogs/ARCHITECTURE_BLOG.md). + +## Evaluation + +Here's the performance of Higgs Audio v2 on four benchmarks, [Seed-TTS Eval](https://github.com/BytedanceSpeech/seed-tts-eval), [Emotional Speech Dataset (ESD)](https://paperswithcode.com/dataset/esd), [EmergentTTS-Eval](https://arxiv.org/abs/2505.23009), and Multi-speaker Eval: + +#### Seed-TTS Eval & ESD + +We prompt Higgs Audio v2 with the reference text, reference audio, and target text for zero-shot TTS. We use the standard evaluation metrics from Seed-TTS Eval and ESD. + +| | SeedTTS-Eval| | ESD | | +|------------------------------|--------|--------|---------|-------------------| +| | WER ↓ | SIM ↑ | WER ↓ | SIM (emo2vec) ↑ | +| Cosyvoice2 | 2.28 | 65.49 | 2.71 | 80.48 | +| Qwen2.5-omni† | 2.33 | 64.10 | - | - | +| ElevenLabs Multilingual V2 | **1.43** | 50.00 | 1.66 | 65.87 | +| Higgs Audio v1 | 2.18 | 66.27 | **1.49** | 82.84 | +| Higgs Audio v2 (base) | 2.44 | **67.70** | 1.78 | **86.13** | + + +#### EmergentTTS-Eval ("Emotions" and "Questions") + +Following the [EmergentTTS-Eval Paper](https://arxiv.org/abs/2505.23009), we report the win-rate over "gpt-4o-mini-tts" with the "alloy" voice. The judge model is Gemini 2.5 Pro. + +| Model | Emotions (%) ↑ | Questions (%) ↑ | +|------------------------------------|--------------|----------------| +| Higgs Audio v2 (base) | **75.71%** | **55.71%** | +| [gpt-4o-audio-preview†](https://platform.openai.com/docs/models/gpt-4o-audio-preview) | 61.64% | 47.85% | +| [Hume.AI](https://www.hume.ai/research) | 61.60% | 43.21% | +| **BASELINE:** [gpt-4o-mini-tts](https://platform.openai.com/docs/models/gpt-4o-mini-tts) | 50.00% | 50.00% | +| [Qwen 2.5 Omni†](https://github.com/QwenLM/Qwen2.5-Omni) | 41.60% | 51.78% | +| [minimax/speech-02-hd](https://replicate.com/minimax/speech-02-hd) | 40.86% | 47.32% | +| [ElevenLabs Multilingual v2](https://elevenlabs.io/blog/eleven-multilingual-v2) | 30.35% | 39.46% | +| [DeepGram Aura-2](https://deepgram.com/learn/introducing-aura-2-enterprise-text-to-speech) | 29.28% | 48.21% | +| [Sesame csm-1B](https://github.com/SesameAILabs/csm) | 15.96% | 31.78% | + +'†' means using the strong-prompting method described in the paper. + + +#### Multi-speaker Eval + +We also designed a multi-speaker evaluation benchmark to evaluate the capability of Higgs Audio v2 for multi-speaker dialog generation. The benchmark contains three subsets + +- `two-speaker-conversation`: 1000 synthetic dialogues involving two speakers. We fix two reference audio clips to evaluate the model's ability in double voice cloning for utterances ranging from 4 to 10 dialogues between two randomly chosen persona. +- `small talk (no ref)`: 250 synthetic dialogues curated in the same way as above, but are characterized by short utterances and a limited number of turns (4–6), we do not fix reference audios in this case and this set is designed to evaluate the model's ability to automatically assign appropriate voices to speakers. +- `small talk (ref)`: 250 synthetic dialogues similar to above, but contains even shorter utterances as this set is meant to include reference clips in it's context, similar to `two-speaker-conversation`. + + +We report the word-error-rate (WER) and the geometric mean between intra-speaker similarity and inter-speaker dis-similarity on these three subsets. Other than Higgs Audio v2, we also evaluated [MoonCast](https://github.com/jzq2000/MoonCast) and [nari-labs/Dia-1.6B-0626](https://huggingface.co/nari-labs/Dia-1.6B-0626), two of the most popular open-source models capable of multi-speaker dialog generation. Results are summarized in the following table. We are not able to run [nari-labs/Dia-1.6B-0626](https://huggingface.co/nari-labs/Dia-1.6B-0626) on our "two-speaker-conversation" subset due to its strict limitation on the length of the utterances and output audio. + +| | two-speaker-conversation | |small talk | | small talk (no ref) | | +| ---------------------------------------------- | -------------- | ------------------ | ---------- | -------------- | ------------------- | -------------- | +| | WER ↓ | Mean Sim & Dis-sim ↑ | WER ↓ | Mean Sim & Dis-sim ↑ | WER ↓ | Mean Sim & Dis-sim ↑ | +| [MoonCast](https://github.com/jzq2000/MoonCast) | 38.77 | 46.02 | **8.33** | 63.68 | 24.65 | 53.94 | +| [nari-labs/Dia-1.6B-0626](https://huggingface.co/nari-labs/Dia-1.6B-0626) | \- | \- | 17.62 | 63.15 | 19.46 | **61.14** | +| Higgs Audio v2 (base) | **18.88** | **51.95** | 11.89 | **67.92** | **14.65** | 55.28 | + + +## Third-Party Licenses + +The `boson_multimodal/audio_processing/` directory contains code derived from third-party repositories, primarily from [xcodec](https://github.com/zhenye234/xcodec). Please see the [`LICENSE`](boson_multimodal/audio_processing/LICENSE) in that directory for complete attribution and licensing information. diff --git a/__pycache__/predict.cpython-311.pyc b/__pycache__/predict.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8ca64aae125d2e0c2ef9280a97cccb0c3b56f14 Binary files /dev/null and b/__pycache__/predict.cpython-311.pyc differ diff --git a/boson_multimodal/.DS_Store b/boson_multimodal/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..9fa70a008673b913184575e79991ce96188ab58b Binary files /dev/null and b/boson_multimodal/.DS_Store differ diff --git a/boson_multimodal/__init__.py b/boson_multimodal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f470c51f27bb03dc78f81b85808a0cf461631762 --- /dev/null +++ b/boson_multimodal/__init__.py @@ -0,0 +1 @@ +from .model.higgs_audio import HiggsAudioConfig, HiggsAudioModel diff --git a/boson_multimodal/__pycache__/__init__.cpython-311.pyc b/boson_multimodal/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..873efc200af5425a9bb7edfbb26ad634d9dda386 Binary files /dev/null and b/boson_multimodal/__pycache__/__init__.cpython-311.pyc differ diff --git a/boson_multimodal/__pycache__/constants.cpython-311.pyc b/boson_multimodal/__pycache__/constants.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e243cff9fdb148970004867f77aa975a1b4d01e Binary files /dev/null and b/boson_multimodal/__pycache__/constants.cpython-311.pyc differ diff --git a/boson_multimodal/__pycache__/data_types.cpython-311.pyc b/boson_multimodal/__pycache__/data_types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c992abe0a9dc2de88557b130a6d8942d1ce1c817 Binary files /dev/null and b/boson_multimodal/__pycache__/data_types.cpython-311.pyc differ diff --git a/boson_multimodal/audio_processing/LICENSE b/boson_multimodal/audio_processing/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..c27989d900614d1a53e5bdcc1f9c2bee23ef3969 --- /dev/null +++ b/boson_multimodal/audio_processing/LICENSE @@ -0,0 +1,51 @@ +Third-Party License Attribution for Audio Processing Module +=========================================================== + +This directory contains code derived from multiple open-source projects. +The following sections detail the licenses and attributions for third-party code. + +## XCodec Repository +The code in this directory is derived from: +https://github.com/zhenye234/xcodec + +## Individual File Attributions + +### Quantization Module (quantization/) +- Several files contain code derived from Meta Platforms, Inc. and the vector-quantize-pytorch repository +- Individual files contain their own license headers where applicable +- The vector-quantize-pytorch portions are licensed under the MIT License + +## License Terms + +### MIT License (for applicable portions) +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +## Attribution Requirements +When using this code, please ensure proper attribution to: +1. The original xcodec repository: https://github.com/zhenye234/xcodec +2. Any other repositories mentioned in individual file headers +3. This derivative work and its modifications + +## Disclaimer +This directory contains modified versions of the original code. Please refer to +the original repositories for the canonical implementations and their specific +license terms. + +For any questions about licensing or attribution, please check the individual +file headers and the original source repositories. \ No newline at end of file diff --git a/boson_multimodal/audio_processing/__pycache__/higgs_audio_tokenizer.cpython-311.pyc b/boson_multimodal/audio_processing/__pycache__/higgs_audio_tokenizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e15d38623d2729447beb9974cc20be9e118d3847 Binary files /dev/null and b/boson_multimodal/audio_processing/__pycache__/higgs_audio_tokenizer.cpython-311.pyc differ diff --git a/boson_multimodal/audio_processing/__pycache__/semantic_module.cpython-311.pyc b/boson_multimodal/audio_processing/__pycache__/semantic_module.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dac34a9e282d4fc087a1899798f41cb0ca578764 Binary files /dev/null and b/boson_multimodal/audio_processing/__pycache__/semantic_module.cpython-311.pyc differ diff --git a/boson_multimodal/audio_processing/descriptaudiocodec/__init__.py b/boson_multimodal/audio_processing/descriptaudiocodec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/boson_multimodal/audio_processing/descriptaudiocodec/__pycache__/__init__.cpython-311.pyc b/boson_multimodal/audio_processing/descriptaudiocodec/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcd9b391796c2bf6ed577785079c9a752a71ccf7 Binary files /dev/null and b/boson_multimodal/audio_processing/descriptaudiocodec/__pycache__/__init__.cpython-311.pyc differ diff --git a/boson_multimodal/audio_processing/descriptaudiocodec/dac/model/__pycache__/base.cpython-311.pyc b/boson_multimodal/audio_processing/descriptaudiocodec/dac/model/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83f65e7c627899f02af04489897941cdf60c770b Binary files /dev/null and b/boson_multimodal/audio_processing/descriptaudiocodec/dac/model/__pycache__/base.cpython-311.pyc differ diff --git a/boson_multimodal/audio_processing/descriptaudiocodec/dac/model/__pycache__/dac.cpython-311.pyc b/boson_multimodal/audio_processing/descriptaudiocodec/dac/model/__pycache__/dac.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..064f2a719cfc07a640be48988b18633ab9c39a6e Binary files /dev/null and b/boson_multimodal/audio_processing/descriptaudiocodec/dac/model/__pycache__/dac.cpython-311.pyc differ diff --git a/boson_multimodal/audio_processing/descriptaudiocodec/dac/model/base.py b/boson_multimodal/audio_processing/descriptaudiocodec/dac/model/base.py new file mode 100644 index 0000000000000000000000000000000000000000..08e39a2d9016c6ddc2491d0e2644b80c8efe3986 --- /dev/null +++ b/boson_multimodal/audio_processing/descriptaudiocodec/dac/model/base.py @@ -0,0 +1,286 @@ +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Union + +import numpy as np +import torch +import tqdm +from audiotools import AudioSignal +from torch import nn + +SUPPORTED_VERSIONS = ["1.0.0"] + + +@dataclass +class DACFile: + codes: torch.Tensor + + # Metadata + chunk_length: int + original_length: int + input_db: float + channels: int + sample_rate: int + padding: bool + dac_version: str + + def save(self, path): + artifacts = { + "codes": self.codes.numpy().astype(np.uint16), + "metadata": { + "input_db": self.input_db.numpy().astype(np.float32), + "original_length": self.original_length, + "sample_rate": self.sample_rate, + "chunk_length": self.chunk_length, + "channels": self.channels, + "padding": self.padding, + "dac_version": SUPPORTED_VERSIONS[-1], + }, + } + path = Path(path).with_suffix(".dac") + with open(path, "wb") as f: + np.save(f, artifacts) + return path + + @classmethod + def load(cls, path): + artifacts = np.load(path, allow_pickle=True)[()] + codes = torch.from_numpy(artifacts["codes"].astype(int)) + if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: + raise RuntimeError(f"Given file {path} can't be loaded with this version of descript-audio-codec.") + return cls(codes=codes, **artifacts["metadata"]) + + +class CodecMixin: + @property + def padding(self): + if not hasattr(self, "_padding"): + self._padding = True + return self._padding + + @padding.setter + def padding(self, value): + assert isinstance(value, bool) + + layers = [l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))] + + for layer in layers: + if value: + if hasattr(layer, "original_padding"): + layer.padding = layer.original_padding + else: + layer.original_padding = layer.padding + layer.padding = tuple(0 for _ in range(len(layer.padding))) + + self._padding = value + + def get_delay(self): + # Any number works here, delay is invariant to input length + l_out = self.get_output_length(0) + L = l_out + + layers = [] + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + layers.append(layer) + + for layer in reversed(layers): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.ConvTranspose1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.Conv1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.ceil(L) + + l_in = L + + return (l_in - l_out) // 2 + + def get_output_length(self, input_length): + L = input_length + # Calculate output length + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.Conv1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.ConvTranspose1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.floor(L) + return L + + @torch.no_grad() + def compress( + self, + audio_path_or_signal: Union[str, Path, AudioSignal], + win_duration: float = 1.0, + verbose: bool = False, + normalize_db: float = -16, + n_quantizers: int = None, + ) -> DACFile: + """Processes an audio signal from a file or AudioSignal object into + discrete codes. This function processes the signal in short windows, + using constant GPU memory. + + Parameters + ---------- + audio_path_or_signal : Union[str, Path, AudioSignal] + audio signal to reconstruct + win_duration : float, optional + window duration in seconds, by default 5.0 + verbose : bool, optional + by default False + normalize_db : float, optional + normalize db, by default -16 + + Returns + ------- + DACFile + Object containing compressed codes and metadata + required for decompression + """ + audio_signal = audio_path_or_signal + if isinstance(audio_signal, (str, Path)): + audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) + + self.eval() + original_padding = self.padding + original_device = audio_signal.device + + audio_signal = audio_signal.clone() + original_sr = audio_signal.sample_rate + + resample_fn = audio_signal.resample + loudness_fn = audio_signal.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if audio_signal.signal_duration >= 10 * 60 * 60: + resample_fn = audio_signal.ffmpeg_resample + loudness_fn = audio_signal.ffmpeg_loudness + + original_length = audio_signal.signal_length + resample_fn(self.sample_rate) + input_db = loudness_fn() + + if normalize_db is not None: + audio_signal.normalize(normalize_db) + audio_signal.ensure_max_of_audio() + + nb, nac, nt = audio_signal.audio_data.shape + audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) + win_duration = audio_signal.signal_duration if win_duration is None else win_duration + + if audio_signal.signal_duration <= win_duration: + # Unchunked compression (used if signal length < win duration) + self.padding = True + n_samples = nt + hop = nt + else: + # Chunked inference + self.padding = False + # Zero-pad signal on either side by the delay + audio_signal.zero_pad(self.delay, self.delay) + n_samples = int(win_duration * self.sample_rate) + # Round n_samples to nearest hop length multiple + n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) + hop = self.get_output_length(n_samples) + + codes = [] + range_fn = range if not verbose else tqdm.trange + + for i in range_fn(0, nt, hop): + x = audio_signal[..., i : i + n_samples] + x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) + + audio_data = x.audio_data.to(self.device) + audio_data = self.preprocess(audio_data, self.sample_rate) + _, c, _, _, _ = self.encode(audio_data, n_quantizers) + codes.append(c.to(original_device)) + chunk_length = c.shape[-1] + + codes = torch.cat(codes, dim=-1) + + dac_file = DACFile( + codes=codes, + chunk_length=chunk_length, + original_length=original_length, + input_db=input_db, + channels=nac, + sample_rate=original_sr, + padding=self.padding, + dac_version=SUPPORTED_VERSIONS[-1], + ) + + if n_quantizers is not None: + codes = codes[:, :n_quantizers, :] + + self.padding = original_padding + return dac_file + + @torch.no_grad() + def decompress( + self, + obj: Union[str, Path, DACFile], + verbose: bool = False, + ) -> AudioSignal: + """Reconstruct audio from a given .dac file + + Parameters + ---------- + obj : Union[str, Path, DACFile] + .dac file location or corresponding DACFile object. + verbose : bool, optional + Prints progress if True, by default False + + Returns + ------- + AudioSignal + Object with the reconstructed audio + """ + self.eval() + if isinstance(obj, (str, Path)): + obj = DACFile.load(obj) + + original_padding = self.padding + self.padding = obj.padding + + range_fn = range if not verbose else tqdm.trange + codes = obj.codes + original_device = codes.device + chunk_length = obj.chunk_length + recons = [] + + for i in range_fn(0, codes.shape[-1], chunk_length): + c = codes[..., i : i + chunk_length].to(self.device) + z = self.quantizer.from_codes(c)[0] + r = self.decode(z) + recons.append(r.to(original_device)) + + recons = torch.cat(recons, dim=-1) + recons = AudioSignal(recons, self.sample_rate) + + resample_fn = recons.resample + loudness_fn = recons.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if recons.signal_duration >= 10 * 60 * 60: + resample_fn = recons.ffmpeg_resample + loudness_fn = recons.ffmpeg_loudness + + recons.normalize(obj.input_db) + resample_fn(obj.sample_rate) + recons = recons[..., : obj.original_length] + loudness_fn() + recons.audio_data = recons.audio_data.reshape(-1, obj.channels, obj.original_length) + + self.padding = original_padding + return recons diff --git a/boson_multimodal/audio_processing/descriptaudiocodec/dac/model/dac.py b/boson_multimodal/audio_processing/descriptaudiocodec/dac/model/dac.py new file mode 100644 index 0000000000000000000000000000000000000000..efaed1c25eee7cbb55a96b4f12376b9d26d4a685 --- /dev/null +++ b/boson_multimodal/audio_processing/descriptaudiocodec/dac/model/dac.py @@ -0,0 +1,365 @@ +import math +from typing import List +from typing import Union + +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.ml import BaseModel +from torch import nn + +from .base import CodecMixin +from dac.nn.layers import Snake1d +from dac.nn.layers import WNConv1d +from dac.nn.layers import WNConvTranspose1d +from dac.nn.quantize import ResidualVectorQuantize + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1), + ResidualUnit(dim // 2, dilation=3), + ResidualUnit(dim // 2, dilation=9), + Snake1d(dim // 2), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 256, + ): + super().__init__() + # Create first convolution + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + WNConv1d(d_model, d_latent, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, out_pad=0): + super().__init__() + self.block = nn.Sequential( + Snake1d(input_dim), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + output_padding=stride % 2, # out_pad, + ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), + ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + ): + super().__init__() + + # Add first conv layer + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + if i == 1: + out_pad = 1 + else: + out_pad = 0 + layers += [DecoderBlock(input_dim, output_dim, stride, out_pad)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3), + # nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class DAC(BaseModel, CodecMixin): + def __init__( + self, + encoder_dim: int = 64, + encoder_rates: List[int] = [2, 4, 8, 8], + latent_dim: int = None, + decoder_dim: int = 1536, + decoder_rates: List[int] = [8, 8, 4, 2], + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: bool = False, + sample_rate: int = 44100, + ): + super().__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) + + self.n_codebooks = n_codebooks + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.quantizer = ResidualVectorQuantize( + input_dim=latent_dim, + n_codebooks=n_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + self.decoder = Decoder( + latent_dim, + decoder_dim, + decoder_rates, + ) + self.sample_rate = sample_rate + self.apply(init_weights) + + self.delay = self.get_delay() + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + + length = audio_data.shape[-1] + right_pad = math.ceil(length / self.hop_length) * self.hop_length - length + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + + return audio_data + + def encode( + self, + audio_data: torch.Tensor, + n_quantizers: int = None, + ): + """Encode given audio data and return quantized latent codes + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + n_quantizers : int, optional + Number of quantizers to use, by default None + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + """ + z = self.encoder(audio_data) + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers) + return z, codes, latents, commitment_loss, codebook_loss + + def decode(self, z: torch.Tensor): + """Decode given latent codes and return audio data + + Parameters + ---------- + z : Tensor[B x D x T] + Quantized continuous representation of input + length : int, optional + Number of samples in output audio, by default None + + Returns + ------- + dict + A dictionary with the following keys: + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + return self.decoder(z) + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = None, + n_quantizers: int = None, + ): + """Model forward pass + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + sample_rate : int, optional + Sample rate of audio data in Hz, by default None + If None, defaults to `self.sample_rate` + n_quantizers : int, optional + Number of quantizers to use, by default None. + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers) + + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + +if __name__ == "__main__": + import numpy as np + from functools import partial + + model = DAC().to("cpu") + + for n, m in model.named_modules(): + o = m.extra_repr() + p = sum([np.prod(p.size()) for p in m.parameters()]) + fn = lambda o, p: o + f" {p / 1e6:<.3f}M params." + setattr(m, "extra_repr", partial(fn, o=o, p=p)) + print(model) + print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) + + length = 88200 * 2 + x = torch.randn(1, 1, length).to(model.device) + x.requires_grad_(True) + x.retain_grad() + + # Make a forward pass + out = model(x)["audio"] + print("Input shape:", x.shape) + print("Output shape:", out.shape) + + # Create gradient variable + grad = torch.zeros_like(out) + grad[:, :, grad.shape[-1] // 2] = 1 + + # Make a backward pass + out.backward(grad) + + # Check non-zero values + gradmap = x.grad.squeeze(0) + gradmap = (gradmap != 0).sum(0) # sum across features + rf = (gradmap != 0).sum() + + print(f"Receptive field: {rf.item()}") + + x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) + model.decompress(model.compress(x, verbose=True), verbose=True) diff --git a/boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/layers.py b/boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..44fbc2929715e11d843b24195d7042a528969a94 --- /dev/null +++ b/boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/layers.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/quantize.py b/boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..8861224cbb49813816dc41b63059faa13d246cc7 --- /dev/null +++ b/boson_multimodal/audio_processing/descriptaudiocodec/dac/nn/quantize.py @@ -0,0 +1,251 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from dac.nn.layers import WNConv1d + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [VectorQuantize(input_dim, codebook_size, codebook_dim[i]) for i in range(n_codebooks)] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual) + + # Create mask to apply quantizer dropout + mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) diff --git a/boson_multimodal/audio_processing/higgs_audio_tokenizer.py b/boson_multimodal/audio_processing/higgs_audio_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..517cb6ddd089e98eb9d89605265d6ca00dafefea --- /dev/null +++ b/boson_multimodal/audio_processing/higgs_audio_tokenizer.py @@ -0,0 +1,327 @@ +# Based on code from: https://github.com/zhenye234/xcodec +# Licensed under MIT License +# Modifications by BosonAI + +import math +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Union, Sequence +import numpy as np +from transformers import AutoModel +import torchaudio +import json +import librosa +from huggingface_hub import snapshot_download + +from vector_quantize_pytorch import ResidualFSQ +from .descriptaudiocodec.dac.model import dac as dac2 +from .quantization.vq import ResidualVectorQuantizer +from .semantic_module import Encoder, Decoder + + +class EncodedResult: + def __init__(self, audio_codes): + self.audio_codes = audio_codes + + +class HiggsAudioFeatureExtractor(nn.Module): + def __init__(self, sampling_rate=16000): + super().__init__() + self.sampling_rate = sampling_rate + + def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"): + # Convert from librosa to torch + audio_signal = torch.tensor(raw_audio) + audio_signal = audio_signal.unsqueeze(0) + if len(audio_signal.shape) < 3: + audio_signal = audio_signal.unsqueeze(0) + return {"input_values": audio_signal} + + +class HiggsAudioTokenizer(nn.Module): + def __init__( + self, + n_filters: int = 32, + D: int = 128, + target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6], + ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320 + sample_rate: int = 16000, + bins: int = 1024, + n_q: int = 8, + codebook_dim: int = None, + normalize: bool = False, + causal: bool = False, + semantic_techer: str = "hubert_base_general", + last_layer_semantic: bool = True, + merge_mode: str = "concat", + downsample_mode: str = "step_down", + semantic_mode: str = "classic", + vq_scale: int = 1, + semantic_sample_rate: int = None, + device: str = "cuda", + ): + super().__init__() + self.hop_length = np.prod(ratios) + self.semantic_techer = semantic_techer + + self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz + + self.target_bandwidths = target_bandwidths + self.n_q = n_q + self.sample_rate = sample_rate + self.encoder = dac2.Encoder(64, ratios, D) + + self.decoder_2 = dac2.Decoder(D, 1024, ratios) + self.last_layer_semantic = last_layer_semantic + self.device = device + if semantic_techer == "hubert_base": + self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960") + self.semantic_sample_rate = 16000 + self.semantic_dim = 768 + self.encoder_semantic_dim = 768 + + elif semantic_techer == "wavlm_base_plus": + self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus") + self.semantic_sample_rate = 16000 + self.semantic_dim = 768 + self.encoder_semantic_dim = 768 + + elif semantic_techer == "hubert_base_general": + self.semantic_model = AutoModel.from_pretrained("bosonai/hubert_base", trust_remote_code=True) + self.semantic_sample_rate = 16000 + self.semantic_dim = 768 + self.encoder_semantic_dim = 768 + + # Overwrite semantic model sr to ensure semantic_downsample_factor is an integer + if semantic_sample_rate is not None: + self.semantic_sample_rate = semantic_sample_rate + + self.semantic_model.eval() + + # make the semantic model parameters do not need gradient + for param in self.semantic_model.parameters(): + param.requires_grad = False + + self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320) + + self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale) + self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim) + self.decoder_semantic = Decoder( + code_dim=self.encoder_semantic_dim, output_channels=self.semantic_dim, decode_channels=self.semantic_dim + ) + + # out_D=D+768 + if isinstance(bins, int): # RVQ + self.quantizer = ResidualVectorQuantizer( + dimension=self.quantizer_dim, codebook_dim=codebook_dim, n_q=n_q, bins=bins + ) + self.quantizer_type = "RVQ" + else: # RFSQ + self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q) + self.quantizer_type = "RFSQ" + + self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim) + self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim) + self.fc_post2 = nn.Linear(self.quantizer_dim, D) + + self.downsample_mode = downsample_mode + if downsample_mode == "avg": + self.semantic_pooling = nn.AvgPool1d( + kernel_size=self.semantic_downsample_factor, stride=self.semantic_downsample_factor + ) + + self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate) + + @property + def tps(self): + return self.frame_rate + + @property + def sampling_rate(self): + return self.sample_rate + + @property + def num_codebooks(self): + return self.n_q + + @property + def codebook_size(self): + return self.quantizer_dim + + def get_last_layer(self): + return self.decoder.layers[-1].weight + + def calculate_rec_loss(self, rec, target): + target = target / target.norm(dim=-1, keepdim=True) + rec = rec / rec.norm(dim=-1, keepdim=True) + rec_loss = (1 - (target * rec).sum(-1)).mean() + + return rec_loss + + @torch.no_grad() + def get_regress_target(self, x): + x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate) + + if ( + self.semantic_techer == "hubert_base" + or self.semantic_techer == "hubert_base_general" + or self.semantic_techer == "wavlm_base_plus" + ): + x = x[:, 0, :] + x = F.pad(x, (160, 160)) + target = self.semantic_model(x, output_hidden_states=True).hidden_states + target = torch.stack(target, dim=1) # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2) + + # average for all layers + target = target.mean(1) + # target = target[9] + # if self.hop_length > 320: + # target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2) + + elif self.semantic_techer == "w2v_bert2": + target = self.semantic_model(x) + + elif self.semantic_techer.startswith("whisper"): + if self.last_layer_semantic: + target = self.semantic_model(x, avg_layers=False) + else: + target = self.semantic_model(x, avg_layers=True) + + elif self.semantic_techer.startswith("mert_music"): + if self.last_layer_semantic: + target = self.semantic_model(x, avg_layers=False) + else: + target = self.semantic_model(x, avg_layers=True) + + elif self.semantic_techer.startswith("qwen_audio_omni"): + target = self.semantic_model(x) + + if self.downsample_mode == "step_down": + if self.semantic_downsample_factor > 1: + target = target[:, :: self.semantic_downsample_factor, :] + + elif self.downsample_mode == "avg": + target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2) + return target + + def forward(self, x: torch.Tensor, bw: int): + e_semantic_input = self.get_regress_target(x).detach() + + e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) + e_acoustic = self.encoder(x) + + e = torch.cat([e_acoustic, e_semantic], dim=1) + + e = self.fc_prior(e.transpose(1, 2)) + + if self.quantizer_type == "RVQ": + e = e.transpose(1, 2) + quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) + quantized = quantized.transpose(1, 2) + else: + quantized, codes = self.quantizer(e) + commit_loss = torch.tensor(0.0) + + quantized_semantic = self.fc_post1(quantized).transpose(1, 2) + quantized_acoustic = self.fc_post2(quantized).transpose(1, 2) + + o = self.decoder_2(quantized_acoustic) + + o_semantic = self.decoder_semantic(quantized_semantic) + semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(), o_semantic) + + return o, commit_loss, semantic_recon_loss, None + + def encode(self, audio_path_or_wv, sr=None, loudness_normalize=False, loudness_threshold=-23.0): + if isinstance(audio_path_or_wv, str): + wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None) + else: + wv = audio_path_or_wv + assert sr is not None + if loudness_normalize: + import pyloudnorm as pyln + + meter = pyln.Meter(sr) + l = meter.integrated_loudness(wv) + wv = pyln.normalize.loudness(wv, l, loudness_threshold) + if sr != self.sampling_rate: + wv = librosa.resample(wv, orig_sr=sr, target_sr=self.sampling_rate) + if self.audio_tokenizer_feature_extractor is not None: + inputs = self.audio_tokenizer_feature_extractor( + raw_audio=wv, sampling_rate=self.audio_tokenizer_feature_extractor.sampling_rate, return_tensors="pt" + ) + input_values = inputs["input_values"].to(self.device) + else: + input_values = torch.from_numpy(wv).float().unsqueeze(0) + with torch.no_grad(): + encoder_outputs = self._xcodec_encode(input_values) + vq_code = encoder_outputs.audio_codes[0] + return vq_code + + def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor: + bw = target_bw + + e_semantic_input = self.get_regress_target(x).detach() + + e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) + e_acoustic = self.encoder(x) + + if e_acoustic.shape[2] != e_semantic.shape[2]: + pad_size = 160 * self.semantic_downsample_factor + e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0)) + + if e_acoustic.shape[2] != e_semantic.shape[2]: + if e_acoustic.shape[2] > e_semantic.shape[2]: + e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]] + else: + e_semantic = e_semantic[:, :, : e_acoustic.shape[2]] + + e = torch.cat([e_acoustic, e_semantic], dim=1) + + e = self.fc_prior(e.transpose(1, 2)) + + if self.quantizer_type == "RVQ": + e = e.transpose(1, 2) + quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) + codes = codes.permute(1, 0, 2) + else: + quantized, codes = self.quantizer(e) + codes = codes.permute(0, 2, 1) + + # return codes + return EncodedResult(codes) + + def decode(self, vq_code: torch.Tensor) -> torch.Tensor: + if self.quantizer_type == "RVQ": + vq_code = vq_code.permute(1, 0, 2) + quantized = self.quantizer.decode(vq_code) + quantized = quantized.transpose(1, 2) + else: + vq_code = vq_code.permute(0, 2, 1) + quantized = self.quantizer.get_output_from_indices(vq_code) + quantized_acoustic = self.fc_post2(quantized).transpose(1, 2) + + o = self.decoder_2(quantized_acoustic) + return o.cpu().numpy() + + +def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"): + is_local = os.path.exists(tokenizer_name_or_path) + if not is_local: + tokenizer_path = snapshot_download(tokenizer_name_or_path) + else: + tokenizer_path = tokenizer_name_or_path + config_path = os.path.join(tokenizer_path, "config.json") + model_path = os.path.join(tokenizer_path, "model.pth") + config = json.load(open(config_path)) + model = HiggsAudioTokenizer( + **config, + device=device, + ) + parameter_dict = torch.load(model_path, map_location=device) + model.load_state_dict(parameter_dict, strict=False) + model.to(device) + model.eval() + return model diff --git a/boson_multimodal/audio_processing/quantization/__init__.py b/boson_multimodal/audio_processing/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfabe52b8cb6f260cdda6137b34df2f4736bd02f --- /dev/null +++ b/boson_multimodal/audio_processing/quantization/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from .vq import QuantizedResult, ResidualVectorQuantizer diff --git a/boson_multimodal/audio_processing/quantization/__pycache__/__init__.cpython-311.pyc b/boson_multimodal/audio_processing/quantization/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e00345232ec121d460272b6a7bd42b8dd8d4aae Binary files /dev/null and b/boson_multimodal/audio_processing/quantization/__pycache__/__init__.cpython-311.pyc differ diff --git a/boson_multimodal/audio_processing/quantization/__pycache__/core_vq_lsx_version.cpython-311.pyc b/boson_multimodal/audio_processing/quantization/__pycache__/core_vq_lsx_version.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3594f12c1880eb80073bcc2511131c31a843113c Binary files /dev/null and b/boson_multimodal/audio_processing/quantization/__pycache__/core_vq_lsx_version.cpython-311.pyc differ diff --git a/boson_multimodal/audio_processing/quantization/__pycache__/ddp_utils.cpython-311.pyc b/boson_multimodal/audio_processing/quantization/__pycache__/ddp_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f894a5715e925f59daaf19e6dc928fd43e77f62b Binary files /dev/null and b/boson_multimodal/audio_processing/quantization/__pycache__/ddp_utils.cpython-311.pyc differ diff --git a/boson_multimodal/audio_processing/quantization/__pycache__/distrib.cpython-311.pyc b/boson_multimodal/audio_processing/quantization/__pycache__/distrib.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b96a3b9951cca55981c2bb43c2c55be4212808d3 Binary files /dev/null and b/boson_multimodal/audio_processing/quantization/__pycache__/distrib.cpython-311.pyc differ diff --git a/boson_multimodal/audio_processing/quantization/__pycache__/vq.cpython-311.pyc b/boson_multimodal/audio_processing/quantization/__pycache__/vq.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64365100f9e07de9dd70e8649d73345141e7ef5a Binary files /dev/null and b/boson_multimodal/audio_processing/quantization/__pycache__/vq.cpython-311.pyc differ diff --git a/boson_multimodal/audio_processing/quantization/ac.py b/boson_multimodal/audio_processing/quantization/ac.py new file mode 100644 index 0000000000000000000000000000000000000000..318d993b610c78a46f3d605e7e2ccbdde4b915ec --- /dev/null +++ b/boson_multimodal/audio_processing/quantization/ac.py @@ -0,0 +1,292 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Arithmetic coder.""" + +import io +import math +import random +import typing as tp +import torch + +from ..binary import BitPacker, BitUnpacker + + +def build_stable_quantized_cdf( + pdf: torch.Tensor, total_range_bits: int, roundoff: float = 1e-8, min_range: int = 2, check: bool = True +) -> torch.Tensor: + """Turn the given PDF into a quantized CDF that splits + [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional + to the PDF. + + Args: + pdf (torch.Tensor): probability distribution, shape should be `[N]`. + total_range_bits (int): see `ArithmeticCoder`, the typical range we expect + during the coding process is `[0, 2 ** total_range_bits - 1]`. + roundoff (float): will round the pdf up to that level to remove difference coming + from e.g. evaluating the Language Model on different architectures. + min_range (int): minimum range width. Should always be at least 2 for numerical + stability. Use this to avoid pathological behavior is a value + that is expected to be rare actually happens in real life. + check (bool): if True, checks that nothing bad happened, can be deactivated for speed. + """ + pdf = pdf.detach() + if roundoff: + pdf = (pdf / roundoff).floor() * roundoff + # interpolate with uniform distribution to achieve desired minimum probability. + total_range = 2**total_range_bits + cardinality = len(pdf) + alpha = min_range * cardinality / total_range + assert alpha <= 1, "you must reduce min_range" + ranges = (((1 - alpha) * total_range) * pdf).floor().long() + ranges += min_range + quantized_cdf = torch.cumsum(ranges, dim=-1) + if min_range < 2: + raise ValueError("min_range must be at least 2.") + if check: + assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1] + if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range: + raise ValueError("You must increase your total_range_bits.") + return quantized_cdf + + +class ArithmeticCoder: + """ArithmeticCoder, + Let us take a distribution `p` over `N` symbols, and assume we have a stream + of random variables `s_t` sampled from `p`. Let us assume that we have a budget + of `B` bits that we can afford to write on device. There are `2**B` possible numbers, + corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single + sequence `(s_t)` by doing the following: + + 1) Initialize the current range to` [0 ** 2 B - 1]`. + 2) For each time step t, split the current range into contiguous chunks, + one for each possible outcome, with size roughly proportional to `p`. + For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks + would be `{[0, 2], [3, 3]}`. + 3) Select the chunk corresponding to `s_t`, and replace the current range with this. + 4) When done encoding all the values, just select any value remaining in the range. + + You will notice that this procedure can fail: for instance if at any point in time + the range is smaller than `N`, then we can no longer assign a non-empty chunk to each + possible outcome. Intuitively, the more likely a value is, the less the range width + will reduce, and the longer we can go on encoding values. This makes sense: for any efficient + coding scheme, likely outcomes would take less bits, and more of them can be coded + with a fixed budget. + + In practice, we do not know `B` ahead of time, but we have a way to inject new bits + when the current range decreases below a given limit (given by `total_range_bits`), without + having to redo all the computations. If we encode mostly likely values, we will seldom + need to inject new bits, but a single rare value can deplete our stock of entropy! + + In this explanation, we assumed that the distribution `p` was constant. In fact, the present + code works for any sequence `(p_t)` possibly different for each timestep. + We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller + the KL between the true distribution and `p_t`, the most efficient the coding will be. + + Args: + fo (IO[bytes]): file-like object to which the bytes will be written to. + total_range_bits (int): the range `M` described above is `2 ** total_range_bits. + Any time the current range width fall under this limit, new bits will + be injected to rescale the initial range. + """ + + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + assert total_range_bits <= 30 + self.total_range_bits = total_range_bits + self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. + self.low: int = 0 + self.high: int = 0 + self.max_bit: int = -1 + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + + @property + def delta(self) -> int: + """Return the current range width.""" + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # If self.low and self.high start with the sames bits, + # those won't change anymore as we always just increase the range + # by powers of 2, and we can flush them out to the bit stream. + assert self.high >= self.low, (self.low, self.high) + assert self.high < 2 ** (self.max_bit + 1) + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= b1 << self.max_bit + self.high -= b1 << self.max_bit + assert self.high >= self.low, (self.high, self.low, self.max_bit) + assert self.low >= 0 + self.max_bit -= 1 + self.packer.push(b1) + else: + break + + def push(self, symbol: int, quantized_cdf: torch.Tensor): + """Push the given symbol on the stream, flushing out bits + if possible. + + Args: + symbol (int): symbol to encode with the AC. + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. + """ + while self.delta < 2**self.total_range_bits: + self.low *= 2 + self.high = self.high * 2 + 1 + self.max_bit += 1 + + range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() + range_high = quantized_cdf[symbol].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits)))) + assert self.low <= self.high + self.high = self.low + effective_high + self.low = self.low + effective_low + assert self.low <= self.high, (effective_low, effective_high, range_low, range_high) + self._dbg.append((self.low, self.high)) + self._dbg2.append((self.low, self.high)) + outs = self._flush_common_prefix() + assert self.low <= self.high + assert self.max_bit >= -1 + assert self.max_bit <= 61, self.max_bit + return outs + + def flush(self): + """Flush the remaining information to the stream.""" + while self.max_bit >= 0: + b1 = (self.low >> self.max_bit) & 1 + self.packer.push(b1) + self.max_bit -= 1 + self.packer.flush() + + +class ArithmeticDecoder: + """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. + + Note that this must be called with **exactly** the same parameters and sequence + of quantized cdf as the arithmetic encoder or the wrong values will be decoded. + + If the AC encoder current range is [L, H], with `L` and `H` having the some common + prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. + For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside + `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained + for a specific sequence of symbols and a binary-search allows us to decode those symbols. + At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, + and we will need to read new bits from the stream and repeat the process. + + """ + + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + self.total_range_bits = total_range_bits + self.low: int = 0 + self.high: int = 0 + self.current: int = 0 + self.max_bit: int = -1 + self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. + # Following is for debugging + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + self._last: tp.Any = None + + @property + def delta(self) -> int: + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # Given the current range [L, H], if both have a common prefix, + # we know we can remove it from our representation to avoid handling large numbers. + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= b1 << self.max_bit + self.high -= b1 << self.max_bit + self.current -= b1 << self.max_bit + assert self.high >= self.low + assert self.low >= 0 + self.max_bit -= 1 + else: + break + + def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: + """Pull a symbol, reading as many bits from the stream as required. + This returns `None` when the stream has been exhausted. + + Args: + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. This must be **exatly** + the same cdf as the one used at encoding time. + """ + while self.delta < 2**self.total_range_bits: + bit = self.unpacker.pull() + if bit is None: + return None + self.low *= 2 + self.high = self.high * 2 + 1 + self.current = self.current * 2 + bit + self.max_bit += 1 + + def bin_search(low_idx: int, high_idx: int): + # Binary search is not just for coding interviews :) + if high_idx < low_idx: + raise RuntimeError("Binary search failed") + mid = (low_idx + high_idx) // 2 + range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 + range_high = quantized_cdf[mid].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits)))) + low = effective_low + self.low + high = effective_high + self.low + if self.current >= low: + if self.current <= high: + return (mid, low, high, self.current) + else: + return bin_search(mid + 1, high_idx) + else: + return bin_search(low_idx, mid - 1) + + self._last = (self.low, self.high, self.current, self.max_bit) + sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) + self._dbg.append((self.low, self.high, self.current)) + self._flush_common_prefix() + self._dbg2.append((self.low, self.high, self.current)) + + return sym + + +def test(): + torch.manual_seed(1234) + random.seed(1234) + for _ in range(4): + pdfs = [] + cardinality = random.randrange(4000) + steps = random.randrange(100, 500) + fo = io.BytesIO() + encoder = ArithmeticCoder(fo) + symbols = [] + for step in range(steps): + pdf = torch.softmax(torch.randn(cardinality), dim=0) + pdfs.append(pdf) + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + symbol = torch.multinomial(pdf, 1).item() + symbols.append(symbol) + encoder.push(symbol, q_cdf) + encoder.flush() + + fo.seek(0) + decoder = ArithmeticDecoder(fo) + for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + decoded_symbol = decoder.pull(q_cdf) + assert decoded_symbol == symbol, idx + assert decoder.pull(torch.zeros(1)) is None + + +if __name__ == "__main__": + test() diff --git a/boson_multimodal/audio_processing/quantization/core_vq.py b/boson_multimodal/audio_processing/quantization/core_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..ad368a980582bcbba901f28b568c4bfb8f4099e6 --- /dev/null +++ b/boson_multimodal/audio_processing/quantization/core_vq.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" + +import typing as tp + +from einops import rearrange, repeat +import torch +from torch import nn +import torch.nn.functional as F + +from xcodec.quantization.distrib import broadcast_tensors, rank + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True)) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) # get embedding based on index + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) # get index based on Euclidean distance + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1.0, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + device = x.device + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)]) + + def forward(self, x, n_q: tp.Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + + for layer in self.layers[:n_q]: + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/boson_multimodal/audio_processing/quantization/core_vq_lsx_version.py b/boson_multimodal/audio_processing/quantization/core_vq_lsx_version.py new file mode 100644 index 0000000000000000000000000000000000000000..d9add3f3016093a744804ae089d525c45d24ad16 --- /dev/null +++ b/boson_multimodal/audio_processing/quantization/core_vq_lsx_version.py @@ -0,0 +1,425 @@ +# Copyright (c) +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# This implementation is inspired from +# https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py and +# https://github.com/clementchadebec/benchmark_VAE/blob/dfa0dcf6c79172df5d27769c09c860c42008baaa/src/pythae/models/vq_vae/vq_vae_utils.py#L81 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" + +import typing as tp + +from einops import rearrange +import torch +from torch import nn +import torch.nn.functional as F +import torch.distributed as dist + +from .distrib import broadcast_tensors, is_distributed +from .ddp_utils import SyncFunction + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10, frames_to_use: int = 10_000, batch_size: int = 64): + """ + Memory-efficient K-means clustering. + Args: + samples (tensor): shape [N, D] + num_clusters (int): number of centroids. + num_iters (int): number of iterations. + frames_to_use (int): subsample size from total samples. + batch_size (int): batch size used in distance computation. + Returns: + means: [num_clusters, D] + bins: [num_clusters] (number of points per cluster) + """ + N, D = samples.shape + dtype, device = samples.dtype, samples.device + + if frames_to_use < N: + indices = torch.randperm(N, device=device)[:frames_to_use] + samples = samples[indices] + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + # Store cluster assignments + all_assignments = [] + + for i in range(0, samples.shape[0], batch_size): + batch = samples[i : i + batch_size] # [B, D] + dists = torch.cdist(batch, means, p=2) # [B, C] + assignments = dists.argmin(dim=1) # [B] + all_assignments.append(assignments) + + buckets = torch.cat(all_assignments, dim=0) # [N] + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + # Compute new means + new_means = torch.zeros_like(means) + for i in range(num_clusters): + mask = buckets == i + if mask.any(): + new_means[i] = samples[mask].mean(dim=0) + + means = torch.where(zero_mask[:, None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + # Flag variable to indicate whether the codebook is initialized + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + # Runing EMA cluster size/count: N_i^t in eq. (6) in vqvae paper + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + # Codebook + self.register_buffer("embed", embed) + # EMA codebook: eq. (7) in vqvae paper + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + """Initialize codebook. + Args: + data (tensor): [B * T, D]. + """ + if self.inited: + return + + ## NOTE (snippet added by Songxiang Liu): gather data from all gpus + if dist.is_available() and dist.is_initialized(): + # [B * T * world_size, D] + data = SyncFunction.apply(data) + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + ## NOTE (snippet added by Songxiang Liu): gather data from all gpus + if is_distributed(): + # [B * T * world_size, D] + batch_samples = SyncFunction.apply(batch_samples) + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True)) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) # [B, T, D] -> [B*T, D] + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + # shape: [B, T, D] + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) # [B, T, D] -> [B*T, D] + + # Initialize codebook + self.init_embed_(x) + + embed_ind = self.quantize(x) # [B*T,] + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) # [B*T, cb-size] + embed_ind = self.postprocess_emb(embed_ind, shape) # [B, T] + quantize = self.dequantize(embed_ind) # [B, T, D] + + if self.training: + ### Update codebook by EMA + embed_onehot_sum = embed_onehot.sum(0) # [cb-size,] + embed_sum = x.t() @ embed_onehot # [D, cb-size] + if is_distributed(): + dist.all_reduce(embed_onehot_sum) + dist.all_reduce(embed_sum) + # Update ema cluster count N_i^t, eq. (6) in vqvae paper + self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay) + # Update ema embed: eq. (7) in vqvae paper + self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay) + # apply laplace smoothing + n = self.cluster_size.sum() + cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n + # Update ema embed: eq. (8) in vqvae paper + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1.0, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + device = x.device + x = x.transpose(1, 2).contiguous() # [b d n] -> [b n d] + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + quantize = quantize.transpose(1, 2).contiguous() # [b n d] -> [b d n] + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)]) + + def forward(self, x, n_q: tp.Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + + for layer in self.layers[:n_q]: + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/boson_multimodal/audio_processing/quantization/ddp_utils.py b/boson_multimodal/audio_processing/quantization/ddp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..990dca85fd518f09e2fcd528e28e7d256f64a15a --- /dev/null +++ b/boson_multimodal/audio_processing/quantization/ddp_utils.py @@ -0,0 +1,197 @@ +import logging +import random +import subprocess +from datetime import datetime + +import numpy as np +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel +from torch.nn.parallel.distributed import _find_tensors +import torch.optim +import torch.utils.data +from packaging import version +from omegaconf import OmegaConf + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def is_logging_process(): + return not dist.is_initialized() or dist.get_rank() == 0 + + +def get_logger(cfg, name=None): + # log_file_path is used when unit testing + if is_logging_process(): + logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_config, resolve=True)) + return logging.getLogger(name) + + +# from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20 +class SyncFunction(torch.autograd.Function): + @staticmethod + # @torch.no_grad() + def forward(ctx, tensor): + ctx.batch_size = tensor.shape[0] + + gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + + torch.distributed.all_gather(gathered_tensor, tensor) + gathered_tensor = torch.cat(gathered_tensor, 0) + + return gathered_tensor + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) + + idx_from = torch.distributed.get_rank() * ctx.batch_size + idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size + return grad_input[idx_from:idx_to] + + +def get_timestamp(): + return datetime.now().strftime("%y%m%d-%H%M%S") + + +def get_commit_hash(): + message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) + return message.strip().decode("utf-8") + + +class DDP(DistributedDataParallel): + """ + Override the forward call in lightning so it goes to training and validation step respectively + """ + + def forward(self, *inputs, **kwargs): # pragma: no cover + if version.parse(torch.__version__[:6]) < version.parse("1.11"): + self._sync_params() + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + assert len(self.device_ids) == 1 + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + if torch.is_grad_enabled(): + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + from torch.nn.parallel.distributed import ( + logging, + Join, + _DDPSink, + _tree_flatten_with_rref, + _tree_unflatten_with_rref, + ) + + with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.logger.set_runtime_stats_and_log() + self.num_iterations += 1 + self.reducer.prepare_for_forward() + + # Notify the join context that this process has not joined, if + # needed + work = Join.notify_join_context(self) + if work: + self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size) + + # Calling _rebuild_buckets before forward compuation, + # It may allocate new buckets before deallocating old buckets + # inside _rebuild_buckets. To save peak memory usage, + # call _rebuild_buckets before the peak memory usage increases + # during forward computation. + # This should be called only once during whole training period. + if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): + logging.info("Reducer buckets have been rebuilt in this iteration.") + self._has_rebuilt_buckets = True + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + buffer_hook_registered = hasattr(self, "buffer_hook") + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + + if self._join_config.enable: + # Notify joined ranks whether they should sync in backwards pass or not. + self._check_global_requires_backward_grad_sync(is_joined_rank=False) + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.require_forward_param_sync = True + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters and not self.static_graph: + # Do not need to populate this for static graph. + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + self.require_forward_param_sync = False + + # TODO: DDPSink is currently enabled for unused parameter detection and + # static graph training for first iteration. + if (self.find_unused_parameters and not self.static_graph) or ( + self.static_graph and self.num_iterations == 1 + ): + state_dict = { + "static_graph": self.static_graph, + "num_iterations": self.num_iterations, + } + + output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output) + output_placeholders = [None for _ in range(len(output_tensor_list))] + # Do not touch tensors that have no grad_fn, which can cause issues + # such as https://github.com/pytorch/pytorch/issues/60733 + for i, output in enumerate(output_tensor_list): + if torch.is_tensor(output) and output.grad_fn is None: + output_placeholders[i] = output + + # When find_unused_parameters=True, makes tensors which require grad + # run through the DDPSink backward pass. When not all outputs are + # used in loss, this makes those corresponding tensors receive + # undefined gradient which the reducer then handles to ensure + # param.grad field is not touched and we don't error out. + passthrough_tensor_list = _DDPSink.apply( + self.reducer, + state_dict, + *output_tensor_list, + ) + for i in range(len(output_placeholders)): + if output_placeholders[i] is None: + output_placeholders[i] = passthrough_tensor_list[i] + + # Reconstruct output data structure. + output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref) + return output diff --git a/boson_multimodal/audio_processing/quantization/distrib.py b/boson_multimodal/audio_processing/quantization/distrib.py new file mode 100644 index 0000000000000000000000000000000000000000..cabf8f8a24eb710ab0eb83ce29ba054b7c11ccf3 --- /dev/null +++ b/boson_multimodal/audio_processing/quantization/distrib.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch distributed utilities.""" + +import typing as tp + +import torch + + +def rank(): + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + + +def world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def is_distributed(): + return world_size() > 1 + + +def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): + if is_distributed(): + return torch.distributed.all_reduce(tensor, op) + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def _check_number_of_params(params: tp.List[torch.Tensor]): + # utility function to check that the number of params in all workers is the same, + # and thus avoid a deadlock with distributed all reduce. + if not is_distributed() or not params: + return + # print('params[0].device ', params[0].device) + tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) + all_reduce(tensor) + if tensor.item() != len(params) * world_size(): + # If not all the workers have the same number, for at least one of them, + # this inequality will be verified. + raise RuntimeError( + f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one." + ) + + +def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): + """Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not is_distributed(): + return + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + _check_number_of_params(tensors) + handles = [] + for tensor in tensors: + handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() + + +def sync_buffer(buffers, average=True): + """ + Sync grad for buffers. If average is False, broadcast instead of averaging. + """ + if not is_distributed(): + return + handles = [] + for buffer in buffers: + if torch.is_floating_point(buffer.data): + if average: + handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + else: + handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True) + handles.append((buffer, handle)) + for buffer, handle in handles: + handle.wait() + if average: + buffer.data /= world_size + + +def sync_grad(params): + """ + Simpler alternative to DistributedDataParallel, that doesn't rely + on any black magic. For simple models it can also be as fast. + Just call this on your model parameters after the call to backward! + """ + if not is_distributed(): + return + handles = [] + for p in params: + if p.grad is not None: + handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + handles.append((p, handle)) + for p, handle in handles: + handle.wait() + p.grad.data /= world_size() + + +def average_metrics(metrics: tp.Dict[str, float], count=1.0): + """Average a dictionary of metrics across all workers, using the optional + `count` as unormalized weight. + """ + if not is_distributed(): + return metrics + keys, values = zip(*metrics.items()) + device = "cuda" if torch.cuda.is_available() else "cpu" + tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) + tensor *= count + all_reduce(tensor) + averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() + return dict(zip(keys, averaged)) diff --git a/boson_multimodal/audio_processing/quantization/vq.py b/boson_multimodal/audio_processing/quantization/vq.py new file mode 100644 index 0000000000000000000000000000000000000000..dac26ba2a3bc2c97d6178fa33c629f324980d5a0 --- /dev/null +++ b/boson_multimodal/audio_processing/quantization/vq.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Residual vector quantizer implementation.""" + +from dataclasses import dataclass, field +import math +import typing as tp + +import torch +from torch import nn + +# from .core_vq import ResidualVectorQuantization +from .core_vq_lsx_version import ResidualVectorQuantization + + +@dataclass +class QuantizedResult: + quantized: torch.Tensor + codes: torch.Tensor + bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. + penalty: tp.Optional[torch.Tensor] = None + metrics: dict = field(default_factory=dict) + + +class ResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer. + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dimension: int = 256, + codebook_dim: int = None, + n_q: int = 8, + bins: int = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.n_q = n_q + self.dimension = dimension + self.codebook_dim = codebook_dim + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + self.vq = ResidualVectorQuantization( + dim=self.dimension, + codebook_dim=self.codebook_dim, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + ) + + def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None): # -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (torch.Tensor): Input tensor. + sample_rate (int): Sample rate of the input tensor. + bandwidth (float): Target bandwidth. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated bandwidth and any penalty term for the loss. + """ + bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) + n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) + quantized, codes, commit_loss = self.vq(x, n_q=n_q) + bw = torch.tensor(n_q * bw_per_q).to(x) + return quantized, codes, bw, torch.mean(commit_loss) + # return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) + + def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int: + """Return n_q based on specified target bandwidth.""" + bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) + n_q = self.n_q + if bandwidth and bandwidth > 0.0: + n_q = int(max(1, math.floor(bandwidth / bw_per_q))) + return n_q + + def get_bandwidth_per_quantizer(self, sample_rate: int): + """Return bandwidth per quantizer for a given input sample rate.""" + return math.log2(self.bins) * sample_rate / 1000 + + def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizer to use + and returns indices for each quantizer. + """ + n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) + codes = self.vq.encode(x, n_q=n_q) + return codes + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation.""" + quantized = self.vq.decode(codes) + return quantized diff --git a/boson_multimodal/audio_processing/semantic_module.py b/boson_multimodal/audio_processing/semantic_module.py new file mode 100644 index 0000000000000000000000000000000000000000..f75efaf74e748cd733782133a65c2d0b6638b3a6 --- /dev/null +++ b/boson_multimodal/audio_processing/semantic_module.py @@ -0,0 +1,282 @@ +# Based on code from: https://github.com/zhenye234/xcodec +# Licensed under MIT License +# Modifications by BosonAI + +import torch +import torch.nn as nn + + +class Conv1d1x1(nn.Conv1d): + """1x1 Conv1d.""" + + def __init__(self, in_channels, out_channels, bias=True): + super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias) + + +class Conv1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = -1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + if padding < 0: + padding = (kernel_size - 1) // 2 * dilation + self.dilation = dilation + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x): + """ + Args: + x (Tensor): Float tensor variable with the shape (B, C, T). + Returns: + Tensor: Float tensor variable with the shape (B, C, T). + """ + x = self.conv(x) + return x + + +class ResidualUnit(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + dilation=1, + bias=False, + nonlinear_activation="ELU", + nonlinear_activation_params={}, + ): + super().__init__() + self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params) + self.conv1 = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + dilation=dilation, + bias=bias, + ) + self.conv2 = Conv1d1x1(out_channels, out_channels, bias) + + def forward(self, x): + y = self.conv1(self.activation(x)) + y = self.conv2(self.activation(y)) + return x + y + + +class ConvTranspose1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding=-1, + output_padding=-1, + groups=1, + bias=True, + ): + super().__init__() + if padding < 0: + padding = (stride + 1) // 2 + if output_padding < 0: + output_padding = 1 if stride % 2 else 0 + self.deconv = nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + ) + + def forward(self, x): + """ + Args: + x (Tensor): Float tensor variable with the shape (B, C, T). + Returns: + Tensor: Float tensor variable with the shape (B, C', T'). + """ + x = self.deconv(x) + return x + + +class EncoderBlock(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, stride: int, dilations=(1, 1), unit_kernel_size=3, bias=True + ): + super().__init__() + self.res_units = torch.nn.ModuleList() + for dilation in dilations: + self.res_units += [ResidualUnit(in_channels, in_channels, kernel_size=unit_kernel_size, dilation=dilation)] + self.num_res = len(self.res_units) + + self.conv = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2 + stride=stride, + bias=bias, + ) + + def forward(self, x): + for idx in range(self.num_res): + x = self.res_units[idx](x) + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + input_channels: int, + encode_channels: int, + channel_ratios=(1, 1), + strides=(1, 1), + kernel_size=3, + bias=True, + block_dilations=(1, 1), + unit_kernel_size=3, + ): + super().__init__() + assert len(channel_ratios) == len(strides) + + self.conv = Conv1d( + in_channels=input_channels, out_channels=encode_channels, kernel_size=kernel_size, stride=1, bias=False + ) + self.conv_blocks = torch.nn.ModuleList() + in_channels = encode_channels + for idx, stride in enumerate(strides): + out_channels = int(encode_channels * channel_ratios[idx]) # could be float + self.conv_blocks += [ + EncoderBlock( + in_channels, + out_channels, + stride, + dilations=block_dilations, + unit_kernel_size=unit_kernel_size, + bias=bias, + ) + ] + in_channels = out_channels + self.num_blocks = len(self.conv_blocks) + self.out_channels = out_channels + + def forward(self, x): + x = self.conv(x) + for i in range(self.num_blocks): + x = self.conv_blocks[i](x) + return x + + +class DecoderBlock(nn.Module): + """Decoder block (no up-sampling)""" + + def __init__( + self, in_channels: int, out_channels: int, stride: int, dilations=(1, 1), unit_kernel_size=3, bias=True + ): + super().__init__() + + if stride == 1: + self.conv = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape + stride=stride, + bias=bias, + ) + else: + self.conv = ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(2 * stride), + stride=stride, + bias=bias, + ) + + self.res_units = torch.nn.ModuleList() + for idx, dilation in enumerate(dilations): + self.res_units += [ + ResidualUnit(out_channels, out_channels, kernel_size=unit_kernel_size, dilation=dilation) + ] + self.num_res = len(self.res_units) + + def forward(self, x): + x = self.conv(x) + for idx in range(self.num_res): + x = self.res_units[idx](x) + return x + + +class Decoder(nn.Module): + def __init__( + self, + code_dim: int, + output_channels: int, + decode_channels: int, + channel_ratios=(1, 1), + strides=(1, 1), + kernel_size=3, + bias=True, + block_dilations=(1, 1), + unit_kernel_size=3, + ): + super().__init__() + assert len(channel_ratios) == len(strides) + + self.conv1 = Conv1d( + in_channels=code_dim, + out_channels=int(decode_channels * channel_ratios[0]), + kernel_size=kernel_size, + stride=1, + bias=False, + ) + + self.conv_blocks = torch.nn.ModuleList() + for idx, stride in enumerate(strides): + in_channels = int(decode_channels * channel_ratios[idx]) + if idx < (len(channel_ratios) - 1): + out_channels = int(decode_channels * channel_ratios[idx + 1]) + else: + out_channels = decode_channels + self.conv_blocks += [ + DecoderBlock( + in_channels, + out_channels, + stride, + dilations=block_dilations, + unit_kernel_size=unit_kernel_size, + bias=bias, + ) + ] + self.num_blocks = len(self.conv_blocks) + + self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False) + + def forward(self, z): + x = self.conv1(z) + for i in range(self.num_blocks): + x = self.conv_blocks[i](x) + x = self.conv2(x) + return x diff --git a/boson_multimodal/constants.py b/boson_multimodal/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..addf77d2512980bfbc84b389ba830924f1f2ba33 --- /dev/null +++ b/boson_multimodal/constants.py @@ -0,0 +1,3 @@ +AUDIO_IN_TOKEN = "<|AUDIO|>" +AUDIO_OUT_TOKEN = "<|AUDIO_OUT|>" +EOS_TOKEN = "<|end_of_text|>" diff --git a/boson_multimodal/data_collator/__init__.py b/boson_multimodal/data_collator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/boson_multimodal/data_collator/__pycache__/__init__.cpython-311.pyc b/boson_multimodal/data_collator/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d68a71ed7ad09d44b9ee9cd83af5a1fc6c104bd5 Binary files /dev/null and b/boson_multimodal/data_collator/__pycache__/__init__.cpython-311.pyc differ diff --git a/boson_multimodal/data_collator/__pycache__/higgs_audio_collator.cpython-311.pyc b/boson_multimodal/data_collator/__pycache__/higgs_audio_collator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1980fd6c2a1b37b649f78d82979f27c95534770 Binary files /dev/null and b/boson_multimodal/data_collator/__pycache__/higgs_audio_collator.cpython-311.pyc differ diff --git a/boson_multimodal/data_collator/higgs_audio_collator.py b/boson_multimodal/data_collator/higgs_audio_collator.py new file mode 100644 index 0000000000000000000000000000000000000000..b3228bbabbe4d63044522191ff12d16bad437aec --- /dev/null +++ b/boson_multimodal/data_collator/higgs_audio_collator.py @@ -0,0 +1,509 @@ +import librosa +import torch +import torch.nn.functional as F +import math +from typing import List, Tuple + +from dataclasses import dataclass +from typing import List, Optional +from transformers.models.whisper.processing_whisper import WhisperProcessor + +from ..dataset.chatml_dataset import ChatMLDatasetSample +from ..model.higgs_audio.utils import build_delay_pattern_mask + + +def _ceil_to_nearest(n, round_to): + return (n + round_to - 1) // round_to * round_to + + +def _ceil_to_next_power_of_two(self, x): + return 1 if x == 0 else 2 ** (x - 1).bit_length() + + +@dataclass +class HiggsAudioBatchInput: + input_ids: torch.LongTensor # shape (bsz, seq_len). + attention_mask: torch.Tensor # shape (bsz, seq_len). + audio_features: Optional[torch.Tensor] # shape (num_audio_in, feature_dim, max_mel_seq_len). + audio_feature_attention_mask: Optional[torch.Tensor] # shape (num_audio_in, max_mel_seq_len). + audio_out_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length) + audio_out_ids_start: Optional[torch.LongTensor] # shape (num_audio_out,) + # The audio_out_ids_start_group_loc has the same length as audio_out_ids_start. It is used to recover group location in a batch for an audio segment + # Currently, we concatenante audio segments along dim 0 to handle variadic audio segment length. However, in the alignment stage, we need the location information + # For example, + # audio_out_ids_start = [0, 2, 4, 8]; and the first two audio segments come from the same sample in a batch, and other two come from different samples. + # This is a batch of 3 samples, then we will have the group location as: + # audio_out_ids_start_group_loc = [0, 0, 1, 2] + audio_out_ids_start_group_loc: Optional[ + torch.LongTensor + ] # shape (num_audio_out,), specify which a sample's group location in the batch + audio_in_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_in_total_length) + audio_in_ids_start: Optional[torch.LongTensor] # shape (num_audio_in,) + label_ids: Optional[torch.LongTensor] # shape (bsz, seq_len) + label_audio_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length) + reward: Optional[float] = None + + +class HiggsAudioSampleCollator: + """Sample collator for Higgs-Audio model. + + Args: + whisper_processor (WhisperProcessor): The whisper processor. + audio_in_token_id (int): The token id for audio-in. + audio_out_token_id (int): The token id for audio-out. + pad_token_id (int): The token id for padding. + audio_stream_bos_id (int): The token id for audio-stream beginning of sentence. + audio_stream_eos_id (int): The token id for audio-stream end of sentence. + round_to (int): The round-to value. + pad_left (bool): Whether to pad left. + return_audio_in_tokens (bool): Whether to return audio-in tokens. + use_delay_pattern (bool): Whether to use delay pattern. + disable_audio_codes_transform (bool): Whether to add bos and eos tokens to audio codes. + chunk_size_seconds (int): The chunk size in seconds. + add_new_bos_eos_for_long_chunk (bool): Whether to add new bos and eos tokens for long chunks. + mask_audio_out_token_label (bool): Whether to always mask the label associated with <|AUDIO_OUT|> token. Since we will always have `<|AUDIO_OUT|>` after `<|audio_bos|>`, we can safely mask <|AUDIO_OUT|>. + + """ + + def __init__( + self, + whisper_processor: WhisperProcessor, + audio_in_token_id, + audio_out_token_id, + pad_token_id, + audio_stream_bos_id, + audio_stream_eos_id, + round_to=8, + pad_left=False, + encode_whisper_embed=True, + return_audio_in_tokens=True, + audio_num_codebooks=None, + use_delay_pattern=False, + disable_audio_codes_transform=False, + chunk_size_seconds=30, # Maximum duration for each chunk + add_new_bos_eos_for_long_chunk=True, + mask_audio_out_token_label=True, + ): + self.whisper_processor = whisper_processor + self.round_to = round_to + self.pad_left = pad_left + self.audio_in_token_id = audio_in_token_id + self.audio_out_token_id = audio_out_token_id + self.audio_stream_bos_id = audio_stream_bos_id + self.audio_stream_eos_id = audio_stream_eos_id + self.pad_token_id = pad_token_id + self.encode_whisper_embed = encode_whisper_embed + self.return_audio_in_tokens = return_audio_in_tokens + self.audio_num_codebooks = audio_num_codebooks + self.use_delay_pattern = use_delay_pattern + if encode_whisper_embed: + self.chunk_size_seconds = chunk_size_seconds + self.chunk_size_samples = int(chunk_size_seconds * whisper_processor.feature_extractor.sampling_rate) + else: + self.chunk_size_seconds = None + self.chunk_size_samples = None + self.disable_audio_codes_transform = disable_audio_codes_transform + self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk + self.mask_audio_out_token_label = mask_audio_out_token_label + + def _process_and_duplicate_audio_tokens( + self, input_ids: torch.Tensor, audio_idx: int, wv: torch.Tensor, sr: int, labels: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor, int]: + """Process long audio and duplicate corresponding audio tokens. + + Args: + input_ids: Input token ids + audio_idx: Index of the audio token in the sequence + wv: Audio waveform + sr: Sample rate + labels: Optional label ids to be duplicated alongside input ids + + Returns: + Tuple of: + - New input ids with duplicated audio tokens + - New label ids (if labels were provided) or None + - Number of chunks created + """ + # Calculate number of chunks needed + total_samples = len(wv) + num_chunks = math.ceil(total_samples / self.chunk_size_samples) + + if num_chunks <= 1: + return input_ids, labels, 1 + + # Get the three tokens: <|audio_bos|><|AUDIO|><|audio_eos|> + audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2] + # Duplicate sequence for each chunk + duplicated_sequence = audio_token_seq.repeat(num_chunks) + + # Create new input_ids with duplicated tokens + new_input_ids = torch.cat([input_ids[: audio_idx - 1], duplicated_sequence, input_ids[audio_idx + 2 :]]) + + # If labels are provided, duplicate them as well + new_labels = None + if labels is not None: + label_seq = labels[audio_idx - 1 : audio_idx + 2] + duplicated_labels = label_seq.repeat(num_chunks) + new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]]) + + return new_input_ids, new_labels, num_chunks + + def __call__(self, batch: List[ChatMLDatasetSample]): + """Collate the input data with support for long audio processing.""" + + label_ids = None + label_audio_ids = None + if all([ele.label_ids is None for ele in batch]): + return_labels = False + else: + return_labels = True + + if self.encode_whisper_embed: + # Process each sample in the batch to handle long audio + # TODO(?) The implementation here can be optimized. + processed_batch = [] + for i in range(len(batch)): + sample = batch[i] + audio_in_mask = sample.input_ids == self.audio_in_token_id + audio_in_indices = torch.where(audio_in_mask)[0] + audio_out_mask = sample.input_ids == self.audio_out_token_id + + # Process each audio token and duplicate if needed + modified_input_ids = sample.input_ids + modified_labels = sample.label_ids if return_labels else None + modified_waveforms_concat = [] + modified_waveforms_start = [] + modified_sample_rate = [] + offset = 0 # Track position changes from duplicating tokens + curr_wv_offset = 0 + + # Process input audio tokens + for idx, audio_idx in enumerate(audio_in_indices): + # Get the audio for this token + wv, sr = sample.get_wv(idx) # Use idx since we want the original audio index + if sr != self.whisper_processor.feature_extractor.sampling_rate: + resampled_wv = librosa.resample( + wv.cpu().numpy(), + orig_sr=sr, + target_sr=self.whisper_processor.feature_extractor.sampling_rate, + ) + else: + resampled_wv = wv.cpu().numpy() + wv = torch.tensor(resampled_wv, device=wv.device) + sr = self.whisper_processor.feature_extractor.sampling_rate + + # Process and duplicate tokens if necessary + token_pos = audio_idx + offset + modified_input_ids, modified_labels, num_chunks = self._process_and_duplicate_audio_tokens( + modified_input_ids, token_pos, wv, sr, modified_labels + ) + + # Update audio data + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * self.chunk_size_samples + chunk_end = min((chunk_idx + 1) * self.chunk_size_samples, len(wv)) + chunk_wv = wv[chunk_start:chunk_end] + modified_waveforms_concat.append(chunk_wv) + modified_waveforms_start.append(curr_wv_offset) + curr_wv_offset += len(chunk_wv) + modified_sample_rate.append(sr) + + # Update offset for next iteration + offset += (num_chunks - 1) * 3 # Each new chunk adds 3 more tokens + + # Create new sample with modified tokens and audio data + processed_sample = ChatMLDatasetSample( + input_ids=modified_input_ids, + label_ids=modified_labels if return_labels else sample.label_ids, + audio_ids_concat=sample.audio_ids_concat, + audio_ids_start=sample.audio_ids_start, + audio_waveforms_concat=torch.cat(modified_waveforms_concat) + if modified_waveforms_concat + else sample.audio_waveforms_concat, + audio_waveforms_start=torch.tensor(modified_waveforms_start, dtype=torch.long) + if modified_waveforms_start + else sample.audio_waveforms_start, + audio_sample_rate=torch.tensor(modified_sample_rate) + if modified_sample_rate + else sample.audio_sample_rate, + audio_speaker_indices=torch.tensor([]), + # FIXME(sxjscience): The logic here is not correct for audio_label_ids_concat. + audio_label_ids_concat=sample.audio_label_ids_concat, + ) + # audio_in_chunk_len = len(torch.where(modified_input_ids == self.audio_in_token_id)[0]) + # assert audio_in_chunk_len == processed_sample.num_audios(), f"Mismatch: audio_in_chunk_len={audio_in_chunk_len}, processed_sample.num_audios()={processed_sample.num_audios()}" + processed_batch.append(processed_sample) + else: + processed_batch = batch + + # Get the max sequence length based on processed batch + max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to) + + # Get the ids for audio-in and audio-out for each batch + audio_in_wv_l = [] + audio_in_ids_l = [] + audio_out_ids_l = [] + audio_out_ids_group_loc_l = [] + audio_in_label_ids_l = None + audio_out_label_ids_l = None + reward_l = [] + + if return_labels: + audio_out_no_train_flag = [] # Whether the audio-out data should be trained on or not. + + # Process the audio inputs and outputs + for i in range(len(processed_batch)): + audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id + audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id + audio_ids = torch.ones_like(processed_batch[i].input_ids) + audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1 + audio_in_ids = audio_ids[audio_in_mask] + audio_out_ids = audio_ids[audio_out_mask] + + if return_labels: + audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0) + if self.mask_audio_out_token_label: + processed_batch[i].label_ids[audio_out_mask] = -100 + + # Process audio inputs + if self.return_audio_in_tokens: + audio_in_ids_l.extend( + [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids] + ) + if processed_batch[i].audio_label_ids_concat is not None: + if audio_in_label_ids_l is None: + audio_in_label_ids_l = [] + audio_in_label_ids_l.extend( + [ + processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :] + for idx in audio_in_ids + ] + ) + + audio_out_ids_l.extend( + [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids] + ) + audio_out_ids_group_loc_l.append(i) + if processed_batch[i].reward is not None: + reward_l.append(processed_batch[i].reward) + + if processed_batch[i].audio_label_ids_concat is not None: + if audio_out_label_ids_l is None: + audio_out_label_ids_l = [] + audio_out_label_ids_l.extend( + [ + processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :] + for idx in audio_out_ids + ] + ) + + if self.encode_whisper_embed: + for idx in audio_in_ids: + wv, sr = processed_batch[i].get_wv(idx) + resampled_wv = wv.cpu().numpy() + # Split long audio into chunks + total_samples = len(resampled_wv) + for chunk_start in range(0, total_samples, self.chunk_size_samples): + chunk_end = min(chunk_start + self.chunk_size_samples, total_samples) + chunk = resampled_wv[chunk_start:chunk_end] + audio_in_wv_l.append(chunk) + # assert len(audio_in_wv_l) == processed_batch[i].num_audios(), \ + # f"Assertion failed: Mismatch in number of audios. " \ + # f"Expected {processed_batch[i].num_audios()}, but got {len(audio_in_wv_l)} at index {i}." + + if return_labels: + audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0) + + # Process all audio features + if len(audio_in_wv_l) > 0: + feature_ret = self.whisper_processor.feature_extractor( + audio_in_wv_l, + sampling_rate=self.whisper_processor.feature_extractor.sampling_rate, + return_attention_mask=True, + padding="max_length", + ) + audio_features = torch.from_numpy(feature_ret["input_features"]) + audio_feature_attention_mask = torch.from_numpy(feature_ret["attention_mask"]) + else: + if self.encode_whisper_embed: + audio_features = torch.zeros( + ( + 0, + self.whisper_processor.feature_extractor.feature_size, + self.whisper_processor.feature_extractor.nb_max_frames, + ), + dtype=torch.float32, + ) + audio_feature_attention_mask = torch.zeros( + (0, self.whisper_processor.feature_extractor.nb_max_frames), dtype=torch.int32 + ) + else: + audio_features = None + audio_feature_attention_mask = None + + # Process audio input tokens + if len(audio_in_ids_l) > 0: + # Append audio-stream-bos and eos tokens + new_audio_in_ids_l = [] + for ele in audio_in_ids_l: + if self.disable_audio_codes_transform: + # Do not add audio-stream-bos or eos tokens. + # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer. + audio_codes = ele + else: + audio_codes = torch.cat( + [ + torch.full((ele.shape[0], 1), self.audio_stream_bos_id, dtype=torch.long), + ele, + torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long), + ], + dim=1, + ) + if self.use_delay_pattern: + audio_codes = build_delay_pattern_mask( + audio_codes.unsqueeze(0), + bos_token_id=self.audio_stream_bos_id, + pad_token_id=self.audio_stream_eos_id, + )[0].squeeze(0) + new_audio_in_ids_l.append(audio_codes) + audio_in_ids = torch.cat(new_audio_in_ids_l, dim=1).long() + audio_in_ids_start = torch.cumsum( + torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_in_ids_l[:-1]]), dim=0 + ) + else: + audio_in_ids = torch.zeros((0, 0), dtype=torch.long) + audio_in_ids_start = torch.zeros(0, dtype=torch.long) + + # Process audio output tokens + audio_out_ids_start_group_loc = None + if len(audio_out_ids_l) > 0: + new_audio_out_ids_l = [] + label_audio_ids_l = [] + for idx, ele in enumerate(audio_out_ids_l): + if self.disable_audio_codes_transform: + # Do not add audio-stream-bos or eos tokens. + # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer. + audio_codes = ele + if return_labels: + label_audio_ids = audio_out_label_ids_l[idx] + else: + audio_codes = torch.cat( + [ + torch.full((ele.shape[0], 1), self.audio_stream_bos_id, dtype=torch.long), + ele, + torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long), + ], + dim=1, + ) + if return_labels: + label_audio_ids = torch.cat( + [ + torch.full((ele.shape[0], 1), -100, dtype=torch.long), + ele, + torch.full((ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long), + ], + dim=1, + ) + if self.use_delay_pattern: + audio_codes = build_delay_pattern_mask( + audio_codes.unsqueeze(0), + bos_token_id=self.audio_stream_bos_id, + pad_token_id=self.audio_stream_eos_id, + )[0].squeeze(0) + if return_labels: + label_audio_ids = build_delay_pattern_mask( + label_audio_ids.unsqueeze(0), + bos_token_id=-100, + pad_token_id=-100, + )[0].squeeze(0) + new_audio_out_ids_l.append(audio_codes) + + if return_labels: + if audio_out_no_train_flag[idx]: + label_audio_ids[:] = -100 + label_audio_ids_l.append(label_audio_ids) + + audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long() + if return_labels: + label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long() + audio_out_ids_start = torch.cumsum( + torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]), dim=0 + ) + audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long) + else: + audio_out_ids = torch.zeros((0, 0), dtype=torch.long) + audio_out_ids_start = torch.zeros(0, dtype=torch.long) + if return_labels: + label_audio_ids = torch.zeros((0, 0), dtype=torch.long) + + reward = torch.tensor(reward_l, dtype=torch.float32) + + # Handle padding for input ids and attention mask + if self.pad_left: + input_ids = torch.stack( + [ + F.pad(ele.input_ids, (max_seq_length - len(ele.input_ids), 0), value=self.pad_token_id) + for ele in processed_batch + ] + ) + if return_labels: + label_ids = torch.stack( + [ + F.pad(ele.label_ids, (max_seq_length - len(ele.label_ids), 0), value=-100) + for ele in processed_batch + ] + ) + attention_mask = torch.stack( + [ + F.pad(torch.ones_like(ele.input_ids), (max_seq_length - len(ele.input_ids), 0), value=0) + for ele in processed_batch + ] + ) + else: + input_ids = torch.stack( + [ + F.pad(ele.input_ids, (0, max_seq_length - len(ele.input_ids)), value=self.pad_token_id) + for ele in processed_batch + ] + ) + if return_labels: + label_ids = torch.stack( + [ + F.pad(ele.label_ids, (0, max_seq_length - len(ele.label_ids)), value=-100) + for ele in processed_batch + ] + ) + attention_mask = torch.stack( + [ + F.pad(torch.ones_like(ele.input_ids), (0, max_seq_length - len(ele.input_ids)), value=0) + for ele in processed_batch + ] + ) + + if not self.return_audio_in_tokens: + audio_in_ids = None + audio_in_ids_start = None + + # Apply audio_num_codebooks limit if specified + if self.audio_num_codebooks is not None: + if audio_in_ids is not None: + audio_in_ids = audio_in_ids[: self.audio_num_codebooks] + if audio_out_ids is not None: + audio_out_ids = audio_out_ids[: self.audio_num_codebooks] + if label_audio_ids is not None: + label_audio_ids = label_audio_ids[: self.audio_num_codebooks] + + return HiggsAudioBatchInput( + input_ids=input_ids, + attention_mask=attention_mask, + audio_features=audio_features, + audio_feature_attention_mask=audio_feature_attention_mask, + audio_out_ids=audio_out_ids, + audio_out_ids_start=audio_out_ids_start, + audio_out_ids_start_group_loc=audio_out_ids_start_group_loc, + audio_in_ids=audio_in_ids, + audio_in_ids_start=audio_in_ids_start, + label_ids=label_ids, + label_audio_ids=label_audio_ids, + reward=reward, + ) diff --git a/boson_multimodal/data_types.py b/boson_multimodal/data_types.py new file mode 100644 index 0000000000000000000000000000000000000000..2b86089d48d6d3e25e307575e44a2c287ccae6fa --- /dev/null +++ b/boson_multimodal/data_types.py @@ -0,0 +1,38 @@ +"""Basic data types for multimodal ChatML format.""" + +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + + +@dataclass +class AudioContent: + audio_url: str + # Base64 encoded audio bytes + raw_audio: Optional[str] = None + offset: Optional[float] = None + duration: Optional[float] = None + row_id: Optional[int] = None + type: str = "audio" + + +@dataclass +class TextContent: + text: str + type: str = "text" + + +@dataclass +class Message: + role: str + content: Union[str, AudioContent, TextContent, List[Union[str, AudioContent, TextContent]]] + recipient: Optional[str] = None + + +@dataclass +class ChatMLSample: + """Dataclass to hold multimodal ChatML data.""" + + messages: List[Message] + start_index: Optional[int] = None # We will mask the messages[:start_index] when finetuning the LLM. + misc: Optional[Dict] = None + speaker: Optional[str] = None diff --git a/boson_multimodal/dataset/__init__.py b/boson_multimodal/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/boson_multimodal/dataset/__pycache__/__init__.cpython-311.pyc b/boson_multimodal/dataset/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b211c8a1006a121bb09dcc542e3996a4dbf5a33 Binary files /dev/null and b/boson_multimodal/dataset/__pycache__/__init__.cpython-311.pyc differ diff --git a/boson_multimodal/dataset/__pycache__/chatml_dataset.cpython-311.pyc b/boson_multimodal/dataset/__pycache__/chatml_dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ab73f5b84879c67b8f8201cb7fa72d756da2f6c Binary files /dev/null and b/boson_multimodal/dataset/__pycache__/chatml_dataset.cpython-311.pyc differ diff --git a/boson_multimodal/dataset/chatml_dataset.py b/boson_multimodal/dataset/chatml_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..574034743b8ee17cf751314569d51472d33067fd --- /dev/null +++ b/boson_multimodal/dataset/chatml_dataset.py @@ -0,0 +1,533 @@ +import dacite +import pandas as pd +import torch +import json + +import numpy as np +import multiprocessing as mp + +from dataclasses import dataclass, fields +from abc import ABC, abstractmethod +from typing import Union, List, Dict, Optional + +from ..data_types import ChatMLSample, TextContent, AudioContent +from ..constants import AUDIO_IN_TOKEN, AUDIO_OUT_TOKEN + +from loguru import logger + +# Whisper processor, 30 sec -> 3000 features +# Then we divide 4 in the audio towker, we decrease 3000 features to 750, which gives 25 Hz +WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC = 25 + + +@dataclass +class ChatMLDatasetSample: + input_ids: torch.LongTensor # Shape (seq_len,): The input text tokens. + label_ids: torch.LongTensor # Shape (seq_len,): The label ids. + audio_ids_concat: torch.LongTensor # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated. + # Here `audio_seq_len` is the length of the concatenated audio tokens.` + audio_ids_start: ( + torch.LongTensor + ) # Shape (num_audios,): The start index of each audio token in the concatenated audio tokens. + audio_waveforms_concat: ( + torch.Tensor + ) # Shape (total_wv_length,): The concatenated audio waveforms for audio-in features. + audio_waveforms_start: ( + torch.LongTensor + ) # Shape (num_audios,): The start index of each audio waveform in the concatenated audio waveforms. + audio_sample_rate: torch.Tensor # Shape (num_audios,): The sampling rate of the audio waveforms. + audio_speaker_indices: ( + torch.LongTensor + ) # Shape (num_audios,) -1 means unknown speaker: The speaker indices for each audio. + audio_label_ids_concat: Optional[torch.LongTensor] = ( + None # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated. + ) + # Here `audio_seq_len` is the length of the concatenated audio tokens.` + reward: Optional[float] = None + + def num_audios(self): + return max(len(self.audio_waveforms_start), len(self.audio_ids_start)) + + def get_audio_codes(self, idx): + code_start = self.audio_ids_start[idx] + if idx < len(self.audio_ids_start) - 1: + code_end = self.audio_ids_start[idx + 1] + else: + code_end = self.audio_ids_concat.shape[-1] + + return self.audio_ids_concat[:, code_start:code_end] + + def get_audio_codes_labels(self, idx): + if self.audio_label_ids_concat is None: + return None + code_start = self.audio_ids_start[idx] + if idx < len(self.audio_ids_start) - 1: + code_end = self.audio_ids_start[idx + 1] + else: + code_end = self.audio_ids_concat.shape[-1] + + return self.audio_label_ids_concat[:, code_start:code_end] + + def get_wv(self, idx): + wv_start = self.audio_waveforms_start[idx] + sr = self.audio_sample_rate[idx] + if idx < len(self.audio_waveforms_start) - 1: + wv_end = self.audio_waveforms_start[idx + 1] + else: + wv_end = self.audio_waveforms_concat.shape[-1] + return self.audio_waveforms_concat[wv_start:wv_end], sr + + def cal_num_tokens( + self, + encode_whisper_embed: bool = True, + encode_audio_in_tokens: bool = False, + encode_audio_out_tokens: bool = True, + audio_in_token_id: int = 128015, + audio_out_token_id: int = 128016, + ) -> int: + # we firstly exclude <|AUDIO|> and <|AUDIO_OUT|> because we do late merging and replace those position with actual audio features and audio token ids + # It's assumed that we always have audio_ids when audio_waveforms are there (but not vice-versa) + num_tokens = len(self.input_ids) - len(self.audio_ids_start) + + if encode_whisper_embed and len(self.audio_waveforms_concat) > 0: + audio_lengths = torch.diff(self.audio_waveforms_start) + if len(audio_lengths): + # Sum before calling .item() + num_tokens += ( + ( + np.ceil(WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC * audio_lengths / self.audio_sample_rate[:-1]) + ).sum() + ).item() + # add the last audio's token estimation + num_tokens += ( + np.ceil( + WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC + * (self.audio_waveforms_concat.shape[0] - self.audio_waveforms_start[-1]) + / self.audio_sample_rate[-1] + ) + ).item() + + if self.audio_ids_concat.size(1) > 0: + audio_io_ids = self.input_ids[ + (self.input_ids == audio_in_token_id) | (self.input_ids == audio_out_token_id) + ] + audio_io_id_lengths = torch.concat( + [ + torch.diff(self.audio_ids_start), + torch.tensor([self.audio_ids_concat.shape[-1] - self.audio_ids_start[-1]]), + ] + ) + if encode_audio_in_tokens: + num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_in_token_id]).item() + + if encode_audio_out_tokens: + num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_out_token_id]).item() + + return int(num_tokens) + + @classmethod + def merge( + cls, + samples: List["ChatMLDatasetSample"], + eos_token_id: int, + ignore_index: int, + padding_size: Optional[int] = None, + ) -> "ChatMLDatasetSample": + """Merges a list of ChatMLDatasetSample instances, inserting eos_token_id and ignore_index between them, and adjusting offsets for audio_ids_start and audio_waveforms_start. + + Args: + samples (List[ChatMLDatasetSample]): List of samples to merge. + eos_token_id (int): Tokens to be inserted into input_ids between samples. + ignore_index (int): Default label for padding. + padding_size (Optional[int]): If provided, pad the sequence to with this length. + + Returns: + ChatMLDatasetSample: Merged and potentially padded sample. + """ + if not samples: + logger.fatal("The samples list is empty and cannot be merged.") + raise ValueError("The samples list is empty and cannot be merged.") + + # Initialize empty lists for concatenation + input_ids_list = [] + label_ids_list = [] + audio_ids_concat_list = [] + audio_ids_start_list = [] + audio_waveforms_concat_list = [] + audio_waveforms_start_list = [] + audio_sample_rate_list = [] + audio_speaker_indices_list = [] + + # Track offsets + audio_ids_offset = 0 + audio_waveforms_offset = 0 + + for sample in samples: + # Add input_ids and label_ids with padding + if input_ids_list: + input_ids_list.append(torch.tensor([eos_token_id], dtype=torch.long)) + label_ids_list.append(torch.tensor([ignore_index], dtype=torch.long)) + input_ids_list.append(sample.input_ids) + label_ids_list.append(sample.label_ids) + + # Add audio_ids_concat and handle empty audio ids + if sample.audio_ids_concat.size(1) > 0: + audio_ids_concat_list.append(sample.audio_ids_concat) + + # Offset and add audio_ids_start + audio_ids_start_list.append(sample.audio_ids_start + audio_ids_offset) + audio_ids_offset += sample.audio_ids_concat.size( + 1 + ) # (num_codebooks, seq_len): Update offset by audio_seq_len + + # Add audio_waveforms_concat + if sample.audio_waveforms_concat.size(0) > 0: + # Check dimensions of the audio waveform to ensure consistency + if ( + audio_waveforms_concat_list + and sample.audio_waveforms_concat.dim() != audio_waveforms_concat_list[0].dim() + ): + logger.warning( + f"Skipping audio waveform with inconsistent dimensions: expected {audio_waveforms_concat_list[0].dim()}D, got {sample.audio_waveforms_concat.dim()}D" + ) + continue + + audio_waveforms_concat_list.append(sample.audio_waveforms_concat) + audio_waveforms_start_list.append(sample.audio_waveforms_start + audio_waveforms_offset) + audio_waveforms_offset += sample.audio_waveforms_concat.size(0) + + # Add audio_sample_rate and audio_speaker_indices + audio_sample_rate_list.append(sample.audio_sample_rate) + + audio_speaker_indices_list.append(sample.audio_speaker_indices) + + # Concatenate all tensors + input_ids = torch.cat(input_ids_list, dim=0) + label_ids = torch.cat(label_ids_list, dim=0) + + # Apply padding if padding_size is specified + if padding_size is not None and padding_size > 0: + input_ids = torch.cat([input_ids, torch.full((padding_size,), eos_token_id, dtype=torch.long)], dim=0) + label_ids = torch.cat([label_ids, torch.full((padding_size,), ignore_index, dtype=torch.long)], dim=0) + + # Safely concatenate audio tensors with proper error handling + try: + audio_ids_concat = torch.cat(audio_ids_concat_list, dim=1) if audio_ids_concat_list else torch.tensor([[]]) + audio_ids_start = torch.cat(audio_ids_start_list, dim=0) if audio_ids_start_list else torch.tensor([]) + + # Check for dimensional consistency in audio waveforms + if audio_waveforms_concat_list: + dims = [t.dim() for t in audio_waveforms_concat_list] + if not all(d == dims[0] for d in dims): + # If dimensions don't match, log warning and filter out the problematic tensors + logger.warning( + f"Inconsistent dimensions in audio waveforms: {dims}. Filtering to keep only consistent ones." + ) + expected_dim = max(set(dims), key=dims.count) # Most common dimension + audio_waveforms_concat_list = [t for t in audio_waveforms_concat_list if t.dim() == expected_dim] + + # Recalculate audio_waveforms_start with the filtered list + if audio_waveforms_concat_list: + audio_waveforms_offset = 0 + audio_waveforms_start_list = [] + for waveform in audio_waveforms_concat_list: + audio_waveforms_start_list.append(torch.tensor([audio_waveforms_offset])) + audio_waveforms_offset += waveform.size(0) + + audio_waveforms_concat = ( + torch.cat(audio_waveforms_concat_list, dim=0) if audio_waveforms_concat_list else torch.tensor([]) + ) + audio_waveforms_start = ( + torch.cat(audio_waveforms_start_list, dim=0) if audio_waveforms_start_list else torch.tensor([]) + ) + audio_sample_rate = ( + torch.cat(audio_sample_rate_list, dim=0) if audio_sample_rate_list else torch.tensor([]) + ) + audio_speaker_indices = ( + torch.cat(audio_speaker_indices_list, dim=0) if audio_speaker_indices_list else torch.tensor([]) + ) + + except RuntimeError as e: + logger.error(f"Error during tensor concatenation: {str(e)}") + logger.warning("Falling back to empty audio tensors") + # Fall back to empty tensors + audio_ids_concat = torch.tensor([[]]) + audio_ids_start = torch.tensor([]) + audio_waveforms_concat = torch.tensor([]) + audio_waveforms_start = torch.tensor([]) + audio_sample_rate = torch.tensor([]) + audio_speaker_indices = torch.tensor([]) + + # Create the merged sample + merged_sample = cls( + input_ids=input_ids, + label_ids=label_ids, + audio_ids_concat=audio_ids_concat, + audio_ids_start=audio_ids_start, + audio_waveforms_concat=audio_waveforms_concat, + audio_waveforms_start=audio_waveforms_start, + audio_sample_rate=audio_sample_rate, + audio_speaker_indices=audio_speaker_indices, + ) + + return merged_sample + + +@dataclass +class RankedChatMLDatasetSampleTuple: + samples: List[ChatMLDatasetSample] + scores: List[float] + + def max_score_sample(self) -> ChatMLDatasetSample: + idx = self.scores.index(max(self.scores)) + self.samples[idx].reward = self.scores[idx] + return self.samples[idx] + + def min_score_sample(self) -> ChatMLDatasetSample: + idx = self.scores.index(min(self.scores)) + self.samples[idx].reward = self.scores[idx] + return self.samples[idx] + + +@dataclass +class ChatMLDatasetStorageSample: + input_tokens: torch.LongTensor + label_tokens: torch.LongTensor + audio_bytes_cache_dir_index: int + audio_codes_cache_dir_index: int + audio_bytes_indices: torch.LongTensor + audio_codes_indices: torch.LongTensor + speaker_indices: torch.LongTensor + file_index: int + original_sample_index: int + + +# TODO(sxjscience): We need to revist the logic about parsing speaker ids. +# Currently, we assume that the speaker id is stored at the "misc" field in ChatMLSample. +def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer): + """Preprocess the ChatML sample to get the tokens for the text part. + + Args: + sample (ChatMLSample): The ChatML sample to preprocess. + tokenizer: The tokenizer to use for encoding the text. + + """ + + try: + if not isinstance(sample, ChatMLSample): + # Handle all fields that could be NaN + if "speaker" in sample and pd.isna(sample["speaker"]): + sample["speaker"] = None + if "start_index" in sample and pd.isna(sample["start_index"]): + sample["start_index"] = None + if "content" in sample and pd.isna(sample["content"]): + sample["content"] = "" + + # Convert any other potential NaN values in nested structures + def convert_nan_to_none(obj): + import numpy as np + + if isinstance(obj, (pd.Series, np.ndarray)): + return obj.tolist() + elif pd.api.types.is_scalar(obj) and pd.isna(obj): + return None + elif isinstance(obj, dict): + return {k: convert_nan_to_none(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): # Fixed: Handle both list and tuple + return [convert_nan_to_none(item) for item in obj] + return obj + + # Clean the sample data + clean_sample = convert_nan_to_none(sample) + + val_keys = [] + for field in fields(ChatMLSample): + if field.name in clean_sample: + val_keys.append(field.name) + clean_sample = {k: clean_sample[k] for k in val_keys} + + try: + sample = dacite.from_dict( + data_class=ChatMLSample, data=clean_sample, config=dacite.Config(strict=True, check_types=True) + ) + except Exception as e: + print(f"Failed to convert to ChatMLSample: {e}") + print(f"Clean sample: {json.dumps(clean_sample, indent=2)}") + return None, None, None, None + + input_tokens = [] + label_tokens = [] + audio_contents = [] + speaker_id = None + if sample.speaker is not None: + speaker_id = sample.speaker + elif sample.misc is not None: + if "speaker" in sample.misc: + speaker_id = sample.misc["speaker"] + + total_m = len(sample.messages) + for turn_id, message in enumerate(sample.messages): + role = message.role + recipient = message.recipient + content = message.content + content_l = [] + + if isinstance(content, str): + content_l.append(TextContent(text=content)) + elif isinstance(content, TextContent): + content_l.append(content) + elif isinstance(content, AudioContent): + content_l.append(content) + elif isinstance(content, list): + for ele in content: + if isinstance(ele, str): + content_l.append(TextContent(text=ele)) + else: + content_l.append(ele) + if turn_id == 0: + prefix = f"<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n" + else: + prefix = f"<|start_header_id|>{role}<|end_header_id|>\n\n" + eot_postfix = "<|eot_id|>" + eom_postfix = "<|eom_id|>" + + prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False) + input_tokens.extend(prefix_tokens) + label_tokens.extend([-100 for _ in prefix_tokens]) + + if recipient: + assert role == "assistant", "Recipient is only available for assistant role." + recipient_tokens = tokenizer.encode(f"{recipient}<|recipient|>", add_special_tokens=False) + input_tokens.extend(recipient_tokens) + label_tokens.extend(recipient_tokens) + + for content in content_l: + if content.type == "text": + text_tokens = tokenizer.encode(content.text, add_special_tokens=False) + input_tokens.extend(text_tokens) + if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index): + label_tokens.extend(text_tokens) + else: + label_tokens.extend([-100 for _ in text_tokens]) + + elif content.type == "audio": + # Generate the text-part of the audio tokens + audio_contents.append(content) + if role == "user" or role == "system": + # Add the text tokens + text_tokens = tokenizer.encode( + f"<|audio_bos|><|AUDIO|><|audio_eos|>", + add_special_tokens=False, + ) + input_tokens.extend(text_tokens) + label_tokens.extend([-100 for _ in text_tokens]) + elif role == "assistant": + # Add the text tokens for audio-out part. + text_tokens = tokenizer.encode( + f"<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>", + add_special_tokens=False, + ) + input_tokens.extend(text_tokens) + if sample.start_index is None or turn_id >= sample.start_index: + label_tokens.extend(text_tokens) + else: + label_tokens.extend([-100 for _ in text_tokens]) + next_id = turn_id + 1 + if role == "assistant" and next_id != total_m and sample.messages[next_id].role == "assistant": + postfix_tokens = tokenizer.encode(eom_postfix, add_special_tokens=False) + input_tokens.extend(postfix_tokens) + else: + postfix_tokens = tokenizer.encode(eot_postfix, add_special_tokens=False) + input_tokens.extend(postfix_tokens) + if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index): + label_tokens.extend(postfix_tokens) + else: + label_tokens.extend([-100 for _ in postfix_tokens]) + + return input_tokens, label_tokens, audio_contents, speaker_id + + except Exception as e: + print(f"Error in prepare_chatml_sample: {str(e)}") + print(f"Sample data: {json.dumps(sample, indent=2)}") + return None, None, None, None + + +def extract_generation_prompt_from_input_tokens(input_tokens, tokenizer): + """Extract the generation prompt and reference answer from the input tokens. + + For example: + + Input Text = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n + What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|> + <|start_header_id|>assistant<|end_header_id|>\n\nAt first they went by quick, too quick to even get.<|eot_id|>' + + --> + + Prompt = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n + What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|> + <|start_header_id|>assistant<|end_header_id|>\n\n', + Reference = 'At first they went by quick, too quick to even get.' + + Args: + input_tokens: The input tokens. + audio_contents: The audio contents. + tokenizer: The tokenizer to use for decoding the text. + + Returns: + prompt_tokens: The tokens for the prompt. + reference_answer: The reference answer. + num_audios_in_reference: The number of audios in the reference answer. + + """ + input_text = tokenizer.decode(input_tokens) + generation_prefix = "<|start_header_id|>assistant<|end_header_id|>\n\n" + postfix = "<|eot_id|>" + assert generation_prefix in input_text + generation_prompt_end_loc = input_text.rfind(generation_prefix) + len(generation_prefix) + generation_prompt = input_text[:generation_prompt_end_loc] + reference_answer = input_text[generation_prompt_end_loc : input_text.find(postfix, generation_prompt_end_loc)] + num_audios_in_reference = reference_answer.count(AUDIO_IN_TOKEN) + reference_answer.count(AUDIO_OUT_TOKEN) + return tokenizer.encode(generation_prompt, add_special_tokens=False), reference_answer, num_audios_in_reference + + +def prepare_chatml_dataframe_single_process(df, tokenizer): + """Prepare the ChatML DataFrame.""" + ret = [] + for _, row in df.iterrows(): + input_tokens, label_tokens, audio_contents, speaker_id = prepare_chatml_sample(row.to_dict(), tokenizer) + ret.append((input_tokens, label_tokens, audio_contents, speaker_id)) + return ret + + +def prepare_chatml_dataframe(df, tokenizer, num_process=16): + if num_process is None: + return prepare_chatml_dataframe_single_process(df, tokenizer) + else: + num_process = max(min(len(df) // 1000, num_process), 1) + workloads = np.array_split(df, num_process) + with mp.Pool(num_process) as pool: + ret = pool.starmap( + prepare_chatml_dataframe_single_process, [(workload, tokenizer) for workload in workloads] + ) + return sum(ret, []) + + +class DatasetInterface(ABC): + @abstractmethod + def __getitem__(self, idx) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]: + """Retrieve a dataset sample by index.""" + raise NotImplementedError + + +class IterableDatasetInterface(ABC): + @abstractmethod + def __iter__(self) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]: + """Retrieve a sample by iterating through the dataset.""" + raise NotImplementedError + + +@dataclass +class DatasetInfo: + dataset_type: str + group_type: Optional[str] = None + mask_text: Optional[bool] = None # Whether to mask the text tokens for pretraining samples. diff --git a/boson_multimodal/model/__init__.py b/boson_multimodal/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/boson_multimodal/model/__pycache__/__init__.cpython-311.pyc b/boson_multimodal/model/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f77b2d137a58980d0a3cd73184d3a75edf872ef Binary files /dev/null and b/boson_multimodal/model/__pycache__/__init__.cpython-311.pyc differ diff --git a/boson_multimodal/model/higgs_audio/__init__.py b/boson_multimodal/model/higgs_audio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ad77c28104c694c79e3588903eba7fdb2051a8e --- /dev/null +++ b/boson_multimodal/model/higgs_audio/__init__.py @@ -0,0 +1,9 @@ +from transformers import AutoConfig, AutoModel + +from .configuration_higgs_audio import HiggsAudioConfig, HiggsAudioEncoderConfig +from .modeling_higgs_audio import HiggsAudioModel + + +AutoConfig.register("higgs_audio_encoder", HiggsAudioEncoderConfig) +AutoConfig.register("higgs_audio", HiggsAudioConfig) +AutoModel.register(HiggsAudioConfig, HiggsAudioModel) diff --git a/boson_multimodal/model/higgs_audio/__pycache__/__init__.cpython-311.pyc b/boson_multimodal/model/higgs_audio/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36ae9def4b8ffbfc96994624e5a0d292d8d1b6aa Binary files /dev/null and b/boson_multimodal/model/higgs_audio/__pycache__/__init__.cpython-311.pyc differ diff --git a/boson_multimodal/model/higgs_audio/__pycache__/audio_head.cpython-311.pyc b/boson_multimodal/model/higgs_audio/__pycache__/audio_head.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99cf13cc76234be71b0c47d4fdb028309040794f Binary files /dev/null and b/boson_multimodal/model/higgs_audio/__pycache__/audio_head.cpython-311.pyc differ diff --git a/boson_multimodal/model/higgs_audio/__pycache__/common.cpython-311.pyc b/boson_multimodal/model/higgs_audio/__pycache__/common.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8c703c3089c93a2e7bb0ada6cbc045b7ad53fa7 Binary files /dev/null and b/boson_multimodal/model/higgs_audio/__pycache__/common.cpython-311.pyc differ diff --git a/boson_multimodal/model/higgs_audio/__pycache__/configuration_higgs_audio.cpython-311.pyc b/boson_multimodal/model/higgs_audio/__pycache__/configuration_higgs_audio.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecdba798e05835b32cc3257f916e5560d35f3d3e Binary files /dev/null and b/boson_multimodal/model/higgs_audio/__pycache__/configuration_higgs_audio.cpython-311.pyc differ diff --git a/boson_multimodal/model/higgs_audio/__pycache__/cuda_graph_runner.cpython-311.pyc b/boson_multimodal/model/higgs_audio/__pycache__/cuda_graph_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..365f7ccc498193af633bf67b746e209b09bdfced Binary files /dev/null and b/boson_multimodal/model/higgs_audio/__pycache__/cuda_graph_runner.cpython-311.pyc differ diff --git a/boson_multimodal/model/higgs_audio/__pycache__/custom_modules.cpython-311.pyc b/boson_multimodal/model/higgs_audio/__pycache__/custom_modules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f986d3a4f72fadd4848433209cf2ab2877e1bff9 Binary files /dev/null and b/boson_multimodal/model/higgs_audio/__pycache__/custom_modules.cpython-311.pyc differ diff --git a/boson_multimodal/model/higgs_audio/__pycache__/modeling_higgs_audio.cpython-311.pyc b/boson_multimodal/model/higgs_audio/__pycache__/modeling_higgs_audio.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12b9e5944ccbb5c5b137d8479880fd20732353d1 --- /dev/null +++ b/boson_multimodal/model/higgs_audio/__pycache__/modeling_higgs_audio.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c15b0901bc8776d4c224efe999eec4addf97f302770963172d2df6f8961cfc94 +size 101267 diff --git a/boson_multimodal/model/higgs_audio/__pycache__/utils.cpython-311.pyc b/boson_multimodal/model/higgs_audio/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..721d058a184934498914f4f1852b10cb6a71201e Binary files /dev/null and b/boson_multimodal/model/higgs_audio/__pycache__/utils.cpython-311.pyc differ diff --git a/boson_multimodal/model/higgs_audio/audio_head.py b/boson_multimodal/model/higgs_audio/audio_head.py new file mode 100644 index 0000000000000000000000000000000000000000..06f05f5a5f00f489623abc7365fe1b778d66af6d --- /dev/null +++ b/boson_multimodal/model/higgs_audio/audio_head.py @@ -0,0 +1,129 @@ +"""Projector that maps hidden states from the LLM component to multimodal logits.""" + +import torch +from torch import nn + +from dataclasses import dataclass +from typing import Optional, Tuple + +from .common import HiggsAudioPreTrainedModel +from .configuration_higgs_audio import HiggsAudioConfig + + +@dataclass +class HiggsAudioDecoderLayerOutput: + logits: torch.FloatTensor + audio_logits: torch.FloatTensor + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +class HiggsAudioDecoderProjector(HiggsAudioPreTrainedModel): + """Projection layers that map hidden states from the LLM component to audio / text logits. + + We support two type of audio head: + - Basic Audio Head: + Directly map the hidden states to audio logits for all the codebooks. + """ + + def __init__(self, config: HiggsAudioConfig, layer_idx: Optional[int] = None): + super().__init__(config) + self.text_lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.audio_lm_head = nn.Linear( + config.text_config.hidden_size, config.audio_num_codebooks * (config.audio_codebook_size + 2), bias=False + ) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + hidden_states, + audio_out_mask, + label_audio_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + output_audio_hidden_states=False, + cache_position=None, + ): + """ + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`): + Hidden states from the LLM component + audio_out_mask (`torch.Tensor` of shape `(batch_size, seq_len)`): + Mask for identifying the audio out tokens. + label_audio_ids (`torch.Tensor` of shape `(num_codebooks, num_audio_out_tokens)`): + Label tokens for the audio-out part. This is used for calculating the logits if RQ-Transformer is used. + attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`): + Mask to avoid performing attention on padding token indices + position_ids (`torch.Tensor` of shape `(batch_size, seq_len)`): + Position ids for the input tokens + + Returns: + logits (`torch.Tensor` of shape `(batch_size, seq_len, vocab_size)`): + Logits for text tokens + audio_logits (`torch.Tensor` of shape `(num_audio_out_tokens, audio_num_codebooks * audio_codebook_size)`): + Logits for audio tokens. We ensure `num_text_tokens + num_audio_tokens == batch_size * seq_len` + """ + logits = self.text_lm_head(hidden_states) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + if self.config.audio_decoder_proj_num_layers > 0: + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for decoder_layer in self.transformer_layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = layer_outputs[0] + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + next_cache = next_decoder_cache if use_cache else None + + audio_logits = self.audio_lm_head(hidden_states[audio_out_mask]) + + if output_audio_hidden_states: + audio_hidden_states = hidden_states[audio_out_mask] + else: + audio_hidden_states = None + + return logits, audio_logits, all_self_attns, all_hidden_states, audio_hidden_states, next_cache diff --git a/boson_multimodal/model/higgs_audio/common.py b/boson_multimodal/model/higgs_audio/common.py new file mode 100644 index 0000000000000000000000000000000000000000..e01ba869e2a5a10ab942730411e54bc0f55f8e2e --- /dev/null +++ b/boson_multimodal/model/higgs_audio/common.py @@ -0,0 +1,27 @@ +from torch import nn + +from transformers.modeling_utils import PreTrainedModel + +from .configuration_higgs_audio import HiggsAudioConfig + + +class HiggsAudioPreTrainedModel(PreTrainedModel): + config_class = HiggsAudioConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_encoder_config.init_std + + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() diff --git a/boson_multimodal/model/higgs_audio/configuration_higgs_audio.py b/boson_multimodal/model/higgs_audio/configuration_higgs_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..1783d029d77a9599b0a7e75a2f4dbaf192431da1 --- /dev/null +++ b/boson_multimodal/model/higgs_audio/configuration_higgs_audio.py @@ -0,0 +1,235 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.models.auto import CONFIG_MAPPING + + +class HiggsAudioEncoderConfig(PretrainedConfig): + """Configuration of the Audio encoder in Higgs-Audio.""" + + model_type = "higgs_audio_encoder" + + def __init__( + self, + num_mel_bins=128, + encoder_layers=32, + encoder_attention_heads=20, + encoder_ffn_dim=5120, + encoder_layerdrop=0.0, + d_model=1280, + dropout=0.0, + attention_dropout=0.0, + activation_function="gelu", + activation_dropout=0.0, + scale_embedding=False, + init_std=0.02, + max_source_positions=1500, + pad_token_id=128001, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_mel_bins = num_mel_bins + self.d_model = d_model + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_function = activation_function + self.activation_dropout = activation_dropout + self.encoder_layerdrop = encoder_layerdrop + self.num_hidden_layers = encoder_layers + self.init_std = init_std + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.max_source_positions = max_source_positions + self.pad_token_id = pad_token_id + + +class HiggsAudioConfig(PretrainedConfig): + r""" + This is the configuration class for the HiggsAudioModel. + + Args: + text_config (`Union[AutoConfig, dict]`): + The config object or dictionary of the text backbone. + audio_encoder_config (`Union[AutoConfig, dict]`): + The config object or dictionary of the whisper encoder. + The audio encoder will be bidirectional and will be only available for audio understanding. + audio_tokenizer_config + The config object or dictionary of the audio tokenizer. + audio_adapter_type + The type of audio adapter to use. We support two types of adapter: + - stack: + We stack additional Transformer layers after the main LLM backbone for audio generation. + - dual_ffn: + For selected part of the LLM backbone, we replace the text FFN with a dual FFN architecture + that contains an additional audio FFN. The audio FFN will be triggered when the location is marked for audio tokens. + - dual_ffn_fast_forward: + We pick a few layers in the LLM backbone to plug-in the audio FFN. For the remaining layers, + the audio hidden states will be directly fast-forward to the next layer. + This reduces the computational cost for audio generation. + audio_embed_avg (`bool`, *optional*, defaults to False): + Whether to average the audio embeddings before sending them to the text attention layer. + audio_ffn_hidden_size + The hidden size of the audio feedforward network in dual-path FFN + audio_ffn_intermediate_size + The intermediate size of the audio feedforward network in dual-path FFN + audio_dual_ffn_layers + The layers in the LLM backbone to plug-in the dual FFN layer (mixture of audio FFN and text FFN). + audio_decoder_proj_num_attention (`int`, *optional*, defaults to 0): + The number of attention heads in the audio decoder projection layer. + use_delay_pattern (`bool`, *optional*, defaults to False): + Whether to use delay pattern in the audio decoder. + skip_audio_tower (`bool`, *optional*, defaults to False): + Whether to skip the audio tower in the audio encoder. + use_audio_out_embed_projector (`bool`, *optional*, defaults to False): + Whether to use an embedding projector to map audio out embeddings. + use_audio_out_self_attention (`bool`, *optional*, defaults to False): + Whether to use self-attention to aggregate information from audio-tokens before sending to the text attention layer. + audio_num_codebooks (`int`, *optional*, defaults to 12): + The number of codebooks in RVQGAN. + audio_codebook_size (`int`, *optional*, defaults to 1024): + The size of each codebook in RVQGAN. + audio_stream_bos_id + The id of the bos in the audio stream + audio_stream_eos_id + The id of the eos in the audio stream + audio_bos_token (`str`, *optional*, defaults to "<|audio_bos|>"): + The special `<|audio_bos|>` token. In Higgs-Audio, it is mapped to 128011, + which is the index of `<|reserved_special_token_3|>` in Llama-3.1-8B-Instruct's tokenizer. + audio_eos_token (`str`, *optional*, defaults to "<|audio_eos|>"): + The special `<|audio_eos|>` token. We use 128012 as the default value, + which is the index of `<|reserved_special_token_4|>` in Llama-3.1-8B-Instruct's tokenizer. + audio_out_bos_token (`str`, *optional*, defaults to "<|audio_out_bos|>"): + The special `<|audio_out_bos|>` token. We use 128013 as the default value, + which is the index of `<|reserved_special_token_5|>` in Llama-3.1-8B-Instruct's tokenizer. + audio_token (`str`, *optional*, defaults to "<|AUDIO|>"): + The special `<|AUDIO|>` token. We use 128015 as the default value, + which is the index of `<|reserved_special_token_7|>` in Llama-3.1-8B-Instruct's tokenizer. + This token indicates that the location should be filled in with whisper features. + audio_out_token (`str`, *optional*, defaults to "<|AUDIO_OUT|>"): + The special `<|AUDIO_OUT|>` token. We use 128016 as the default value, + which is the index of `<|reserved_special_token_8|>` in Llama-3.1-8B-Instruct's tokenizer. + This token indicates that the location should be filled in with audio tokens extracted via audio tokenizer. + """ + + model_type = "higgs_audio" + is_composition = True + + def __init__( + self, + text_config=None, + audio_encoder_config=None, + audio_tokenizer_config=None, + audio_adapter_type="stack", + audio_embed_avg=False, + audio_ffn_hidden_size=4096, + audio_ffn_intermediate_size=14336, + audio_dual_ffn_layers=None, + audio_decoder_proj_num_layers=0, + encode_whisper_embed=True, + encode_audio_in_tokens=False, + use_delay_pattern=False, + skip_audio_tower=False, + use_audio_out_embed_projector=False, + use_audio_out_self_attention=False, + use_rq_transformer=False, + rq_transformer_hidden_size=None, + rq_transformer_intermediate_size=None, + rq_transformer_num_attention_heads=None, + rq_transformer_num_key_value_heads=None, + rq_transformer_num_hidden_layers=3, + audio_num_codebooks=12, + audio_codebook_size=1024, + audio_stream_bos_id=1024, + audio_stream_eos_id=1025, + audio_bos_token="<|audio_bos|>", + audio_eos_token="<|audio_eos|>", + audio_out_bos_token="<|audio_out_bos|>", + audio_in_token="<|AUDIO|>", + audio_out_token="<|AUDIO_OUT|>", + audio_in_token_idx=128015, + audio_out_token_idx=128016, + pad_token_id=128001, + audio_out_bos_token_id=128013, + audio_eos_token_id=128012, + **kwargs, + ): + if isinstance(audio_encoder_config, dict): + audio_encoder_config["model_type"] = ( + audio_encoder_config["model_type"] if "model_type" in audio_encoder_config else "higgs_audio_encoder" + ) + audio_encoder_config = CONFIG_MAPPING[audio_encoder_config["model_type"]](**audio_encoder_config) + elif audio_encoder_config is None: + audio_encoder_config = HiggsAudioEncoderConfig() + + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + assert audio_adapter_type in [ + "stack", + "dual_ffn", + "dual_ffn_fast_forward", + ], f"Invalid audio adapter type: {audio_adapter_type}" + if audio_adapter_type.startswith("dual_ffn"): + assert audio_dual_ffn_layers is not None, ( + "audio_dual_ffn_layers must be specified when using dual_ffn adapter." + ) + self.text_config = text_config + self.audio_encoder_config = audio_encoder_config + self.audio_tokenizer_config = audio_tokenizer_config + self.audio_adapter_type = audio_adapter_type + self.audio_embed_avg = audio_embed_avg + self.audio_ffn_hidden_size = audio_ffn_hidden_size + self.audio_ffn_intermediate_size = audio_ffn_intermediate_size + self.audio_dual_ffn_layers = audio_dual_ffn_layers + self.audio_decoder_proj_num_layers = audio_decoder_proj_num_layers + self.encode_whisper_embed = encode_whisper_embed + self.encode_audio_in_tokens = encode_audio_in_tokens + self.use_delay_pattern = use_delay_pattern + self.skip_audio_tower = skip_audio_tower + self.use_audio_out_embed_projector = use_audio_out_embed_projector + self.use_audio_out_self_attention = use_audio_out_self_attention + + self.use_rq_transformer = use_rq_transformer + + if self.use_rq_transformer: + assert not self.use_delay_pattern, "Delay pattern is not supported if you turned on RQ-Transformer!" + self.rq_transformer_hidden_size = rq_transformer_hidden_size + self.rq_transformer_intermediate_size = rq_transformer_intermediate_size + self.rq_transformer_num_attention_heads = rq_transformer_num_attention_heads + self.rq_transformer_num_key_value_heads = rq_transformer_num_key_value_heads + self.rq_transformer_num_hidden_layers = rq_transformer_num_hidden_layers + + if use_rq_transformer: + # For RQ-Transformer, we set the hidden_size to the same as the text model's hidden size if it is not specified. + if self.rq_transformer_hidden_size is None: + self.rq_transformer_hidden_size = text_config.hidden_size + assert self.rq_transformer_hidden_size % 128 == 0 + if self.rq_transformer_intermediate_size is None: + self.rq_transformer_intermediate_size = text_config.intermediate_size + if self.rq_transformer_num_attention_heads is None: + self.rq_transformer_num_attention_heads = self.rq_transformer_hidden_size // 128 + if self.rq_transformer_num_key_value_heads is None: + self.rq_transformer_num_key_value_heads = self.rq_transformer_hidden_size // 128 // 4 + assert self.rq_transformer_hidden_size % self.rq_transformer_num_attention_heads == 0 + assert self.rq_transformer_hidden_size % self.rq_transformer_num_key_value_heads == 0 + + self.audio_num_codebooks = audio_num_codebooks + self.audio_codebook_size = audio_codebook_size + self.audio_bos_token = audio_bos_token + self.audio_eos_token = audio_eos_token + self.audio_out_bos_token = audio_out_bos_token + self.audio_in_token = audio_in_token + self.audio_out_token = audio_out_token + self.audio_in_token_idx = audio_in_token_idx + self.audio_out_token_idx = audio_out_token_idx + self.audio_stream_bos_id = audio_stream_bos_id + self.audio_stream_eos_id = audio_stream_eos_id + self.audio_out_bos_token_id = audio_out_bos_token_id + self.audio_eos_token_id = audio_eos_token_id + + super().__init__(**kwargs) + self.pad_token_id = pad_token_id diff --git a/boson_multimodal/model/higgs_audio/cuda_graph_runner.py b/boson_multimodal/model/higgs_audio/cuda_graph_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..a99507cb6c3a17c5414c46e09398b942af1f4004 --- /dev/null +++ b/boson_multimodal/model/higgs_audio/cuda_graph_runner.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +from typing import Optional, List, Dict, Tuple, Union +import gc + +from transformers.cache_utils import Cache + + +_NUM_WARMUP_ITERS = 2 + + +class CUDAGraphRunner(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + self.input_buffers: Dict[str, torch.Tensor] = {} + self.output_buffers: Dict[str, torch.Tensor] = {} + + self._graph: Optional[torch.cuda.CUDAGraph] = None + + @property + def graph(self): + assert self._graph is not None + return self._graph + + def capture( + self, + hidden_states: torch.Tensor, + causal_mask: torch.Tensor, + position_ids: torch.Tensor, + audio_discrete_codes_mask: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Union[Cache, List[torch.FloatTensor]], + use_cache: bool, + audio_attention_mask: torch.Tensor, + fast_forward_attention_mask: torch.Tensor, + output_attentions: bool, + output_hidden_states: bool, + is_decoding_audio_token: Optional[bool] = None, + is_using_cuda_graph: Optional[bool] = False, + stream: torch.cuda.Stream = None, + memory_pool: Optional[Tuple[int, int]] = None, + ): + assert self._graph is None + # Run warmup iterations + for _ in range(_NUM_WARMUP_ITERS): + self.model( + hidden_states=hidden_states, + causal_mask=causal_mask, + position_ids=position_ids, + audio_discrete_codes_mask=audio_discrete_codes_mask, + cache_position=cache_position, + past_key_values=past_key_values, + use_cache=use_cache, + audio_attention_mask=audio_attention_mask, + fast_forward_attention_mask=fast_forward_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + is_decoding_audio_token=is_decoding_audio_token, + is_using_cuda_graph=is_using_cuda_graph, + ) + + torch.cuda.synchronize() + + # Capture the graph + self._graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): + out_hidden_states, all_hidden_states, all_self_attns = self.model( + hidden_states=hidden_states, + causal_mask=causal_mask, + position_ids=position_ids, + audio_discrete_codes_mask=audio_discrete_codes_mask, + cache_position=cache_position, + past_key_values=past_key_values, + use_cache=use_cache, + audio_attention_mask=audio_attention_mask, + fast_forward_attention_mask=fast_forward_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + is_decoding_audio_token=is_decoding_audio_token, + is_using_cuda_graph=is_using_cuda_graph, + ) + # hidden_states_out = torch.ops._C.weak_ref_tensor(outputs[0]) + # del outputs + gc.collect() + torch.cuda.synchronize() + + # Save input and output buffers + self.input_buffers = { + "hidden_states": hidden_states, + "causal_mask": causal_mask, + "position_ids": position_ids, + "audio_discrete_codes_mask": audio_discrete_codes_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "audio_attention_mask": audio_attention_mask, + "fast_forward_attention_mask": fast_forward_attention_mask, + } + self.output_buffers = { + "hidden_states": out_hidden_states, + "all_hidden_states": all_hidden_states, + "all_self_attns": all_self_attns, + } + + def forward( + self, + hidden_states: torch.Tensor, + causal_mask: torch.Tensor, + position_ids: torch.Tensor, + audio_discrete_codes_mask: torch.Tensor, + cache_position: torch.Tensor, + audio_attention_mask: torch.Tensor, + fast_forward_attention_mask: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + # Copy input tensors to buffers + self.input_buffers["hidden_states"].copy_(hidden_states, non_blocking=True) + self.input_buffers["causal_mask"].copy_(causal_mask, non_blocking=True) + self.input_buffers["position_ids"].copy_(position_ids, non_blocking=True) + self.input_buffers["audio_discrete_codes_mask"].copy_(audio_discrete_codes_mask, non_blocking=True) + self.input_buffers["cache_position"].copy_(cache_position, non_blocking=True) + self.input_buffers["audio_attention_mask"].copy_(audio_attention_mask, non_blocking=True) + self.input_buffers["fast_forward_attention_mask"].copy_(fast_forward_attention_mask, non_blocking=True) + + # Run the captured graph + self.graph.replay() + + return self.output_buffers["hidden_states"], None, None diff --git a/boson_multimodal/model/higgs_audio/custom_modules.py b/boson_multimodal/model/higgs_audio/custom_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..eb585c8cc8edb6be7762cbc5ccd149e77079ecb3 --- /dev/null +++ b/boson_multimodal/model/higgs_audio/custom_modules.py @@ -0,0 +1,155 @@ +import torch +import torch.nn as nn + + +class PartiallyFrozenEmbedding(nn.Module): + """Split an existing `nn.Embedding` module that splits the embedding into: + + - A frozen embedding for indices [0..freeze_until_idx]. + - A trainable embedding for indices [freeze_until_idx+1..vocab_size-1]. + + This should work with both Zero-2 and Zero-3 seamlessly + """ + + def __init__(self, original_embedding: nn.Embedding, freeze_until_idx: int): + """ + :param original_embedding: An instance of nn.Embedding (the original embedding layer). + :param freeze_until_idx: The index up to which the embedding is frozen (excluding). The freeze_until_idx is not frozen. + """ + super().__init__() + self.freeze_until_idx = freeze_until_idx + self.original_vocab_size = original_embedding.num_embeddings + self.embedding_dim = original_embedding.embedding_dim + + # Split the original embedding into frozen and trainable parts + self.embedding_frozen = nn.Embedding( + freeze_until_idx, + self.embedding_dim, + dtype=original_embedding.weight.dtype, + device=original_embedding.weight.device, + ) + self.embedding_trainable = nn.Embedding( + self.original_vocab_size - freeze_until_idx, + self.embedding_dim, + dtype=original_embedding.weight.dtype, + device=original_embedding.weight.device, + ) + + # Copy weights from the original embedding into the frozen and trainable parts + with torch.no_grad(): + self.embedding_frozen.weight.copy_(original_embedding.weight[:freeze_until_idx]) + self.embedding_trainable.weight.copy_(original_embedding.weight[freeze_until_idx:]) + + # Freeze the frozen embedding + self.embedding_frozen.weight.requires_grad = False + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the split embedding wrapper. + :param input_ids: Tensor of shape [batch_size, seq_len] with indices in [0..original_vocab_size-1]. + """ + # Masks to separate frozen and trainable indices + # (bsz, seq_len) + mask_frozen = input_ids < self.freeze_until_idx + mask_trainable = ~mask_frozen + + # Output tensor for embedding results + batch_size, seq_len = input_ids.shape + embeddings = torch.zeros( + batch_size, + seq_len, + self.embedding_dim, + device=input_ids.device, + dtype=self.embedding_frozen.weight.dtype, + ) + + # Handle frozen embedding + if mask_frozen.any(): + frozen_ids = input_ids[mask_frozen] + frozen_emb = self.embedding_frozen(frozen_ids) + embeddings[mask_frozen] = frozen_emb + + # Handle trainable embedding + if mask_trainable.any(): + # Adjust trainable IDs to the local index space of the trainable embedding + trainable_ids = input_ids[mask_trainable] - (self.freeze_until_idx) + trainable_emb = self.embedding_trainable(trainable_ids) + embeddings[mask_trainable] = trainable_emb + + return embeddings + + def to_unsplit(self) -> nn.Embedding: + unsplit_embedding = nn.Embedding( + self.original_vocab_size, + self.embedding_dim, + dtype=self.embedding_frozen.weight.dtype, + device=self.embedding_frozen.weight.device, + ) + + with torch.no_grad(): + unsplit_embedding.weight[: self.freeze_until_idx].copy_(self.embedding_frozen.weight) + unsplit_embedding.weight[self.freeze_until_idx :].copy_(self.embedding_trainable.weight) + + return unsplit_embedding + + +class PartiallyFrozenLinear(nn.Module): + """A wrapper around nn.Linear to partially freeze part of the weight matrix.""" + + def __init__(self, original_linear: nn.Linear, freeze_until_idx: int): + """ + :param original_linear: The original nn.Linear layer. + :param freeze_until_idx: The index up to which the rows of the weight matrix are frozen. + """ + super().__init__() + assert original_linear.bias is None, "Currently only support linear module without bias" + + self.freeze_until_idx = freeze_until_idx + self.input_dim = original_linear.in_features + self.output_dim = original_linear.out_features + + # Create frozen and trainable linear layers + self.linear_frozen = nn.Linear( + self.input_dim, + freeze_until_idx, + bias=False, + dtype=original_linear.weight.dtype, + device=original_linear.weight.device, + ) + self.linear_trainable = nn.Linear( + self.input_dim, + self.output_dim - freeze_until_idx, + bias=False, + dtype=original_linear.weight.dtype, + device=original_linear.weight.device, + ) + + # Copy weights from the original linear layer + with torch.no_grad(): + self.linear_frozen.weight.copy_(original_linear.weight[:freeze_until_idx]) + self.linear_trainable.weight.copy_(original_linear.weight[freeze_until_idx:]) + + # Freeze the frozen linear layer + self.linear_frozen.weight.requires_grad = False + + def forward(self, input_tensor): + # input_tensor: (bsz, seq_len, hidden_state_dim) + frozen_output = self.linear_frozen(input_tensor) + trainable_output = self.linear_trainable(input_tensor) + return torch.cat((frozen_output, trainable_output), dim=-1) + + def to_unsplit(self) -> nn.Linear: + unsplit_linear = nn.Linear( + self.input_dim, + self.output_dim, + bias=False, + dtype=self.linear_frozen.weight.dtype, + device=self.linear_frozen.weight.device, + ) + + # Copy weights from the frozen and trainable layers into the unsplit linear layer + with torch.no_grad(): + unsplit_linear.weight[: self.freeze_until_idx].copy_(self.linear_frozen.weight) + unsplit_linear.weight[self.freeze_until_idx :].copy_(self.linear_trainable.weight) + + return unsplit_linear diff --git a/boson_multimodal/model/higgs_audio/modeling_higgs_audio.py b/boson_multimodal/model/higgs_audio/modeling_higgs_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..354fc32d3d41797901cb6af4bc2b25dc8bedcbf6 --- /dev/null +++ b/boson_multimodal/model/higgs_audio/modeling_higgs_audio.py @@ -0,0 +1,2289 @@ +"""Higgs-Audio is an end-to-end multimodal model with the capability to understand and generate text / audio.""" + +import torch +import torch.nn as nn +import math +import glob +import functools +import os +from collections import defaultdict, OrderedDict +from dataclasses import dataclass +from enum import Enum +from safetensors.torch import load_file +from typing import Optional, Tuple, Union, List, Dict, Any + +from transformers import AutoTokenizer +from transformers.modeling_outputs import BaseModelOutput +from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaRMSNorm, + LlamaRotaryEmbedding, + LLAMA_ATTENTION_CLASSES, + LlamaMLP, + LlamaRMSNorm, +) +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessorList, StoppingCriteriaList +from transformers.generation.utils import GenerateNonBeamOutput +from transformers.utils import logging, ModelOutput + +from .common import HiggsAudioPreTrainedModel +from .utils import ( + merge_input_ids_with_audio_features, + count_parameters, +) +from .configuration_higgs_audio import HiggsAudioConfig, HiggsAudioEncoderConfig +from .custom_modules import PartiallyFrozenLinear, PartiallyFrozenEmbedding +from .cuda_graph_runner import CUDAGraphRunner +from .audio_head import HiggsAudioDecoderProjector + +logger = logging.get_logger(__name__) + + +class GenerationMode(Enum): + """Enum for different generation modes in HiggsAudio model.""" + + TEXT = 0 # Text generation mode + AUDIO_INIT = 1 # Audio generation mode initialization + AUDIO_IN_PROGRESS = 2 # Audio generation mode in progress + + +def _whisper_encoder_zero_shape_forward(whisper_encoder, *args, **kwargs): + """The whisper encoder does not support zero-shape tensor by default due to the following implementations + + key_states = self._shape(self.k_proj(current_states), -1, bsz) + + If `bsz` is 0, the "-1" dimension will be ambiguous and triggers error in the shape inference pass. + + See also: https://github.com/huggingface/transformers/blob/30335093276212ce74938bdfd85bfd5df31a668a/src/transformers/models/whisper/modeling_whisper.py#L306-L307 + + This function monkey-patches all `_shape` functions in the whisper encoder's self-attention layers to ensure function supports zero-shape tensor. + + #FIXME!!!! This is a temporary workaround and should be removed once the upstream issue is resolved. + + """ + + global _higgs_flash_attention_forward + + def _patched_shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): + if seq_len == -1: + return tensor.view(bsz, tensor.shape[1], num_heads, head_dim).transpose(1, 2).contiguous() + else: + return tensor.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2).contiguous() + + def _patched_scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False + ) -> torch.Tensor: + # IMPORTANT! Implementation here is wrong and is only for the purpose of obtaining the correct attn_weight shape + if enable_gqa: + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) + return attn_weight @ value + + # Apply monkey-patch + if whisper_encoder.config._attn_implementation != "flash_attention_2": + old_shape_functions = [] + for layer in whisper_encoder.layers: + old_shape_functions.append(getattr(layer.self_attn, "_shape")) + layer.self_attn._shape = functools.partial( + _patched_shape, num_heads=layer.self_attn.num_heads, head_dim=layer.self_attn.head_dim + ) + + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention + torch.nn.functional.scaled_dot_product_attention = _patched_scaled_dot_product_attention + + out = whisper_encoder(*args, **kwargs) + torch.nn.functional.scaled_dot_product_attention = original_scaled_dot_product_attention + + # Restore the original shape functions + if whisper_encoder.config._attn_implementation != "flash_attention_2": + for layer, old_shape_function in zip(whisper_encoder.layers, old_shape_functions): + layer.self_attn._shape = old_shape_function + + return out + + +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class HiggsAudioFeatureProjector(nn.Module): + """Projector that maps audio features extracted by Whisper to hidden state of the text model.""" + + def __init__(self, config: HiggsAudioConfig): + super().__init__() + self.linear = nn.Linear(config.audio_encoder_config.d_model, config.text_config.hidden_size, bias=True) + + def forward(self, audio_features): + hidden_states = self.linear(audio_features) + return hidden_states + + +# Revised on top of transformers.models.qwen2_audio.modeling_qwen2_audio with Qwen2AudioEncoder --> HiggsAudioEncoder +# The code was originally borrowed from WhisperEncoder +class HiggsAudioEncoder(HiggsAudioPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`WhisperEncoderLayer`]. + + Args: + config: HiggsAudioEncoderConfig + """ + + # Ignore copy + config_class = HiggsAudioEncoderConfig + main_input_name = "input_features" + _no_split_modules = ["WhisperEncoderLayer"] + + def __init__(self, config: HiggsAudioEncoderConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) + + self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) + self.embed_positions.requires_grad_(False) + + # Flash Attention 2 does not support zero shape tensor, so we have to use sdpa implementation for the Whisper component. + self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + # Ignore copy + self.avg_pooler = nn.AvgPool1d(2, stride=2) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + def forward( + self, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + check_seq_length=True, + ): + r""" + Args: + input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a + `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding + and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`torch.Tensor`)`, *optional*): + HiggsAudio does not support masking of the `input_features`, this argument is preserved for compatibility, + but it is not used. By default the silence in the input log mel spectrogram are ignored. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + + expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0] + if check_seq_length and (input_features.shape[-1] != expected_seq_length): + raise ValueError( + f"HiggsAudio expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}." + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Ignore copy + input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device) + + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == (len(self.layers)), ( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + # Ignore copy + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Ignore copy + hidden_states = hidden_states.permute(0, 2, 1) + # If the sequence length after average pooling is not divisible by the sequence parallel size, we would duplicate it across the sequence parallel ranks. + # In this case, gradients need to be scaled up because the subsequent scaling up in the function _apply_audio_tower is skipped. + hidden_states = self.avg_pooler(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 1) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + # Ignore copy + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + +class HiggsAudioDualFFNDecoderLayer(nn.Module): + """We implement a dual-path FFN decoder layer where the audio tokens and text tokens go through separate FFN layers. + + The audio and text tokens share the text-attention layer, but will be encoded with separate feedforward layers. + In addition, the audio tokens can be configured to go through separate attention layer. + + Following is an illustration: + + t t t a a a t t t + | + | (shared attention layer) + v + h_t h_t h_t h_a h_a h_a h_t h_t h_t + | + | (separate text/audio hidden states) + v + [h_t h_t h_t h_t h_t h_t], [h_a, h_a, h_a] + | | + | (separate FFNs) | + v v + [o_t o_t o_t o_t o_t o_t], [o_a, o_a, o_a] + | + | (reorder) + v + o_t o_t o_t o_a o_a o_a o_t o_t o_t + + This has a few advantages: + 1) We are able to use a smaller FFN, or even bypass the FFN for audio tokens. This accelerates the inference speed. + 2) The Audio-FFN introduces more trainable parameters to the model. + This should have the same effect as the mixture-of-expert layer and we may expect better performance due to parameter scaling. + 3) We can replace the original FFN in LLMs with the dual-path FFN without changing the number of FLOPs. + + + """ + + def __init__( + self, config: HiggsAudioConfig, layer_idx: int, fast_forward: bool = False, use_audio_attention: bool = False + ): + super().__init__() + text_config = config.text_config + self.hidden_size = text_config.hidden_size + self.layer_idx = layer_idx + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=text_config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(text_config) + + if not fast_forward: + if use_audio_attention: + self.audio_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( + config=text_config, layer_idx=layer_idx + 1 + ) + self.audio_post_audio_attn_layer_norm = LlamaRMSNorm( + text_config.hidden_size, eps=text_config.rms_norm_eps + ) + + self.audio_mlp = LlamaMLP(text_config) + self.audio_input_layernorm = LlamaRMSNorm(text_config.hidden_size, eps=text_config.rms_norm_eps) + self.audio_post_attention_layernorm = LlamaRMSNorm(text_config.hidden_size, eps=text_config.rms_norm_eps) + + self.use_audio_attention = use_audio_attention + self.fast_forward = fast_forward + if self.fast_forward: + assert not self.use_audio_attention, ( + "We cannot use audio_attention if the layer is marked as fast-forward." + ) + self.input_layernorm = LlamaRMSNorm(text_config.hidden_size, eps=text_config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(text_config.hidden_size, eps=text_config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + audio_attention_mask: Optional[torch.Tensor] = None, + fast_forward_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + audio_out_mask: Optional[torch.BoolTensor] = None, + is_decoding_audio_token: Optional[bool] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + is_using_cuda_graph: Optional[bool] = False, + **kwargs, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids + IDs of positions in the input sequence + audio_out_mask + Mask for identifying the audio tokens. Size (batch_size, sequence_length) + 1 --> location contains audio_out + 0 --> location does not contain audio_out + + When use_cache is True and not in torch compile mode, the audio_out_mask contains audio_out masks for + all tokens up to the current token. That means, it has size (batch_size, sequence_length) while + hidden_states will have size (batch_size, 1). In the torch compile mode, the audio_out_mask will have + size (batch_size, 1). + is_decoding_audio_token + Used in the torch compile mode to determine if the current token is an audio token or not. + past_key_value (`Cache`, *optional*): cached past key and value projection states. We fetch the corresponding cached key/value via the layer_idx. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + is_using_cuda_graph (`bool`, *optional*): + Indicates whether the model is running by cuda graph. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + target_length = hidden_states.shape[1] + use_static_cache = isinstance(past_key_value, StaticCache) + decode_stage = hidden_states.shape[1] == 1 + if is_using_cuda_graph: + assert decode_stage and use_static_cache, ( + "The CUDA graph mode should only be used in the decoding stage with static cache." + ) + + # If we are decoding an audio token and the layer is marked as fast-forward, + # we can skip it. + if is_decoding_audio_token and self.fast_forward: + return (hidden_states,) + + has_audio_out = audio_out_mask is not None and audio_out_mask.shape[0] > 0 + + audio_out_mask_sq = audio_out_mask + + if self.fast_forward and has_audio_out: + original_hidden_states = hidden_states.clone() + min_dtype = torch.finfo(hidden_states.dtype).min + if attention_mask is None: + attention_mask = ~audio_out_mask + + if self.self_attn.config._attn_implementation != "flash_attention_2": + sequence_length = audio_out_mask.shape[1] + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask=attention_mask, + sequence_length=sequence_length, + target_length=sequence_length, + dtype=hidden_states.dtype, + min_dtype=min_dtype, + device=hidden_states.device, + cache_position=cache_position, + batch_size=hidden_states.shape[0], + ) + if use_cache: + attention_mask = attention_mask[:, :, -target_length:, :] + elif len(attention_mask.shape) == 2: + # Attention mask has shape (batch_size, sequence_length) + # We should be using flash attention 2 + attention_mask = attention_mask * ~audio_out_mask + elif len(attention_mask.shape) == 4: + # When using static cache, the attention mask was already preprocessed in the previous layer + if use_static_cache: + attention_mask = fast_forward_attention_mask + else: + if use_cache: + # Attention mask has shape (batch_size, 1, query_length, key_length) + # In addition, the attention mask should be inverted, that means "1" (attend_to) --> "0", and "0" --> minimal dtype value. + attention_mask = attention_mask.masked_fill( + audio_out_mask[:, -target_length:].reshape(audio_out_mask.shape[0], 1, target_length, 1) + | audio_out_mask.reshape(audio_out_mask.shape[0], 1, 1, audio_out_mask.shape[1]), + min_dtype, + ) + else: + attention_mask = attention_mask.masked_fill( + audio_out_mask.reshape(audio_out_mask.shape[0], 1, audio_out_mask.shape[1], 1) + | audio_out_mask.reshape(audio_out_mask.shape[0], 1, 1, audio_out_mask.shape[1]), + min_dtype, + ) + else: + raise NotImplementedError(f"Unsupported attention_mask format, attention_mask={attention_mask}") + + if ( + self.self_attn.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype) + + if has_audio_out and not self.fast_forward: + # Apply separate layernorm layers for audio tokens and text tokens + if use_cache: + hidden_states = torch.where( + audio_out_mask_sq[:, -target_length:].unsqueeze(-1), + self.audio_input_layernorm(hidden_states), + self.input_layernorm(hidden_states), + ) + else: + hidden_states = torch.where( + audio_out_mask_sq.unsqueeze(-1), + self.audio_input_layernorm(hidden_states), + self.input_layernorm(hidden_states), + ) + else: + hidden_states = self.input_layernorm(hidden_states) + + # Audio Attention + if self.use_audio_attention and has_audio_out: + if use_static_cache: + assert audio_attention_mask is not None, ( + "audio_attention_mask should not be None when using static cache." + ) + + if audio_attention_mask is None: + no_audio_out_mask = (~audio_out_mask)[:, -target_length:].reshape( + audio_out_mask.shape[0], 1, target_length, 1 + ) | (~audio_out_mask).reshape(audio_out_mask.shape[0], 1, 1, audio_out_mask.shape[1]) + min_dtype = torch.finfo(hidden_states.dtype).min + + if attention_mask is None: + audio_attention_mask = audio_out_mask + + if self.audio_attn.config._attn_implementation != "flash_attention_2": + sequence_length = audio_out_mask.shape[1] + audio_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask=audio_attention_mask, + sequence_length=sequence_length, + target_length=sequence_length, + dtype=hidden_states.dtype, + min_dtype=min_dtype, + device=hidden_states.device, + cache_position=cache_position, + batch_size=hidden_states.shape[0], + ) + if use_cache: + audio_attention_mask = audio_attention_mask[:, :, -target_length:, :] + audio_attention_mask = audio_attention_mask.masked_fill(no_audio_out_mask, min_dtype) + elif len(attention_mask.shape) == 2: + # Attention mask has shape (batch_size, sequence_length) + audio_attention_mask = attention_mask * audio_out_mask + elif len(attention_mask.shape) == 4: + # Attention mask has shape (batch_size, 1, query_length, key_length) + # In addition, the attention mask should be inverted. This means "1" (attend_to) --> "0", and "0" --> minimal dtype value. + audio_attention_mask = attention_mask.masked_fill(no_audio_out_mask, min_dtype) + else: + raise NotImplementedError(f"Unsupported attention_mask format, attention_mask={attention_mask}") + + if ( + self.audio_attn.config._attn_implementation == "sdpa" + and audio_attention_mask is not None + and audio_attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + audio_attention_mask = AttentionMaskConverter._unmask_unattended(audio_attention_mask, min_dtype) + + audio_attention_mask = audio_attention_mask.contiguous() + + audio_hidden_states, audio_self_attn_weights, audio_present_key_value = self.audio_attn( + hidden_states=hidden_states, + attention_mask=audio_attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + audio_hidden_states = residual + audio_hidden_states + if use_cache: + residual = torch.where( + audio_out_mask_sq[:, -target_length:].unsqueeze(-1), audio_hidden_states, residual + ) + else: + residual = torch.where(audio_out_mask_sq.unsqueeze(-1), audio_hidden_states, residual) + audio_hidden_states = self.audio_post_audio_attn_layer_norm(audio_hidden_states) + if use_cache: + hidden_states = torch.where( + audio_out_mask_sq[:, -target_length:].unsqueeze(-1), audio_hidden_states, hidden_states + ) + else: + hidden_states = torch.where(audio_out_mask_sq.unsqueeze(-1), audio_hidden_states, hidden_states) + + # Text Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Apply Dual-path FFN + residual = hidden_states + + if has_audio_out and not self.fast_forward: + if use_cache: + real_audio_out_mask = audio_out_mask_sq[:, -target_length:] + else: + real_audio_out_mask = audio_out_mask_sq + + # Make whole graph in decode stage + if decode_stage and is_using_cuda_graph: + assert is_decoding_audio_token is not None, ( + "is_decoding_audio_token should be present in the decoding stage." + ) + if is_decoding_audio_token: + hidden_states = self.audio_post_attention_layernorm(hidden_states) + hidden_states = self.audio_mlp(hidden_states) + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + residual = residual + hidden_states + else: + text_hidden_states = self.post_attention_layernorm(hidden_states[~real_audio_out_mask]) + audio_hidden_states = self.audio_post_attention_layernorm(hidden_states[real_audio_out_mask]) + + text_hidden_states = self.mlp(text_hidden_states) + residual[~real_audio_out_mask] += text_hidden_states + + audio_hidden_states = self.audio_mlp(audio_hidden_states) + residual[real_audio_out_mask] += audio_hidden_states + + hidden_states = residual + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + if self.fast_forward and has_audio_out: + if use_cache: + hidden_states = torch.where( + audio_out_mask_sq[:, -target_length:].unsqueeze(-1), original_hidden_states, hidden_states + ) + else: + hidden_states = torch.where(audio_out_mask_sq.unsqueeze(-1), original_hidden_states, hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + if self.use_audio_attention: + # The returned attn weights have shape (batch_size, num_heads + num_audio_attn_heads, seq_length, seq_length) + outputs += (torch.concat([self_attn_weights, audio_self_attn_weights], dim=1),) + else: + # The returned attn weights have shape (batch_size, num_heads, seq_length, seq_length) + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@dataclass +class HiggsAudioModelOutputWithPast(ModelOutput): + loss: Optional[torch.FloatTensor] = None + llm_loss: Optional[torch.FloatTensor] = None + audio_loss: Optional[torch.FloatTensor] = None + codebook_losses: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + expanded_input_ids: Optional[torch.LongTensor] = None + expanded_labels: Optional[torch.LongTensor] = None + audio_in_mask: Optional[torch.BoolTensor] = None + audio_in_discrete_codes_mask: Optional[torch.BoolTensor] = None + audio_out_mask: Optional[torch.BoolTensor] = None + attention_mask: Optional[torch.BoolTensor] = None + audio_logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + audio_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class HiggsAudioGenerationOutput(ModelOutput): + """ + Outputs of HiggsAudio generation models, when using non-beam methods. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + audio_sequences (`tuple(torch.LongTensor)` *optional*): + The generated discrete audio codes. These codes can be used to fill-in related locations of <|AUDIO_OUT|> at input sequences. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token). + If the generated token is a text token, the tensor will have shape `(batch_size, config.vocab_size)`. + If the generated token is an audio token, the tensor will have shape `(config.audio_num_codebooks, self.audio_codebook_size)` + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): + Unprocessed prediction scores of the language modeling head or the audio head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token). + If the generated token is a text token, the tensor will have shape `(batch_size, config.vocab_size)`. + If the generated token is an audio token, the tensor will have shape `(config.audio_num_codebooks, self.audio_codebook_size)` + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. + """ + + sequences: torch.LongTensor = None + audio_sequences: Optional[List[torch.LongTensor]] = None + scores: Optional[Tuple[torch.FloatTensor]] = None + logits: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None + + +class HiggsAudioModel(HiggsAudioPreTrainedModel, GenerationMixin): + """Higgs-Audio is an end-to-end multimodal model with the capability to understand and generate text / audio. + + Consider the following example for mixed text/audio understanding / generation: + + - input_tokens: <|audio_bos|>[AUDIO]<|audio_eos|><|audio_bos|>[AUDIO]<|audio_eos|> + - input_tokens: <|audio_bos|>[AUDIO]<|audio_eos|><|audio_out_bos|>[AUDIO_OUT]<|audio_eos|> + + We will fill [AUDIO] with the audio features extracted by Whisper and fill [AUDIO_OUT] with the audio tokens. + + Consider the following example for mixed text/audio generation: + + text: <|audio_out_bos|> MASK MASK MASK MASK MASK <|audio_eos|> [text_token1] + audio: MASK <|audio_stream_bos|> [audio_token1] [audio_token2] [audio_token3] <|audio_stream_eos|> MASK MASK + token_type: 0 1 1 1 1 1 0 0 + + """ + + _supports_cache_class = True + _supports_static_cache = True + + def __init__(self, config: HiggsAudioConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.audio_in_token_idx = config.audio_in_token_idx + self.audio_out_token_idx = config.audio_out_token_idx + self.audio_out_bos_token_id = config.audio_out_bos_token_id if "audio_out_bos_token_id" in config else None + self.audio_eos_token_id = config.audio_eos_token_id if "audio_eos_token_id" in config else None + self.vocab_size = config.text_config.vocab_size + self.audio_num_codebooks = config.audio_num_codebooks + self.use_delay_pattern = config.use_delay_pattern + self.use_audio_out_embed_projector = config.use_audio_out_embed_projector + self.use_audio_out_self_attention = config.use_audio_out_self_attention + + self.embed_tokens = nn.Embedding(self.vocab_size, config.text_config.hidden_size, self.padding_idx) + + if config.audio_adapter_type == "dual_ffn": + layer_idx = 0 + layers = [] + for j in range(config.text_config.num_hidden_layers): + if j in config.audio_dual_ffn_layers: + layers.append( + HiggsAudioDualFFNDecoderLayer( + config, layer_idx, use_audio_attention=self.use_audio_out_self_attention + ) + ) + layer_idx += 2 if self.use_audio_out_self_attention else 1 + else: + layers.append(LlamaDecoderLayer(config.text_config, layer_idx)) + layer_idx += 1 + self.layers = nn.ModuleList(layers) + elif config.audio_adapter_type == "dual_ffn_fast_forward": + layer_idx = 0 + layers = [] + for j in range(config.text_config.num_hidden_layers): + if j in config.audio_dual_ffn_layers: + layers.append( + HiggsAudioDualFFNDecoderLayer( + config, + layer_idx, + fast_forward=False, + use_audio_attention=self.use_audio_out_self_attention, + ) + ) + layer_idx += 2 if self.use_audio_out_self_attention else 1 + else: + layers.append( + HiggsAudioDualFFNDecoderLayer(config, layer_idx, fast_forward=True, use_audio_attention=False) + ) + layer_idx += 1 + self.layers = nn.ModuleList(layers) + elif config.audio_adapter_type == "stack": + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer(config.text_config, layer_idx) + for layer_idx in range(config.text_config.num_hidden_layers) + ] + ) + layer_idx = config.text_config.num_hidden_layers + else: + raise NotImplementedError(f"Audio adapter type {config.audio_adapter_type} not implemented.") + + self.num_activation_checkpointing_layers = len(self.layers) + + self.decode_graph_runners = defaultdict(dict[bool, CUDAGraphRunner]) + self.norm = LlamaRMSNorm(config.text_config.hidden_size, eps=config.text_config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config.text_config) + + if not config.skip_audio_tower: + self.audio_tower = HiggsAudioEncoder(config.audio_encoder_config) + self.audio_encoder_proj = HiggsAudioFeatureProjector(config) + else: + self.audio_tower = None + self.audio_encoder_proj = None + self.audio_decoder_proj = HiggsAudioDecoderProjector(config, layer_idx=layer_idx) + self.audio_codebook_size = ( + config.audio_codebook_size + 2 + ) # We add 1 for the audio_stream_bos token and 1 for the audio_stream_eos token + + if config.use_audio_out_embed_projector: + self.audio_out_embed_projector = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=False + ) + + self.audio_codebook_embeddings = nn.Embedding( + config.audio_num_codebooks * self.audio_codebook_size, config.text_config.hidden_size + ) + + self.audio_codebook_weights = ( + torch.ones(config.audio_num_codebooks) / config.audio_num_codebooks + ) # default to equal weights + self.post_init() + + def set_num_activation_checkpointing_layers(self, num_layers): + self.num_activation_checkpointing_layers = num_layers + + def set_delay_pattern(self): + self.config.use_delay_pattern = True + self.use_delay_pattern = True + + def set_audio_special_tokens(self, tokenizer: AutoTokenizer): + self.audio_out_bos_token_id = tokenizer.convert_tokens_to_ids("<|audio_out_bos|>") + self.audio_eos_token_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>") + + def _embed_audio_ids(self, audio_ids): + """Embed the audio ids + + Args: + audio_ids: torch.LongTensor of shape (num_codebooks, audio_in_total_length) + + Returns: + audio_embed: torch.LongTensor of shape (audio_in_total_length, hidden_size) + """ + codebook_shift = ( + torch.arange(self.config.audio_num_codebooks, device=audio_ids.device) * self.audio_codebook_size + ) + audio_embed = self.audio_codebook_embeddings(audio_ids + codebook_shift.unsqueeze(-1)) + if self.config.audio_embed_avg: + audio_embed = torch.mean(audio_embed, dim=0) + else: + audio_embed = torch.sum(audio_embed, dim=0) + if self.use_audio_out_embed_projector: + audio_embed = self.audio_out_embed_projector(audio_embed) + return audio_embed + + def _apply_audio_tower(self, audio_features, audio_feature_attention_mask): + """Apply the audio tower to the audio features""" + + if audio_features.shape[0] == 0: + if torch.is_grad_enabled(): + # FIXME!!!!!!!! + # This is a hack to ensure that the forward+backward pass of audio_tower and audio_encoder_proj get triggered. + # The monkey patch won't overwrite the backward pass of nn.Module. + audio_outputs = _whisper_encoder_zero_shape_forward( + self.audio_tower, audio_features, attention_mask=None, check_seq_length=False + ) + selected_audio_feature = audio_outputs.last_hidden_state + audio_features_embed = self.audio_encoder_proj(selected_audio_feature) + audio_feat_out_lengths = None + return audio_features_embed, audio_feat_out_lengths + else: + return None, None + + audio_feat_lengths, audio_feat_out_lengths = self.audio_tower._get_feat_extract_output_lengths( + audio_feature_attention_mask.sum(-1) + ) + batch_size, _, max_mel_seq_len = audio_features.shape + max_seq_len = (max_mel_seq_len - 1) // 2 + 1 + # Create a sequence tensor of shape (batch_size, max_seq_len) + seq_range = ( + torch.arange(0, max_seq_len, dtype=audio_feat_lengths.dtype, device=audio_feat_lengths.device) + .unsqueeze(0) + .expand(batch_size, max_seq_len) + ) + lengths_expand = audio_feat_lengths.unsqueeze(1).expand(batch_size, max_seq_len) + # Create mask + padding_mask = seq_range < lengths_expand + + if self.config._attn_implementation != "flash_attention_2": + audio_attention_mask = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( + batch_size, 1, max_seq_len, max_seq_len + ) + else: + audio_attention_mask = padding_mask + + audio_outputs = self.audio_tower(audio_features, attention_mask=audio_attention_mask) + selected_audio_feature = audio_outputs.last_hidden_state + audio_features_embed = self.audio_encoder_proj(selected_audio_feature) + + return audio_features_embed, audio_feat_out_lengths + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + def _prepare_all_static_kv_cache_masks(self, hidden_states, attention_mask, audio_out_mask, past_key_values): + target_length = hidden_states.shape[1] + cur_pos = audio_out_mask.shape[1] + min_dtype = torch.finfo(hidden_states.dtype).min + assert len(attention_mask.shape) == 4, "Only support SDPA for now" + kv_cache_len = past_key_values.get_max_cache_shape() + audio_out_mask_padded = torch.nn.functional.pad(audio_out_mask, (0, kv_cache_len - cur_pos), value=True) + fast_forward_attention_mask = attention_mask.masked_fill( + audio_out_mask_padded[:, audio_out_mask.shape[1] - target_length : audio_out_mask.shape[1]].reshape( + audio_out_mask_padded.shape[0], 1, target_length, 1 + ) + | audio_out_mask_padded.reshape(audio_out_mask_padded.shape[0], 1, 1, audio_out_mask_padded.shape[1]), + min_dtype, + ) + + no_audio_out_mask = ~audio_out_mask + no_audio_out_mask = torch.nn.functional.pad( + no_audio_out_mask, (0, kv_cache_len - audio_out_mask.shape[1]), value=False + ) + no_audio_out_mask = no_audio_out_mask[ + :, audio_out_mask.shape[1] - target_length : audio_out_mask.shape[1] + ].reshape(audio_out_mask.shape[0], 1, target_length, 1) | no_audio_out_mask.reshape( + audio_out_mask.shape[0], 1, 1, kv_cache_len + ) + audio_attention_mask = attention_mask.masked_fill(no_audio_out_mask, min_dtype) + return fast_forward_attention_mask, audio_attention_mask + + def _forward_core( + self, + hidden_states: torch.Tensor, + causal_mask: torch.Tensor, + position_ids: torch.Tensor, + audio_discrete_codes_mask: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]], + use_cache: bool, + audio_attention_mask: torch.Tensor, + fast_forward_attention_mask: torch.Tensor, + output_attentions: bool, + output_hidden_states: bool, + is_decoding_audio_token: Optional[bool] = None, + is_using_cuda_graph: Optional[bool] = False, + ): + # create position embeddings to be shared across the decoder layers + # When past_key_values is passed in, we need to offset the position ids when calculating the position embeddings. + # Therefore, cache_position is used. + position_id_offset = cache_position[0] if use_cache else 0 + position_embeddings = self.rotary_emb(hidden_states, position_ids + position_id_offset) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + if isinstance(decoder_layer, HiggsAudioDualFFNDecoderLayer): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + audio_attention_mask=audio_attention_mask, + fast_forward_attention_mask=fast_forward_attention_mask, + position_ids=position_ids, + audio_out_mask=audio_discrete_codes_mask, + is_decoding_audio_token=is_decoding_audio_token, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + is_using_cuda_graph=is_using_cuda_graph, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + return hidden_states, all_hidden_states, all_self_attns + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + audio_features: Optional[torch.FloatTensor] = None, + audio_feature_attention_mask: Optional[torch.BoolTensor] = None, + audio_in_ids: Optional[torch.LongTensor] = None, + audio_in_ids_start: Optional[torch.LongTensor] = None, + audio_out_ids: Optional[torch.LongTensor] = None, + audio_out_ids_start: Optional[torch.LongTensor] = None, + audio_out_ids_start_group_loc: Optional[torch.LongTensor] = None, + label_ids: Optional[torch.LongTensor] = None, + label_audio_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_audio_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + cache_audio_discrete_codes_mask: Optional[torch.LongTensor] = None, + past_key_values_buckets: Optional[OrderedDict[int, Cache]] = None, + reward: Optional[torch.FloatTensor] = None, + ): + """Forward pass for the Higgs-Audio model. + + Args: + input_ids (:obj:`torch.LongTensor`): + The input ids of the prompt. It will have shape (bsz, seq_len). + When use_cache is enabled, the input_ids will have + shape (bsz, 1) for incremental decode or None + inputs_embeds: + Input embeddings. This flag won't be used. + attention_mask (:obj:`torch.LongTensor`): + The attention mask of the prompt. It will have shape (bsz, seq_len). + audio_features (:obj:`torch.FloatTensor`): + The audio features extracted by Whisper. It will have shape (num_audio_in, feature_dim, max_mel_seq_len). + audio_feature_attention_mask (:obj:`torch.LongTensor`): + The attention mask of the audio features. It will have shape (num_audio_in, max_mel_seq_len). + audio_in_ids (:obj:`torch.LongTensor`): + The discretized audio tokens. It will have shape (num_codebooks, audio_in_total_length). + audio_in_ids_start (:obj:`torch.LongTensor`): + The start indices for each audio in audio_in_ids. It will have shape (num_audio_in,) + audio_out_ids (:obj:`torch.LongTensor`): + The discretized audio tokens. It will have shape (num_codebooks, audio_out_total_length). + audio_out_ids_start (:obj:`torch.LongTensor`): + The start indices for each audio in audio_out_ids. It will have shape (num_audio_out,) + audio_out_ids_start_group_loc (:obj:`torch.LongTensor`): + The sample indices in a batch that map to each element in the audio_out_ids_start. It will have shape (num_audio_out,) + label_text_ids (:obj:`torch.LongTensor`): + The labels of the prompt. It will have shape (bsz, seq_len). + label_audio_ids (:obj:`torch.LongTensor`): + The labels of the audio tokens. It will have the same shape as audio_out_ids, i.e., (num_codebooks, audio_out_total_length) + past_key_values (:obj:`Tuple`): + Tuple of past key values. + use_cache (:obj:`bool`): + Whether to use cache. + output_attentions (:obj:`bool`): + Whether to output attentions. + output_hidden_states (:obj:`bool`): + Whether to output hidden states. + output_audio_hidden_states (:obj:`bool`): + Whether to output audio hidden states. + return_dict (:obj:`bool`): + Whether to return a dictionary. + cache_position (:obj:`torch.LongTensor`): + The position of the cache. + cache_audio_discrete_codes_mask (:obj:`torch.LongTensor`): + The cached audio discrete codes mask. It will only be used when use_cache is turned on. + past_key_values_buckets (:obj:`OrderedDict`): + The buckets of past key values. + """ + target_device = input_ids.device + + # not used + del inputs_embeds + + if audio_features is not None: + audio_features = audio_features.to(target_device) + audio_feature_attention_mask = audio_feature_attention_mask.to(target_device) + + # 1. Extract the input embeddings + inputs_embeds = self.embed_tokens(input_ids) + + # 2. Extract audio embeddings + if self.config.skip_audio_tower: + audio_features_embed = audio_features_length = None + else: + audio_features_embed, audio_features_length = self._apply_audio_tower( + audio_features, audio_feature_attention_mask + ) + + if self.config.encode_audio_in_tokens: + if audio_in_ids is not None and audio_in_ids.shape[-1] > 0: + audio_in_ids = audio_in_ids.to(target_device) + else: + audio_in_ids = torch.zeros((self.audio_num_codebooks, 0), device=target_device, dtype=torch.long) + audio_in_embed = self._embed_audio_ids(audio_in_ids) + else: + audio_in_embed = None + + if audio_out_ids is not None and audio_out_ids.shape[-1] > 0: + audio_out_ids = audio_out_ids.to(target_device) + else: + audio_out_ids = torch.zeros((self.audio_num_codebooks, 0), device=target_device, dtype=torch.long) + audio_out_embed = self._embed_audio_ids(audio_out_ids) + + # 3. Merge text, audio-in embeddings, and audio-out embeddings + + # use_cache is turned on during inference time, we should set round_to to 1 to avoid extra padding in the end. + round_to = 1 if use_cache else 8 + left_padding = True if use_cache or input_ids.shape[0] == 1 else False + ( + inputs_embeds, + attention_mask, + labels, + position_ids, + input_ids, + audio_in_mask, + audio_in_discrete_codes_mask, + audio_out_mask, + ) = merge_input_ids_with_audio_features( + audio_features_embed, + audio_features_length, + audio_in_embed, + audio_in_ids_start, + audio_out_embed, + audio_out_ids_start, + self.audio_in_token_idx, + self.audio_out_token_idx, + inputs_embeds, + input_ids, + attention_mask, + label_ids, + pad_token_id=self.padding_idx, + round_to=round_to, + left_padding=left_padding, + ) + + # re-check if we use the correct kv cache bucket after + # the input_embeds has been merged with audio features + if past_key_values_buckets is not None and inputs_embeds.shape[1] > past_key_values.get_max_cache_shape(): + past_key_values, self.current_past_key_values_bucket = self._prepare_kv_cache( + inputs_embeds.shape[1], None, past_key_values_buckets + ) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if isinstance(past_key_values, StaticCache) and past_seen_tokens >= past_key_values.get_max_cache_shape(): + raise ValueError( + f"The current sequence length ({past_seen_tokens}) exceeds " + f"the maximum cache shape. " + f"Please consider increasing the cache size." + ) + + # Use torch compile + use_static_cache = isinstance(past_key_values, StaticCache) + + # Apply the LLM component + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + audio_discrete_codes_mask = audio_in_discrete_codes_mask | audio_out_mask + if cache_audio_discrete_codes_mask is not None and use_cache: + audio_discrete_codes_mask = torch.concat( + [cache_audio_discrete_codes_mask, audio_discrete_codes_mask], dim=1 + ) + + # Generate the audio attention mask outside the layer to avoid recompilation + if use_static_cache: + fast_forward_attention_mask, audio_attention_mask = self._prepare_all_static_kv_cache_masks( + hidden_states, causal_mask, audio_discrete_codes_mask, past_key_values + ) + # Set the audio out mask to the last token + if hidden_states.shape[1] == 1: + audio_discrete_codes_mask = audio_discrete_codes_mask[:, -1:] + audio_discrete_codes_mask = audio_discrete_codes_mask.reshape((-1, 1)).contiguous() + is_decoding_audio_token = audio_discrete_codes_mask.item() + else: + is_decoding_audio_token = False + + # Use the captured cuda graph runner for decoding + # if it exists, otherwise use the normal forward pass + if ( + past_key_values is not None + and past_key_values.get_max_cache_shape() in self.decode_graph_runners + and (input_ids.shape[-1] == 1) + ): + _forward_core = self.decode_graph_runners[past_key_values.get_max_cache_shape()][is_decoding_audio_token] + is_using_cuda_graph = True + else: + _forward_core = self._forward_core + is_using_cuda_graph = False + + hidden_states, all_hidden_states, all_self_attns = _forward_core( + hidden_states=hidden_states, + causal_mask=causal_mask, + position_ids=position_ids, + audio_discrete_codes_mask=audio_discrete_codes_mask, + is_decoding_audio_token=is_decoding_audio_token if use_static_cache else None, + cache_position=cache_position, + past_key_values=past_key_values, + use_cache=use_cache, + audio_attention_mask=audio_attention_mask if use_static_cache else None, + fast_forward_attention_mask=fast_forward_attention_mask if use_static_cache else None, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + is_using_cuda_graph=is_using_cuda_graph, + ) + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # Apply the audio decoder projector + logits, audio_logits, decoder_all_self_attns, decoder_all_hidden_states, audio_hidden_states, _ = ( + self.audio_decoder_proj( + hidden_states, + audio_out_mask, + label_audio_ids=label_audio_ids, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_audio_hidden_states=output_audio_hidden_states, + cache_position=cache_position, + ) + ) + + if audio_logits is not None: + audio_logits = audio_logits.view( + audio_logits.shape[0], self.audio_num_codebooks, self.audio_codebook_size + ).float() + + if output_hidden_states: + if decoder_all_hidden_states is not None and len(decoder_all_hidden_states) > 1: + all_hidden_states += decoder_all_hidden_states[1:] + + if output_attentions: + all_self_attns += decoder_all_self_attns + + next_cache = past_key_values if use_cache else None + + ret = HiggsAudioModelOutputWithPast( + logits=logits, + audio_logits=audio_logits, + expanded_input_ids=input_ids, + expanded_labels=labels, + audio_in_mask=audio_in_mask, + audio_in_discrete_codes_mask=audio_in_discrete_codes_mask, + audio_out_mask=audio_out_mask, + attention_mask=attention_mask, + past_key_values=next_cache, + hidden_states=all_hidden_states, + audio_hidden_states=audio_hidden_states, + attentions=all_self_attns, + ) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if not return_dict: + outputs = ret.to_tuple() + return outputs + + return ret + + # Overwrite GenerationMixin._update_model_kwargs_for_generation + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + extend_attention_mask: bool = True, + ) -> Dict[str, Any]: + """Update the model kwargs for each step.""" + model_kwargs["past_key_values"] = outputs.past_key_values + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + if extend_attention_mask: + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + if "cache_audio_discrete_codes_mask" in model_kwargs: + if model_kwargs["cache_audio_discrete_codes_mask"] is None: + model_kwargs["cache_audio_discrete_codes_mask"] = ( + outputs.audio_in_discrete_codes_mask | outputs.audio_out_mask + ) + else: + model_kwargs["cache_audio_discrete_codes_mask"] = torch.concat( + [ + model_kwargs["cache_audio_discrete_codes_mask"], + outputs.audio_in_discrete_codes_mask | outputs.audio_out_mask, + ], + 1, + ) + + return model_kwargs + + def _copy_kv_cache(self, from_cache: Cache, to_cache: Cache): + num_layers = self.config.text_config.num_hidden_layers + if self.config.audio_dual_ffn_layers is not None: + num_layers += len(self.config.audio_dual_ffn_layers) + """ Copy the key-value pairs from one cache to another. """ + for layer_idx in range(num_layers): + from_cache_size = from_cache.get_max_cache_shape() + assert to_cache.get_max_cache_shape() >= from_cache_size, ( + f"The target cache size {to_cache.get_max_cache_shape()} is smaller than the source cache size {from_cache_size}." + ) + to_cache.key_cache[layer_idx][:, :, :from_cache_size, :] = from_cache.key_cache[layer_idx] + to_cache.value_cache[layer_idx][:, :, :from_cache_size, :] = from_cache.value_cache[layer_idx] + + def _prepare_kv_cache( + self, + current_sequence_length: int, + current_past_key_values_bucket: Optional[int], + past_key_values_buckets: OrderedDict[int, Cache], + ) -> Tuple[Optional[Cache], Optional[int]]: + """Prepare the KV cache for the current sequence length.""" + for cache_length in past_key_values_buckets.keys(): + if cache_length >= current_sequence_length: + # Promote to the next KV cache bucket, copy the current KV cache bucket + # to the new one. + if current_past_key_values_bucket is not None and cache_length != current_past_key_values_bucket: + self._copy_kv_cache( + past_key_values_buckets[current_past_key_values_bucket], past_key_values_buckets[cache_length] + ) + + return past_key_values_buckets[cache_length], cache_length + + raise ValueError( + f"The current sequence length {current_sequence_length} is larger than " + f"all past key values buckets {past_key_values_buckets.keys()}." + ) + + def _sample_audio_tokens( + self, + hidden_states: torch.Tensor, + audio_logits: torch.Tensor, + audio_out_ids: torch.Tensor, + do_sample: bool, + logits_processor: LogitsProcessorList, + device: torch.device, + torch_generator: Optional[torch.Generator], + generation_config: GenerationConfig, + num_delay: int, + num_remaining_delays: Optional[int], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[int]]: + """Sample audio tokens and its corresponding text tokens from the logits""" + + # parameters related to repetition aware sampling + ras_win_len = generation_config.generation_kwargs.get("ras_win_len", None) + ras_win_max_num_repeat = generation_config.generation_kwargs.get("ras_win_max_num_repeat", 2) + audio_eos_token_id = generation_config.generation_kwargs.get("audio_eos_token_id", None) + # In the audio generation mode, we sample from audio_logits and keep updating audio_out_ids. + next_audio_token_logits = audio_logits.clone()[-1, :, :].float().to(device) + # TopP, TopK logits processor supports empty input_ids + next_audio_token_scores = logits_processor(None, next_audio_token_logits) + + # token selection + if do_sample: + # next_audio_token_scores has been applied top_p, top_k, and temperature. + probs = nn.functional.softmax(next_audio_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution + next_audio_tokens = torch.multinomial(probs, num_samples=1, generator=torch_generator).squeeze(1) + else: + next_audio_tokens = torch.argmax(next_audio_token_scores, dim=-1) + + # next_tokens: (num_codebooks, ) + if ras_win_len is not None: + # check if there are repetitions over a window of tokens. + rep_num = (audio_out_ids[:, -ras_win_len:] == next_audio_tokens.unsqueeze(1)).sum(dim=1) + + # if we saw repeated tokens in the most recent window of tokens, resample without temperature. + row_indices = torch.nonzero(rep_num >= ras_win_max_num_repeat).squeeze(1) + resampled_next_tokens = ( + next_audio_token_logits[row_indices] + .softmax(dim=-1) + .multinomial(1, replacement=True, generator=torch_generator) + .squeeze(1) + ) + next_audio_tokens[row_indices] = resampled_next_tokens + + # Force the next text tokens to be <|AUDIO_OUT|> in audio generation mode + next_tokens = torch.full( + (audio_logits.shape[0],), + self.config.audio_out_token_idx, + dtype=torch.long, + device=device, + ) + + # Handle delay_pattern + if self.use_delay_pattern: + if num_delay + 1 < next_audio_tokens.shape[0]: + next_audio_tokens[(num_delay + 1) :] = self.config.audio_stream_bos_id + num_delay += 1 + if num_remaining_delays is not None: + next_audio_tokens[: (self.audio_num_codebooks - num_remaining_delays)] = ( + self.config.audio_stream_eos_id + ) + num_remaining_delays -= 1 + else: + all_eos_indices = (next_audio_tokens == self.config.audio_stream_eos_id).nonzero() + if torch.numel(all_eos_indices) > 0: + all_eos_indices = all_eos_indices[0] + last_eos_idx = all_eos_indices[-1] + next_audio_tokens[:last_eos_idx] = self.config.audio_stream_eos_id + num_remaining_delays = self.audio_num_codebooks - last_eos_idx - 1 + if num_remaining_delays is not None and num_remaining_delays <= 0: + next_tokens[...] = audio_eos_token_id + num_delay = 0 + num_remaining_delays = None + + return ( + next_tokens, + next_audio_tokens, + next_audio_token_logits, + next_audio_token_scores, + num_delay, + num_remaining_delays, + ) + + def _sample_text_tokens( + self, + logits: torch.Tensor, + input_ids: torch.Tensor, + do_sample: bool, + logits_processor: LogitsProcessorList, + device: torch.device, + generation_mode: GenerationMode, + torch_generator: Optional[torch.Generator], + ) -> torch.Tensor: + """Sample text tokens from the logits""" + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = logits.clone()[:, -1, :].float() + next_token_logits = next_token_logits.to(input_ids.device) + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + + if generation_mode == GenerationMode.AUDIO_INIT: + # See the audio bos token, we should start generating audio tokens + next_tokens = torch.full( + (input_ids.shape[0],), + self.audio_out_token_idx, + dtype=torch.long, + device=device, + ) + next_audio_tokens = torch.full( + (self.config.audio_num_codebooks,), + self.config.audio_stream_bos_id, + dtype=torch.long, + device=device, + ) + else: + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution + next_tokens = torch.multinomial(probs, num_samples=1, generator=torch_generator).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + next_audio_tokens = None + + return next_tokens, next_audio_tokens, next_token_logits, next_token_scores + + # Built on top of GenerationMixin._sample. + # We revise the implementation to support generating both audio / text. + def _sample( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + past_key_values_buckets: Optional[OrderedDict[int, Cache]], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for joint text/audio models using **multinomial sampling**. + + This function may also be revised to support generating samples from HiggsAudio-like end-to-end text/audio models built on top of LLMs. + If the input_ids ends with <|audio_out_bos|>, we will switch to the audio-generation mode. + + ``` + ...<|start_header_id|>assistant<|end_header_id|>\n\n<|audio_out_bos|> + ``` + + Otherwise, we will keep generating the text tokens. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + assert input_ids.shape[0] == 1, "Only support batch_size=1 in _sample()" + audio_out_bos_token_id = generation_config.generation_kwargs.get("audio_out_bos_token_id", None) + + # torch generator for sampling + seed = generation_config.generation_kwargs.get("seed", None) + if seed is not None: + torch_generator = torch.Generator(device=input_ids.device).manual_seed(seed) + else: + torch_generator = None + + # init values + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + # Used to track which past_key_va + self.current_past_key_values_bucket = None + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # keep track of which sequences are already finished + batch_size, cur_len = input_ids.shape + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + if generation_config.use_cache: + model_kwargs["cache_audio_discrete_codes_mask"] = None + + init_model_input = True + num_delay = 0 + num_remaining_delays = None + audio_sequences = [] + # A tensor to keep track of all the audio placeholder tokens. + input_ids_full = input_ids.clone() + + # Initialize the audio variables based on the input prompt. + if input_ids[0][-1] == self.config.audio_out_token_idx: + audio_sequences = [model_kwargs["audio_out_ids"][:, model_kwargs["audio_out_ids_start"][-1] :]] + if self.use_delay_pattern: + num_delay = ( + self.audio_num_codebooks + - (model_kwargs["audio_out_ids"][:, -1] == self.config.audio_stream_bos_id).sum() + ) + all_eos_indices = (model_kwargs["audio_out_ids"][:, -1] == self.config.audio_stream_eos_id).nonzero() + if torch.numel(all_eos_indices) > 0: + all_eos_indices = all_eos_indices[0] + last_eos_idx = all_eos_indices[-1] + num_remaining_delays = self.audio_num_codebooks - last_eos_idx - 1 + + while self._has_unfinished_sequences( + this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length + ): + # Check which multimodal stage we are in + # FIXME: Assume single input generation + if input_ids[0][-1] == audio_out_bos_token_id: + generation_mode = GenerationMode.AUDIO_INIT + elif input_ids[0][-1] == self.audio_out_token_idx: + generation_mode = GenerationMode.AUDIO_IN_PROGRESS + else: + generation_mode = GenerationMode.TEXT + + is_audio_generation_mode = generation_mode == GenerationMode.AUDIO_IN_PROGRESS + + if init_model_input or not generation_config.use_cache: + model_inputs = {"input_ids": input_ids, **model_kwargs} + else: + model_inputs = {"input_ids": input_ids[:, -1:], **model_kwargs} + + if is_audio_generation_mode and generation_config.use_cache: + model_inputs["audio_out_ids"] = model_kwargs["audio_out_ids"][:, -1:] + model_inputs["audio_out_ids_start"] = torch.tensor([0], dtype=torch.long, device=input_ids.device) + elif not is_audio_generation_mode: + del model_inputs["audio_out_ids"] + del model_inputs["audio_out_ids_start"] + + if generation_config.use_cache: + if "audio_features" in model_inputs and model_inputs["audio_features"] is not None: + model_inputs["audio_features"] = model_inputs["audio_features"][:0, ...] + model_inputs["audio_feature_attention_mask"] = model_inputs["audio_feature_attention_mask"][ + :0, ... + ] + + if "audio_in_ids" in model_inputs and model_inputs["audio_in_ids"] is not None: + model_inputs["audio_in_ids"] = None + model_inputs["audio_in_ids_start"] = None + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + if past_key_values_buckets is not None: + past_key_values, self.current_past_key_values_bucket = self._prepare_kv_cache( + cur_len, self.current_past_key_values_bucket, past_key_values_buckets + ) + if past_key_values is not None: + model_inputs.update({"past_key_values": past_key_values}) + model_inputs["past_key_values_buckets"] = past_key_values_buckets + + # forward pass to get next token + outputs = self(**model_inputs, return_dict=True) + + # Update the actual sequence length after the first forward pass + if init_model_input and past_key_values_buckets is not None: + cur_len = past_key_values_buckets[self.current_past_key_values_bucket].get_seq_length().item() + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + extend_attention_mask=True, + ) + + # After the first forward pass, we can set init_model_input to False. + init_model_input = False + + if synced_gpus and this_peer_finished: + continue + + if is_audio_generation_mode: + # In audio generation mode, we sample the audio tokens from audio logits. + # It might also generate the audio eos token to end the audio generation. + ( + next_tokens, + next_audio_tokens, + next_audio_token_logits, + next_audio_token_scores, + num_delay, + num_remaining_delays, + ) = self._sample_audio_tokens( + hidden_states=outputs.audio_hidden_states, + audio_logits=outputs.audio_logits, + audio_out_ids=model_kwargs["audio_out_ids"], + do_sample=do_sample, + logits_processor=logits_processor, + device=input_ids.device, + torch_generator=torch_generator, + generation_config=generation_config, + num_delay=num_delay, + num_remaining_delays=num_remaining_delays, + ) + + # update generated ids, model inputs, and length for next step + model_kwargs["audio_out_ids"] = torch.cat( + [model_kwargs["audio_out_ids"], next_audio_tokens[:, None]], dim=-1 + ) + audio_sequences[-1] = torch.cat([audio_sequences[-1], next_audio_tokens[:, None]], dim=-1) + + if streamer is not None: + streamer.put(next_audio_tokens.cpu()) + else: + # In text generation mode, we sample the text tokens from text logits. + # It might also generate the audio placeholder token to start the audio generation. + next_tokens, next_audio_tokens, next_token_logits, next_token_scores = self._sample_text_tokens( + input_ids=input_ids, + logits=outputs.logits, + do_sample=do_sample, + logits_processor=logits_processor, + device=input_ids.device, + generation_mode=generation_mode, + torch_generator=torch_generator, + ) + + if streamer is not None: + streamer.put(next_tokens.cpu()) + + if next_audio_tokens is not None: + # If the token is audio bos token, we will generate the audio placeholder token + # and the corrensponding audio stream bos token to start the audio generation. + audio_sequences.append(next_audio_tokens[:, None]) + if streamer is not None: + streamer.put(next_audio_tokens.cpu()) + if model_kwargs["audio_out_ids"] is None or model_kwargs["audio_out_ids"].shape[0] == 0: + # Initialize audio_out_ids + model_kwargs["audio_out_ids"] = next_audio_tokens[:, None] + model_kwargs["audio_out_ids_start"] = torch.tensor( + [0], dtype=torch.long, device=input_ids.device + ) + else: + model_kwargs["audio_out_ids_start"] = torch.concat( + [ + model_kwargs["audio_out_ids_start"], + torch.tensor( + [model_kwargs["audio_out_ids"].shape[1]], dtype=torch.long, device=input_ids.device + ), + ], + dim=0, + ) + model_kwargs["audio_out_ids"] = torch.concat( + [model_kwargs["audio_out_ids"], next_audio_tokens[:, None]], dim=1 + ) + + if return_dict_in_generate: + if output_scores: + if is_audio_generation_mode: + scores += (next_audio_token_scores,) + else: + scores += (next_token_scores,) + if output_logits: + if is_audio_generation_mode: + raw_logits += (next_audio_token_logits,) + else: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += (outputs.attentions,) + if output_hidden_states: + decoder_hidden_states += (outputs.hidden_states,) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + if "tokenizer_length" in generation_config.generation_kwargs: + tokenizer_length = generation_config.generation_kwargs["tokenizer_length"] + if torch.max(next_tokens) >= tokenizer_length: + raise ValueError( + f"Next generated token has max value {torch.max(next_tokens)} which is greater than the tokenizer's vocabulary size {tokenizer_length}, this is undesired behavior." + ) + + # update generated ids, model inputs, and length for next step + if not is_audio_generation_mode or next_tokens[0] != self.audio_out_token_idx: + # We only add one <|AUDIO_OUT|> token to the input_ids for simplicity. + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + input_ids_full = torch.cat([input_ids_full, next_tokens[:, None]], dim=-1) + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids_full, scores) + this_peer_finished = unfinished_sequences.max() == 0 + cur_len += 1 + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + return HiggsAudioGenerationOutput( + sequences=input_ids, + audio_sequences=audio_sequences, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids, audio_sequences + + @torch.inference_mode() + def generate( + self, + input_ids: Optional[torch.LongTensor] = None, + audio_features: Optional[torch.FloatTensor] = None, + audio_feature_attention_mask: Optional[torch.BoolTensor] = None, + audio_in_ids: Optional[torch.LongTensor] = None, + audio_in_ids_start: Optional[torch.LongTensor] = None, + audio_out_ids: Optional[torch.LongTensor] = None, + audio_out_ids_start: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + audio_out_bos_token_id: int = None, + audio_eos_token_id: int = None, + past_key_values_buckets: Optional[OrderedDict[int, Cache]] = None, + seed: Optional[int] = None, + **kwargs, + ): + """ + The generate function in huggingface generally follows these steps: + + for sample_step in 1, 2, 3, 4, 5, ... + ... + + """ + # Right now, it's a very simplified version of generate, we should revisit this after our model architecture stabilizes. + assert input_ids.shape[0] == 1, ( + "Currently HiggsAudioModel.generate() only supports batch_size=1. See the implementation of " + ) + generation_config, kwargs = self._prepare_generation_config(kwargs.pop("generation_config", None), **kwargs) + if audio_out_bos_token_id is not None: + generation_config.generation_kwargs["audio_out_bos_token_id"] = audio_out_bos_token_id + else: + try: + generation_config.generation_kwargs["audio_out_bos_token_id"] = self.audio_out_bos_token_id + except: + generation_config.generation_kwargs["audio_out_bos_token_id"] = None + + if audio_eos_token_id is not None: + generation_config.generation_kwargs["audio_eos_token_id"] = audio_eos_token_id + else: + try: + generation_config.generation_kwargs["audio_eos_token_id"] = self.audio_eos_token_id + except: + generation_config.generation_kwargs["audio_eos_token_id"] = None + + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + + generation_config.generation_kwargs["ras_win_len"] = kwargs.pop("ras_win_len", None) + generation_config.generation_kwargs["ras_win_max_num_repeat"] = kwargs.pop("ras_win_max_num_repeat", 2) + # Set generation seed if determinstic generation is required + if seed is not None: + generation_config.generation_kwargs["seed"] = seed + + # Store tokenizer in generation config if it is in kwargs without popping it + if "tokenizer" in kwargs: + generation_config.generation_kwargs["tokenizer_length"] = len(kwargs["tokenizer"]) + + # input_ids: [bsz, seq_len] + # The merging of audio features happens inside the forward path. The input_ids does not need to change. + # TODO: prepare the final input embeddings to improve generation performance + input_ids_length = input_ids.shape[-1] + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=None, + inputs_tensor=None, + input_ids_length=input_ids_length, + ) + assert generation_config.num_beams == 1, "Currently, we only support beam search with num_beams=1" + return_dict_in_generate = generation_config.return_dict_in_generate + output_scores = generation_config.output_scores + + # When attn_implement is spda or flash-attention, it will create causal mask automatically. + attention_mask = kwargs.pop("attention_mask", None) + return super().generate( + input_ids=input_ids, + attention_mask=attention_mask, + audio_features=audio_features, + audio_feature_attention_mask=audio_feature_attention_mask, + audio_in_ids=audio_in_ids, + audio_in_ids_start=audio_in_ids_start, + audio_out_ids=audio_out_ids, + audio_out_ids_start=audio_out_ids_start, + past_key_values=past_key_values, + generation_config=generation_config, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + past_key_values_buckets=past_key_values_buckets, + **kwargs, + ) + + def parameter_count_per_component(self): + """Count the number of parameters per component in the model. + + HiggsAudio has the following main components: + audio_tower: For mapping audio features to hidden states), + llm_embed: The size of embedding layer of the LLM + llm_non_embed: The size of non-embedding layer of the LLM + audio_adapter: The overall size of additional layers for audio generation + + """ + trainable_stats = { + "audio_tower": 0, + "llm_embed": 0, + "llm_non_embed": 0, + "audio_embed": 0, + "audio_adapter": 0, + "overall": 0, + } + total_stats = { + "audio_tower": 0, + "llm_embed": 0, + "llm_non_embed": 0, + "audio_embed": 0, + "audio_adapter": 0, + "overall": 0, + } + + total_stats["overall"] = count_parameters(self, trainable_only=False) + trainable_stats["overall"] = count_parameters(self, trainable_only=True) + + for mod in [self.audio_tower]: + if mod is not None: + total_stats["audio_tower"] += count_parameters(mod, trainable_only=False) + trainable_stats["audio_tower"] += count_parameters(mod, trainable_only=True) + + total_stats["llm_embed"] = count_parameters(self.embed_tokens, trainable_only=False) + trainable_stats["llm_embed"] = count_parameters(self.embed_tokens, trainable_only=True) + + total_stats["audio_embed"] = count_parameters(self.audio_codebook_embeddings, trainable_only=False) + trainable_stats["audio_embed"] = count_parameters(self.audio_codebook_embeddings, trainable_only=True) + + # Calculate number of parameters for LLM + for layer in self.layers: + if isinstance(layer, HiggsAudioDualFFNDecoderLayer): + total_param_count = count_parameters(layer, trainable_only=False) + total_trainable_param_count = count_parameters(layer, trainable_only=True) + total_stats["llm_non_embed"] += total_param_count + trainable_stats["llm_non_embed"] += total_trainable_param_count + if not layer.fast_forward: + audio_mlp_param_count = count_parameters(layer.audio_mlp, trainable_only=False) + audio_mlp_trainable_param_count = count_parameters(layer.audio_mlp, trainable_only=True) + + audio_norm_param_count = count_parameters( + layer.audio_post_attention_layernorm, trainable_only=False + ) + count_parameters(layer.audio_input_layernorm, trainable_only=False) + audio_norm_trainable_param_count = count_parameters( + layer.audio_post_attention_layernorm, trainable_only=True + ) + count_parameters(layer.audio_input_layernorm, trainable_only=True) + total_stats["llm_non_embed"] -= audio_mlp_param_count + audio_norm_param_count + trainable_stats["llm_non_embed"] -= ( + audio_mlp_trainable_param_count + audio_norm_trainable_param_count + ) + total_stats["audio_adapter"] += audio_mlp_param_count + audio_norm_param_count + trainable_stats["audio_adapter"] += ( + audio_mlp_trainable_param_count + audio_norm_trainable_param_count + ) + + if layer.use_audio_attention: + audio_attn_param_count = count_parameters( + layer.audio_attn, trainable_only=False + ) + count_parameters(layer.audio_post_audio_attn_layer_norm, trainable_only=False) + audio_attn_trainable_param_count = count_parameters( + layer.audio_attn, trainable_only=True + ) + count_parameters(layer.audio_post_audio_attn_layer_norm, trainable_only=True) + total_stats["llm_non_embed"] -= audio_attn_param_count + trainable_stats["llm_non_embed"] -= audio_attn_trainable_param_count + total_stats["audio_adapter"] += audio_attn_param_count + trainable_stats["audio_adapter"] += audio_attn_trainable_param_count + else: + total_stats["llm_non_embed"] += count_parameters(layer, trainable_only=False) + trainable_stats["llm_non_embed"] += count_parameters(layer, trainable_only=True) + total_stats["llm_non_embed"] += count_parameters(self.norm, trainable_only=False) + trainable_stats["llm_non_embed"] += count_parameters(self.norm, trainable_only=True) + + total_stats["audio_adapter"] += count_parameters(self.audio_decoder_proj.audio_lm_head, trainable_only=False) + trainable_stats["audio_adapter"] += count_parameters( + self.audio_decoder_proj.audio_lm_head, trainable_only=True + ) + total_stats["llm_embed"] += count_parameters(self.audio_decoder_proj.text_lm_head, trainable_only=False) + trainable_stats["llm_embed"] += count_parameters(self.audio_decoder_proj.text_lm_head, trainable_only=True) + + other_audio_modules = [self.audio_encoder_proj] + if self.use_audio_out_embed_projector: + other_audio_modules.append(self.audio_out_embed_projector) + + for mod in other_audio_modules: + if mod is not None: + total_stats["audio_adapter"] += count_parameters(mod, trainable_only=False) + trainable_stats["audio_adapter"] += count_parameters(mod, trainable_only=True) + return {"trainable": trainable_stats, "total": total_stats} + + def set_skip_audio_tower(self): + self.config.skip_audio_tower = True + self.config.encode_whisper_embed = False + + def set_encode_audio_in_tokens(self): + self.config.encode_audio_in_tokens = True + + def freeze_audio_tower(self): + if self.audio_tower is not None: + for param in self.audio_tower.parameters(): + param.requires_grad = False + + def freeze_audio_encoder_proj(self): + if self.audio_encoder_proj is not None: + for param in self.audio_encoder_proj.parameters(): + param.requires_grad = False + + def freeze_llm(self, freeze_embed=True, freeze_embed_until_idx: Optional[int] = None): + for layer in self.layers: + if isinstance(layer, HiggsAudioDualFFNDecoderLayer): + for param in layer.self_attn.parameters(): + param.requires_grad = False + for param in layer.mlp.parameters(): + param.requires_grad = False + + for param in layer.post_attention_layernorm.parameters(): + param.requires_grad = False + + for param in layer.input_layernorm.parameters(): + param.requires_grad = False + else: + for param in layer.parameters(): + param.requires_grad = False + + for param in self.norm.parameters(): + param.requires_grad = False + + if freeze_embed: + if freeze_embed_until_idx is None: + for param in self.embed_tokens.parameters(): + param.requires_grad = False + else: + assert isinstance(self.embed_tokens, nn.Embedding) + self.embed_tokens = PartiallyFrozenEmbedding( + original_embedding=self.embed_tokens, freeze_until_idx=freeze_embed_until_idx + ) + + def freeze_text_head(self, freeze_text_head_until_idx: Optional[int] = None): + """Freeze the final text head""" + if freeze_text_head_until_idx is None: + for param in self.audio_decoder_proj.text_lm_head.parameters(): + param.requires_grad = False + + else: + assert isinstance(self.audio_decoder_proj.text_lm_head, nn.Linear) + self.audio_decoder_proj.text_lm_head = PartiallyFrozenLinear( + original_linear=self.audio_decoder_proj.text_lm_head, freeze_until_idx=freeze_text_head_until_idx + ) + + @classmethod + def merge_weights_from_checkpoint(cls, checkpoint_dir: str, merged_output_dir: str, *model_args, **kwargs): + # For users' convenience, we merge back embedding and text_lm_head if they are splitted + splitted_model = super().from_pretrained( + checkpoint_dir, + *model_args, + torch_dtype=torch.bfloat16, + device_map="cpu", + **{**kwargs, "state_dict": None}, # Prevent auto-loading state_dict + ) + + # Load all safetensor shards + state_dict = {} + shard_paths = sorted(glob.glob(os.path.join(checkpoint_dir, "*.safetensors"))) + + for shard_path in shard_paths: + shard_dict = load_file(shard_path) # Load each shard + state_dict.update(shard_dict) # Merge into a single dict + + # Merge weights + if ( + "audio_decoder_proj.text_lm_head.linear_frozen.weight" in state_dict + and "audio_decoder_proj.text_lm_head.linear_trainable.weight" in state_dict + ): + state_dict["audio_decoder_proj.text_lm_head.weight"] = torch.cat( + [ + state_dict["audio_decoder_proj.text_lm_head.linear_frozen.weight"], + state_dict["audio_decoder_proj.text_lm_head.linear_trainable.weight"], + ], + dim=0, + ) + + del state_dict["audio_decoder_proj.text_lm_head.linear_frozen.weight"] + del state_dict["audio_decoder_proj.text_lm_head.linear_trainable.weight"] + + if ( + "embed_tokens.embedding_frozen.weight" in state_dict + and "embed_tokens.embedding_trainable.weight" in state_dict + ): + state_dict["embed_tokens.weight"] = torch.cat( + [ + state_dict["embed_tokens.embedding_frozen.weight"], + state_dict["embed_tokens.embedding_trainable.weight"], + ], + dim=0, + ) + + del state_dict["embed_tokens.embedding_frozen.weight"] + del state_dict["embed_tokens.embedding_trainable.weight"] + + # Load the final state_dict + splitted_model.load_state_dict(state_dict, strict=True) + + if merged_output_dir: + splitted_model.save_pretrained(merged_output_dir, is_main_process=True, state_dict=state_dict) + + @torch.inference_mode() + def capture_model(self, past_key_values: list[Union[Cache, List[torch.FloatTensor]]]) -> None: + """Capture CUDA graphs for the model's forward pass with different KV cache lengths. + + Args: + past_key_values: List of KV caches to capture graphs for + """ + for past_key_value in past_key_values: + kv_cache_length = past_key_value.get_max_cache_shape() + # We capture two graphs, one for decoding audio tokens and one for decoding text tokens + for is_decoding_audio_token in [True, False]: + runner = CUDAGraphRunner(self._forward_core) + + # Create dummy inputs for graph capture + batch_size = 1 + hidden_dim = self.config.hidden_size + + hidden_states = torch.zeros( + (batch_size, 1, hidden_dim), dtype=self.config.torch_dtype, device=self.device + ) + causal_mask = torch.ones( + (batch_size, 1, 1, kv_cache_length), dtype=self.config.torch_dtype, device=self.device + ) + position_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=self.device) + audio_discrete_codes_mask = torch.tensor( + [[is_decoding_audio_token]], dtype=torch.bool, device=self.device + ) + cache_position = torch.tensor([kv_cache_length - 1], dtype=torch.long, device=self.device) + audio_attention_mask = torch.ones_like(causal_mask) + fast_forward_attention_mask = torch.ones_like(causal_mask) + + runner.capture( + hidden_states=hidden_states, + causal_mask=causal_mask, + position_ids=position_ids, + audio_discrete_codes_mask=audio_discrete_codes_mask, + cache_position=cache_position, + past_key_values=past_key_value, + use_cache=True, + audio_attention_mask=audio_attention_mask, + fast_forward_attention_mask=fast_forward_attention_mask, + output_attentions=False, + output_hidden_states=False, + is_decoding_audio_token=is_decoding_audio_token, + is_using_cuda_graph=True, + stream=torch.cuda.Stream(device=self.device), + ) + + self.decode_graph_runners[kv_cache_length][is_decoding_audio_token] = runner diff --git a/boson_multimodal/model/higgs_audio/utils.py b/boson_multimodal/model/higgs_audio/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d1ee633162b309d5985cf5bd40157b42f65a4082 --- /dev/null +++ b/boson_multimodal/model/higgs_audio/utils.py @@ -0,0 +1,756 @@ +import contextlib +from contextlib import contextmanager +from functools import wraps +import torch +from transformers.integrations import is_deepspeed_available + +if is_deepspeed_available(): + from deepspeed.utils import groups as deepspeed_groups + from deepspeed.sequence.layer import _SeqAllToAll +else: + deepspeed_groups = None + _SeqAllToAll = None + + +def _ceil_to_nearest(n, round_to): + return (n + round_to - 1) // round_to * round_to + + +def count_parameters(model, trainable_only=True): + if trainable_only: + return sum(p.numel() for p in model.parameters() if p.requires_grad) + else: + return sum(p.numel() for p in model.parameters()) + + +def build_delay_pattern_mask( + input_ids: torch.LongTensor, + bos_token_id: int, + pad_token_id: int, +): + """Implement the delay pattern proposed in "Simple and Controllable Music Generation", https://arxiv.org/pdf/2306.05284 + + In the delay pattern, each codebook is offset by the previous codebook by + one. We insert a special delay token at the start of the sequence if its delayed, and append pad token once the sequence finishes. + + Take the example where there are 4 codebooks and audio sequence length=5. After shifting, the output should have length seq_len + num_codebooks - 1 + + - [ *, *, *, *, *, P, P, P] + - [ B, *, *, *, *, *, P, P] + - [ B, B, *, *, *, *, *, P] + - [ B, B, B, *, *, *, *, *] + + where B indicates the delay token id, P is the special padding token id and `*` indicates that the original audio token. + + Now let's consider the case where we have a sequence of audio tokens to condition on. + The audio tokens were originally in the following non-delayed form: + + - [a, b] + - [c, d] + - [e, f] + - [g, h] + + After conversion, we get the following delayed form: + - [a, b, -1, -1, -1] + - [B, c, d, -1, -1] + - [B, B, e, f, -1] + - [B, B, B, g, h] + + Note that we have a special token `-1` that indicates it should be replaced by a new token we see in the generation phase. + In that case, we should override the `-1` tokens in auto-regressive generation. + + Args: + input_ids (:obj:`torch.LongTensor`): + The input ids of the prompt. It will have shape (bsz, num_codebooks, seq_len). + bos_token_id (:obj:`int`): + The id of the special delay token + pad_token_id (:obj:`int`): + The id of the padding token. Should be the same as eos_token_id. + + Returns: + input_ids (:obj:`torch.LongTensor`): + The transformed input ids with delay pattern applied. It will have shape (bsz, num_codebooks, seq_len + num_codebooks - 1). + input_ids_with_gen_mask (:obj:`torch.LongTensor`): + The transformed input ids with delay pattern applied. The -1 in the output indicates new tokens that should be generated. + + """ + bsz, num_codebooks, seq_len = input_ids.shape + + new_seq_len = seq_len + num_codebooks - 1 + input_ids_with_gen_mask = torch.ones((bsz, num_codebooks, new_seq_len), dtype=torch.long, device=input_ids.device) + bos_mask = torch.tril(input_ids_with_gen_mask, -1) > 0 + eos_mask = torch.triu(input_ids_with_gen_mask, seq_len) > 0 + input_ids_with_gen_mask[bos_mask] = bos_token_id + input_ids_with_gen_mask[(~bos_mask) & (~eos_mask)] = input_ids.reshape(-1) + input_ids = input_ids_with_gen_mask.clone() + input_ids[eos_mask] = pad_token_id + input_ids_with_gen_mask[eos_mask] = -1 + return input_ids, input_ids_with_gen_mask + + +def revert_delay_pattern(data): + """Convert samples encoded with delay pattern back to the original form. + + Args: + data (:obj:`torch.Tensor`): + The data with delay pattern applied. It will have shape (num_codebooks, seq_len + num_codebooks - 1). + + Returns: + ret (:obj:`torch.Tensor`): + Recovered data with delay pattern removed. It will have shape (num_codebooks, seq_len). + """ + assert len(data.shape) == 2 + out_l = [] + num_codebooks = data.shape[0] + for i in range(num_codebooks): + out_l.append(data[i : (i + 1), i : (data.shape[1] - num_codebooks + 1 + i)]) + return torch.cat(out_l, dim=0) + + +def merge_input_ids_with_audio_features( + audio_features_embed, + audio_features_length, + audio_in_embed, + audio_in_ids_start, + audio_out_embed, + audio_out_ids_start, + audio_in_token_idx, + audio_out_token_idx, + inputs_embeds, + input_ids, + attention_mask, + label_ids, + pad_token_id, + ignore_index=-100, + round_to=8, + left_padding=True, +): + """ + Merge input_ids with audio features into final embeddings. + + Args: + audio_features_embed (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`): + Encoded vectors of all audios in the batch (obtained from the semantic encoder) + audio_features_length (`torch.LongTensor` of shape `(num_audios,)`): + The length of audio embeddings of each audio as stacked in `audio_features_embed` + audio_in_embed (`torch.Tensor` of shape `(total_num_audio_in_tokens, embed_dim)`): + The embeddings of audio-in tokens + audio_in_ids_start (`torch.LongTensor` of shape `(num_audios,)`): + The start index of the audio-in tokens for each audio + audio_out_embed (`torch.Tensor` of shape `(total_num_audio_out_tokens, embed_dim)`): + The embeddings of audio-out tokens + audio_out_ids_start (`torch.LongTensor` of shape `(num_audios,)`): + The start index of the audio-out tokens for each audio + audio_in_token_idx + The index of the audio-in token in the vocabulary + audio_out_token_idx + The index of the audio-out token in the vocabulary + inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`): + Token embeddings before merging with audio embeddings + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Input_ids of tokens, possibly filled with audio token + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Mask to avoid performing attention on padding token indices. + label_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) + labels need to be recalculated to support training (if provided) + pad_token_id (`int`): + The index of the pad token in the vocabulary + ignore_index + The index to ignore in the loss calculation + round_to + The number to round to for padding + left_padding + Whether to apply left padding + + Returns: + final_embedding + The final embeddings after merging audio embeddings with text embeddings. + final_attention_mask + The final attention mask after merging audio embeddings with text embeddings. + final_labels + The labels for the text stream + position_ids + Positional ids for the merged data + final_input_ids + The final input_ids after merging audio embeddings with text embeddings. + final_audio_in_mask + Mask for audio-in embeddings + final_audio_in_discrete_codes_mask + Mask for audio-in discrete tokens + final_audio_out_mask + Mask for audio-out embeddings + + Explanation: + each audio has variable length embeddings, with length specified by + - audio_features_length + - audio_in_ids_start + - audio_out_ids_start + + Task: + - fill each <|AUDIO|> with audio embeddings (it can be the combination of embeddings extracted by WhisperEncoder and embeddings from audio codebooks) + - fill each <|AUDIO_OUT|> with the audio-out embeddings + + Example: + <|AUDIO_OUT|>: X (5 tokens), Y (3 tokens) + <|AUDIO|>: Z (8 tokens) + + X, Y are in the same sequence (in-context voice-clone). Z is in a different sequence (audio understanding). + if right padding + input_ids: [ + a b c d e f X g h i j k Y l m + o p q r Z s t u v _ _ _ _ _ _ + ] + input_ids should be: [ + a b c d e f X X X X X g h i j k Y Y Y l m + o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _ + ] + labels should be: [ + a b c d e f _ _ _ _ _ g h i j k _ _ _ l m + o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _ + ] + elif left padding + input_ids: [ + a b c d e f X g h i j k Y l m + _ _ _ _ _ _ o p q r Z s t u v + ] + input_ids should be: [ + a b c d e f X X X X X g h i j k Y Y Y l m + _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v + ] + labels should be: [ + a b c d e f _ _ _ _ _ g h i j k _ _ _ l m + _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v + ] + + """ + if label_ids is None: + skip_labels = True + else: + skip_labels = False + if audio_features_embed is not None and audio_features_embed.shape[0] == 0: + audio_features_embed = None + if audio_in_embed is not None and audio_in_embed.shape[0] == 0: + audio_in_embed = None + if audio_out_embed is not None and audio_out_embed.shape[0] == 0: + audio_out_embed = None + + batch_size, sequence_length, embed_dim = inputs_embeds.shape + + target_device = inputs_embeds.device + if left_padding is None: + left_padding = torch.any(attention_mask[:, 0] == 0) + + audio_in_token_mask = input_ids == audio_in_token_idx + audio_out_token_mask = input_ids == audio_out_token_idx + text_token_mask = (input_ids != audio_in_token_idx) & (input_ids != audio_out_token_idx) + + # 1. Calculate the number of tokens for each placeholder (like [<|AUDIO|>, <|AUDIO_OUT|>]). + token_placeholder_num = torch.ones_like(input_ids) + + if audio_features_embed is not None: + num_audios, max_audio_tokens, _ = audio_features_embed.shape + audio_in_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to( + audio_features_length.device + ) < audio_features_length.unsqueeze(1) + masked_audio_in_features = audio_features_embed[audio_in_features_mask].view(-1, embed_dim) + token_placeholder_num[audio_in_token_mask] = audio_features_length.long() + + if audio_in_embed is not None: + audio_in_codes_length = torch.concat( + [ + audio_in_ids_start[1:] - audio_in_ids_start[:-1], + torch.tensor( + [audio_in_embed.shape[0] - audio_in_ids_start[-1]], + device=audio_in_ids_start.device, + dtype=torch.long, + ), + ], + dim=0, + ) + if audio_features_embed is not None: + token_placeholder_num[audio_in_token_mask] += audio_in_codes_length.long() + else: + token_placeholder_num[audio_in_token_mask] = audio_in_codes_length.long() + + if audio_out_embed is not None: + audio_out_codes_length = torch.concat( + [ + audio_out_ids_start[1:] - audio_out_ids_start[:-1], + torch.tensor( + [audio_out_embed.shape[0] - audio_out_ids_start[-1]], + device=audio_out_ids_start.device, + dtype=torch.long, + ), + ], + dim=0, + ) + token_placeholder_num[audio_out_token_mask] = audio_out_codes_length.long() + + new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1 + max_token_num = _ceil_to_nearest(token_placeholder_num.sum(-1).max(), round_to) + nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1] + + if left_padding: + new_token_positions += nb_audio_pad[:, None] # offset for left padding + + # 2. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + (batch_size, max_token_num, embed_dim), dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + (batch_size, max_token_num), dtype=attention_mask.dtype, device=inputs_embeds.device + ) + final_input_ids = torch.full( + (batch_size, max_token_num), pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device + ) + if skip_labels: + final_labels = None + else: + final_labels = torch.full( + (batch_size, max_token_num), ignore_index, dtype=label_ids.dtype, device=inputs_embeds.device + ) + + final_audio_in_mask = torch.full((batch_size, max_token_num), False, dtype=torch.bool, device=inputs_embeds.device) + final_audio_in_discrete_codes_mask = torch.full( + (batch_size, max_token_num), False, dtype=torch.bool, device=inputs_embeds.device + ) + final_audio_out_mask = torch.full( + (batch_size, max_token_num), False, dtype=torch.bool, device=inputs_embeds.device + ) + # 3. Get the audio-in token positions and audio-out token positions + batch_id = torch.arange(batch_size, device=target_device).unsqueeze(1).expand(batch_size, sequence_length) + audio_in_batch_id = batch_id[audio_in_token_mask] # Shape (num_audio_in,) + audio_out_batch_id = batch_id[audio_out_token_mask] # Shape (num_audio_out,) + audio_features_token_ends = new_token_positions[audio_in_token_mask] # Shape (num_audio_in,) + audio_out_embed_ends = new_token_positions[audio_out_token_mask] # Shape (num_audio_out,) + + if audio_in_embed is not None: + # Fill in the audio-in embeddings + seq_indices = ( + torch.arange(max_token_num, device=target_device) + .unsqueeze(0) + .expand(audio_in_ids_start.shape[0], max_token_num) + ) + audio_in_embed_token_starts = audio_features_token_ends - audio_in_codes_length + 1 + batch_indices, col_indices = torch.where( + (seq_indices >= audio_in_embed_token_starts.unsqueeze(1)) + & (seq_indices <= audio_features_token_ends.unsqueeze(1)) + ) + batch_indices = audio_in_batch_id[batch_indices] + final_embedding[batch_indices, col_indices] = audio_in_embed + final_input_ids[batch_indices, col_indices] = audio_in_token_idx + if not skip_labels: + final_labels[batch_indices, col_indices] = ignore_index + final_audio_in_mask[batch_indices, col_indices] = True + final_audio_in_discrete_codes_mask[batch_indices, col_indices] = True + audio_features_token_ends = audio_features_token_ends - audio_in_codes_length + + if audio_features_embed is not None: + # Fill in the audio features + seq_indices = ( + torch.arange(max_token_num, device=target_device) + .unsqueeze(0) + .expand(audio_features_embed.shape[0], max_token_num) + ) + audio_features_token_starts = audio_features_token_ends - audio_features_length + 1 + batch_indices, col_indices = torch.where( + (seq_indices >= audio_features_token_starts.unsqueeze(1)) + & (seq_indices <= audio_features_token_ends.unsqueeze(1)) + ) + batch_indices = audio_in_batch_id[batch_indices] + final_embedding[batch_indices, col_indices] = masked_audio_in_features + final_input_ids[batch_indices, col_indices] = audio_in_token_idx + if not skip_labels: + final_labels[batch_indices, col_indices] = ignore_index + final_audio_in_mask[batch_indices, col_indices] = True + + if audio_out_embed is not None: + # Fill in the audio-out embeddings + seq_indices = ( + torch.arange(max_token_num, device=target_device) + .unsqueeze(0) + .expand(audio_out_ids_start.shape[0], max_token_num) + ) + audio_out_embed_token_starts = audio_out_embed_ends - audio_out_codes_length + 1 + batch_indices, col_indices = torch.where( + (seq_indices >= audio_out_embed_token_starts.unsqueeze(1)) + & (seq_indices <= audio_out_embed_ends.unsqueeze(1)) + ) + batch_indices = audio_out_batch_id[batch_indices] + final_embedding[batch_indices, col_indices] = audio_out_embed + final_input_ids[batch_indices, col_indices] = audio_out_token_idx + if not skip_labels: + final_labels[batch_indices, col_indices] = ignore_index + final_audio_out_mask[batch_indices, col_indices] = True + + # Fill in the original text embeddings and labels + batch_indices, non_audio_indices = torch.where(text_token_mask) + text_to_overwrite = new_token_positions[batch_indices, non_audio_indices] + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_audio_indices] + if not skip_labels: + final_labels[batch_indices, text_to_overwrite] = label_ids[batch_indices, non_audio_indices] + final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_audio_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_audio_indices] + final_attention_mask = final_attention_mask | final_audio_in_mask | final_audio_out_mask + + # Trim the tensor if there are redundant padding tokens + if left_padding: + first_non_zero_loc = final_attention_mask.sum(0).nonzero()[0] + first_non_zero_loc = (first_non_zero_loc // round_to) * round_to + if first_non_zero_loc > 0: + final_attention_mask = final_attention_mask[:, first_non_zero_loc:] + final_embedding = final_embedding[:, first_non_zero_loc:] + if not skip_labels: + final_labels = final_labels[:, first_non_zero_loc:] + final_input_ids = final_input_ids[:, first_non_zero_loc:] + final_audio_in_mask = final_audio_in_mask[:, first_non_zero_loc:] + final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, first_non_zero_loc:] + final_audio_out_mask = final_audio_out_mask[:, first_non_zero_loc:] + else: + # We have done right padding, so we need to trim the mask + last_non_zero_loc = final_attention_mask.sum(0).nonzero()[-1] + 1 + last_non_zero_loc = ((last_non_zero_loc + round_to - 1) // round_to) * round_to + if last_non_zero_loc < max_token_num: + final_attention_mask = final_attention_mask[:, :last_non_zero_loc] + final_embedding = final_embedding[:, :last_non_zero_loc] + if not skip_labels: + final_labels = final_labels[:, :last_non_zero_loc] + final_input_ids = final_input_ids[:, :last_non_zero_loc] + final_audio_in_mask = final_audio_in_mask[:, :last_non_zero_loc] + final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, :last_non_zero_loc] + final_audio_out_mask = final_audio_out_mask[:, :last_non_zero_loc] + + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + return ( + final_embedding, + final_attention_mask, + final_labels, + position_ids, + final_input_ids, + final_audio_in_mask, + final_audio_in_discrete_codes_mask, + final_audio_out_mask, + ) + + +def is_deepspeed_ulysses_enabled(): + if deepspeed_groups is None: + return False + + """Check if sequence parallelism is enabled.""" + return deepspeed_groups._get_sequence_parallel_world_size() > 1 + + +def support_deepspeed_ulysses(module): + """A decorator around Pytorch module. It is needed for the module that needs access to sequence parallel info.""" + module._sp_size = None + module._sp_rank = None + module._sp_group = None + + @property + def sp_size(self): + if self._sp_size is None: + self._sp_size = 1 + if is_deepspeed_ulysses_enabled(): + self._sp_size = deepspeed_groups._get_sequence_parallel_group().size() + return self._sp_size + + @property + def sp_rank(self): + if self._sp_rank is None: + self._sp_rank = 0 + if is_deepspeed_ulysses_enabled(): + self._sp_rank = deepspeed_groups._get_sequence_parallel_rank() + return self._sp_rank + + @property + def sp_group(self): + if self._sp_group is None and is_deepspeed_ulysses_enabled(): + self._sp_group = deepspeed_groups._get_sequence_parallel_group() + return self._sp_group + + module.sp_size = sp_size + module.sp_rank = sp_rank + module.sp_group = sp_group + + return module + + +def deepspeed_ulysses_attention(seq_dim=1, head_dim=2): + """Perform all-to-all before and after the attention function.""" + + def attention_decorator(attn_func=None): + def wrapped(*args, **kwargs): + if is_deepspeed_ulysses_enabled(): + sp_group = deepspeed_groups._get_sequence_parallel_group() + scatter_idx = head_dim # Scatter on num_heads dimension + gather_idx = seq_dim # Gather on seq_len dimension + batch_dim_idx = 0 + args = list(args) + args[0] = _SeqAllToAll.apply(sp_group, args[0], scatter_idx, gather_idx, batch_dim_idx) + args[1] = _SeqAllToAll.apply(sp_group, args[1], scatter_idx, gather_idx, batch_dim_idx) + args[2] = _SeqAllToAll.apply(sp_group, args[2], scatter_idx, gather_idx, batch_dim_idx) + args = tuple(args) + + attn_output = attn_func(*args, **kwargs) + + if is_deepspeed_ulysses_enabled(): + scatter_idx = seq_dim # Scatter back on seq_len dimension + gather_idx = head_dim # Gather on num_heads dimension + batch_dim_idx = 0 + attn_output = _SeqAllToAll.apply(sp_group, attn_output, scatter_idx, gather_idx, batch_dim_idx) + + return attn_output + + return wrapped + + return attention_decorator + + +def deepspeed_ulysses_rope(state_seq_dim=2, trig_seq_dim=1): + """Slice the corresponding cos and sin chunks for rope.""" + + def rope_decorator(rope_func=None): + def wrapped(*args, **kwargs): + if is_deepspeed_ulysses_enabled(): + sp_rank = deepspeed_groups._get_sequence_parallel_rank() + args = list(args) + seq_chunk_size = args[0].size(state_seq_dim) + args[2] = torch.narrow(args[2], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size) + args[3] = torch.narrow(args[3], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size) + args = tuple(args) + + return rope_func(*args, **kwargs) + + return wrapped + + return rope_decorator + + +def _gather_tensors(input_, group=None): + """Gather tensors and concatenate them along a dimension.""" + input_ = input_.contiguous() + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + tensor_shapes = [ + torch.empty(len(input_.size()), dtype=torch.int64, device=input_.device) for _ in range(world_size) + ] + input_size = torch.tensor(input_.size(), dtype=torch.int64, device=input_.device) + torch.distributed.all_gather(tensor_shapes, input_size, group=group) + gathered_buffers = [ + torch.empty(tensor_shapes[i].tolist(), dtype=input_.dtype, device=input_.device) for i in range(world_size) + ] + torch.distributed.all_gather(gathered_buffers, input_, group=group) + return gathered_buffers + + +def _scatter_tensors(input_, group=None): + """Scatter tensors.""" + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + rank = torch.distributed.get_rank(group) + return input_[rank] + + +class _GatherTensors(torch.autograd.Function): + """All gather tensors among the ranks.""" + + @staticmethod + def symbolic(graph, input_, group): + return _gather_tensors(input_, group) + + @staticmethod + def forward(ctx, input_, group): + ctx.group = group + return torch.nested.as_nested_tensor(_gather_tensors(input_, group), layout=torch.jagged) + + @staticmethod + def backward(ctx, grad_output): + return _scatter_tensors(grad_output, ctx.group), None + + +def all_gather_tensors(input_, size=None, dim=0, group=None): + if torch.distributed.get_world_size(group) == 1: + # no sequence parallelism + return input_ + gathered_tensors = _GatherTensors.apply(input_, group) + + if size: + split_gathered_tensors = [] + for s, gathered_tensor in zip(size, gathered_tensors): + split_gathered_tensor = torch.split(gathered_tensor, s.tolist()) + split_gathered_tensors.append(split_gathered_tensor) + + gathered_tensors = [y for x in zip(*split_gathered_tensors) for y in x] + + return torch.cat(gathered_tensors, dim).contiguous() + + +def get_sequence_data_parallel_world_size(): + return torch.distributed.get_world_size() + + +def get_sequence_data_parallel_rank(): + return torch.distributed.get_rank() + + +def get_sequence_data_parallel_group(): + return torch.distributed.group.WORLD + + +if is_deepspeed_available(): + deepspeed_groups._get_sequence_data_parallel_world_size = get_sequence_data_parallel_world_size + deepspeed_groups._get_sequence_data_parallel_rank = get_sequence_data_parallel_rank + deepspeed_groups._get_sequence_data_parallel_group = get_sequence_data_parallel_group + + +def _gather_tokens(input_, dim=0, group=None): + """Gather tensors and concatenate them along a dimension""" + input_ = input_.contiguous() + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + + gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device) + torch.distributed.all_gather_into_tensor(gather_buffer, input_, group=group) + if dim == 0: + shape = list(input_.size()) + shape[0] = shape[0] * world_size + output = gather_buffer.view(shape) + else: + tensor_list = [ + gather_buffer.narrow(0, input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size) + ] + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +def _drop_tokens(input_, dim=0, group=None): + """Divide a tensor among the sequence parallel ranks""" + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + this_rank = torch.distributed.get_rank(group) + assert input_.shape[dim] % world_size == 0, ( + f"input dimension {dim} ({input_.shape[dim]}) is not divisible by sequence parallel world size ({world_size})" + ) + chunk_size = input_.shape[dim] // world_size + + return torch.narrow(input_, dim, this_rank * chunk_size, chunk_size) + + +class _DropTokens(torch.autograd.Function): + "Divide tokens equally among the sequence parallel ranks" + + @staticmethod + def symbolic(graph, input_, dim, group, grad_scale): + return _drop_tokens(input_, dim, group) + + @staticmethod + def forward(ctx, input_, dim, group, grad_scale): + ctx.dim = dim + ctx.group = group + ctx.grad_scale = grad_scale + return _drop_tokens(input_, dim, group) + + @staticmethod + def backward(ctx, grad_output): + grad_input = _gather_tokens(grad_output, ctx.dim, ctx.group) + if ctx.grad_scale != 1: + grad_input /= ctx.grad_scale + return grad_input, None, None, None + + +class _GatherTokens(torch.autograd.Function): + "Gather tokens among the sequence parallel ranks" + + @staticmethod + def symbolic(graph, input_, dim, group, grad_scale): + return _gather_tokens(input_, dim, group) + + @staticmethod + def forward(ctx, input_, dim, group, grad_scale): + ctx.dim = dim + ctx.group = group + ctx.grad_scale = grad_scale + return _gather_tokens(input_, dim, group) + + @staticmethod + def backward(ctx, grad_output): + grad_input = _drop_tokens(grad_output, ctx.dim, ctx.group) + if ctx.grad_scale != 1: + grad_input *= ctx.grad_scale + return grad_input, None, None, None + + +def drop_tokens(input_, dim=0, group=None, grad_scale=1): + if torch.distributed.get_world_size(group) == 1: + # no sequence parallelism + return input_ + return _DropTokens.apply(input_, dim, group, grad_scale) + + +def gather_tokens(input_, dim=0, group=None, grad_scale=1): + if torch.distributed.get_world_size(group) == 1: + # no sequence parallelism + return input_ + return _GatherTokens.apply(input_, dim, group, grad_scale) + + +def sequence_chunking_per_rank(sp_size, sp_rank, *args, dim=1): + """ + Slice the inputs to create chuncks per the sequence parallel rank. This is used for the context parallel training. + + Args: + sp_size (`int`): + Sequence parallel size. + sp_rank (`int`): + Sequence parallel rank for the current process. + dim (`int`): + The dimension to slice + """ + if sp_size == 1: + return args[0] if len(args) == 1 else args + + seq_length = args[0].size(dim) + for arg in args[1:]: + assert arg.size(dim) == seq_length, ( + f"arg={arg} ({arg.shape[dim]}) does not have the same size as args[0] ({seq_length}) in dimension {dim}" + ) + assert seq_length % sp_size == 0, ( + f"dimension {dim} ({args[0].shape[dim]}) is not divisible by sequence parallel world size ({sp_size})" + ) + + sub_seq_length = seq_length // sp_size + sub_seq_start = sp_rank * sub_seq_length + + output = [] + for ind in args: + ind = torch.narrow(ind, dim, sub_seq_start, sub_seq_length) + output.append(ind) + + return tuple(output) if len(output) > 1 else output[0] + + +@contextmanager +def disable_deepspeed_ulysses(): + """Disable deepspeed ulysses (sequence parallelism) if it is enabled""" + if is_deepspeed_ulysses_enabled(): + _old_get_sequence_parallel_world_size = deepspeed_groups._get_sequence_parallel_world_size + + def _get_sequence_parallel_world_size(): + return 1 + + deepspeed_groups._get_sequence_parallel_world_size = _get_sequence_parallel_world_size + try: + yield + finally: + deepspeed_groups._get_sequence_parallel_world_size = _old_get_sequence_parallel_world_size + else: + context = contextlib.nullcontext + with context(): + yield diff --git a/boson_multimodal/serve/__pycache__/serve_engine.cpython-311.pyc b/boson_multimodal/serve/__pycache__/serve_engine.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc36b91e875eef1842aa25093cb3cd725ef762a9 Binary files /dev/null and b/boson_multimodal/serve/__pycache__/serve_engine.cpython-311.pyc differ diff --git a/boson_multimodal/serve/serve_engine.py b/boson_multimodal/serve/serve_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..a4a32070f8c7b4b932ffcf97b11647b7ce793a09 --- /dev/null +++ b/boson_multimodal/serve/serve_engine.py @@ -0,0 +1,423 @@ +import asyncio +import base64 +import torch +import numpy as np +from io import BytesIO +from dataclasses import dataclass +from typing import List, Optional, Union +from copy import deepcopy +from transformers import AutoTokenizer, AutoProcessor +from transformers.cache_utils import StaticCache +from transformers.generation.streamers import BaseStreamer +from transformers.generation.stopping_criteria import StoppingCriteria +from dataclasses import asdict +from loguru import logger +import threading +import librosa + + +from ..dataset.chatml_dataset import ChatMLSample, ChatMLDatasetSample, prepare_chatml_sample +from ..model.higgs_audio import HiggsAudioModel +from ..model.higgs_audio.utils import revert_delay_pattern +from ..data_collator.higgs_audio_collator import HiggsAudioSampleCollator +from ..audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer + + +@dataclass +class HiggsAudioStreamerDelta: + """Represents a chunk of generated content, either text or audio tokens.""" + + text: Optional[str] = None + text_tokens: Optional[torch.Tensor] = None + audio_tokens: Optional[torch.Tensor] = None + finish_reason: Optional[str] = None + + +class AsyncHiggsAudioStreamer(BaseStreamer): + """ + Async streamer that handles both text and audio token generation from Higgs-Audio model. + Stores chunks in a queue to be consumed by downstream applications. + + Parameters: + tokenizer (`AutoTokenizer`): + The tokenizer used to decode text tokens. + skip_prompt (`bool`, *optional*, defaults to `False`): + Whether to skip the prompt tokens in generation. + timeout (`float`, *optional*): + The timeout for the queue. If `None`, the queue will block indefinitely. + decode_kwargs (`dict`, *optional*): + Additional keyword arguments to pass to the tokenizer's `decode` method. + + Examples: + ```python + >>> from transformers import AutoTokenizer + >>> from threading import Thread + >>> import asyncio + + >>> tokenizer = AutoTokenizer.from_pretrained("path/to/higgs/tokenizer") + >>> model = HiggsAudioModel.from_pretrained("path/to/higgs/model") + >>> inputs = tokenizer(["Generate some text and audio:"], return_tensors="pt") + + >>> async def main(): + ... streamer = AsyncHiggsAudioStreamer(tokenizer) + ... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20) + ... thread = Thread(target=model.generate, kwargs=generation_kwargs) + ... thread.start() + ... + ... async for delta in streamer: + ... if delta.text is not None: + ... print("Text:", delta.text) + ... if delta.audio_tokens is not None: + ... print("Audio tokens shape:", delta.audio_tokens.shape) + >>> asyncio.run(main()) + ``` + """ + + def __init__( + self, + tokenizer: "AutoTokenizer", + skip_prompt: bool = False, + timeout: Optional[float] = None, + audio_num_codebooks: int = 1, + **decode_kwargs, + ): + self.tokenizer = tokenizer + self.skip_prompt = skip_prompt + self.timeout = timeout + self.decode_kwargs = decode_kwargs + self.audio_num_codebooks = audio_num_codebooks + # Queue to store generated chunks + self.queue = asyncio.Queue() + self.stop_signal = None + + # Get running event loop + self.loop = asyncio.get_running_loop() + self.has_asyncio_timeout = hasattr(asyncio, "timeout") + + # State tracking + self.next_tokens_are_prompt = True + + def put(self, value: torch.Tensor): + """ + Receives tokens and processes them as either text or audio tokens. + For text tokens, decodes and caches them until complete words are formed. + For audio tokens, directly queues them. + """ + if value.shape[0] > 1 and not self.next_tokens_are_prompt: + # This is likely audio tokens (shape: [audio_num_codebooks]) + assert value.shape[0] == self.audio_num_codebooks, "Number of codebooks mismatch" + delta = HiggsAudioStreamerDelta(audio_tokens=value) + self.loop.call_soon_threadsafe(self.queue.put_nowait, delta) + return + + # Skip prompt tokens if configured + if self.skip_prompt and self.next_tokens_are_prompt: + self.next_tokens_are_prompt = False + return + + # Process as text tokens + if len(value.shape) > 1: + value = value[0] + + text = self.tokenizer.decode(value, **self.decode_kwargs) + delta = HiggsAudioStreamerDelta(text=text, text_tokens=value) + self.loop.call_soon_threadsafe(self.queue.put_nowait, delta) + + def end(self): + """Flushes any remaining text tokens and signals the end of generation.""" + self.next_tokens_are_prompt = True + self.loop.call_soon_threadsafe(self.queue.put_nowait, self.stop_signal) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + if self.has_asyncio_timeout: + async with asyncio.timeout(self.timeout): + value = await self.queue.get() + else: + value = await asyncio.wait_for(self.queue.get(), timeout=self.timeout) + except asyncio.TimeoutError: + raise TimeoutError() + else: + if value == self.stop_signal: + raise StopAsyncIteration() + else: + return value + + +class AsyncStoppingCriteria(StoppingCriteria): + """ + Stopping criteria that checks for stop signal from a threading event. + + Args: + stop_signal (threading.Event): Event that will receive stop signals + """ + + def __init__(self, stop_signal: threading.Event): + self.stop_signal = stop_signal + + def __call__(self, input_ids, scores, **kwargs) -> bool: + if self.stop_signal.is_set(): + logger.info(f"Stop signal received. Can be caused by client disconnection.") + return True + return False + + +@dataclass +class HiggsAudioResponse: + audio: Optional[np.ndarray] = None + generated_audio_tokens: Optional[np.ndarray] = None + sampling_rate: Optional[int] = None + generated_text: str = "" + generated_text_tokens: Optional[np.ndarray] = None + usage: Optional[dict] = None + + +class HiggsAudioServeEngine: + def __init__( + self, + model_name_or_path: str, + audio_tokenizer_name_or_path: str, + tokenizer_name_or_path: Optional[str] = None, + device: str = "cuda", + torch_dtype: Union[torch.dtype, str] = "auto", + kv_cache_lengths: List[int] = [1024, 4096, 8192], # Multiple KV cache sizes + ): + """ + Initialize the HiggsAudioServeEngine, a serving wrapper for the HiggsAudioModel. + The model, tokenizer, and audio tokenizer will be downloaded from the Hugging Face Hub if they are not local. + + Args: + model_name_or_path (str): + The name or path of the model to load. + audio_tokenizer_name_or_path (str): + The name or path of the audio tokenizer to load. + tokenizer_name_or_path (str): + The name or path of the tokenizer to load. + device (str): + The device to use for the model. + kv_cache_lengths (List[int]): + The lengths of the KV caches to use for the model. Used for cuda graph capture when device is cuda. + torch_dtype (Union[torch.dtype, str]): + The dtype to use for the model. + """ + self.device = device + self.model_name_or_path = model_name_or_path + self.torch_dtype = torch_dtype + + # Initialize model and tokenizer + self.model = HiggsAudioModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).to(device) + logger.info(f"Loaded model from {model_name_or_path}, dtype: {self.model.dtype}") + + if tokenizer_name_or_path is None: + tokenizer_name_or_path = model_name_or_path + logger.info(f"Loading tokenizer from {tokenizer_name_or_path}") + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + + logger.info(f"Initializing Higgs Audio Tokenizer") + self.audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer_name_or_path, device=device) + + self.audio_num_codebooks = self.model.config.audio_num_codebooks + self.audio_codebook_size = self.model.config.audio_codebook_size + self.audio_tokenizer_tps = self.audio_tokenizer.tps + self.samples_per_token = int(self.audio_tokenizer.sampling_rate // self.audio_tokenizer_tps) + self.hamming_window_len = 2 * self.audio_num_codebooks * self.samples_per_token + # Set the audio special tokens + self.model.set_audio_special_tokens(self.tokenizer) + + # Prepare KV caches for different lengths + cache_config = deepcopy(self.model.config.text_config) + cache_config.num_hidden_layers = self.model.config.text_config.num_hidden_layers + if self.model.config.audio_dual_ffn_layers: + cache_config.num_hidden_layers += len(self.model.config.audio_dual_ffn_layers) + # A list of KV caches for different lengths + self.kv_caches = { + length: StaticCache( + config=cache_config, + max_batch_size=1, + max_cache_len=length, + device=self.model.device, + dtype=self.model.dtype, + ) + for length in sorted(kv_cache_lengths) + } + + if self.model.config.encode_whisper_embed: + logger.info(f"Loading whisper processor") + whisper_processor = AutoProcessor.from_pretrained( + "openai/whisper-large-v3-turbo", + trust_remote=True, + device=self.device, + ) + else: + whisper_processor = None + + # Reuse collator to prepare inference samples + self.collator = HiggsAudioSampleCollator( + whisper_processor=whisper_processor, + encode_whisper_embed=self.model.config.encode_whisper_embed, + audio_in_token_id=self.model.config.audio_in_token_idx, + audio_out_token_id=self.model.config.audio_out_token_idx, + audio_stream_bos_id=self.model.config.audio_stream_bos_id, + audio_stream_eos_id=self.model.config.audio_stream_eos_id, + pad_token_id=self.model.config.pad_token_id, + return_audio_in_tokens=False, + use_delay_pattern=self.model.config.use_delay_pattern, + audio_num_codebooks=self.model.config.audio_num_codebooks, + round_to=1, + ) + + # Capture CUDA graphs for each KV cache length + if device == "cuda": + logger.info(f"Capturing CUDA graphs for each KV cache length") + self.model.capture_model(self.kv_caches.values()) + + def _prepare_inputs(self, chat_ml_sample: ChatMLSample, force_audio_gen: bool = False): + input_tokens, _, audio_contents, _ = prepare_chatml_sample( + chat_ml_sample, + self.tokenizer, + ) + + postfix = "<|start_header_id|>assistant<|end_header_id|>\n\n" + if force_audio_gen: + postfix += "<|audio_out_bos|>" + postfix = self.tokenizer.encode(postfix, add_special_tokens=False) + input_tokens.extend(postfix) + + # Configure the audio inputs + audio_ids_l = [] + for audio_content in audio_contents: + if audio_content.audio_url not in ["placeholder", ""]: + raw_audio, _ = librosa.load(audio_content.audio_url, sr=self.audio_tokenizer.sampling_rate) + elif audio_content.raw_audio is not None: + raw_audio, _ = librosa.load( + BytesIO(base64.b64decode(audio_content.raw_audio)), sr=self.audio_tokenizer.sampling_rate + ) + else: + raw_audio = None + + if raw_audio is not None: + audio_ids = self.audio_tokenizer.encode(raw_audio, self.audio_tokenizer.sampling_rate) + audio_ids_l.append(audio_ids.squeeze(0).cpu()) + + if len(audio_ids_l) > 0: + audio_ids_start = torch.tensor( + np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])), + dtype=torch.long, + device=self.device, + )[0:-1] + audio_ids_concat = torch.cat(audio_ids_l, dim=1) + else: + audio_ids_start = None + audio_ids_concat = None + + sample = ChatMLDatasetSample( + input_ids=torch.LongTensor(input_tokens), + label_ids=None, + audio_ids_concat=audio_ids_concat, + audio_ids_start=audio_ids_start, + audio_waveforms_concat=None, + audio_waveforms_start=None, + audio_sample_rate=None, + audio_speaker_indices=None, + ) + data = self.collator([sample]) + inputs = asdict(data) + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + inputs[k] = v.to(self.model.device) + + return inputs + + def _prepare_kv_caches(self): + for kv_cache in self.kv_caches.values(): + kv_cache.reset() + + def generate( + self, + chat_ml_sample: ChatMLSample, + max_new_tokens: int, + temperature: float = 0.7, + top_k: Optional[int] = None, + top_p: float = 0.95, + stop_strings: Optional[List[str]] = None, + force_audio_gen: bool = False, + ras_win_len: Optional[int] = 7, + ras_win_max_num_repeat: int = 2, + seed: Optional[int] = None, + ): + """ + Generate audio from a chatml sample. + Args: + chat_ml_sample: A chatml sample. + max_new_tokens: The maximum number of new tokens to generate. + temperature: The temperature to use for the generation. + top_p: The top p to use for the generation. + stop_strings: A list of strings to stop the generation. + force_audio_gen: Whether to force audio generation. This ensures the model generates audio tokens rather than text tokens. + ras_win_len: The length of the RAS window. We use 7 by default. You can disable it by setting it to None or <=0. + ras_win_max_num_repeat: The maximum number of times to repeat the RAS window. + Returns: + A dictionary with the following keys: + audio: The generated audio. + sampling_rate: The sampling rate of the generated audio. + """ + # Default stop strings + if stop_strings is None: + stop_strings = ["<|end_of_text|>", "<|eot_id|>"] + if ras_win_len is not None and ras_win_len <= 0: + ras_win_len = None + + with torch.no_grad(): + inputs = self._prepare_inputs(chat_ml_sample, force_audio_gen=force_audio_gen) + prompt_token_ids = inputs["input_ids"][0].cpu().numpy() + + self._prepare_kv_caches() + + outputs = self.model.generate( + **inputs, + max_new_tokens=max_new_tokens, + use_cache=True, + stop_strings=stop_strings, + tokenizer=self.tokenizer, + do_sample=False if temperature == 0.0 else True, + temperature=temperature, + top_k=top_k, + top_p=top_p, + past_key_values_buckets=self.kv_caches, + ras_win_len=ras_win_len, + ras_win_max_num_repeat=ras_win_max_num_repeat, + seed=seed, + ) + + if len(outputs[1]) > 0: + wv_list = [] + for output_audio in outputs[1]: + vq_code = revert_delay_pattern(output_audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1] + wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0] + wv_list.append(wv_numpy) + wv_numpy = np.concatenate(wv_list) + else: + wv_numpy = None + + # We only support one request at a time now + generated_text_tokens = outputs[0][0].cpu().numpy()[len(prompt_token_ids) :] + generated_text = self.tokenizer.decode(generated_text_tokens) + generated_audio_tokens = outputs[1][0].cpu().numpy() + return HiggsAudioResponse( + audio=wv_numpy, + generated_audio_tokens=generated_audio_tokens, + sampling_rate=self.audio_tokenizer.sampling_rate, + generated_text=generated_text, + generated_text_tokens=generated_text_tokens, + usage={ + "prompt_tokens": prompt_token_ids.shape[0], + "completion_tokens": generated_text_tokens.shape[0] + generated_audio_tokens.shape[1], + "total_tokens": ( + prompt_token_ids.shape[0] + generated_text_tokens.shape[0] + generated_audio_tokens.shape[1] + ), + "cached_tokens": 0, + }, + ) diff --git a/boson_multimodal/serve/utils.py b/boson_multimodal/serve/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f25e8993f6bd53b6454140516c303279302cd2f3 --- /dev/null +++ b/boson_multimodal/serve/utils.py @@ -0,0 +1,246 @@ +import uuid +import base64 +import re +import regex +from typing import AsyncGenerator, Union +import io +from pydub import AudioSegment +import torch +import numpy as np +from functools import lru_cache + +from ..audio_processing.higgs_audio_tokenizer import HiggsAudioTokenizer + + +def random_uuid() -> str: + return str(uuid.uuid4().hex) + + +async def async_generator_wrap(first_element, gen: AsyncGenerator): + """Wrap an async generator with the first element.""" + yield first_element + async for item in gen: + yield item + + +@lru_cache(maxsize=50) +def encode_base64_content_from_file(file_path: str) -> str: + """Encode a content from a local file to base64 format.""" + # Read the MP3 file as binary and encode it directly to Base64 + with open(file_path, "rb") as audio_file: + audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8") + return audio_base64 + + +def pcm16_to_target_format( + np_audio: np.ndarray, + sample_rate: int, + bit_depth: int, + channels: int, + format: str, + target_rate: int, +): + wav_audio = AudioSegment( + np_audio.tobytes(), + frame_rate=sample_rate, + sample_width=bit_depth // 8, + channels=channels, + ) + if target_rate is not None and target_rate != sample_rate: + wav_audio = wav_audio.set_frame_rate(target_rate) + + # Convert WAV to MP3 + target_io = io.BytesIO() + wav_audio.export(target_io, format=format) + target_io.seek(0) + + return target_io + + +chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+") + + +def contains_chinese(text: str): + return bool(chinese_char_pattern.search(text)) + + +# remove blank between chinese character +def replace_blank(text: str): + out_str = [] + for i, c in enumerate(text): + if c == " ": + if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "): + out_str.append(c) + else: + out_str.append(c) + return "".join(out_str) + + +def replace_corner_mark(text: str): + text = text.replace("²", "平方") + text = text.replace("³", "立方") + return text + + +# remove meaningless symbol +def remove_bracket(text: str): + text = text.replace("(", "").replace(")", "") + text = text.replace("【", "").replace("】", "") + text = text.replace("`", "").replace("`", "") + text = text.replace("——", " ") + return text + + +# split paragrah logic: +# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len +# 2. cal sentence len according to lang +# 3. split sentence according to puncatation +def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False): + def calc_utt_length(_text: str): + if lang == "zh": + return len(_text) + else: + return len(tokenize(_text)) + + def should_merge(_text: str): + if lang == "zh": + return len(_text) < merge_len + else: + return len(tokenize(_text)) < merge_len + + if lang == "zh": + pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"] + else: + pounc = [".", "?", "!", ";", ":"] + if comma_split: + pounc.extend([",", ","]) + + if text[-1] not in pounc: + if lang == "zh": + text += "。" + else: + text += "." + + st = 0 + utts = [] + for i, c in enumerate(text): + if c in pounc: + if len(text[st:i]) > 0: + utts.append(text[st:i] + c) + if i + 1 < len(text) and text[i + 1] in ['"', "”"]: + tmp = utts.pop(-1) + utts.append(tmp + text[i + 1]) + st = i + 2 + else: + st = i + 1 + + final_utts = [] + cur_utt = "" + for utt in utts: + if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: + final_utts.append(cur_utt) + cur_utt = "" + cur_utt = cur_utt + utt + if len(cur_utt) > 0: + if should_merge(cur_utt) and len(final_utts) != 0: + final_utts[-1] = final_utts[-1] + cur_utt + else: + final_utts.append(cur_utt) + + return final_utts + + +def is_only_punctuation(text: str): + # Regular expression: Match strings that consist only of punctuation marks or are empty. + punctuation_pattern = r"^[\p{P}\p{S}]*$" + return bool(regex.fullmatch(punctuation_pattern, text)) + + +# spell Arabic numerals +def spell_out_number(text: str, inflect_parser): + new_text = [] + st = None + for i, c in enumerate(text): + if not c.isdigit(): + if st is not None: + num_str = inflect_parser.number_to_words(text[st:i]) + new_text.append(num_str) + st = None + new_text.append(c) + else: + if st is None: + st = i + if st is not None and st < len(text): + num_str = inflect_parser.number_to_words(text[st:]) + new_text.append(num_str) + return "".join(new_text) + + +def remove_emoji(text: str): + # Pattern to match emojis and their modifiers + # - Standard emoji range + # - Zero-width joiners (U+200D) + # - Variation selectors (U+FE0F, U+FE0E) + # - Skin tone modifiers (U+1F3FB to U+1F3FF) + emoji_pattern = re.compile( + r"[" + r"\U00010000-\U0010FFFF" # Standard emoji range + r"\u200D" # Zero-width joiner + r"\uFE0F\uFE0E" # Variation selectors + r"\U0001F3FB-\U0001F3FF" # Skin tone modifiers + r"]+", + flags=re.UNICODE, + ) + return emoji_pattern.sub(r"", text) + + +def remove_repeated_punctuations(text, punctuations): + if len(punctuations) == 0: + return text + pattern = f"[{re.escape(''.join(punctuations))}]" # Create regex pattern for given punctuations + return re.sub(rf"({pattern})\1+", r"\1", text) + + +def full_to_half_width(text: str) -> str: + """Convert full-width punctuation to half-width in a given string.""" + full_width = "!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~" + half_width = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" + trans_table = str.maketrans(full_width, half_width) + return text.translate(trans_table) + + +def split_interleaved_delayed_audios( + audio_data: Union[list[list[int]], torch.Tensor], + audio_tokenizer: HiggsAudioTokenizer, + audio_stream_eos_id: int, +) -> list[tuple[list[list[int]], torch.Tensor]]: + separator = [audio_stream_eos_id] * audio_tokenizer.num_codebooks + + # Convert separator to numpy array if audio_data is numpy array + if isinstance(audio_data, torch.Tensor): + audio_data = audio_data.transpose(1, 0) + separator = torch.tensor(separator) + # Find the indices where the rows equal the separator + split_indices = torch.where(torch.all(audio_data == separator, dim=1))[0] + start = 0 + groups = [] + for idx in split_indices: + groups.append(audio_data[start:idx].transpose(1, 0)) + start = idx + 1 + if start < len(audio_data): + groups.append(audio_data[start:].transpose(1, 0)) + else: + groups = [] + current = [] + for row in audio_data: + current.append(row) + + if row == separator: + groups.append(current) + current = [] + + # Don't forget the last group if there's no trailing separator + if current: + groups.append(current) + + return groups diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000000000000000000000000000000000000..34e408bc3ee6810269a403a4f31e991ad22626f5 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,25 @@ +# Configuration for Cog ⚙️ +# Reference: https://cog.run/yaml + +build: + # set to true if your model requires a GPU + gpu: true + + # a list of ubuntu apt packages to install + system_packages: + - "libsndfile1" + - "ffmpeg" + + # python version in the form '3.11' or '3.11.4' + python_version: "3.11" + + # path to a Python requirements.txt file + python_requirements: requirements.txt + + # cog_runtime: true + + run: + - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget + +# predict.py defines how predictions are run on your model +predict: "predict.py:Predictor" diff --git a/examples/.DS_Store b/examples/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..ca9153fc7294035e3c0e09306468964c9a6d2b0a Binary files /dev/null and b/examples/.DS_Store differ diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..276a8ef7b1210591e03c33de6ed53bdc05c363ac --- /dev/null +++ b/examples/README.md @@ -0,0 +1,166 @@ +# Examples + +> [!NOTE] +> If you do not like the audio you get, you can generate multiple times with different seeds. In addition, you may need to apply text normalization to get the best performance, e.g. converting 70 °F to "seventy degrees Fahrenheit", and converting "1 2 3 4" to "one two three four". The model also performs better in longer sentences. Right now, the model has not been post-trained, we will release the post-trained model in the future. + +## Single-speaker Audio Generation + +### Voice clone + +```bash +python3 generation.py \ +--transcript transcript/single_speaker/en_dl.txt \ +--ref_audio broom_salesman \ +--seed 12345 \ +--out_path generation.wav +``` + +The model will read the transcript with the same voice as in the [reference audio](./voice_prompts/broom_salesman.wav). The technique is also called shallow voice clone. + +We have some example audio prompts stored in [voice_prompts](./voice_prompts/). Feel free to pick one in the folder and try out the model. Here's another example that uses the voice of `belinda`. You can also add new own favorite voice in the folder and clone the voice. + +```bash +python3 generation.py \ +--transcript transcript/single_speaker/en_dl.txt \ +--ref_audio belinda \ +--seed 12345 \ +--out_path generation.wav +``` + +#### (Experimental) Cross-lingual voice clone + +This example demonstrates voice cloning with a Chinese prompt, where the synthesized speech is in English. + +```bash +python3 generation.py \ +--transcript transcript/single_speaker/en_dl.txt \ +--scene_prompt empty \ +--ref_audio zh_man_sichuan \ +--temperature 0.3 \ +--seed 12345 \ +--out_path generation.wav +``` + +### Smart voice + +The model supports reading the transcript with a random voice. + +```bash +python3 generation.py \ +--transcript transcript/single_speaker/en_dl.txt \ +--seed 12345 \ +--out_path generation.wav +``` + +It also works for other languages like Chinese. + +```bash +python3 generation.py \ +--transcript transcript/single_speaker/zh_ai.txt \ +--seed 12345 \ +--out_path generation.wav +``` + +### Describe speaker characteristics with text + +The model allows you to describe the speaker via text. See [voice_prompts/profile.yaml](voice_prompts/profile.yaml) for examples. You can run the following two examples that try to specify male / female British accent for the speakers. Also, try to remove the `--seed 12345` flag to see how the model is generating different voices. + +```bash +# Male British Accent +python3 generation.py \ +--transcript transcript/single_speaker/en_dl.txt \ +--ref_audio profile:male_en_british \ +--seed 12345 \ +--out_path generation.wav + +# Female British Accent +python3 generation.py \ +--transcript transcript/single_speaker/en_dl.txt \ +--ref_audio profile:female_en_british \ +--seed 12345 \ +--out_path generation.wav +``` + +### Chunking for long-form audio generation + +To generate long-form audios, you can chunk the text and render each chunk one by one while putting the previous generated audio and the reference audio in the prompt. Here's an example that generates the first five paragraphs of Higgs Audio v1 release blog. See [text](./transcript/single_speaker/en_higgs_audio_blog.md). + +```bash +python3 generation.py \ +--scene_prompt scene_prompts/reading_blog.txt \ +--transcript transcript/single_speaker/en_higgs_audio_blog.md \ +--ref_audio en_man \ +--chunk_method word \ +--temperature 0.3 \ +--generation_chunk_buffer_size 2 \ +--seed 12345 \ +--out_path generation.wav +``` + +### Experimental and Emergent Capabilities + +As shown in our demo, the pretrained model is demonstrating emergent features. We prepared some samples to help you explore these experimental prompts. We will enhance the stability of these experimental prompts in the future version of HiggsAudio. + +#### (Experimental) Hum a tune with the cloned voice +The model is able to hum a tune with the cloned voice. + +```bash +python3 generation.py \ +--transcript transcript/single_speaker/experimental/en_humming.txt \ +--ref_audio en_woman \ +--ras_win_len 0 \ +--seed 12345 \ +--out_path generation.wav +``` + +#### (Experimental) Read the sentence while adding background music (BGM) + +```bash +python3 generation.py \ +--transcript transcript/single_speaker/experimental/en_bgm.txt \ +--ref_audio en_woman \ +--ras_win_len 0 \ +--ref_audio_in_system_message \ +--seed 123456 \ +--out_path generation.wav +``` + +## Multi-speaker Audio Generation + + +### Smart voice + +To get started to explore HiggsAudio's capability in generating multi-speaker audios. Let's try to generate a multi-speaker dialog from transcript in the zero-shot fashion. See the transcript in [transcript/multi_speaker/en_argument.txt](transcript/multi_speaker/en_argument.txt). The speakers are annotated with `[SPEAKER0]` and `[SPEAKER1]`. + +```bash +python3 generation.py \ +--transcript transcript/multi_speaker/en_argument.txt \ +--seed 12345 \ +--out_path generation.wav +``` + +### Multi-voice clone +You can also try to clone the voices from multiple people simultaneously and generate audio about the transcript. Here's an example that puts reference audios in the system message and prompt the model iteratively. You can hear "Belinda" arguing with "Broom Salesman". + +```bash +python3 generation.py \ +--transcript transcript/multi_speaker/en_argument.txt \ +--ref_audio belinda,broom_salesman \ +--ref_audio_in_system_message \ +--chunk_method speaker \ +--seed 12345 \ +--out_path generation.wav +``` + +You can also let "Broom Salesman" talking to "Belinda", who recently trained HiggsAudio. + +```bash +python3 generation.py \ +--transcript transcript/multi_speaker/en_higgs.txt \ +--ref_audio broom_salesman,belinda \ +--ref_audio_in_system_message \ +--chunk_method speaker \ +--chunk_max_num_turns 2 \ +--seed 12345 \ +--out_path generation.wav +``` diff --git a/examples/generation.py b/examples/generation.py new file mode 100644 index 0000000000000000000000000000000000000000..a765ca5773252646674af6cc7332a77c8a404049 --- /dev/null +++ b/examples/generation.py @@ -0,0 +1,718 @@ +"""Example script for generating audio using HiggsAudio.""" + +import click +import soundfile as sf +import langid +import jieba +import os +import re +import copy +import torchaudio +import tqdm +import yaml + +from loguru import logger +from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse +from boson_multimodal.data_types import Message, ChatMLSample, AudioContent, TextContent + +from boson_multimodal.model.higgs_audio import HiggsAudioConfig, HiggsAudioModel +from boson_multimodal.data_collator.higgs_audio_collator import HiggsAudioSampleCollator +from boson_multimodal.audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer +from boson_multimodal.dataset.chatml_dataset import ( + ChatMLDatasetSample, + prepare_chatml_sample, +) +from boson_multimodal.model.higgs_audio.utils import revert_delay_pattern +from typing import List +from transformers import AutoConfig, AutoTokenizer +from transformers.cache_utils import StaticCache +from typing import Optional +from dataclasses import asdict +import torch + +CURR_DIR = os.path.dirname(os.path.abspath(__file__)) + + +AUDIO_PLACEHOLDER_TOKEN = "<|__AUDIO_PLACEHOLDER__|>" + + +MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE = """You are an AI assistant designed to convert text into speech. +If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice. +If no speaker tag is present, select a suitable voice on your own.""" + + +def normalize_chinese_punctuation(text): + """ + Convert Chinese (full-width) punctuation marks to English (half-width) equivalents. + """ + # Mapping of Chinese punctuation to English punctuation + chinese_to_english_punct = { + ",": ", ", # comma + "。": ".", # period + ":": ":", # colon + ";": ";", # semicolon + "?": "?", # question mark + "!": "!", # exclamation mark + "(": "(", # left parenthesis + ")": ")", # right parenthesis + "【": "[", # left square bracket + "】": "]", # right square bracket + "《": "<", # left angle quote + "》": ">", # right angle quote + "“": '"', # left double quotation + "”": '"', # right double quotation + "‘": "'", # left single quotation + "’": "'", # right single quotation + "、": ",", # enumeration comma + "—": "-", # em dash + "…": "...", # ellipsis + "·": ".", # middle dot + "「": '"', # left corner bracket + "」": '"', # right corner bracket + "『": '"', # left double corner bracket + "』": '"', # right double corner bracket + } + + # Replace each Chinese punctuation with its English counterpart + for zh_punct, en_punct in chinese_to_english_punct.items(): + text = text.replace(zh_punct, en_punct) + + return text + + +def prepare_chunk_text( + text, chunk_method: Optional[str] = None, chunk_max_word_num: int = 100, chunk_max_num_turns: int = 1 +): + """Chunk the text into smaller pieces. We will later feed the chunks one by one to the model. + + Parameters + ---------- + text : str + The text to be chunked. + chunk_method : str, optional + The method to use for chunking. Options are "speaker", "word", or None. By default, we won't use any chunking and + will feed the whole text to the model. + replace_speaker_tag_with_special_tags : bool, optional + Whether to replace speaker tags with special tokens, by default False + If the flag is set to True, we will replace [SPEAKER0] with <|speaker_id_start|>SPEAKER0<|speaker_id_end|> + chunk_max_word_num : int, optional + The maximum number of words for each chunk when "word" chunking method is used, by default 100 + chunk_max_num_turns : int, optional + The maximum number of turns for each chunk when "speaker" chunking method is used, + + Returns + ------- + List[str] + The list of text chunks. + + """ + if chunk_method is None: + return [text] + elif chunk_method == "speaker": + lines = text.split("\n") + speaker_chunks = [] + speaker_utterance = "" + for line in lines: + line = line.strip() + if line.startswith("[SPEAKER") or line.startswith("<|speaker_id_start|>"): + if speaker_utterance: + speaker_chunks.append(speaker_utterance.strip()) + speaker_utterance = line + else: + if speaker_utterance: + speaker_utterance += "\n" + line + else: + speaker_utterance = line + if speaker_utterance: + speaker_chunks.append(speaker_utterance.strip()) + if chunk_max_num_turns > 1: + merged_chunks = [] + for i in range(0, len(speaker_chunks), chunk_max_num_turns): + merged_chunk = "\n".join(speaker_chunks[i : i + chunk_max_num_turns]) + merged_chunks.append(merged_chunk) + return merged_chunks + return speaker_chunks + elif chunk_method == "word": + # TODO: We may improve the logic in the future + # For long-form generation, we will first divide the text into multiple paragraphs by splitting with "\n\n" + # After that, we will chunk each paragraph based on word count + language = langid.classify(text)[0] + paragraphs = text.split("\n\n") + chunks = [] + for idx, paragraph in enumerate(paragraphs): + if language == "zh": + # For Chinese, we will chunk based on character count + words = list(jieba.cut(paragraph, cut_all=False)) + for i in range(0, len(words), chunk_max_word_num): + chunk = "".join(words[i : i + chunk_max_word_num]) + chunks.append(chunk) + else: + words = paragraph.split(" ") + for i in range(0, len(words), chunk_max_word_num): + chunk = " ".join(words[i : i + chunk_max_word_num]) + chunks.append(chunk) + chunks[-1] += "\n\n" + return chunks + else: + raise ValueError(f"Unknown chunk method: {chunk_method}") + + +def _build_system_message_with_audio_prompt(system_message): + contents = [] + + while AUDIO_PLACEHOLDER_TOKEN in system_message: + loc = system_message.find(AUDIO_PLACEHOLDER_TOKEN) + contents.append(TextContent(system_message[:loc])) + contents.append(AudioContent(audio_url="")) + system_message = system_message[loc + len(AUDIO_PLACEHOLDER_TOKEN) :] + + if len(system_message) > 0: + contents.append(TextContent(system_message)) + ret = Message( + role="system", + content=contents, + ) + return ret + + +class HiggsAudioModelClient: + def __init__( + self, + model_path, + audio_tokenizer, + device_id=None, + max_new_tokens=2048, + kv_cache_lengths: List[int] = [1024, 4096, 8192], # Multiple KV cache sizes, + use_static_kv_cache=False, + ): + if device_id is None: + self._device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self._device = f"cuda:{device_id}" + self._audio_tokenizer = ( + load_higgs_audio_tokenizer(audio_tokenizer, device=self._device) + if isinstance(audio_tokenizer, str) + else audio_tokenizer + ) + self._model = HiggsAudioModel.from_pretrained( + model_path, + device_map=self._device, + torch_dtype=torch.bfloat16, + ) + self._model.eval() + self._kv_cache_lengths = kv_cache_lengths + self._use_static_kv_cache = use_static_kv_cache + + self._tokenizer = AutoTokenizer.from_pretrained(model_path) + self._config = AutoConfig.from_pretrained(model_path) + self._max_new_tokens = max_new_tokens + self._collator = HiggsAudioSampleCollator( + whisper_processor=None, + audio_in_token_id=self._config.audio_in_token_idx, + audio_out_token_id=self._config.audio_out_token_idx, + audio_stream_bos_id=self._config.audio_stream_bos_id, + audio_stream_eos_id=self._config.audio_stream_eos_id, + encode_whisper_embed=self._config.encode_whisper_embed, + pad_token_id=self._config.pad_token_id, + return_audio_in_tokens=self._config.encode_audio_in_tokens, + use_delay_pattern=self._config.use_delay_pattern, + round_to=1, + audio_num_codebooks=self._config.audio_num_codebooks, + ) + self.kv_caches = None + if use_static_kv_cache: + self._init_static_kv_cache() + + def _init_static_kv_cache(self): + cache_config = copy.deepcopy(self._model.config.text_config) + cache_config.num_hidden_layers = self._model.config.text_config.num_hidden_layers + if self._model.config.audio_dual_ffn_layers: + cache_config.num_hidden_layers += len(self._model.config.audio_dual_ffn_layers) + # A list of KV caches for different lengths + self.kv_caches = { + length: StaticCache( + config=cache_config, + max_batch_size=1, + max_cache_len=length, + device=self._model.device, + dtype=self._model.dtype, + ) + for length in sorted(self._kv_cache_lengths) + } + # Capture CUDA graphs for each KV cache length + if "cuda" in self._device: + logger.info(f"Capturing CUDA graphs for each KV cache length") + self._model.capture_model(self.kv_caches.values()) + + def _prepare_kv_caches(self): + for kv_cache in self.kv_caches.values(): + kv_cache.reset() + + @torch.inference_mode() + def generate( + self, + messages, + audio_ids, + chunked_text, + generation_chunk_buffer_size, + temperature=1.0, + top_k=50, + top_p=0.95, + ras_win_len=7, + ras_win_max_num_repeat=2, + seed=123, + *args, + **kwargs, + ): + if ras_win_len is not None and ras_win_len <= 0: + ras_win_len = None + sr = 24000 + audio_out_ids_l = [] + generated_audio_ids = [] + generation_messages = [] + for idx, chunk_text in tqdm.tqdm( + enumerate(chunked_text), desc="Generating audio chunks", total=len(chunked_text) + ): + generation_messages.append( + Message( + role="user", + content=chunk_text, + ) + ) + chatml_sample = ChatMLSample(messages=messages + generation_messages) + input_tokens, _, _, _ = prepare_chatml_sample(chatml_sample, self._tokenizer) + postfix = self._tokenizer.encode( + "<|start_header_id|>assistant<|end_header_id|>\n\n", add_special_tokens=False + ) + input_tokens.extend(postfix) + + logger.info(f"========= Chunk {idx} Input =========") + logger.info(self._tokenizer.decode(input_tokens)) + context_audio_ids = audio_ids + generated_audio_ids + + curr_sample = ChatMLDatasetSample( + input_ids=torch.LongTensor(input_tokens), + label_ids=None, + audio_ids_concat=torch.concat([ele.cpu() for ele in context_audio_ids], dim=1) + if context_audio_ids + else None, + audio_ids_start=torch.cumsum( + torch.tensor([0] + [ele.shape[1] for ele in context_audio_ids], dtype=torch.long), dim=0 + ) + if context_audio_ids + else None, + audio_waveforms_concat=None, + audio_waveforms_start=None, + audio_sample_rate=None, + audio_speaker_indices=None, + ) + + batch_data = self._collator([curr_sample]) + batch = asdict(batch_data) + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.contiguous().to(self._device) + + if self._use_static_kv_cache: + self._prepare_kv_caches() + + # Generate audio + outputs = self._model.generate( + **batch, + max_new_tokens=self._max_new_tokens, + use_cache=True, + do_sample=True, + temperature=temperature, + top_k=top_k, + top_p=top_p, + past_key_values_buckets=self.kv_caches, + ras_win_len=ras_win_len, + ras_win_max_num_repeat=ras_win_max_num_repeat, + stop_strings=["<|end_of_text|>", "<|eot_id|>"], + tokenizer=self._tokenizer, + seed=seed, + ) + + step_audio_out_ids_l = [] + for ele in outputs[1]: + audio_out_ids = ele + if self._config.use_delay_pattern: + audio_out_ids = revert_delay_pattern(audio_out_ids) + step_audio_out_ids_l.append(audio_out_ids.clip(0, self._audio_tokenizer.codebook_size - 1)[:, 1:-1]) + audio_out_ids = torch.concat(step_audio_out_ids_l, dim=1) + audio_out_ids_l.append(audio_out_ids) + generated_audio_ids.append(audio_out_ids) + + generation_messages.append( + Message( + role="assistant", + content=AudioContent(audio_url=""), + ) + ) + if generation_chunk_buffer_size is not None and len(generated_audio_ids) > generation_chunk_buffer_size: + generated_audio_ids = generated_audio_ids[-generation_chunk_buffer_size:] + generation_messages = generation_messages[(-2 * generation_chunk_buffer_size) :] + + logger.info(f"========= Final Text output =========") + logger.info(self._tokenizer.decode(outputs[0][0])) + concat_audio_out_ids = torch.concat(audio_out_ids_l, dim=1) + concat_wv = self._audio_tokenizer.decode(concat_audio_out_ids.unsqueeze(0))[0, 0] + text_result = self._tokenizer.decode(outputs[0][0]) + return concat_wv, sr, text_result + + +def prepare_generation_context(scene_prompt, ref_audio, ref_audio_in_system_message, audio_tokenizer, speaker_tags): + """Prepare the context for generation. + + The context contains the system message, user message, assistant message, and audio prompt if any. + """ + system_message = None + messages = [] + audio_ids = [] + if ref_audio is not None: + num_speakers = len(ref_audio.split(",")) + speaker_info_l = ref_audio.split(",") + voice_profile = None + if any([speaker_info.startswith("profile:") for speaker_info in ref_audio.split(",")]): + ref_audio_in_system_message = True + if ref_audio_in_system_message: + speaker_desc = [] + for spk_id, character_name in enumerate(speaker_info_l): + if character_name.startswith("profile:"): + if voice_profile is None: + with open(f"{CURR_DIR}/voice_prompts/profile.yaml", "r", encoding="utf-8") as f: + voice_profile = yaml.safe_load(f) + character_desc = voice_profile["profiles"][character_name[len("profile:") :].strip()] + speaker_desc.append(f"SPEAKER{spk_id}: {character_desc}") + else: + speaker_desc.append(f"SPEAKER{spk_id}: {AUDIO_PLACEHOLDER_TOKEN}") + if scene_prompt: + system_message = ( + "Generate audio following instruction." + "\n\n" + f"<|scene_desc_start|>\n{scene_prompt}\n\n" + "\n".join(speaker_desc) + "\n<|scene_desc_end|>" + ) + else: + system_message = ( + "Generate audio following instruction.\n\n" + + f"<|scene_desc_start|>\n" + + "\n".join(speaker_desc) + + "\n<|scene_desc_end|>" + ) + system_message = _build_system_message_with_audio_prompt(system_message) + else: + if scene_prompt: + system_message = Message( + role="system", + content=f"Generate audio following instruction.\n\n<|scene_desc_start|>\n{scene_prompt}\n<|scene_desc_end|>", + ) + voice_profile = None + for spk_id, character_name in enumerate(ref_audio.split(",")): + if not character_name.startswith("profile:"): + prompt_audio_path = os.path.join(f"{CURR_DIR}/voice_prompts", f"{character_name}.wav") + prompt_text_path = os.path.join(f"{CURR_DIR}/voice_prompts", f"{character_name}.txt") + assert os.path.exists(prompt_audio_path), ( + f"Voice prompt audio file {prompt_audio_path} does not exist." + ) + assert os.path.exists(prompt_text_path), f"Voice prompt text file {prompt_text_path} does not exist." + with open(prompt_text_path, "r", encoding="utf-8") as f: + prompt_text = f.read().strip() + audio_tokens = audio_tokenizer.encode(prompt_audio_path) + audio_ids.append(audio_tokens) + + if not ref_audio_in_system_message: + messages.append( + Message( + role="user", + content=f"[SPEAKER{spk_id}] {prompt_text}" if num_speakers > 1 else prompt_text, + ) + ) + messages.append( + Message( + role="assistant", + content=AudioContent( + audio_url=prompt_audio_path, + ), + ) + ) + else: + if len(speaker_tags) > 1: + # By default, we just alternate between male and female voices + speaker_desc_l = [] + + for idx, tag in enumerate(speaker_tags): + if idx % 2 == 0: + speaker_desc = f"feminine" + else: + speaker_desc = f"masculine" + speaker_desc_l.append(f"{tag}: {speaker_desc}") + + speaker_desc = "\n".join(speaker_desc_l) + scene_desc_l = [] + if scene_prompt: + scene_desc_l.append(scene_prompt) + scene_desc_l.append(speaker_desc) + scene_desc = "\n\n".join(scene_desc_l) + + system_message = Message( + role="system", + content=f"{MULTISPEAKER_DEFAULT_SYSTEM_MESSAGE}\n\n<|scene_desc_start|>\n{scene_desc}\n<|scene_desc_end|>", + ) + else: + system_message_l = ["Generate audio following instruction."] + if scene_prompt: + system_message_l.append(f"<|scene_desc_start|>\n{scene_prompt}\n<|scene_desc_end|>") + system_message = Message( + role="system", + content="\n\n".join(system_message_l), + ) + if system_message: + messages.insert(0, system_message) + return messages, audio_ids + + +@click.command() +@click.option( + "--model_path", + type=str, + default="bosonai/higgs-audio-v2-generation-3B-base", + help="Output wav file path.", +) +@click.option( + "--audio_tokenizer", + type=str, + default="bosonai/higgs-audio-v2-tokenizer", + help="Audio tokenizer path, if not set, use the default one.", +) +@click.option( + "--max_new_tokens", + type=int, + default=2048, + help="The maximum number of new tokens to generate.", +) +@click.option( + "--transcript", + type=str, + default="transcript/single_speaker/en_dl.txt", + help="The prompt to use for generation. If not set, we will use a default prompt.", +) +@click.option( + "--scene_prompt", + type=str, + default=f"{CURR_DIR}/scene_prompts/quiet_indoor.txt", + help="The scene description prompt to use for generation. If not set, or set to `empty`, we will leave it to empty.", +) +@click.option( + "--temperature", + type=float, + default=1.0, + help="The value used to module the next token probabilities.", +) +@click.option( + "--top_k", + type=int, + default=50, + help="The number of highest probability vocabulary tokens to keep for top-k-filtering.", +) +@click.option( + "--top_p", + type=float, + default=0.95, + help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.", +) +@click.option( + "--ras_win_len", + type=int, + default=7, + help="The window length for RAS sampling. If set to 0 or a negative value, we won't use RAS sampling.", +) +@click.option( + "--ras_win_max_num_repeat", + type=int, + default=2, + help="The maximum number of times to repeat the RAS window. Only used when --ras_win_len is set.", +) +@click.option( + "--ref_audio", + type=str, + default=None, + help="The voice prompt to use for generation. If not set, we will let the model randomly pick a voice. " + "For multi-speaker generation, you can specify the prompts as `belinda,chadwick` and we will use the voice of belinda as SPEAKER0 and the voice of chadwick as SPEAKER1.", +) +@click.option( + "--ref_audio_in_system_message", + is_flag=True, + default=False, + help="Whether to include the voice prompt description in the system message.", + show_default=True, +) +@click.option( + "--chunk_method", + default=None, + type=click.Choice([None, "speaker", "word"]), + help="The method to use for chunking the prompt text. Options are 'speaker', 'word', or None. By default, we won't use any chunking and will feed the whole text to the model.", +) +@click.option( + "--chunk_max_word_num", + default=200, + type=int, + help="The maximum number of words for each chunk when 'word' chunking method is used. Only used when --chunk_method is set to 'word'.", +) +@click.option( + "--chunk_max_num_turns", + default=1, + type=int, + help="The maximum number of turns for each chunk when 'speaker' chunking method is used. Only used when --chunk_method is set to 'speaker'.", +) +@click.option( + "--generation_chunk_buffer_size", + default=None, + type=int, + help="The maximal number of chunks to keep in the buffer. We will always keep the reference audios, and keep `max_chunk_buffer` chunks of generated audio.", +) +@click.option( + "--seed", + default=None, + type=int, + help="Random seed for generation.", +) +@click.option( + "--device_id", + type=int, + default=None, + help="The device to run the model on.", +) +@click.option( + "--out_path", + type=str, + default="generation.wav", +) +@click.option( + "--use_static_kv_cache", + type=int, + default=1, + help="Whether to use static KV cache for faster generation. Only works when using GPU.", +) +def main( + model_path, + audio_tokenizer, + max_new_tokens, + transcript, + scene_prompt, + temperature, + top_k, + top_p, + ras_win_len, + ras_win_max_num_repeat, + ref_audio, + ref_audio_in_system_message, + chunk_method, + chunk_max_word_num, + chunk_max_num_turns, + generation_chunk_buffer_size, + seed, + device_id, + out_path, + use_static_kv_cache, +): + if device_id is None: + if torch.cuda.is_available(): + device_id = 0 + device = "cuda:0" + else: + device_id = None + device = "cpu" + else: + device = f"cuda:{device_id}" + audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer, device=device) + + model_client = HiggsAudioModelClient( + model_path=model_path, + audio_tokenizer=audio_tokenizer, + device_id=device_id, + max_new_tokens=max_new_tokens, + use_static_kv_cache=use_static_kv_cache, + ) + pattern = re.compile(r"\[(SPEAKER\d+)\]") + + if os.path.exists(transcript): + logger.info(f"Loading transcript from {transcript}") + with open(transcript, "r", encoding="utf-8") as f: + transcript = f.read().strip() + + if scene_prompt is not None and scene_prompt != "empty" and os.path.exists(scene_prompt): + with open(scene_prompt, "r", encoding="utf-8") as f: + scene_prompt = f.read().strip() + else: + scene_prompt = None + + speaker_tags = sorted(set(pattern.findall(transcript))) + + # Perform some basic normalization + transcript = normalize_chinese_punctuation(transcript) + # Other normalizations (e.g., parentheses and other symbols. Will be improved in the future) + transcript = transcript.replace("(", " ") + transcript = transcript.replace(")", " ") + transcript = transcript.replace("°F", " degrees Fahrenheit") + transcript = transcript.replace("°C", " degrees Celsius") + + for tag, replacement in [ + ("[laugh]", "[Laughter]"), + ("[humming start]", "[Humming]"), + ("[humming end]", "[Humming]"), + ("[music start]", "[Music]"), + ("[music end]", "[Music]"), + ("[music]", "[Music]"), + ("[sing start]", "[Singing]"), + ("[sing end]", "[Singing]"), + ("[applause]", "[Applause]"), + ("[cheering]", "[Cheering]"), + ("[cough]", "[Cough]"), + ]: + transcript = transcript.replace(tag, replacement) + lines = transcript.split("\n") + transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()]) + transcript = transcript.strip() + + if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "", ""]]): + transcript += "." + + messages, audio_ids = prepare_generation_context( + scene_prompt=scene_prompt, + ref_audio=ref_audio, + ref_audio_in_system_message=ref_audio_in_system_message, + audio_tokenizer=audio_tokenizer, + speaker_tags=speaker_tags, + ) + chunked_text = prepare_chunk_text( + transcript, + chunk_method=chunk_method, + chunk_max_word_num=chunk_max_word_num, + chunk_max_num_turns=chunk_max_num_turns, + ) + + logger.info("Chunks used for generation:") + for idx, chunk_text in enumerate(chunked_text): + logger.info(f"Chunk {idx}:") + logger.info(chunk_text) + logger.info("-----") + + concat_wv, sr, text_output = model_client.generate( + messages=messages, + audio_ids=audio_ids, + chunked_text=chunked_text, + generation_chunk_buffer_size=generation_chunk_buffer_size, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ras_win_len=ras_win_len, + ras_win_max_num_repeat=ras_win_max_num_repeat, + seed=seed, + ) + + sf.write(out_path, concat_wv, sr) + logger.info(f"Wav file is saved to '{out_path}' with sample rate {sr}") + + +if __name__ == "__main__": + main() diff --git a/examples/scene_prompts/quiet_indoor.txt b/examples/scene_prompts/quiet_indoor.txt new file mode 100644 index 0000000000000000000000000000000000000000..cc1caec23509492c31b2d26f963d263d57471184 --- /dev/null +++ b/examples/scene_prompts/quiet_indoor.txt @@ -0,0 +1 @@ +Audio is recorded from a quiet room. diff --git a/examples/scene_prompts/reading_blog.txt b/examples/scene_prompts/reading_blog.txt new file mode 100644 index 0000000000000000000000000000000000000000..bb38f87a939d7adfe0b154a242af45b653b0add2 --- /dev/null +++ b/examples/scene_prompts/reading_blog.txt @@ -0,0 +1 @@ +In this audio, the person is reading a blog post aloud. The content is informative and engaging, with the speaker using a clear, conversational tone to make the material feel more approachable. The pacing is moderate, allowing listeners to absorb the information, and the tone shifts slightly to emphasize key points. The speaker occasionally pauses for effect, ensuring each section flows smoothly, as they guide the listener through the post's main ideas. diff --git a/examples/serve_engine/README.md b/examples/serve_engine/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f45e8b54f803b59232b051848098beb5c7a02a9c --- /dev/null +++ b/examples/serve_engine/README.md @@ -0,0 +1,25 @@ +# Examples to use HiggsAudioServeEngine + +The `run_hf_example.py` script provides three different examples for using the `HiggsAudioServeEngine`. +Each example will generate an audio file (`output_{example}.wav`) in the current directory. + +### Zero-Shot Voice Generation +Generate audio with specific voice characteristics (e.g., accents). + +```bash +python run_hf_example.py zero_shot +``` + +### Voice Cloning +Clone a voice from a reference audio sample. + +```bash +python run_hf_example.py voice_clone +``` + +### (Experimental) Interleaved Dialogue Generation +Higgs Audio v2 is also able to generate text. Here's an example that shows it is able to generate multi-speaker conversations with interleaved transcript and audio from scene descriptions. + +```bash +python run_hf_example.py interleaved_dialogue +``` diff --git a/examples/serve_engine/input_samples.py b/examples/serve_engine/input_samples.py new file mode 100644 index 0000000000000000000000000000000000000000..344dd5b78388267ca09c423f911f88a2f488c211 --- /dev/null +++ b/examples/serve_engine/input_samples.py @@ -0,0 +1,87 @@ +import base64 +import os +from boson_multimodal.data_types import ChatMLSample, Message, AudioContent + + +def encode_base64_content_from_file(file_path: str) -> str: + """Encode a content from a local file to base64 format.""" + # Read the audio file as binary and encode it directly to Base64 + with open(file_path, "rb") as audio_file: + audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8") + return audio_base64 + + +def get_interleaved_dialogue_input_sample(): + system_prompt = ( + "Generate audio following instruction.\n\n" + "<|scene_desc_start|>\n" + "SPEAKER0: vocal fry;moderate pitch;monotone;masculine;young adult;slightly fast\n" + "SPEAKER1: masculine;moderate;moderate pitch;monotone;mature\n\n" + "In this scene, a group of adventurers is debating whether to investigate a potentially dangerous situation.\n" + "<|scene_desc_end|>" + ) + + messages = [ + Message( + role="system", + content=system_prompt, + ), + Message( + role="user", + content="<|generation_instruction_start|>\nGenerate interleaved transcript and audio that lasts for around 20 seconds.\n<|generation_instruction_end|>", + ), + ] + chat_ml_sample = ChatMLSample(messages=messages) + return chat_ml_sample + + +def get_zero_shot_input_sample(): + system_prompt = ( + "Generate audio following instruction.\n\n<|scene_desc_start|>\nSPEAKER0: british accent\n<|scene_desc_end|>" + ) + + messages = [ + Message( + role="system", + content=system_prompt, + ), + Message( + role="user", + content="Hey, everyone! Welcome back to Tech Talk Tuesdays.\n" + "It's your host, Alex, and today, we're diving into a topic that's become absolutely crucial in the tech world — deep learning.\n" + "And let's be honest, if you've been even remotely connected to tech, AI, or machine learning lately, you know that deep learning is everywhere.", + ), + ] + chat_ml_sample = ChatMLSample(messages=messages) + return chat_ml_sample + + +def get_voice_clone_input_sample(): + reference_text = "I would imagine so. A wand with a dragon heartstring core is capable of dazzling magic." + reference_audio = encode_base64_content_from_file( + os.path.join(os.path.dirname(__file__), "voice_examples/old_man.wav") + ) + messages = [ + Message( + role="user", + content=reference_text, + ), + Message( + role="assistant", + content=AudioContent(raw_audio=reference_audio, audio_url="placeholder"), + ), + Message( + role="user", + content="Hey, everyone! Welcome back to Tech Talk Tuesdays.\n" + "It's your host, Alex, and today, we're diving into a topic that's become absolutely crucial in the tech world — deep learning.\n" + "And let's be honest, if you've been even remotely connected to tech, AI, or machine learning lately, you know that deep learning is everywhere.", + ), + ] + return ChatMLSample(messages=messages) + + +INPUT_SAMPLES = { + "interleaved_dialogue": get_interleaved_dialogue_input_sample, + "zero_shot": get_zero_shot_input_sample, + "voice_clone": get_voice_clone_input_sample, +} diff --git a/examples/serve_engine/run_hf_example.py b/examples/serve_engine/run_hf_example.py new file mode 100644 index 0000000000000000000000000000000000000000..72ebd868794396c49568d2231945379d91a401d2 --- /dev/null +++ b/examples/serve_engine/run_hf_example.py @@ -0,0 +1,48 @@ +"""Example for using HiggsAudio for generating both the transcript and audio in an interleaved manner.""" + +from boson_multimodal.serve.serve_engine import HiggsAudioServeEngine, HiggsAudioResponse +import torch +import torchaudio +import time +from loguru import logger +import click + +from input_samples import INPUT_SAMPLES + +MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base" +AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer" + + +@click.command() +@click.argument("example", type=click.Choice(list(INPUT_SAMPLES.keys()))) +def main(example: str): + input_sample = INPUT_SAMPLES[example]() + device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"Using device: {device}") + + serve_engine = HiggsAudioServeEngine( + MODEL_PATH, + AUDIO_TOKENIZER_PATH, + device=device, + ) + + logger.info("Starting generation...") + start_time = time.time() + output: HiggsAudioResponse = serve_engine.generate( + chat_ml_sample=input_sample, + max_new_tokens=1024, + temperature=1.0, + top_p=0.95, + top_k=50, + stop_strings=["<|end_of_text|>", "<|eot_id|>"], + ) + elapsed_time = time.time() - start_time + logger.info(f"Generation time: {elapsed_time:.2f} seconds") + + torchaudio.save(f"output_{example}.wav", torch.from_numpy(output.audio)[None, :], output.sampling_rate) + logger.info(f"Generated text:\n{output.generated_text}") + logger.info(f"Saved audio to output_{example}.wav") + + +if __name__ == "__main__": + main() diff --git a/examples/serve_engine/voice_examples/old_man.wav b/examples/serve_engine/voice_examples/old_man.wav new file mode 100644 index 0000000000000000000000000000000000000000..0ab490768699e4ac6b337efea4a8f5a3d64bfbab --- /dev/null +++ b/examples/serve_engine/voice_examples/old_man.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:83bda9cd63be92366ef40dbe15c33e67b78766fb7069609f10dfc05cc626deba +size 1246508 diff --git a/examples/transcript/multi_speaker/en_argument.txt b/examples/transcript/multi_speaker/en_argument.txt new file mode 100644 index 0000000000000000000000000000000000000000..016cd77d0e0947e7b2125c8c5eac818289423144 --- /dev/null +++ b/examples/transcript/multi_speaker/en_argument.txt @@ -0,0 +1,4 @@ +[SPEAKER0] I can't believe you did that without even asking me first! +[SPEAKER1] Oh, come on! It wasn't a big deal, and I knew you would overreact like this. +[SPEAKER0] Overreact? You made a decision that affects both of us without even considering my opinion! +[SPEAKER1] Because I didn't have time to sit around waiting for you to make up your mind! Someone had to act. diff --git a/examples/transcript/multi_speaker/en_higgs.txt b/examples/transcript/multi_speaker/en_higgs.txt new file mode 100644 index 0000000000000000000000000000000000000000..6ff82cde60198041e89838e4d8f981b8dfe64e9e --- /dev/null +++ b/examples/transcript/multi_speaker/en_higgs.txt @@ -0,0 +1,6 @@ +[SPEAKER0] You're training HiggsAudio again? Aren't you tired of staring at it all day? +[SPEAKER1] Ha! This time, I'm trying to get it to generate multi-speaker dialogues. +[SPEAKER0] Oh, so you want it to sound like a real conversation with multiple people? That sounds… tricky. +[SPEAKER1] It is. The biggest challenge is making sure it understands who's speaking and when. We need a solid dataset with real conversations, including interruptions and natural flow. +[SPEAKER0] Right, because real conversations aren't just people taking turns like robots. There are overlaps, hesitations, and sudden topic changes. +[SPEAKER1] Exactly! That's why we need speaker diarization — so the model knows when one speaker stops and another starts, even if they overlap. diff --git a/examples/transcript/single_speaker/en_basic.txt b/examples/transcript/single_speaker/en_basic.txt new file mode 100644 index 0000000000000000000000000000000000000000..ba1afad6bc776a1d364328653d99b2566cede73c --- /dev/null +++ b/examples/transcript/single_speaker/en_basic.txt @@ -0,0 +1 @@ +The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years. diff --git a/examples/transcript/single_speaker/en_dl.txt b/examples/transcript/single_speaker/en_dl.txt new file mode 100644 index 0000000000000000000000000000000000000000..54b1a2b7e2c92a75187737aef4d3b299bf899332 --- /dev/null +++ b/examples/transcript/single_speaker/en_dl.txt @@ -0,0 +1,10 @@ +Hey, everyone! Welcome back to Tech Talk Tuesdays. +It’s your host, Alex, and today, we’re diving into a topic that’s become absolutely crucial in the tech world — deep learning. +And let’s be honest, if you’ve been even remotely connected to tech, AI, or machine learning lately, you know that deep learning is everywhere. + +So here’s the big question: Do you want to understand how deep learning works? +How to use it to build powerful models that can predict, automate, and transform industries? +Well, today, I’ve got some exciting news for you. + +We’re going to talk about a course that I highly recommend: Dive into Deep Learning. +It’s not just another course; it’s an entire experience that will take you from a beginner to someone who is well-versed in deep learning techniques. diff --git a/examples/transcript/single_speaker/en_higgs_audio_blog.md b/examples/transcript/single_speaker/en_higgs_audio_blog.md new file mode 100644 index 0000000000000000000000000000000000000000..d5e5e8ce768e759a18ce7c37fdaee93d673c37a8 --- /dev/null +++ b/examples/transcript/single_speaker/en_higgs_audio_blog.md @@ -0,0 +1,10 @@ +At Boson AI, we work on making communication with AI as easy, natural and fun as talking to a human. Today, we are excited to introduce Higgs Audio Understanding and Higgs Audio Generation — two powerful tools designed to build customized AI agents tailored for diverse audio understanding and generation needs. + +# Higgs Audio Generation +To communicate with humans in a delightful and natural manner, we need to be able to generate realistic, emotionally competent and well-accentuated speech. We need a system that is capable of pronouncing words correctly, even if they derive from a foreign language, particularly for people’s names and places. We need a system that can generate conversations between multiple speakers, particularly when multiple characters in games are involved, or when reading books or screenplays. + +Pure TTS (text to speech) systems struggle at these tasks, since they typically do not understand the meaning of what they’re generating, or any sense of urgency, hesitation, or other intonations that would be plainly obvious to a human speaker. They also struggle to adopt the natural character of a speaker, for example, whether they’re naturally enthusiastic or more deliberate and thoughtful. + +The way to address this problem is to build a TTS system using a Large Language Model (LLM) as a backbone. This endows the TTS system with the understanding needed to generate competent speech. Higgs Audio Generation enhances the underlying LLM to process audio by treating raw audio as tokens. This approach enables the model to be trained end-to-end on extensive text-audio datasets. + +The base model we are introducing today demonstrates impressive performance on benchmark tests. Additionally, it showcases emerging capabilities, including generating speech with emotional tone based on text semantics and producing multi-speaker dialogues from written transcripts, all due to the improved understanding. Before diving into technical details, let’s listen to two examples of audio generated by our model. diff --git a/examples/transcript/single_speaker/experimental/en_bgm.txt b/examples/transcript/single_speaker/experimental/en_bgm.txt new file mode 100644 index 0000000000000000000000000000000000000000..36dfcc098fb19f97079da3fca89f85baf61022af --- /dev/null +++ b/examples/transcript/single_speaker/experimental/en_bgm.txt @@ -0,0 +1 @@ +[music start] I will remember this, thought Ender, when I am defeated. To keep dignity, and give honor where it’s due, so that defeat is not disgrace. And I hope I don’t have to do it often. [music end] diff --git a/examples/transcript/single_speaker/experimental/en_humming.txt b/examples/transcript/single_speaker/experimental/en_humming.txt new file mode 100644 index 0000000000000000000000000000000000000000..97025e950804ca2ef8f21e18f2ce3b6532afb962 --- /dev/null +++ b/examples/transcript/single_speaker/experimental/en_humming.txt @@ -0,0 +1 @@ +Are you asking if I can hum a tune? Of course I can! [humming start] la la la la la [humming end] See? diff --git a/examples/transcript/single_speaker/zh_ai.txt b/examples/transcript/single_speaker/zh_ai.txt new file mode 100644 index 0000000000000000000000000000000000000000..04756fab3259b3be610c79034e9d74093a89c72d --- /dev/null +++ b/examples/transcript/single_speaker/zh_ai.txt @@ -0,0 +1,4 @@ +大家好,欢迎收听本期的跟李沐学AI。今天沐哥在忙着洗数据,所以由我,希格斯主播代替他讲这期视频。 +今天我们要聊的是一个你绝对不能忽视的话题"多模态学习"。 +无论你是开发者,数据科学爱好者,还是只是对人工智能感兴趣的人都一定听说过这个词。它已经成为AI时代的一个研究热点。 +那么,问题来了,你真的了解多模态吗 你知道如何自己动手构建多模态大模型吗。 \ No newline at end of file diff --git a/examples/vllm/README.md b/examples/vllm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..192aa07fb0f098ce67e379ca33cf2e878f6fb171 --- /dev/null +++ b/examples/vllm/README.md @@ -0,0 +1,73 @@ +# Serve Higgs Audio with vLLM + +We provided both OpenAI compatible chat completion and audio speech server backed by vLLM engine. To start the server, you can use the following command + +```bash +docker run --gpus all --ipc=host --shm-size=20gb --network=host \ +bosonai/higgs-audio-vllm:latest \ +--served-model-name "higgs-audio-v2-generation-3B-base" \ +--model "bosonai/higgs-audio-v2-generation-3B-base" \ +--audio-tokenizer-type "bosonai/higgs-audio-v2-tokenizer" \ +--limit-mm-per-prompt audio=50 \ +--max-model-len 8192 \ +--port 8000 \ +--gpu-memory-utilization 0.8 \ +--disable-mm-preprocessor-cache +``` + +In audio speech API, we provided the same voices as the [voice_prompts](../voice_prompts) folder. In addition, if you want to use your custom voices, you can add the voice presets in the docker run command + +```bash +--voice-presets-dir YOUR_VOICE_PRESETS_PATH +``` + +And in the voice presets directory, you need to add `config.json` file for each voice in the following format: +```json +{ + "belinda": { + "transcript": "Twas the night before my birthday. Hooray! It's almost here! It may not be a holiday, but it's the best day of the year.", + "audio_file": "belinda.wav" + }, + "broom_salesman": { + "transcript": "I would imagine so. A wand with a dragon heartstring core is capable of dazzling magic. And the bond between you and your wand should only grow stronger. Do not be surprised at your new wand's ability to perceive your intentions - particularly in a moment of need.", + "audio_file": "broom_salesman.wav" + } +} +``` + +We tested on A100 GPU with 40GB memory, which can achieve about 1500 tokens/s throughput for audio generation, which translate to 60 seconds audio generation per second with higgs-audio-tokenizer. +We also tested on RTX 4090 GPU with 24GB memory, which can achieve about 600 tokens/s throughput for audio generation, which translate to 24 seconds audio generation per second. + +### cURL Example +To quickly test the server with curl, you can use the following command to generate audio with the audio speech API. + +```bash +curl -X POST "http://localhost:8000/v1/audio/speech" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "higgs-audio-v2-generation-3B-base", + "voice": "en_woman", + "input": "Today is a wonderful day to build something people love!", + "response_format": "pcm" + }' \ + --output - | ffmpeg -f s16le -ar 24000 -ac 1 -i - speech.wav +``` + + +### Python example +You can also use the python client code to achieve more complex use cases with the chat completion API. + +Voice clone +```bash +python run_chat_completion.py --api-base http://localhost:8000/v1 --task voice_clone +``` + +Smart voice +```bash +python run_chat_completion.py --api-base http://localhost:8000/v1 --task smart_voice +``` + +Multispeaker +```bash +python run_chat_completion.py --api-base http://localhost:8000/v1 --task multispeaker +``` diff --git a/examples/vllm/run_chat_completion.py b/examples/vllm/run_chat_completion.py new file mode 100644 index 0000000000000000000000000000000000000000..916668c25ec099c6161d6a7ef93f02af8153fd53 --- /dev/null +++ b/examples/vllm/run_chat_completion.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 +"""An example showing how to use vLLM to serve multimodal models +and run online inference with OpenAI client. +""" + +import argparse +import base64 +import os +import time +from io import BytesIO + +import numpy as np +import requests +import soundfile as sf +from openai import OpenAI + +OPENAI_AUDIO_SAMPLE_RATE = 24000 +DEFAULT_SYSTEM_PROMPT = ( + "Generate audio following instruction.\n\n" + "<|scene_desc_start|>\n" + "Audio is recorded from a quiet room.\n" + "<|scene_desc_end|>" +) + + +def encode_base64_content_from_file(file_path: str) -> str: + """Encode a content from a local file to base64 format.""" + # Read the MP3 file as binary and encode it directly to Base64 + with open(file_path, "rb") as audio_file: + audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8") + return audio_base64 + + +def run_smart_voice() -> None: + chat_completion = client.chat.completions.create( + messages=[ + {"role": "system", "content": DEFAULT_SYSTEM_PROMPT}, + { + "role": "user", + "content": ( + "The sun rises in the east and sets in the west. This simple fact has been observed by humans for thousands of years." + ), + }, + ], + model=model, + modalities=["text", "audio"], + audio={"format": "wav"}, + ) + + text = chat_completion.choices[0].message.content + audio = chat_completion.choices[0].message.audio.data + # Decode base64 audio string to bytes + audio_bytes = base64.b64decode(audio) + print("Chat completion text output:", text) + print("Saving the audio to file") + with open("output_smart_voice.wav", "wb") as f: + f.write(audio_bytes) + + +def run_voice_clone(stream: bool = False) -> None: + data_dir = os.path.join(os.path.dirname(__file__), "..", "voice_prompts") + audio_path = os.path.join(data_dir, "belinda.wav") + audio_text_path = os.path.join(data_dir, "belinda.txt") + with open(audio_text_path, "r") as f: + audio_text = f.read() + audio_base64 = encode_base64_content_from_file(audio_path) + messages = [ + {"role": "user", "content": audio_text}, + { + "role": "assistant", + "content": [ + { + "type": "input_audio", + "input_audio": { + "data": audio_base64, + "format": "wav", + }, + } + ], + }, + { + "role": "user", + "content": ( + "Hey there! I'm your friendly voice twin in the making. Pick a voice preset below or upload your own audio - let's clone some vocals and bring your voice to life!" + ), + }, + ] + start_time = time.time() + chat_completion = client.chat.completions.create( + messages=messages, + model=model, + max_completion_tokens=500, + stream=stream, + modalities=["text", "audio"], + temperature=1.0, + top_p=0.95, + extra_body={"top_k": 50}, + stop=["<|eot_id|>", "<|end_of_text|>", "<|audio_eos|>"], + ) + if stream: + audio_bytes_io = BytesIO() + i = 0 + first_audio_latency = None + for chunk in chat_completion: + if chunk.choices and hasattr(chunk.choices[0].delta, "audio") and chunk.choices[0].delta.audio: + if first_audio_latency is None: + first_audio_latency = time.time() - start_time + audio_bytes = base64.b64decode(chunk.choices[0].delta.audio["data"]) + audio_bytes_io.write(audio_bytes) + audio_data = np.frombuffer(audio_bytes, dtype=np.int16) + i += 1 + audio_bytes_io.seek(0) + audio_data = np.frombuffer(audio_bytes_io.getvalue(), dtype=np.int16) + print("Saving the audio to file") + print(f"First audio latency: {first_audio_latency * 1000} ms") + print(f"Total audio latency: {(time.time() - start_time) * 1000} ms") + sf.write("output_voice_clone.wav", audio_data, OPENAI_AUDIO_SAMPLE_RATE) + else: + text = chat_completion.choices[0].message.content + audio = chat_completion.choices[0].message.audio.data + audio_bytes = base64.b64decode(audio) + print("Chat completion text output:", text) + print("Saving the audio to file") + with open("output_voice_clone.wav", "wb") as f: + f.write(audio_bytes) + + +def run_generate_multispeaker(stream: bool = False) -> None: + MULTI_SPEAKER_SYSTEM_PROMPT = ( + "You are an AI assistant designed to convert text into speech.\n" + "If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.\n" + "If no speaker tag is present, select a suitable voice on your own.\n\n" + "<|scene_desc_start|>\n" + "SPEAKER0: feminine\n" + "SPEAKER1: masculine\n" + "<|scene_desc_end|>" + ) + transcript_path = os.path.join(os.path.dirname(__file__), "..", "transcript", "multi_speaker", "en_argument.txt") + with open(transcript_path, "r") as f: + transcript = f.read() + + messages = [{"role": "system", "content": MULTI_SPEAKER_SYSTEM_PROMPT}, {"role": "user", "content": transcript}] + chat_completion = client.chat.completions.create( + messages=messages, + model=model, + stream=stream, + stream_options={"include_usage": True}, + stop=["<|end_of_text|>", "<|eot_id|>", "<|audio_eos|>"], + modalities=["text", "audio"], + temperature=1.0, + top_p=0.95, + extra_body={"top_k": 50}, + ) + + if stream: + audio_bytes_io = BytesIO() + i = 0 + for chunk in chat_completion: + if chunk.choices and hasattr(chunk.choices[0].delta, "audio") and chunk.choices[0].delta.audio: + audio_bytes = base64.b64decode(chunk.choices[0].delta.audio["data"]) + audio_bytes_io.write(audio_bytes) + audio_data = np.frombuffer(audio_bytes, dtype=np.int16) + # sf.write(f"output_tts_{i}.wav", audio_data, target_rate) + i += 1 + else: + print(chunk) + audio_bytes_io.seek(0) + audio_data = np.frombuffer(audio_bytes_io.getvalue(), dtype=np.int16) + print("Saving the audio to file") + sf.write("output_multispeaker.wav", audio_data, OPENAI_AUDIO_SAMPLE_RATE) + else: + text = chat_completion.choices[0].message.content + audio = chat_completion.choices[0].message.audio.data + audio_bytes = base64.b64decode(audio) + print("Chat completion text output:", text) + print("Saving the audio to file") + with open("output_multispeaker.wav", "wb") as f: + f.write(audio_bytes) + + +def main(args) -> None: + if args.task == "voice_clone": + run_voice_clone(args.stream) + elif args.task == "smart_voice": + run_smart_voice() + elif args.task == "multispeaker": + run_generate_multispeaker(args.stream) + else: + raise ValueError(f"Task {args.task} not supported") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--api-base", + type=str, + default="http://localhost:8000/v1", + help="API base URL for OpenAI client.", + ) + parser.add_argument("--api-key", type=str, default="EMPTY", help="API key for OpenAI client.") + parser.add_argument("--stream", action="store_true", help="Stream the audio.") + parser.add_argument( + "--task", + type=str, + default="voice_clone", + help="Task to run.", + choices=["voice_clone", "smart_voice", "multispeaker"], + ) + parser.add_argument("--model", type=str, default=None, help="Model to use.") + args = parser.parse_args() + + client = OpenAI( + api_key=args.api_key, + base_url=args.api_base, + ) + + if args.model is None: + models = client.models.list() + model = models.data[0].id + else: + model = args.model + + main(args) diff --git a/examples/voice_prompts/belinda.txt b/examples/voice_prompts/belinda.txt new file mode 100644 index 0000000000000000000000000000000000000000..6faca438399e0e61c7e8aaddf5604241c2661210 --- /dev/null +++ b/examples/voice_prompts/belinda.txt @@ -0,0 +1 @@ +Twas the night before my birthday. Hooray! It's almost here! It may not be a holiday, but it's the best day of the year. diff --git a/examples/voice_prompts/belinda.wav b/examples/voice_prompts/belinda.wav new file mode 100644 index 0000000000000000000000000000000000000000..3bed40ff62455023813c22f3332122921a8e49c5 --- /dev/null +++ b/examples/voice_prompts/belinda.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e663310bfe539efac3350fd6b277214dcddd65d5a46949180f11c719c8b9b769 +size 896776 diff --git a/examples/voice_prompts/bigbang_amy.txt b/examples/voice_prompts/bigbang_amy.txt new file mode 100644 index 0000000000000000000000000000000000000000..f08d0723146e40e93e1f2ff460dc57249c0d09bc --- /dev/null +++ b/examples/voice_prompts/bigbang_amy.txt @@ -0,0 +1 @@ +If that was slang, I'm unfamiliar with it. [Laughter] If it was literal, I share your aversion to soiled hosiery. [Laughter] In any case, I'm here because my mother and I have agreed that I will date at least once a year.''' diff --git a/examples/voice_prompts/bigbang_amy.wav b/examples/voice_prompts/bigbang_amy.wav new file mode 100644 index 0000000000000000000000000000000000000000..fdd050dc0f51b15778043bb666dda00941f79801 --- /dev/null +++ b/examples/voice_prompts/bigbang_amy.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3c638d14a6f626cafb443c51aa0275b96a510829a477fce1f8d0cb02d9717ec +size 1177260 diff --git a/examples/voice_prompts/bigbang_sheldon.txt b/examples/voice_prompts/bigbang_sheldon.txt new file mode 100644 index 0000000000000000000000000000000000000000..b0fdbfd40394dfd707a5f206ccaad5d1b7a67b66 --- /dev/null +++ b/examples/voice_prompts/bigbang_sheldon.txt @@ -0,0 +1 @@ +Hello, Amy Farrah Fowler. I'm sorry to inform you that you have been taken in by unsupportable mathematics designed to prey on the gullible and the lonely. Additionally, I'm being blackmailed with a hidden dirty sock. [Laughter] diff --git a/examples/voice_prompts/bigbang_sheldon.wav b/examples/voice_prompts/bigbang_sheldon.wav new file mode 100644 index 0000000000000000000000000000000000000000..7a81a774a853a718f0182e42da7ea2cd5bf91358 --- /dev/null +++ b/examples/voice_prompts/bigbang_sheldon.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e713f5f72e2f6e0295244626a095bbf347598d610dbd485ad2211de71ff0bd7 +size 1276844 diff --git a/examples/voice_prompts/broom_salesman.txt b/examples/voice_prompts/broom_salesman.txt new file mode 100644 index 0000000000000000000000000000000000000000..1eba026f01b0836917755fbea17668a9d4e6c708 --- /dev/null +++ b/examples/voice_prompts/broom_salesman.txt @@ -0,0 +1 @@ +I would imagine so. A wand with a dragon heartstring core is capable of dazzling magic. And the bond between you and your wand should only grow stronger. Do not be surprised at your new wand's ability to perceive your intentions - particularly in a moment of need. diff --git a/examples/voice_prompts/broom_salesman.wav b/examples/voice_prompts/broom_salesman.wav new file mode 100644 index 0000000000000000000000000000000000000000..845ec38a4bf2bf1a7ae722d7ef78d19cb0948336 --- /dev/null +++ b/examples/voice_prompts/broom_salesman.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c9cb4f37dcac12227045845c07c8aef823519cbf7b62bcbc6223158f9d282e1a +size 3383338 diff --git a/examples/voice_prompts/chadwick.txt b/examples/voice_prompts/chadwick.txt new file mode 100644 index 0000000000000000000000000000000000000000..27eb9100852af802e996fd3cc2eb72b35b67bcb9 --- /dev/null +++ b/examples/voice_prompts/chadwick.txt @@ -0,0 +1 @@ +Oh dear, who left all this junk lying around? Whoops, there it goes! Mind your pointed little pink head, starfish man. diff --git a/examples/voice_prompts/chadwick.wav b/examples/voice_prompts/chadwick.wav new file mode 100644 index 0000000000000000000000000000000000000000..82c36273f54a1484cfc7dec6325fafb9c903bc18 --- /dev/null +++ b/examples/voice_prompts/chadwick.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:166acd9a8d8bf3e205bf8217dfd47f8232437c0ea128c326bd1a9060c099e003 +size 458796 diff --git a/examples/voice_prompts/en_man.txt b/examples/voice_prompts/en_man.txt new file mode 100644 index 0000000000000000000000000000000000000000..fa1838bd257e3914d3048d908b64700895166642 --- /dev/null +++ b/examples/voice_prompts/en_man.txt @@ -0,0 +1 @@ +Maintaining your ability to learn translates into increased marketability, improved career options and higher salaries. diff --git a/examples/voice_prompts/en_man.wav b/examples/voice_prompts/en_man.wav new file mode 100644 index 0000000000000000000000000000000000000000..ea2eea69681d3a285ea6970a4d7ac31d252c609d --- /dev/null +++ b/examples/voice_prompts/en_man.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ca3df71ad1b6968765e69870220d34c6b2c2550a499cf59560d9d764d10b94e +size 375566 diff --git a/examples/voice_prompts/en_woman.txt b/examples/voice_prompts/en_woman.txt new file mode 100644 index 0000000000000000000000000000000000000000..3fc8671efdc4f5642d211dade2cb3d93fcd04e7b --- /dev/null +++ b/examples/voice_prompts/en_woman.txt @@ -0,0 +1 @@ +The device would work during the day as well, if you took steps to either block direct sunlight or point it away from the sun. diff --git a/examples/voice_prompts/en_woman.wav b/examples/voice_prompts/en_woman.wav new file mode 100644 index 0000000000000000000000000000000000000000..047f016d6f1126e51dee08c4b591cf3f75b98e3e --- /dev/null +++ b/examples/voice_prompts/en_woman.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e1d49dc69f3b0731ed7b10ddf51dfc8f73465d4323f45841d93583d8b1e4d3e6 +size 313272 diff --git a/examples/voice_prompts/fiftyshades_anna.txt b/examples/voice_prompts/fiftyshades_anna.txt new file mode 100644 index 0000000000000000000000000000000000000000..4a2bc08d55b3f10c9143fa7d2f72b9552e32659d --- /dev/null +++ b/examples/voice_prompts/fiftyshades_anna.txt @@ -0,0 +1 @@ +I'm working at the hardware store till 7. I think I'd like that too. What? diff --git a/examples/voice_prompts/fiftyshades_anna.wav b/examples/voice_prompts/fiftyshades_anna.wav new file mode 100644 index 0000000000000000000000000000000000000000..a035339117d87eecfd034a76042954c272285dcb --- /dev/null +++ b/examples/voice_prompts/fiftyshades_anna.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3354d6c28724e6e8ad6692e05481aa850d121073c5f3d2e5ae7668bd0ed3243d +size 599804 diff --git a/examples/voice_prompts/mabaoguo.txt b/examples/voice_prompts/mabaoguo.txt new file mode 100644 index 0000000000000000000000000000000000000000..f24c5aa41f02d7882a756ca17947a105b67f2499 --- /dev/null +++ b/examples/voice_prompts/mabaoguo.txt @@ -0,0 +1 @@ +我是浑元形意太极门掌门人马保国,刚才有个朋友问我:马老师发生什么事啦.我说怎么回事,给我发了几张截图,我一看,哦,原来是昨天,有两个年轻人,三十多岁,一个体重九十多公斤,一个体重八十多公斤.他们说,哎,有一个说是:我在健身房练功,颈椎练坏了,马老师你能不能教教我浑元功法 diff --git a/examples/voice_prompts/mabaoguo.wav b/examples/voice_prompts/mabaoguo.wav new file mode 100644 index 0000000000000000000000000000000000000000..5593ecc5d0dd99d5c196e331102e0aa959166640 --- /dev/null +++ b/examples/voice_prompts/mabaoguo.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36dad5fe5a7017b1020997170ecc7345371f499c6a780ec419a5b76d794c1f3c +size 2443184 diff --git a/examples/voice_prompts/mabel.txt b/examples/voice_prompts/mabel.txt new file mode 100644 index 0000000000000000000000000000000000000000..0faaf40430a77ce24d1b19811dc76a7bcd1be090 --- /dev/null +++ b/examples/voice_prompts/mabel.txt @@ -0,0 +1 @@ +You do talk an awful lot about weather, did you know that? Sometimes I wonder if you're actually content to be a wizard or if you're secretly harbouring a desire to become a seer of the clouds. diff --git a/examples/voice_prompts/mabel.wav b/examples/voice_prompts/mabel.wav new file mode 100644 index 0000000000000000000000000000000000000000..2dc3a8fa306e07d5ce0c90dfee0141b9de31bb81 --- /dev/null +++ b/examples/voice_prompts/mabel.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e6c5e522c662c5d6b862d8b17e1618546666ce993dcd560f3bdd34a48bacd9f +size 1054730 diff --git a/examples/voice_prompts/profile.yaml b/examples/voice_prompts/profile.yaml new file mode 100644 index 0000000000000000000000000000000000000000..faca3f34cd5cd13e79b7eb5eecfc8dbe32e23dbf --- /dev/null +++ b/examples/voice_prompts/profile.yaml @@ -0,0 +1,5 @@ +profiles: + male_en: Male, American accent, modern speaking rate, moderate-pitch, friendly tone, and very clear audio. + female_en_story: She speaks with a calm, gentle, and informative tone at a measured pace, with excellent articulation and very clear audio. She naturally brings storytelling to life with an articulate, genuine, and personable vocal style. + male_en_british: He speaks with a clear British accent and a conversational, inquisitive tone. His delivery is articulate and at a moderate pace, and very clear audio. + female_en_british: A female voice with a clear British accent speaking at a modern rate with a moderate-pitch in an expressive and friendly tone and very clear audio. diff --git a/examples/voice_prompts/shrek_donkey.txt b/examples/voice_prompts/shrek_donkey.txt new file mode 100644 index 0000000000000000000000000000000000000000..a5426b7a9af6dfd113e021e1c8e18f8621e007a8 --- /dev/null +++ b/examples/voice_prompts/shrek_donkey.txt @@ -0,0 +1,2 @@ +And I've got a great idea, I'll stick with you. You're a mean green fighting machine, together we'll scare the spit out of anybody that crosses us. +Oh, Wow, that was really scary. And if you don't mind me saying, if that don't work, your breath certainly will get the job done, 'cause you definitely need some Tic Tacs or something, 'cause your breath stinks! \ No newline at end of file diff --git a/examples/voice_prompts/shrek_donkey.wav b/examples/voice_prompts/shrek_donkey.wav new file mode 100644 index 0000000000000000000000000000000000000000..40f3c5602abd2a8e9e16252df8e0d667eef56e07 --- /dev/null +++ b/examples/voice_prompts/shrek_donkey.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3856fce1962a3ce35c4bdeb9f92ea280d5e01163529fccd097f3bbeb3c6ad1f4 +size 1552364 diff --git a/examples/voice_prompts/shrek_donkey_es.txt b/examples/voice_prompts/shrek_donkey_es.txt new file mode 100644 index 0000000000000000000000000000000000000000..449af46f6f17d4ca76d8cfe843f62c3a5e7a6304 --- /dev/null +++ b/examples/voice_prompts/shrek_donkey_es.txt @@ -0,0 +1 @@ +¡Uy, guau! Eso sí que asusta. Y si el rugido no funciona, tu mal aliento seguro los desmaya. Necesitas unas pastillitas de menta porque el hocico te apesta. \ No newline at end of file diff --git a/examples/voice_prompts/shrek_donkey_es.wav b/examples/voice_prompts/shrek_donkey_es.wav new file mode 100644 index 0000000000000000000000000000000000000000..4d6a77b468114570e3f83e1e55251c7123a68174 --- /dev/null +++ b/examples/voice_prompts/shrek_donkey_es.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:438daf675e9a8d2a618e40679c4d7260147f576c87682eca763191d90351fed8 +size 970244 diff --git a/examples/voice_prompts/shrek_fiona.txt b/examples/voice_prompts/shrek_fiona.txt new file mode 100644 index 0000000000000000000000000000000000000000..ee71d46acf8c0d8a622b9f25fa09976755ff55e3 --- /dev/null +++ b/examples/voice_prompts/shrek_fiona.txt @@ -0,0 +1,2 @@ +Well, when one lives alone, one has to learn these things in case there's a... There's an arrow in your butt! +Calm down. If you want to help Shrek, run into the woods and find me a blue flower with red thorns. \ No newline at end of file diff --git a/examples/voice_prompts/shrek_fiona.wav b/examples/voice_prompts/shrek_fiona.wav new file mode 100644 index 0000000000000000000000000000000000000000..31f7e5677505cc55c1b795a0116ff3cc6bcae9bc --- /dev/null +++ b/examples/voice_prompts/shrek_fiona.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26176d3244f6cd2aad44977a00251fa03163f6f6269a7eedeba4132f1320e8bb +size 1252484 diff --git a/examples/voice_prompts/shrek_shrek.txt b/examples/voice_prompts/shrek_shrek.txt new file mode 100644 index 0000000000000000000000000000000000000000..bca9f9d9f298687276346cc63d7916783993824c --- /dev/null +++ b/examples/voice_prompts/shrek_shrek.txt @@ -0,0 +1,2 @@ +Well, it's no wonder you don't have any friends. Listen, little donkey, take a look at me. What am I? +No! I'm an ogre! You know, with a torch and pitchfork. Doesn't that bother you? \ No newline at end of file diff --git a/examples/voice_prompts/shrek_shrek.wav b/examples/voice_prompts/shrek_shrek.wav new file mode 100644 index 0000000000000000000000000000000000000000..2029f0fc09c7563bd0c01e5e640e875748d73d95 --- /dev/null +++ b/examples/voice_prompts/shrek_shrek.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f59b81a33f72dd1eeb61ffdc8fa20183281fb4ea5fe4302142de1dd021c1e370 +size 1412126 diff --git a/examples/voice_prompts/vex.txt b/examples/voice_prompts/vex.txt new file mode 100644 index 0000000000000000000000000000000000000000..2d75fdbe2378a4e2adff60b2ace7f4adefed7598 --- /dev/null +++ b/examples/voice_prompts/vex.txt @@ -0,0 +1 @@ +Uhh, this is going to take forever. Why is everything so far? diff --git a/examples/voice_prompts/vex.wav b/examples/voice_prompts/vex.wav new file mode 100644 index 0000000000000000000000000000000000000000..0cff153cfb26ed66e536747513ccf2ddaac63006 --- /dev/null +++ b/examples/voice_prompts/vex.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d95c6dcf7265847edd76989ffb2d3f5a92aa3e2bbd3718317010b49842c98954 +size 523086 diff --git a/examples/voice_prompts/zh_man_sichuan.txt b/examples/voice_prompts/zh_man_sichuan.txt new file mode 100644 index 0000000000000000000000000000000000000000..df9185baef6186de89a28f4b73ec28771a5260ae --- /dev/null +++ b/examples/voice_prompts/zh_man_sichuan.txt @@ -0,0 +1 @@ +对,这就是我,万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。 \ No newline at end of file diff --git a/examples/voice_prompts/zh_man_sichuan.wav b/examples/voice_prompts/zh_man_sichuan.wav new file mode 100644 index 0000000000000000000000000000000000000000..67c3fe1cd5d24002567f3e6bacd8d7862f629249 --- /dev/null +++ b/examples/voice_prompts/zh_man_sichuan.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:53892ece071342958403bc5643f84169a30b89cc0fc79eb69508bfa11dd85e68 +size 618528 diff --git a/figures/dual_ffn_comparison_seed_tts_en_sim.png b/figures/dual_ffn_comparison_seed_tts_en_sim.png new file mode 100644 index 0000000000000000000000000000000000000000..13ab15ebfd6965067d14195c82e265857655ff58 Binary files /dev/null and b/figures/dual_ffn_comparison_seed_tts_en_sim.png differ diff --git a/figures/dual_ffn_comparison_seed_tts_en_wer.png b/figures/dual_ffn_comparison_seed_tts_en_wer.png new file mode 100644 index 0000000000000000000000000000000000000000..b066da1ad091277dbd048342a180482ff24745ad Binary files /dev/null and b/figures/dual_ffn_comparison_seed_tts_en_wer.png differ diff --git a/figures/dual_ffn_comparison_seed_tts_zh_sim.png b/figures/dual_ffn_comparison_seed_tts_zh_sim.png new file mode 100644 index 0000000000000000000000000000000000000000..bcb59f354bfa3ec40337e114a195d4161127043f Binary files /dev/null and b/figures/dual_ffn_comparison_seed_tts_zh_sim.png differ diff --git a/figures/dual_ffn_comparison_seed_tts_zh_wer.png b/figures/dual_ffn_comparison_seed_tts_zh_wer.png new file mode 100644 index 0000000000000000000000000000000000000000..618840d6161593398017e773075e70b85c8f7b14 Binary files /dev/null and b/figures/dual_ffn_comparison_seed_tts_zh_wer.png differ diff --git a/figures/emergent-tts-emotions-win-rate.png b/figures/emergent-tts-emotions-win-rate.png new file mode 100644 index 0000000000000000000000000000000000000000..46ed24950a2a125bbe827ba98217d5e0e8134a0b --- /dev/null +++ b/figures/emergent-tts-emotions-win-rate.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:63bc6a63f3e3217ff05b5e5e0adb8ce89cdbb9da086e74d0c469c6465e611221 +size 838024 diff --git a/figures/higgs_audio_tokenizer_architecture.png b/figures/higgs_audio_tokenizer_architecture.png new file mode 100644 index 0000000000000000000000000000000000000000..cdbc6f67fffa2c1b52dcee1eb387dfe28d3b1bd9 --- /dev/null +++ b/figures/higgs_audio_tokenizer_architecture.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f7c0177e38dd9c873acf8ac55c159ce65ba50970cbeba9663582da4698037447 +size 117208 diff --git a/figures/higgs_audio_v2_architecture_combined.png b/figures/higgs_audio_v2_architecture_combined.png new file mode 100644 index 0000000000000000000000000000000000000000..ca5721fa55fb2f141f77284ad5f8402ea3291b32 --- /dev/null +++ b/figures/higgs_audio_v2_architecture_combined.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6260cd2e98235c9e181316db9fd6f716fbca1e314ef367ff338b988dcb54a76c +size 438316 diff --git a/figures/higgs_audio_v2_open_source_delay_pattern.png b/figures/higgs_audio_v2_open_source_delay_pattern.png new file mode 100644 index 0000000000000000000000000000000000000000..8cdabdf1566578547b587cab1ce60cde7bb6fb2c Binary files /dev/null and b/figures/higgs_audio_v2_open_source_delay_pattern.png differ diff --git a/higgs-audio-v2-tokenizer/.gitattributes b/higgs-audio-v2-tokenizer/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..ed1d7cd8b843c6942c7d9a5c13dc9ccc51aa6a08 --- /dev/null +++ b/higgs-audio-v2-tokenizer/.gitattributes @@ -0,0 +1,37 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +model.pth filter=lfs diff=lfs merge=lfs -text +higgs_audio_tokenizer_architecture.png filter=lfs diff=lfs merge=lfs -text diff --git a/higgs-audio-v2-tokenizer/LICENSE b/higgs-audio-v2-tokenizer/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..3275de3797b6932a91d4a7b3db41494f933305e6 --- /dev/null +++ b/higgs-audio-v2-tokenizer/LICENSE @@ -0,0 +1,43 @@ +BOSON HIGGS AUDIO 2 COMMUNITY LICENSE AGREEMENT + +Boson Higgs Audio 2 Version Release Date: June 20, 2025 + +This License Agreement (the “Agreement”) is entered into by and between Licensee (as defined below) and Boson AI USA, Inc. (“Boson”) and is based upon the Meta Llama 3 Community License Agreement as of April 18, 2024 (the “Meta License Agreement”), which can be found at https://llama.meta.com/llama3/license/. The terms and conditions of the Meta License Agreement are hereby incorporated herein by reference and Unless stated otherwise below, its terms apply. The Higgs Audio 2 model developed by Boson AI USA, Inc. (“Higgs Materials”) is an audio model derived from Meta Llama 3 software and algorithms. + +“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Higgs Materials set forth herein and the Meta License Agreement. + +“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering into this Agreement on their behalf. + +“Higgs Audio 2” means the foundational large audio language models and software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing developed by Boson AI distributed at https://github.com/boson-ai/boson-multimodal or otherwise. +“Higgs Materials” means, collectively, Boson’s proprietary modification of Meta Llama 3 and Documentation (and any portion thereof) made available under this Agreement. + +“Boson” or “we” means Boson AI USA, Inc. + +By clicking “I Accept” below or by using or distributing any portion or element of the Higgs Materials, you agree to be bound by this Agreement. + +1. License Rights and Redistribution. +a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Boson’s intellectual property or other rights owned by Boson embodied in the Higgs Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Higgs Materials. +b. Redistribution and Use. +i. If you distribute or make available the Higgs Materials (or any derivative works thereof), or a product or service that uses any of them, including another AI model, you shall (A) provide a copy of this Agreement and the of Meta License ’s Llama 3 agreement with any such Higgs Materials; and (B) prominently display “Built with Higgs Materials licensed from Boson AI USA, Inc., Copyright Boson AI USA, Inc., All Rights Reserved and Meta Llama 3 licensed under the Meta Llama 3 Community License, Copyright Meta Platforms, Inc., All Right Reserved". based on Meta Llama 3” on a related website, user interface, blogpost, about page, or product documentation. If you use the Higgs Materials to create, modify, enhance, train, fine tune, or otherwise improve an AI model or similar software, which is distributed or made available, you shall also include “Higgs Audio 2” at the beginning of any such AI model or software name. +ii. Even if you receive Higgs Materials, or any modifications, enhancements or derivative works thereof, from a Licensee as part of an integrated end user product, then Section 2 of this Agreement will apply to you. +iii. You must retain in all copies of the Llama Materials that you distribute and as set forth above, include the following attribution notice within a “Notice” text file distributed as a part of such copies: +“Meta Llama 3 is licensed under the Meta Llama 3 Community License, Copyright © Meta Platforms, Inc. All Rights Reserved.” +“Boson Higgs Audio 2 is licensed under the Boson Community License, Copyright © Boson AI USA, Inc. All Rights Reserved.” +iv. Your use of the Higgs Materials must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Llama Materials (available at https://llama.meta.com/llama3/use-policy), which is hereby incorporated by reference into this Agreement. +v. You will not use the Higgs Materials or any output or results of the Higgs Materials to improve any other large language model (excluding Boson Higgs Audio 2 or derivative works thereof). +vi. You hereby acknowledge that Boson is the owner of the Higgs Materials and under no circumstance shall you bring any legal action, claim, charge, demand challenging such ownership rights of Boson. + +2. Additional Commercial Terms. If the annual active users of the products or services made available by or for Licensee, or Licensee’s affiliates, is greater than 100,000 annual active users in the preceding calendar year, you must request an expanded license from Boson AI, which Boson AI may grant to you in its sole discretion, and you are not authorized to exercise any of the rights under this Agreement unless or until Boson AI otherwise expressly grants you such rights. + +3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE Higgs Materials AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITH ALL FAULTS, WITHOUT WARRANTIES OF ANY KIND EXPRESS, IMPLIED, BASED UPON CUSTOM AND USAGE OR COURSE OF DEALING, AND BOSON AI DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE HIGGS MATERIALS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR USE OF THE HIGGS MATERIALS AND ANY OUTPUT AND RESULTS. + +4. Limitation of Liability. IN NO EVENT WILL BOSON AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF BOSON, META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. + +5. Intellectual Property. +a. No trademark licenses are granted under this Agreement, or in connection with the Higgs Materials., nNeither Boson nor Licensee may use any name or mark owned by, or associated with, the other party hereto or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Higgs Materials or as set forth in this Section 5(a). Boson hereby grants you a license to use “Higgs Audio 2” (the “Mark”) solely as required to comply with the last sentence of Section 1.b.i. All goodwill arising out of your use of the Mark will inure to the benefit of Meta and Boson AI. +b. Subject to Boson’s ownership of the Higgs Materials and derivatives made by or for Boson AI, with respect to any derivative works and modifications of the Higgs Materials that are made by you, as between you and Boson AI, you are and will be the owner of such derivative works and modifications. +c. If you institute litigation or other proceedings against Boson AI, Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Higgs Materials or Boson Higgs Audio 2 outputs or results, or any portion thereof any of the foregoing, constitutes infringement of the intellectual property or other rights owned or licensable by you, then any licenses granted to you hereunder this Agreement shall immediately terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Boson AI from and against any claim, charge, demand, cause of action by any third party arising out of or related to your use or distribution of the Higgs Materials. + +6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Higgs Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Boson AI may terminate this Agreement if you are in breach of any term or condition of this Agreement by providing you with written notice. Upon your receipt of written notice of termination of this Agreement, you shall delete the Higgs Materials from any computer, server or IT device and cease use of the Higgs Materials in all respects. Sections 1(b)(vi), 3, 4 and 7 shall survive the termination of this Agreement. + +7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The federal courts in the Northern District of California and the state courts in Santa Clara County, California shall have exclusive jurisdiction of any dispute arising out of this Agreement. diff --git a/higgs-audio-v2-tokenizer/README.md b/higgs-audio-v2-tokenizer/README.md new file mode 100644 index 0000000000000000000000000000000000000000..06fcbd4faee874463a3c94e0e2d9f7e57ebf8c14 --- /dev/null +++ b/higgs-audio-v2-tokenizer/README.md @@ -0,0 +1,185 @@ +--- +license: other +--- + +# Higgs Audio Tokenizer + +
+ + +
+ +Check our open-source repository https://github.com/boson-ai/higgs-audio for more details! + + +We introduce a new discretized audio tokenizer that runs at just **25 frames per second** while keeping—or even improving—audio quality compared to tokenizers with twice the bitrate. Our model is the first to train on **24 kHz data** covering speech, music, and sound events in one unified system. It also uses a simple non-diffusion encoder/decoder for fast, batch inference. + +

+ Architecture diagram of the Higgs Audio Tokenizer +

+ +## Basics of Audio Quantization + +An audio signal sampled at \\(f_s\\) Hz is first split into frames by an encoder with hop size \\(M\\), giving a frame rate \\(f_r = \frac{f_s}{M}\quad\text{(frames/s)}.\\) +Two common quantizers are: + +- **Residual Vector Quantization (RVQ)**: \\(N_q\\) cascaded vector‑quantizer layers, each with codebook size \\(N_{cb}\\). When \\(N_{q}=1\\), it degenerates to ordinary vector quantization. +- **Finite Scalar Quantization (FSQ)**: A single-layer scalar quantizer in which + every scalar coefficient is independently mapped to one of \\(N_{cb}\\) discrete levels. + + +If every combination of codewords is a token, the vocabulary size is \\(N_{cb}^{N_q}\\), and each token needs \\(N_q\log_2 N_{cb}\\) bits. The overall bitrate (bits/s, BPS) is simply \\(f_r \times N_q \log_2 N_{cb}\\). +We aim to push this bitrate as low as possible without hurting audio fidelity. + +## What Makes Ours Better + +- **Low Frame Rate**: Runs at just 25 fps, halving the frame rate of many baselines while preserving high audio quality. +- **Unified 24 kHz Training**: A single model jointly trained on speech, music, and sound‑event data, capturing both semantic and acoustic nuances and greatly simplifying downstream audio‑language‑model training. +- **Fast Inference**: A non‑diffusion encoder/decoder that processes batches quickly, making it practical for real-time or large-scale tasks. + + +## Evaluation Data and Metrics + +We test on four subsets: + +- **Speech, Music and Sound Event**: Include 1,000 clips per category, with each clip lasting 10 seconds. Clips are randomly sampled from [DAPS](https://ccrma.stanford.edu/~gautham/Site/daps.html) (Speech), [MUSDB](https://sigsep.github.io/datasets/musdb.html) (Music), and [AudioSet](https://research.google.com/audioset/index.html) (Sound Event). + +- **Audiophile**: Contains 150 clips, each 30 seconds long, curated from eleven high-fidelity test discs that were designed for perceptual listening tests. The clips feature both high-quality music and sound events. + +We measure: + +- **Acoustic Quality**: Acoustic reconstruction error between the original and reconstructed audio. +- **Semantic Integrity**: Degree of semantic preservation, evaluated on the English and Chinese subsets of [SeedTTS](https://arxiv.org/abs/2406.02430)[15]. +- **Aesthetics**: SOTA unified model-based quality metrics computed with [Meta Audiobox Aesthetics](https://github.com/facebookresearch/audiobox-aesthetics)[8]. + + +We compare our tokenizer with a wide range of baselines, from tokenizers mainly built for better acoustic reconstruction and compression rate, to those focused on semantic integrity, and to tokenizers used in existing large audio language models. We also compare with tokenizers that are pretrained specifically on speech or on music. + + +The tables below summarize the tokenizers evaluated. As shown, our tokenizer achieves a well-rounded balance of efficiency, semantic fidelity, and acoustic quality. + +### Acoustic Evaluation + +This table reports the Short‑Time Fourier Transform (STFT) distance between the original and reconstructed audio. Baselines are listed chronologically and grouped by whether semantic distillation (SD) is applied. Despite DAC’s top acoustic quality at 12× the bitrate, our tokenizer leads all other baselines. + + +| Tokenizer | 💬 | 🎵 | 🥁 | SD | \\(f_s\\) | \\(f_r\\) | BPS* (k) ↓ | Speech ↓ | Sound Event ↓ | Music ↓ | Audiophile ↓ | +|-----------|----|----|----|----|-------|-------|--------------------------|----------|----------------|--------|--------------| +| [Encodec](https://huggingface.co/facebook/encodec_24khz)[3] | ✓ | ✓ | ✓ | | 24 | 75 | 24 | 1.96 | 2.65 | 2.52 | 2.30 | +| [DAC](https://huggingface.co/hance-ai/descript-audio-codec-24khz)[2] | ✓ | ✓ | ✓ | | 24 | 75 | 24 | **1.13** | **1.45** | **1.34** | **1.62** | +| [SNAC-24k](https://huggingface.co/hubertsiuzdak/snac_24khz)[6] | ✓ | | | | 24 | (12, 23, 47) | 0.98 | 1.92 | 2.69 | 2.54 | 2.52 | +| [SNAC-44k](https://huggingface.co/hubertsiuzdak/snac_44khz)[6] | | ✓ | ✓ | | 44.1 | (14, 29, 57, 115) | 2.6 | 1.83 | 2.25 | 2.05 | 2.00 | +| [WavTokenizer](https://huggingface.co/novateur/WavTokenizer-medium-music-audio-75token/blob/main/wavtokenizer_medium_music_audio_320_24k_v2.ckpt)[7] | | ✓ | ✓ | | 24 | 75 | 0.9 | 1.93 | 2.44 | 2.17 | 2.15 | +| [WavTokenizer (Speech)](https://huggingface.co/novateur/WavTokenizer-large-speech-75token/tree/main)[7] | ✓ | | | | 24 | 75 | 0.9 | 1.78 | 2.47 | 2.42 | 2.47 | +| [MuCodec](https://huggingface.co/haoheliu/audioldm_48k/tree/main)[11] | | ✓ | | | 48 | 25 | 0.35 | 2.87 | 3.69 | 3.36 | 2.97 | +| [FlowDec-75m](https://github.com/facebookresearch/FlowDec?tab=readme-ov-file)[12] | ✓ | ✓ | ✓ | | 48 | 75 | 7.5 | 1.73 | 2.14 | 2.01 | 2.03 | +| [FlowDec-25s](https://github.com/facebookresearch/FlowDec?tab=readme-ov-file)[12] | ✓ | ✓ | ✓ | | 48 | 25 | 4 | 1.94 | 2.42 | 2.25 | 2.33 | +| [SpeechTokenizer](https://huggingface.co/fnlp/SpeechTokenizer/tree/main/speechtokenizer_hubert_avg)[14] | ✓ | | | ✓ | 16 | 50 | 4 | 3.21 | 3.58 | 3.65 | 3.69 | +| [SemantiCodec](https://huggingface.co/haoheliu/SemantiCodec/tree/main/semanticodec_tokenrate_100)[5] | ✓ | ✓ | ✓ | ✓ | 16 | 100 | 1.35 | 3.05 | 3.28 | 3.24 | 3.18 | +| [Mimi](https://huggingface.co/docs/transformers/en/model_doc/mimi)[13] | ✓ | | | ✓ | 24 | 12.5 | 4.4 | 1.77 | 2.40 | 2.30 | 2.15 | +| [XCodec](https://huggingface.co/ZhenYe234/xcodec/blob/main/config_hubert_general.yaml)[1] | ✓ | ✓ | ✓ | ✓ | 16 | 50 | 4 | 2.95 | 3.16 | 3.00 | 3.03 | +| [CosyVoice 2](https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B)[13] | ✓ | | | ✓ | 16 | 25 | -**| 2.30 | 3.30 | 3.14 | 3.25 | +| [XCodec2](https://huggingface.co/HKUST-Audio/xcodec2/blob/main/ckpt/epoch%3D4-step%3D1400000.ckpt)[9] | ✓ | | | ✓ | 16 | 50 | 0.8 | 3.06 | 3.72 | 3.62 | 3.64 | +| [XY](https://huggingface.co/fnlp/XY_Tokenizer_TTSD_V0/tree/main)[10] | ✓ | | | ✓ | 24 | 12.5 | 1 | 1.89 | 2.51 | 2.40 | 2.26 | +| Ours | ✓ | ✓ | ✓ | ✓ | 24 | 25 | 2 | **1.62** | **2.03** | **1.85** | **1.80** | + + + +* Bits-per-second is calculated according to the checkpoint the author provided. + +** CosyVoice 2 uses the continuous feature as the conditioning; we include it for completeness. + + +### Semantic Evaluation +[SeedTTS](https://github.com/BytedanceSpeech/seed-tts-eval) is a dataset that includes prompt/target audio and texts. We reconstruct the target audio, and use the word error rate (WER) and speaker similarity (SIM) metrics to evaluate the semantic integrity. SIM is calculated by the similarity between the prompt audio and reconstructed target audio with [WavLM-large](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view) as the embedding model. + +The following table compares our tokenizer with semantic-distillation-trained baselines and shows that it delivers performance comparable to tokenizers operating at 2.2× our model’s bitrate. + +| Model | BPS (k) | en WER ↓ | en SIM ↑ | zh WER ↓ | zh SIM ↑ | +|------------------|---------|------------|------------|------------|------------| +| [SpeechTokenizer](https://huggingface.co/fnlp/SpeechTokenizer/tree/main/speechtokenizer_hubert_avg) | 4 | 2.82 | 0.63 | 2.04 | 0.65 | +| [SemantiCodec](https://huggingface.co/haoheliu/SemantiCodec/tree/main/semanticodec_tokenrate_100) | 1.35 | 3.46 | 0.56 | 2.18 | 0.60 | +| [Mimi](https://huggingface.co/docs/transformers/en/model_doc/mimi) | 4.4 | **2.35** | **0.70** | **1.48** | **0.72** | +| [XCodec](https://huggingface.co/ZhenYe234/xcodec/blob/main/config_hubert_general.yaml) | 4.0 | 2.68 | 0.63 | 1.66 | 0.66 | +| [CosyVoice 2](https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B) | - | 3.17 | 0.65 | 2.11 | 0.70 | +| [XCodec2](https://huggingface.co/HKUST-Audio/xcodec2/blob/main/ckpt/epoch%3D4-step%3D1400000.ckpt) | 0.8 | 2.74 | 0.62 | 1.91 | 0.67 | +| [XY-MOSS-TTSD](https://huggingface.co/fnlp/XY_Tokenizer_TTSD_V0/tree/main) | 1.0 | 2.72 | 0.61 | 1.58 | 0.67 | +| Ours | 2.0 | 2.52 | 0.67 | **1.48** | 0.71 | + + + +### Audiobox Aesthetics Evaluation + +This model-based evaluation[8] further demonstrates the superiority of our tokenizer. CU denotes the Content Usefulness and CE denotes the Content Enjoyment; both are rated on a 1-10 scale. Notably, our tokenizer performs best on the Audiophile set, demonstrating a clear advantage when the original audio quality is high. + + +| Model | BPS (k) | Music CE ↑ | Music CU ↑ | Sound Event CE ↑ | Sound Event CU ↑ | Speech CE ↑ | Speech CU ↑ | Audiophile CE ↑ | Audiophile CU ↑ | +|------------------|---------|--------------|--------------|--------------------|--------------------|---------------|---------------|--------------------|--------------------| +| Origin | - | 6.20 | 7.10 | 4.47 | 5.64 | 5.03 | 4.87 | 7.17 | 7.65 | +| [SpeechTokenizer](https://huggingface.co/fnlp/SpeechTokenizer/tree/main/speechtokenizer_hubert_avg) | 4.0 | 3.55 | 5.22 | 3.03 | 4.50 | 4.68 | 4.58 | 3.59 | 5.07 | +| [SemantiCodec](https://huggingface.co/haoheliu/SemantiCodec/tree/main/semanticodec_tokenrate_100) | 1.35 | 6.01 | 6.83 | 4.22 | 5.30 | 4.28 | 4.12 | 6.97 | 7.43 | +| [Mimi](https://huggingface.co/docs/transformers/en/model_doc/mimi) | 4.4 | 6.01 | 6.83 | 4.26 | 5.35 | 4.87 | 4.72 | 6.80 | 7.29 | +| [XCodec](https://huggingface.co/ZhenYe234/xcodec/blob/main/config_hubert_general.yaml) | 4.0 | **6.30** | **7.10** | **4.43** | 5.45 | **4.96** | **4.79** | 7.06 | 7.49 | +| [CosyVoice 2](https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B) | - | 5.21 | 6.14 | 4.08 | 4.73 | **4.91** | **4.75** | 5.97 | 6.56 | +| [XCodec2](https://huggingface.co/HKUST-Audio/xcodec2/blob/main/ckpt/epoch%3D4-step%3D1400000.ckpt) | 0.8 | 4.38 | 5.66 | 3.43 | 4.63 | **4.93** | **4.78** | 4.56 | 5.46 | +| [XY-MOSS-TTSD](https://huggingface.co/fnlp/XY_Tokenizer_TTSD_V0/tree/main) | 1.0 | 5.77 | 6.80 | 4.23 | 5.34 | 4.88 | 4.72 | 6.95 | 7.48 | +| Ours | 2.0 | **6.35** | **7.15** | **4.47** | **5.51** | 4.90 | 4.70 | **7.21** | **7.66** | + + + +Note that since some tokenizers are trained on 16 kHz data, we upsample their audio outputs to 24 kHz before computing metrics. Different upsampling methods may cause slight variations (e.g., 4.36 vs. 4.43 for XCodec Sound Event CE). We report the best results we could obtain and highlight any results within 0.05 of the best one. + + + + + + + + + + + + +## Reference +[1] [Ye, Zhen, et al. "Codec does matter: Exploring the semantic shortcoming of codec for audio language model." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 39. No. 24. 2025.](https://arxiv.org/abs/2408.17175) + +[2] [Kumar, Rithesh, et al. "High-fidelity audio compression with improved rvqgan." Advances in Neural Information Processing Systems 36 (2023): 27980-27993.](https://dl.acm.org/doi/10.5555/3666122.3667336) + +[3] [Défossez, Alexandre, et al. "High fidelity neural audio compression." arXiv preprint arXiv:2210.13438 (2022).](https://arxiv.org/abs/2210.13438) + +[4] [Défossez, Alexandre, et al. "Moshi: a speech-text foundation model for real-time dialogue." arXiv preprint arXiv:2410.00037 (2024).](https://arxiv.org/abs/2410.00037) + +[5] [Liu, Haohe, et al. "Semanticodec: An ultra low bitrate semantic audio codec for general sound." IEEE Journal of Selected Topics in Signal Processing (2024).](https://ieeexplore.ieee.org/document/10768970) + +[6] [Siuzdak, Hubert, Florian Grötschla, and Luca A. Lanzendörfer. "Snac: Multi-scale neural audio codec." arXiv preprint arXiv:2410.14411 (2024).](https://arxiv.org/abs/2410.14411) + +[7] [Ji, Shengpeng, et al. "Wavtokenizer: an efficient acoustic discrete codec tokenizer for audio language modeling." arXiv preprint arXiv:2408.16532 (2024).](https://arxiv.org/abs/2408.16532) + +[8] [Tjandra, Andros, et al. "Meta audiobox aesthetics: Unified automatic quality assessment for speech, music, and sound." arXiv preprint arXiv:2502.05139 (2025).](https://arxiv.org/abs/2502.05139) + +[9] [Ye, Zhen, et al. "Llasa: Scaling Train-Time and Inference-Time Compute for Llama-based Speech Synthesis." arXiv preprint arXiv:2502.04128 (2025).](https://arxiv.org/abs/2502.04128) + +[10] [Gong, Yitian, et al. "XY-Tokenizer: Mitigating the Semantic-Acoustic Conflict in Low-Bitrate Speech Codecs." arXiv preprint arXiv:2506.23325 (2025).](https://arxiv.org/abs/2506.23325) + +[11] [Xu, Yaoxun, et al. "MuCodec: Ultra Low-Bitrate Music Codec." arXiv preprint arXiv:2409.13216 (2024).](https://arxiv.org/abs/2409.13216) + +[12] [Welker, Simon, et al. "FlowDec: A flow-based full-band general audio codec with high perceptual quality." arXiv preprint arXiv:2503.01485 (2025).](https://arxiv.org/abs/2503.01485) + +[13] [Du, Zhihao, et al. "Cosyvoice 2: Scalable streaming speech synthesis with large language models." arXiv preprint arXiv:2412.10117 (2024).](https://arxiv.org/abs/2412.10117) + +[14] [Zhang, Xin, et al. "Speechtokenizer: Unified speech tokenizer for speech large language models." arXiv preprint arXiv:2308.16692 (2023).](https://arxiv.org/abs/2308.16692) + +[15] [Anastassiou, Philip, et al. "Seed-tts: A family of high-quality versatile speech generation models." arXiv preprint arXiv:2406.02430 (2024).](https://arxiv.org/abs/2406.02430) \ No newline at end of file diff --git a/higgs-audio-v2-tokenizer/config.json b/higgs-audio-v2-tokenizer/config.json new file mode 100644 index 0000000000000000000000000000000000000000..916afcabbc44e2bd733f21ead967baf745bb7d3c --- /dev/null +++ b/higgs-audio-v2-tokenizer/config.json @@ -0,0 +1,12 @@ +{ + "n_filters": 32, + "D": 256, + "codebook_dim": 64, + "target_bandwidths": [0.5, 1, 1.5, 2, 4], + "ratios": [8, 5, 4, 2, 3], + "sample_rate": 24000, + "bins": 1024, + "n_q": 8, + "semantic_techer": "hubert_base_general" +} + diff --git a/higgs-audio-v2-tokenizer/model.pth b/higgs-audio-v2-tokenizer/model.pth new file mode 100644 index 0000000000000000000000000000000000000000..597eddd66e2b56ff78c9dc37e0cebcd61ac680cd --- /dev/null +++ b/higgs-audio-v2-tokenizer/model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e64349910124c5e615bae9e373f45d0ed346056910844ae1a5ac44c92dc24c92 +size 805947058 diff --git a/predict.py b/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..79d9fd40b30b2eb4e56e860db5b53b379e6f1e3e --- /dev/null +++ b/predict.py @@ -0,0 +1,132 @@ +# Prediction interface for Cog ⚙️ +# https://cog.run/python +import os +import subprocess +import time + +import torch +import torchaudio +from cog import BasePredictor, Input, Path + +from boson_multimodal.data_types import ChatMLSample, Message, AudioContent +from boson_multimodal.serve.serve_engine import HiggsAudioResponse, HiggsAudioServeEngine + + +MODEL_PATH = "higgs-audio-v2-generation-3B-base" +AUDIO_TOKENIZER_PATH = "higgs-audio-v2-tokenizer" +MODEL_URL = "https://weights.replicate.delivery/default/bosonai/higgs-audio-v2-generation-3B-base/model.tar" +TOKENIZER_URL = "https://weights.replicate.delivery/default/bosonai/higgs-audio-v2-tokenizer/model.tar" + + +def download_weights(url, dest): + start = time.time() + print("downloading url: ", url) + print("downloading to: ", dest) + subprocess.check_call(["pget", "-xf", url, dest], close_fds=False) + print("downloading took: ", time.time() - start) + + +class Predictor(BasePredictor): + def setup(self) -> None: + """Load the model into memory to make running multiple predictions efficient""" + # Download weights + if not os.path.exists(MODEL_PATH): + download_weights(MODEL_URL, MODEL_PATH) + if not os.path.exists(AUDIO_TOKENIZER_PATH): + download_weights(TOKENIZER_URL, AUDIO_TOKENIZER_PATH) + + # Set device + self.device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {self.device}") + + # Initialize the serve engine + self.serve_engine = HiggsAudioServeEngine( + MODEL_PATH, + AUDIO_TOKENIZER_PATH, + device=self.device) + print("Higgs Audio V2 model loaded successfully") + + def predict( + self, + text: str = Input( + description="Text to convert to speech", + default="The sun rises in the east and sets in the west", + ), + temperature: float = Input( + description="Controls randomness in generation. Lower values are more deterministic.", + ge=0.1, + le=1.0, + default=0.3, + ), + top_p: float = Input( + description="Nucleus sampling parameter. Controls diversity of generated audio.", + ge=0.1, + le=1.0, + default=0.95, + ), + top_k: int = Input( + description="Top-k sampling parameter. Limits vocabulary to top k tokens.", ge=1, le=100, default=50 + ), + max_new_tokens: int = Input( + description="Maximum number of audio tokens to generate", ge=256, le=2048, default=1024 + ), + scene_description: str = Input( + description="Scene description for audio context", default="Audio is recorded from a quiet room." + ), + system_message: str = Input(description="Custom system message (optional)", default=""), + ref_audio: Path = Input( + description="Reference audio file for voice cloning (optional). Supports WAV, MP3, etc.", + default=None, + ), + ) -> Path: + """Run a single prediction on the model""" + try: + # Construct system prompt + if system_message: + system_prompt = system_message + else: + system_prompt = f"Generate audio following instruction.\n\n<|scene_desc_start|>\n{scene_description}\n<|scene_desc_end|>" + + # Prepare messages + messages = [ + Message( + role="system", + content=system_prompt, + ), + ] + + # Add reference audio message if provided (voice cloning) + if ref_audio is not None: + messages.append( + Message( + role="assistant", + content=AudioContent(audio_url=str(ref_audio)), + ) + ) + + # Add user text message + messages.append( + Message( + role="user", + content=text, + ) + ) + + # Generate audio + output: HiggsAudioResponse = self.serve_engine.generate( + chat_ml_sample=ChatMLSample(messages=messages), + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stop_strings=["<|end_of_text|>", "<|eot_id|>"], + ) + # Save output audio to a temporary file with a clear filename + output_path = "/tmp/audio_output.wav" + # Convert output audio to tensor and save + audio_tensor = torch.from_numpy(output.audio)[None, :] + torchaudio.save(output_path, audio_tensor, output.sampling_rate, format="wav") + return Path(output_path) + + except Exception as e: + raise RuntimeError(f"Audio generation failed: {str(e)}") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..28586cfc54c0e6d28dd9434c4ad568c5b9d597e1 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,101 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.ruff] +line-length = 119 +target-version = "py310" +indent-width = 4 +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", + "external", + "third_party", +] + +[tool.ruff.lint] +preview = true +ignore-init-module-imports = true +extend-select = [ + "B009", # static getattr + "B010", # static setattr + "CPY", # Copyright + "E", # PEP8 errors + "F", # PEP8 formatting + "I", # Import sorting + "TID251", # Banned API + "UP", # Pyupgrade + "W", # PEP8 warnings +] +ignore = [ + "E501", # Line length (handled by ruff-format) + "E741", # Ambiguous variable name + "W605", # Invalid escape sequence + "UP007", # X | Y type annotations +] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = [ + "F401", # Ignore seemingly unused imports (they're meant for re-export) +] + +[tool.ruff.lint.isort] +lines-after-imports = 2 +known-first-party = ["character_tuning"] + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" + +# Enable auto-formatting of code examples in docstrings. Markdown, +# reStructuredText code/literal blocks and doctests are all supported. +# +# This is currently disabled by default, but it is planned for this +# to be opt-out in the future. +docstring-code-format = false + +# Set the line length limit used when formatting code snippets in +# docstrings. +# +# This only has an effect when the `docstring-code-format` setting is +# enabled. +docstring-code-line-length = "dynamic" + +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"os.getenv".msg = "Use os.environ instead" +"os.putenv".msg = "Use os.environ instead" +"os.unsetenv".msg = "Use os.environ instead" + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3f8a9057daa58aaf4ca49958f957e82f02c7cc0e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +descript-audio-codec +torch +transformers>=4.45.1,<4.47.0 +librosa +dacite +boto3==1.35.36 +s3fs +torchvision +torchaudio +torchcodec +json_repair +pandas +pydantic +vector_quantize_pytorch +loguru +pydub +ruff==0.12.2 +omegaconf +click +langid +jieba +accelerate>=0.26.0 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..9a32f0261164c78e784621e229b6b40c2ea3b135 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,16 @@ +[metadata] +name = boson_multimodal +author = Boson AI +version = 0.1.0 +url = https://github.com/boson-ai/higgs-audio +description = Higgs Audio +long_description = file: README.md +long_description_content_type = text/markdown + +[options] +packages = find: + +[options.packages.find] +exclude = + tests* + training* diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..b024da80e9c1c8c800cc1b46e90c1e783cb446cc --- /dev/null +++ b/setup.py @@ -0,0 +1,4 @@ +from setuptools import setup + + +setup() diff --git a/tech_blogs/ARCHITECTURE_BLOG.md b/tech_blogs/ARCHITECTURE_BLOG.md new file mode 100644 index 0000000000000000000000000000000000000000..8fddd98213749639beca03c3fcbba21b17df0862 --- /dev/null +++ b/tech_blogs/ARCHITECTURE_BLOG.md @@ -0,0 +1,24 @@ +# HiggsAudio-V2 Model Architecture + + + +Our model is built on top of [Llama-3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B). To enhance the model’s ability to process audio tokens, we incorporate the "DualFFN" architecture as an audio adapter. DualFFN acts as an audio-specific expert, boosting the LLM's performance with minimal computational overhead. Our implementation preserves 91% of the original LLM’s training speed with the inclusion of DualFFN. + +Since our audio tokenizer is based on Residual Vector-Quantization (RVQ) and contains multiple codebooks, we adopt the [delay pattern](https://proceedings.neurips.cc/paper_files/paper/2023/file/94b472a1842cd7c56dcb125fb2765fbd-Paper-Conference.pdf) to enable simultaneous code generation across codebooks while supporting streaming. + + + + +## DualFFN Performance Ablation Study + +To assess the effectiveness of DualFFN, we trained two smaller models based on LLaMA-3.1-1B: one incorporating DualFFN and one without. Both models were trained for 250K steps with a learning rate of 5e-4 on a subset of the AudioVerse dataset. We evaluated their performance on SeedTTS-Eval, with the results presented in the figures below. The model equipped with DualFFN consistently outperforms its counterpart in terms of word error rate (WER) and speaker similarity. + +- SeedTTS-EN + + + +- SeedTTS-ZH + + + +We may notice that the model with DualFFN consistently outperforms the model without DualFFN in terms of word-error-rate (WER) and speaker similarity. diff --git a/tech_blogs/TOKENIZER_BLOG.md b/tech_blogs/TOKENIZER_BLOG.md new file mode 100644 index 0000000000000000000000000000000000000000..e6285f111db3096779f13c689edfd8197f97be5d --- /dev/null +++ b/tech_blogs/TOKENIZER_BLOG.md @@ -0,0 +1,170 @@ +# Higgs Audio Tokenizer + +In this work, we introduce a new discretized audio tokenizer that runs at just **25 frames per second** while keeping—or even improving—audio quality compared to tokenizers with twice the bitrate. Our model is the first to train on **24 kHz data** covering speech, music, and sound events in one unified system. It also uses a simple non-diffusion encoder/decoder for fast, batch inference. + +![XCodec Architecture](../figures/higgs_audio_tokenizer_architecture.png) + +## Basics of Audio Quantization + +An audio signal sampled at $f_s$ Hz is first split into frames by an encoder with hop size $M$, giving a frame rate $f_r = \frac{f_s}{M}\quad\text{(frames/s)}.$ +Two common quantizers are: + +- **Residual Vector Quantization (RVQ)**: $N_q$ cascaded layers with codebook size $N_{cb}$ each. When $N_{cb}=1$, it reduces to single-vector quantization. +- **Finite Scalar Quantization (FSQ)**: A single layer ($N_q=1$) with codebook size $N_{cb}$. + +If every combination of codewords is a token, the vocabulary size is $N_{cb}^{N_q}$, and each token needs $N_q\log_2 N_{cb}$ bits. The overall bitrate (bits/s, BPS) is simply $f_r \times N_q \log_2 N_{cb}.$ +We aim to push this bitrate as low as possible without hurting audio fidelity. + +## What Makes Ours Better + +- **Low Frame Rate**: At 25 fps, our tokenizer halves the frame rate of many baselines when still maintaining high audio quality. +- **Unified 24 kHz Training**: We mix speech, music, and sound-event clips in one model, capturing both semantic and acoustic details, hugely facilitating the training of audio language models. +- **Fast Inference**: By avoiding diffusion steps, our encoder/decoder processes batches quickly, making it practical for real-time or large-scale tasks. + + +## Data and Evaluation Metrics + +We test on four subsets: + +- **Speech, Music, Sound Event**: Includes 1,000 clips for each category, with each clip lasting 10 seconds. Clips are randomly sampled from [DAPS](https://ccrma.stanford.edu/~gautham/Site/daps.html) (Speech), [MUSDB](https://sigsep.github.io/datasets/musdb.html) (Music), and [AudioSet](https://research.google.com/audioset/index.html) (Sound Event). + +- **Audiophile**: Contains 150 clips, each 30 seconds long, curated from eleven high-fidelity test discs. The clips feature both music and sound events, selected for audio quality evaluation. + +We measure: + +- **Acoustic Quality**: STFT distance between the original and reconstructed audio. +- **Semantic Integrity**: Semantic preservation of the original audio using [SeedTTS](https://arxiv.org/abs/2406.02430)[15] dataset on English and Chinese. +- **Aesthetics**: SOTA unified model-based quality assessment, [Meta Audiobox Aesthetics](https://github.com/facebookresearch/audiobox-aesthetics)[8], for Content Enjoyment (CE), Content Usefulness (CU) . + + +We compare our tokenizer with a wide range of baselines, from tokenizers mainly built for better acoustic reconstruction and compression rate, to those focused on semantic integrity, and to tokenizers used in existing large audio language models. We also compare with tokenizers that are pretrained specifically on speech or on music. + + +The tables below summarize the tokenizers evaluated. As shown, our tokenizer achieves a well-rounded balance of efficiency, semantic fidelity, and acoustic quality. + +### Accoustic Evaluation + +We use the STFT metric here for simplicity. The baselines are ordered chronologically, grouped by whether semantic distillation (SD) is applied.Despite DAC’s top acoustic quality at 12× the bitrate, our tokenizer leads all other baselines. + + +| Tokenizer | 💬 | 🎵 | 🥁 | SD | $f_s$ | $f_r$ | BPS* (k) ↓ | Speech ↓ | Sound Event ↓ | Music ↓ | Audiophile ↓ | +|-----------|----|----|----|----|-------|-------|--------------------------|----------|----------------|--------|--------------| +| [Encodec](https://huggingface.co/facebook/encodec_24khz)[3] | ✓ | ✓ | ✓ | | 24 | 75 | 24 | 1.96 | 2.65 | 2.52 | 2.30 | +| [DAC](https://huggingface.co/hance-ai/descript-audio-codec-24khz)[2] | ✓ | ✓ | ✓ | | 24 | 75 | 24 | **1.13** | **1.45** | **1.34** | **1.62** | +| [SNAC-24k](https://huggingface.co/hubertsiuzdak/snac_24khz)[6] | ✓ | | | | 24 | (12, 23, 47) | 0.98 | 1.92 | 2.69 | 2.54 | 2.52 | +| [SNAC-44.1k](https://huggingface.co/hubertsiuzdak/snac_44khz)[6] | | ✓ | ✓ | | 44.1 | (14, 29, 57, 115) | 2.6 | 1.83 | 2.25 | 2.05 | 2.00 | +| [WavTokenizer](https://huggingface.co/novateur/WavTokenizer-medium-music-audio-75token/blob/main/wavtokenizer_medium_music_audio_320_24k_v2.ckpt)[7] | | ✓ | ✓ | | 24 | 75 | 0.9 | 1.93 | 2.44 | 2.17 | 2.15 | +| [WavTokenizer (Speech)](https://huggingface.co/novateur/WavTokenizer-large-speech-75token/tree/main)[7] | ✓ | | | | 24 | 75 | 0.9 | 1.78 | 2.47 | 2.42 | 2.47 | +| [MuCodec](https://huggingface.co/haoheliu/audioldm_48k/tree/main)[11] | | ✓ | | | 48 | 25 | 0.35 | 2.87 | 3.69 | 3.36 | 2.97 | +| [FlowDec-75m](https://github.com/facebookresearch/FlowDec?tab=readme-ov-file)[12] | ✓ | ✓ | ✓ | | 48 | 75 | 7.5 | 1.73 | 2.14 | 2.01 | 2.03 | +| [FlowDec-25s](https://github.com/facebookresearch/FlowDec?tab=readme-ov-file)[12] | ✓ | ✓ | ✓ | | 48 | 25 | 4 | 1.94 | 2.42 | 2.25 | 2.33 | +| [SpeechTokenizer](https://huggingface.co/fnlp/SpeechTokenizer/tree/main/speechtokenizer_hubert_avg)[14] | ✓ | | | ✓ | 16 | 50 | 4 | 3.21 | 3.58 | 3.65 | 3.69 | +| [SemantiCodec](https://huggingface.co/haoheliu/SemantiCodec/tree/main/semanticodec_tokenrate_100)[5] | ✓ | ✓ | ✓ | ✓ | 16 | 50 | 1.4 | 3.05 | 3.28 | 3.24 | 3.18 | +| [Mimi](https://huggingface.co/docs/transformers/en/model_doc/mimi)[13] | ✓ | | | ✓ | 24 | 12.5 | 4.4 | 1.77 | 2.40 | 2.30 | 2.15 | +| [XCodec](https://huggingface.co/ZhenYe234/xcodec/blob/main/config_hubert_general.yaml)[1] | ✓ | ✓ | ✓ | ✓ | 16 | 50 | 4 | 2.95 | 3.16 | 3.00 | 3.03 | +| [CosyVoice 2](https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B)[13] | ✓ | | | ✓ | 16 | 25 | -** | 2.30 | 3.30 | 3.14 | 3.25 | +| [XCodec2](https://huggingface.co/HKUST-Audio/xcodec2/blob/main/ckpt/epoch%3D4-step%3D1400000.ckpt)[9] | ✓ | | | ✓ | 16 | 50 | 0.8 | 3.06 | 3.72 | 3.62 | 3.64 | +| [XY](https://huggingface.co/fnlp/XY_Tokenizer_TTSD_V0/tree/main)[10] | ✓ | | | ✓ | 24 | 12.5 | 1 | 1.89 | 2.51 | 2.40 | 2.26 | +| Ours | ✓ | ✓ | ✓ | ✓ | 24 | 25 | 2 | **1.62** | **2.03** | **1.85** | **1.80** | + + + +\* Bits-per-second is calculated according to the checkpoint the author provided. + +\*\* CosyVoice 2 uses the continuous feature as the conditioning, we include it for completeness. + + +### Semantic Evaluation +Here we only compare with tokenizers that are trained with semantic distillation. +[SeedTTS](https://github.com/BytedanceSpeech/seed-tts-eval) is a dataset includes prompt/target audio and texts. We reconstructed the target audio, and use the word error rate (WER) and speaker similarity (SIM) metrics to evaluate the semantic integrity. SIM is calculated by the similarity between the prompt audio and reconstructed targeted audio with [WavLM-large](https://drive.google.com/file/d/1-aE1NfzpRCLxA4GUxX9ITI3F9LlbtEGP/view) as the embedding model. + +The following table shows that our tokenizer achieves comparable performance to tokenizers that 2.2x the bitrate of our model. + +| Model | BPS (k) | en WER ↓ | en SIM ↑ | zh WER ↓ | zh SIM ↑ | +|------------------|---------|------------|------------|------------|------------| +| [SpeechTokenizer](https://huggingface.co/fnlp/SpeechTokenizer/tree/main/speechtokenizer_hubert_avg) | 4 | 2.82 | 0.63 | 2.04 | 0.65 | +| [SemantiCodec](https://huggingface.co/haoheliu/SemantiCodec/tree/main/semanticodec_tokenrate_100) | 1.4 | 3.46 | 0.56 | 2.18 | 0.60 | +| [Mimi](https://huggingface.co/docs/transformers/en/model_doc/mimi) | 4.4 | **2.35** | **0.70** | **1.48** | **0.72** | +| [XCodec](https://huggingface.co/ZhenYe234/xcodec/blob/main/config_hubert_general.yaml) | 4.0 | 2.68 | 0.63 | 1.66 | 0.66 | +| [CosyVoice 2](https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B) | - | 3.17 | 0.65 | 2.11 | 0.70 | +| [XCodec2](https://huggingface.co/HKUST-Audio/xcodec2/blob/main/ckpt/epoch%3D4-step%3D1400000.ckpt) | 0.8 | 2.74 | 0.62 | 1.91 | 0.67 | +| [XY-MOSS-TTSD](https://huggingface.co/fnlp/XY_Tokenizer_TTSD_V0/tree/main) | 1.0 | 2.72 | 0.61 | 1.58 | 0.67 | +| Ours | 2.0 | 2.52 | 0.67 | **1.48** | 0.71 | + + + +### Audiobox Aesthetics Evaluation + +This model based evaluation[8] further demonstrates the superiority of our tokenizer. CU is the Content Usefulness and CE is the Content Enjoyment. Each term is rated on a scale of 1-10. Notably, our tokenizer performs best on the Audiophile set—demonstrating a clear advantage when the original audio quality is high. + + +| Model | BPS (k) | Music CE ↑ | Music CU ↑ | Sound Event CE ↑ | Sound Event CU ↑ | Speech CE ↑ | Speech CU ↑ | Audiophile CE ↑ | Audiophile CU ↑ | +|------------------|---------|--------------|--------------|--------------------|--------------------|---------------|---------------|--------------------|--------------------| +| Origin | - | 6.20 | 7.10 | 4.47 | 5.64 | 5.03 | 4.87 | 7.17 | 7.65 | +| [SpeechTokenizer](https://huggingface.co/fnlp/SpeechTokenizer/tree/main/speechtokenizer_hubert_avg) | 4.0 | 3.55 | 5.22 | 3.03 | 4.50 | 4.68 | 4.58 | 3.59 | 5.07 | +| [SemantiCodec](https://huggingface.co/haoheliu/SemantiCodec/tree/main/semanticodec_tokenrate_100) | 1.4 | 6.01 | 6.83 | 4.22 | 5.30 | 4.28 | 4.12 | 6.97 | 7.43 | +| [Mimi](https://huggingface.co/docs/transformers/en/model_doc/mimi) | 4.4 | 6.01 | 6.83 | 4.26 | 5.35 | 4.87 | 4.72 | 6.80 | 7.29 | +| [XCodec](https://huggingface.co/ZhenYe234/xcodec/blob/main/config_hubert_general.yaml) | 4.0 | **6.30** | **7.10** | **4.43** | 5.45 | **4.96** | **4.79** | 7.06 | 7.49 | +| [CosyVoice 2](https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B) | - | 5.21 | 6.14 | 4.08 | 4.73 | **4.91** | **4.75** | 5.97 | 6.56 | +| [XCodec2](https://huggingface.co/HKUST-Audio/xcodec2/blob/main/ckpt/epoch%3D4-step%3D1400000.ckpt) | 0.8 | 4.38 | 5.66 | 3.43 | 4.63 | **4.93** | **4.78** | 4.56 | 5.46 | +| [XY-MOSS-TTSD](https://huggingface.co/fnlp/XY_Tokenizer_TTSD_V0/tree/main) | 1.0 | 5.77 | 6.80 | 4.23 | 5.34 | 4.88 | 4.72 | 6.95 | 7.48 | +| Ours | 2.0 | **6.35** | **7.15** | **4.47** | **5.51** | 4.90 | 4.70 | **7.21** | **7.66** | + + + +Note that since some tokenizers are trained on 16 kHz data, we upsample their audio outputs to 24 kHz before computing metrics. Different upsampling methods may cause slight variations (e.g., 4.36 vs. 4.43 for XCodec Sound Event CE). We report the best results we could obtain and highlight any results within 0.05 of the best one. + + + + + + + + + + + + +## Reference +[1] [Ye, Zhen, et al. "Codec does matter: Exploring the semantic shortcoming of codec for audio language model." Proceedings of the AAAI Conference on Artificial Intelligence. Vol. 39. No. 24. 2025.](https://arxiv.org/abs/2408.17175) + +[2] [Kumar, Rithesh, et al. "High-fidelity audio compression with improved rvqgan." Advances in Neural Information Processing Systems 36 (2023): 27980-27993.](https://dl.acm.org/doi/10.5555/3666122.3667336) + +[3] [Défossez, Alexandre, et al. "High fidelity neural audio compression." arXiv preprint arXiv:2210.13438 (2022).](https://arxiv.org/abs/2210.13438) + +[4] [Défossez, Alexandre, et al. "Moshi: a speech-text foundation model for real-time dialogue." arXiv preprint arXiv:2410.00037 (2024).](https://arxiv.org/abs/2410.00037) + +[5] [Liu, Haohe, et al. "Semanticodec: An ultra low bitrate semantic audio codec for general sound." IEEE Journal of Selected Topics in Signal Processing (2024).](https://ieeexplore.ieee.org/document/10768970) + +[6] [Siuzdak, Hubert, Florian Grötschla, and Luca A. Lanzendörfer. "Snac: Multi-scale neural audio codec." arXiv preprint arXiv:2410.14411 (2024).](https://arxiv.org/abs/2410.14411) + +[7] [Ji, Shengpeng, et al. "Wavtokenizer: an efficient acoustic discrete codec tokenizer for audio language modeling." arXiv preprint arXiv:2408.16532 (2024).](https://arxiv.org/abs/2408.16532) + +[8] [Tjandra, Andros, et al. "Meta audiobox aesthetics: Unified automatic quality assessment for speech, music, and sound." arXiv preprint arXiv:2502.05139 (2025).](https://arxiv.org/abs/2502.05139) + +[9] [Ye, Zhen, et al. "Llasa: Scaling Train-Time and Inference-Time Compute for Llama-based Speech Synthesis." arXiv preprint arXiv:2502.04128 (2025).](https://arxiv.org/abs/2502.04128) + +[10] [Gong, Yitian, et al. "XY-Tokenizer: Mitigating the Semantic-Acoustic Conflict in Low-Bitrate Speech Codecs." arXiv preprint arXiv:2506.23325 (2025).](https://arxiv.org/abs/2506.23325) + +[11] [Xu, Yaoxun, et al. "MuCodec: Ultra Low-Bitrate Music Codec." arXiv preprint arXiv:2409.13216 (2024).](https://arxiv.org/abs/2409.13216) + +[12] [Welker, Simon, et al. "FlowDec: A flow-based full-band general audio codec with high perceptual quality." arXiv preprint arXiv:2503.01485 (2025).](https://arxiv.org/abs/2503.01485) + +[13] [Du, Zhihao, et al. "Cosyvoice 2: Scalable streaming speech synthesis with large language models." arXiv preprint arXiv:2412.10117 (2024).](https://arxiv.org/abs/2412.10117) + +[14] [Zhang, Xin, et al. "Speechtokenizer: Unified speech tokenizer for speech large language models." arXiv preprint arXiv:2308.16692 (2023).](https://arxiv.org/abs/2308.16692) + +[15] [Anastassiou, Philip, et al. "Seed-tts: A family of high-quality versatile speech generation models." arXiv preprint arXiv:2406.02430 (2024).](https://arxiv.org/abs/2406.02430) diff --git a/voice.mp3 b/voice.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..0858b017a1808f3f9fbaffeb41556a389328418b Binary files /dev/null and b/voice.mp3 differ