codeStriker commited on
Commit
b1175d1
·
1 Parent(s): bb739d5

Initial commit

Browse files
Files changed (41) hide show
  1. .gitattributes +8 -0
  2. .gitignore +10 -0
  3. .gradio/certificate.pem +31 -0
  4. README.md +7 -7
  5. app.py +773 -0
  6. higgs_audio/__init__.py +1 -0
  7. higgs_audio/audio_processing/LICENSE +51 -0
  8. higgs_audio/audio_processing/descriptaudiocodec/__init__.py +0 -0
  9. higgs_audio/audio_processing/descriptaudiocodec/dac/model/base.py +286 -0
  10. higgs_audio/audio_processing/descriptaudiocodec/dac/model/dac.py +365 -0
  11. higgs_audio/audio_processing/descriptaudiocodec/dac/nn/layers.py +33 -0
  12. higgs_audio/audio_processing/descriptaudiocodec/dac/nn/quantize.py +251 -0
  13. higgs_audio/audio_processing/higgs_audio_tokenizer.py +341 -0
  14. higgs_audio/audio_processing/quantization/__init__.py +8 -0
  15. higgs_audio/audio_processing/quantization/ac.py +301 -0
  16. higgs_audio/audio_processing/quantization/core_vq.py +360 -0
  17. higgs_audio/audio_processing/quantization/core_vq_lsx_version.py +431 -0
  18. higgs_audio/audio_processing/quantization/ddp_utils.py +197 -0
  19. higgs_audio/audio_processing/quantization/distrib.py +123 -0
  20. higgs_audio/audio_processing/quantization/vq.py +116 -0
  21. higgs_audio/audio_processing/semantic_module.py +310 -0
  22. higgs_audio/constants.py +3 -0
  23. higgs_audio/data_collator/__init__.py +0 -0
  24. higgs_audio/data_collator/higgs_audio_collator.py +583 -0
  25. higgs_audio/data_types.py +38 -0
  26. higgs_audio/dataset/__init__.py +0 -0
  27. higgs_audio/dataset/chatml_dataset.py +554 -0
  28. higgs_audio/model/__init__.py +9 -0
  29. higgs_audio/model/audio_head.py +139 -0
  30. higgs_audio/model/common.py +27 -0
  31. higgs_audio/model/configuration_higgs_audio.py +235 -0
  32. higgs_audio/model/cuda_graph_runner.py +129 -0
  33. higgs_audio/model/custom_modules.py +155 -0
  34. higgs_audio/model/modeling_higgs_audio.py +0 -0
  35. higgs_audio/model/utils.py +778 -0
  36. higgs_audio/serve/serve_engine.py +474 -0
  37. higgs_audio/serve/utils.py +254 -0
  38. pyproject.toml +100 -0
  39. requirements.txt +17 -0
  40. theme.json +285 -0
  41. voice_examples/config.json +30 -0
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ voice_examples/en_woman.wav filter=lfs diff=lfs merge=lfs -text
37
+ voice_examples/mabel.wav filter=lfs diff=lfs merge=lfs -text
38
+ voice_examples/vex.wav filter=lfs diff=lfs merge=lfs -text
39
+ voice_examples/zh_man_sichuan.wav filter=lfs diff=lfs merge=lfs -text
40
+ voice_examples/belinda.wav filter=lfs diff=lfs merge=lfs -text
41
+ voice_examples/broom_salesman.wav filter=lfs diff=lfs merge=lfs -text
42
+ voice_examples/chadwick.wav filter=lfs diff=lfs merge=lfs -text
43
+ voice_examples/en_man.wav filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ *.pyw
6
+ *.pyz
7
+ *.pywz
8
+ *.pyzw
9
+ *.pyzwz
10
+ .ruff_cache/
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: OnlyAgencies Audio TTS
3
- emoji: 📈
4
- colorFrom: indigo
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.6.0
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Banana Voice
3
+ emoji: 🎤
4
+ colorFrom: red
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 5.36.2
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Banana Voice
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,773 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Xentrik Audio Text-to-Speech - TTS for Chatters
3
+ """
4
+
5
+ import argparse
6
+ import base64
7
+ import os
8
+ import uuid
9
+ import json
10
+ from typing import Optional
11
+ import gradio as gr
12
+ from loguru import logger
13
+ import numpy as np
14
+ import time
15
+ from functools import lru_cache
16
+ import re
17
+ import spaces
18
+ import torch
19
+
20
+ # Import HiggsAudio components
21
+ from higgs_audio.serve.serve_engine import HiggsAudioServeEngine
22
+ from higgs_audio.data_types import ChatMLSample, AudioContent, Message
23
+
24
+ # Global engine instance
25
+ engine = None
26
+
27
+ # Default model configuration
28
+ DEFAULT_MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
29
+ DEFAULT_AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"
30
+ SAMPLE_RATE = 24000
31
+
32
+ DEFAULT_SYSTEM_PROMPT = (
33
+ "Generate audio following instruction.\n\n"
34
+ "<|scene_desc_start|>\n"
35
+ "Audio is recorded from a Quiet Bedroom.\n"
36
+ "<|scene_desc_end|>"
37
+ )
38
+
39
+ DEFAULT_STOP_STRINGS = ["<|end_of_text|>", "<|eot_id|>"]
40
+
41
+ # Predefined examples for system and input messages - OnlyFans themed
42
+ PREDEFINED_EXAMPLES = {
43
+ "voice-clone": {
44
+ "system_prompt": "",
45
+ "input_text": "Hey there! I'm your hottest, sexiest and sweetest Voice Cloner. See below for Custom Reference, Click on it and upload a voice file. - let's clone some vocals and bring your voice to life! ",
46
+ "description": "Voice clone to clone the reference audio. Leave the system prompt empty.",
47
+ },
48
+ "smart-voice": {
49
+ "system_prompt": DEFAULT_SYSTEM_PROMPT,
50
+ "input_text": "The OnlyFans models are becoming richer than movie stars, I am doing nothing but OFM for sure. this is about to go super crazy!",
51
+ "description": "Smart voice to generate speech based on the context",
52
+ },
53
+ "multispeaker-voice-description": {
54
+ "system_prompt": "You are an AI assistant designed to convert text into speech.\n"
55
+ "If the user's message includes a [SPEAKER*] tag, do not read out the tag and generate speech for the following text, using the specified voice.\n"
56
+ "If no speaker tag is present, select a suitable voice on your own.\n\n"
57
+ "<|scene_desc_start|>\n"
58
+ "SPEAKER0: feminine\n"
59
+ "SPEAKER1: masculine\n"
60
+ "<|scene_desc_end|>",
61
+ "input_text": "[SPEAKER0] I can't believe you did that without even asking me first!\n"
62
+ "[SPEAKER1] Oh, come on! It wasn't a big deal, and I knew you would overreact like this.\n"
63
+ "[SPEAKER0] Overreact? You made a decision that affects both of us without even considering my opinion!\n"
64
+ "[SPEAKER1] Because I didn't have time to sit around waiting for you to make up your mind! Someone had to act.",
65
+ "description": "Multispeaker with different voice descriptions in the system prompt",
66
+ },
67
+ "single-speaker-voice-description": {
68
+ "system_prompt": "Generate audio following instruction.\n\n"
69
+ "<|scene_desc_start|>\n"
70
+ "SPEAKER0: He speaks with a clear British accent and a conversational, inquisitive tone. His delivery is articulate and at a moderate pace, and very clear audio.\n"
71
+ "<|scene_desc_end|>",
72
+ "input_text": "Hey, everyone! Welcome back to Tech Talk Tuesdays.\n"
73
+ "It's your host, Alex, and today, we're diving into a topic that's become absolutely crucial in the tech world — deep learning.\n"
74
+ "And let's be honest, if you've been even remotely connected to tech, AI, or machine learning lately, you know that deep learning is everywhere.\n"
75
+ "\n"
76
+ "So here's the big question: Do you want to understand how deep learning works?\n",
77
+ "description": "Single speaker with voice description in the system prompt",
78
+ },
79
+ "single-speaker-bgm": {
80
+ "system_prompt": DEFAULT_SYSTEM_PROMPT,
81
+ "input_text": "[music start] I will remember this, thought Ender, when I am defeated. To keep dignity, and give honor where it's due, so that defeat is not disgrace. And I hope I don't have to do it often. [music end]",
82
+ "description": "Single speaker with BGM using music tag. This is an experimental feature and you may need to try multiple times to get the best result.",
83
+ },
84
+ }
85
+
86
+
87
+ @lru_cache(maxsize=20)
88
+ def encode_audio_file(file_path):
89
+ """Encode an audio file to base64."""
90
+ with open(file_path, "rb") as audio_file:
91
+ return base64.b64encode(audio_file.read()).decode("utf-8")
92
+
93
+
94
+ def get_current_device():
95
+ """Get the current device."""
96
+ return "cuda" if torch.cuda.is_available() else "cpu"
97
+
98
+
99
+ def load_voice_presets():
100
+ """Load the voice presets from both config.json and voice files directory."""
101
+ voice_presets = {}
102
+ voice_dir = os.path.join(os.path.dirname(__file__), "voice_examples")
103
+ config_path = os.path.join(voice_dir, "config.json")
104
+
105
+ # First try to load from config.json if it exists
106
+ if os.path.exists(config_path):
107
+ try:
108
+ with open(config_path, "r", encoding='utf-8') as f:
109
+ voice_dict = json.load(f)
110
+
111
+ for voice_id, voice_info in voice_dict.items():
112
+ if isinstance(voice_info, dict) and "transcript" in voice_info:
113
+ voice_presets[voice_id] = voice_info["transcript"]
114
+ else:
115
+ voice_presets[voice_id] = f"Voice sample {voice_id}"
116
+
117
+ logger.info(f"Loaded {len(voice_presets)} voice presets from config.json")
118
+
119
+ except json.JSONDecodeError as e:
120
+ logger.warning(f"Config.json has JSON syntax error: {e}. Will scan folder instead.")
121
+ except Exception as e:
122
+ logger.warning(f"Error loading config.json: {e}. Will scan folder instead.")
123
+
124
+ # Then scan the voice_examples folder for any .wav files
125
+ if os.path.exists(voice_dir):
126
+ wav_files = [f for f in os.listdir(voice_dir) if f.endswith('.wav')]
127
+
128
+ for wav_file in wav_files:
129
+ voice_id = os.path.splitext(wav_file)[0] # Remove .wav extension
130
+
131
+ # Only add if not already in config or if we need to override
132
+ if voice_id not in voice_presets:
133
+ # Create a friendly name from the filename
134
+ friendly_name = voice_id.replace('_', ' ').title()
135
+ voice_presets[voice_id] = f"{friendly_name} Voice"
136
+ logger.info(f"Added voice preset from file: {voice_id}")
137
+
138
+ # Always include EMPTY option
139
+ voice_presets["EMPTY"] = "No reference voice"
140
+
141
+ logger.info(f"Total voice presets available: {list(voice_presets.keys())}")
142
+ return voice_presets
143
+
144
+
145
+ def get_voice_preset(voice_preset):
146
+ """Get the voice path and text for a given voice preset."""
147
+ if voice_preset == "EMPTY":
148
+ return None, ""
149
+
150
+ voice_path = os.path.join(os.path.dirname(__file__), "voice_examples", f"{voice_preset}.wav")
151
+
152
+ if not os.path.exists(voice_path):
153
+ logger.warning(f"Voice preset file not found: {voice_path}")
154
+ return None, "Voice preset not found"
155
+
156
+ # Get the transcript from loaded presets or create a default
157
+ voice_presets = load_voice_presets()
158
+ text = voice_presets.get(voice_preset, f"Voice sample: {voice_preset}")
159
+
160
+ return voice_path, text
161
+
162
+
163
+ def normalize_chinese_punctuation(text):
164
+ """
165
+ Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
166
+ """
167
+ # Mapping of Chinese punctuation to English punctuation
168
+ chinese_to_english_punct = {
169
+ ",": ", ", # comma
170
+ "。": ".", # period
171
+ ":": ":", # colon
172
+ ";": ";", # semicolon
173
+ "?": "?", # question mark
174
+ "!": "!", # exclamation mark
175
+ "极": "(", # left parenthesis
176
+ ")": ")", # right parenthesis
177
+ "【": "[", # left square bracket
178
+ "】": "]", # right square bracket
179
+ "《": "<", # left angle quote
180
+ "》": ">", # right angle quote
181
+ "“": '"', # left double quotation
182
+ "”": '"', # right double quotation
183
+ "‘": "'", # left single quotation
184
+ "’": "'", # right single quotation
185
+ "、": ",", # enumeration comma
186
+ "—": "-", # em dash
187
+ "…": "...", # ellipsis
188
+ "·": ".", # middle dot
189
+ "「": '"', # left corner bracket
190
+ "」": '"', # right corner bracket
191
+ "『": '"', # left double corner bracket
192
+ "』": '"', # right double corner bracket
193
+ }
194
+
195
+ # Replace each Chinese punctuation with its English counterpart
196
+ for zh_punct, en_punct in chinese_to_english_punct.items():
197
+ text = text.replace(zh_punct, en_punct)
198
+
199
+ return text
200
+
201
+
202
+ def normalize_text(transcript: str):
203
+ transcript = normalize_chinese_punctuation(transcript)
204
+ # Other normalizations (e.g., parentheses and other symbols. Will be improved in the future)
205
+ transcript = transcript.replace("(", " ")
206
+ transcript = transcript.replace(")", " ")
207
+ transcript = transcript.replace("°F", " degrees Fahrenheit")
208
+ transcript = transcript.replace("°C", " degrees Celsius")
209
+
210
+ for tag, replacement in [
211
+ ("[laugh]", "<SE>[Laughter]</SE>"),
212
+ ("[humming start]", "<SE>[Humming]</SE>"),
213
+ ("[humming end]", "<SE_e>[Humming]</SE_e>"),
214
+ ("[music start]", "<SE_s>[Music]</SE_s>"),
215
+ ("[music end]", "<SE_e>[Music]</SE_e>"),
216
+ ("[music]", "<SE>[Music]</SE>"),
217
+ ("[sing start]", "<SE_s>[Singing]</SE_s>"),
218
+ ("[sing end]", "<SE_e>[Singing]</SE_e>"),
219
+ ("[applause]", "<SE>[Applause]</SE>"),
220
+ ("[cheering]", "<SE>[Cheering]</SE>"),
221
+ ("[cough]", "<SE>[Cough]</SE>"),
222
+ ]:
223
+ transcript = transcript.replace(tag, replacement)
224
+
225
+ lines = transcript.split("\n")
226
+ transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()])
227
+ transcript = transcript.strip()
228
+
229
+ if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "</SE_e>", "</SE>"]]):
230
+ transcript += "."
231
+
232
+ return transcript
233
+
234
+
235
+ @spaces.GPU
236
+ def initialize_engine(model_path, audio_tokenizer_path) -> bool:
237
+ """
238
+ Initialize the HiggsAudioServeEngine with the specified model and tokenizer.
239
+
240
+ Args:
241
+ model_path: Path to the model to load
242
+ audio_tokenizer_path: Path to the audio tokenizer to load
243
+
244
+ Returns:
245
+ True if initialization was successful, False otherwise
246
+ """
247
+ global engine
248
+ try:
249
+ logger.info(f"Initializing engine with model: {model_path} and audio tokenizer: {audio_tokenizer_path}")
250
+ engine = HiggsAudioServeEngine(
251
+ model_name_or_path=model_path,
252
+ audio_tokenizer_name_or_path=audio_tokenizer_path,
253
+ device=get_current_device(),
254
+ )
255
+ logger.info(f"Successfully initialized HiggsAudioServeEngine with model: {model_path}")
256
+ return True
257
+ except Exception as e:
258
+ logger.error(f"Failed to initialize engine: {e}")
259
+ return False
260
+
261
+
262
+ def check_return_audio(audio_wv: np.ndarray):
263
+ # check if the audio returned is all silent
264
+ if np.all(audio_wv == 0):
265
+ logger.warning("Audio is silent, returning None")
266
+
267
+
268
+ def process_text_output(text_output: str):
269
+ # remove all the continuous <|AUDIO_OUT|> tokens with a single <|AUDIO_OUT|>
270
+ text_output = re.sub(r"(<\|AUDIO_OUT\|>)+", r"<|AUDIO_OUT|>", text_output)
271
+ return text_output
272
+
273
+
274
+ def prepare_chatml_sample(
275
+ voice_preset: str,
276
+ text: str,
277
+ reference_audio: Optional[str] = None,
278
+ reference_text: Optional[str] = None,
279
+ system_prompt: str = DEFAULT_SYSTEM_PROMPT,
280
+ ):
281
+ """Prepare a ChatMLSample for the HiggsAudioServeEngine."""
282
+ messages = []
283
+
284
+ # Add system message if provided
285
+ if len(system_prompt) > 0:
286
+ messages.append(Message(role="system", content=system_prompt))
287
+
288
+ # Add reference audio if provided
289
+ audio_base64 = None
290
+ ref_text = ""
291
+
292
+ if reference_audio:
293
+ # Custom reference audio
294
+ audio_base64 = encode_audio_file(reference_audio)
295
+ ref_text = reference_text or ""
296
+ elif voice_preset != "EMPTY":
297
+ # Voice preset
298
+ voice_path, ref_text = get_voice_preset(voice_preset)
299
+ if voice_path is None:
300
+ logger.warning(f"Voice preset {voice_preset} not found, skipping reference audio")
301
+ else:
302
+ audio_base64 = encode_audio_file(voice_path)
303
+
304
+ # Only add reference audio if we have it
305
+ if audio_base64 is not None:
306
+ # Add user message with reference text
307
+ messages.append(Message(role="user", content=ref_text))
308
+
309
+ # Add assistant message with audio content
310
+ audio_content = AudioContent(raw_audio=audio_base64, audio_url="")
311
+ messages.append(Message(role="assistant", content=[audio_content]))
312
+
313
+ # Add the main user message
314
+ text = normalize_text(text)
315
+ messages.append(Message(role="user", content=text))
316
+
317
+ return ChatMLSample(messages=messages)
318
+
319
+
320
+ @spaces.GPU(duration=120)
321
+ def text_to_speech(
322
+ text,
323
+ voice_preset,
324
+ reference_audio=None,
325
+ reference_text=None,
326
+ max_completion_tokens=1024,
327
+ temperature=1.0,
328
+ top_p=0.95,
329
+ top_k=50,
330
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
331
+ stop_strings=None,
332
+ ras_win_len=7,
333
+ ras_win_max_num_repeat=2,
334
+ ):
335
+ """
336
+ Convert text to speech using HiggsAudioServeEngine.
337
+
338
+ Args:
339
+ text: The text to convert to speech
340
+ voice_preset: The voice preset to use (or "EMPTY" for no preset)
341
+ reference_audio: Optional path to reference audio file
342
+ reference_text: Optional transcript of the reference audio
343
+ max_completion_tokens: Maximum number of tokens to generate
344
+ temperature: Sampling temperature for generation
345
+ top_p: Top-p sampling parameter
346
+ top_k: Top-k sampling parameter
347
+ system_prompt: System prompt to guide the model
348
+ stop_strings: Dataframe containing stop strings
349
+ ras_win_len: Window length for repetition avoidance sampling
350
+ ras_win_max_num_repeat: Maximum number of repetitions allowed in the window
351
+
352
+ Returns:
353
+ Tuple of (generated_text, (sample_rate, audio_data)) where audio_data is int16 numpy array
354
+ """
355
+ global engine
356
+
357
+ if engine is None:
358
+ initialize_engine(DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH)
359
+
360
+ try:
361
+ # Prepare ChatML sample
362
+ chatml_sample = prepare_chatml_sample(voice_preset, text, reference_audio, reference_text, system_prompt)
363
+
364
+ # Convert stop strings format
365
+ if stop_strings is None:
366
+ stop_list = DEFAULT_STOP_STRINGS
367
+ else:
368
+ stop_list = [s for s in stop_strings["stops"] if s.strip()]
369
+
370
+ request_id = f"tts-playground-{str(uuid.uuid4())}"
371
+ logger.info(
372
+ f"{request_id}: Generating speech for text: {text[:100]}..., \n"
373
+ f"with parameters: temperature={temperature}, top_p={top_p}, top_k={top_k}, stop_list={stop_list}, "
374
+ f"ras_win_len={ras_win_len}, ras_win_max_num_repeat={ras_win_max_num_repeat}"
375
+ )
376
+ start_time = time.time()
377
+
378
+ # Generate using the engine
379
+ response = engine.generate(
380
+ chat_ml_sample=chatml_sample,
381
+ max_new_tokens=max_completion_tokens,
382
+ temperature=temperature,
383
+ top_k=top_k if top_k > 0 else None,
384
+ top_p=top_p,
385
+ stop_strings=stop_list,
386
+ ras_win_len=ras_win_len if ras_win_len > 0 else None,
387
+ ras_win_max_num_repeat=max(ras_win_len, ras_win_max_num_repeat),
388
+ )
389
+
390
+ generation_time = time.time() - start_time
391
+ logger.info(f"{request_id}: Generated audio in {generation_time:.3f} seconds")
392
+ gr.Info(f"Generated audio in {generation_time:.3f} seconds")
393
+
394
+ # Process the response
395
+ text_output = process_text_output(response.generated_text)
396
+
397
+ if response.audio is not None:
398
+ # Convert to int16 for Gradio
399
+ audio_data = (response.audio * 32767).astype(np.int16)
400
+ check_return_audio(audio_data)
401
+ return text_output, (response.sampling_rate, audio_data)
402
+ else:
403
+ logger.warning("No audio generated")
404
+ return text_output, None
405
+
406
+ except Exception as e:
407
+ error_msg = f"Error generating speech: {e}"
408
+ logger.error(error_msg)
409
+ gr.Error(error_msg)
410
+ return f"❌ {error_msg}", None
411
+
412
+
413
+ def create_ui():
414
+ # Load voice presets
415
+ VOICE_PRESETS = load_voice_presets()
416
+
417
+ # Load theme with fallback
418
+ try:
419
+ my_theme = gr.Theme.load("theme.json")
420
+ except:
421
+ my_theme = gr.themes.Default()
422
+ logger.warning("Using default theme - theme.json not found")
423
+
424
+ # ... rest of your UI code ...
425
+
426
+ # Custom OnlyFans-inspired CSS
427
+ custom_css = """
428
+ .gradio-container {
429
+ max-width: 1200px;
430
+ margin: 0 auto;
431
+ border-radius: 20px;
432
+ background: rgba(255, 255, 255, 0.9);
433
+ backdrop-filter: blur(10px);
434
+ box-shadow: 0 10px 30px rgba(0, 0, 0, 0.1);
435
+ padding: 25px;
436
+ }
437
+ h1 {
438
+ background: linear-gradient(90deg, #ff4da6 0%, #9d4dff 100%);
439
+ -webkit-background-clip: text;
440
+ -webkit-text-fill-color: transparent;
441
+ text-align: center;
442
+ font-size: 2.5em;
443
+ font-weight: 800;
444
+ margin-bottom: 10px;
445
+ }
446
+ .gr-markdown p {
447
+ text-align: center;
448
+ color: #666;
449
+ font-size: 1.1em;
450
+ margin-bottom: 25px;
451
+ }
452
+ .gr-box {
453
+ border-radius: 15px;
454
+ border: 2px solid #ff4da6;
455
+ padding: 15px;
456
+ }
457
+ textarea, select, input {
458
+ border-radius: 15px;
459
+ border: 2px solid #ff9ec4;
460
+ padding: 12px;
461
+ font-size: 1em;
462
+ transition: all 0.3s ease;
463
+ }
464
+ textarea:focus, select:focus, input:focus {
465
+ border-color: #ff4da6;
466
+ box-shadow: 0 0 0 2px rgba(255, 77, 166, 0.2);
467
+ outline: none;
468
+ }
469
+ button {
470
+ background: linear-gradient(90deg, #ff4da6 0%, #9d4dff 100%);
471
+ color: white;
472
+ border: none;
473
+ border-radius: 15px;
474
+ padding: 12px 25px;
475
+ font-weight: 600;
476
+ font-size: 1.1em;
477
+ cursor: pointer;
478
+ transition: all 0.3s ease;
479
+ }
480
+ button:hover {
481
+ transform: translateY(-2px);
482
+ box-shadow: 0 5px 15px rgba(255, 77, 166, 0.4);
483
+ }
484
+ .gr-accordion {
485
+ border-radius: 15px;
486
+ border: 2px solid #ff9ec4;
487
+ margin-bottom: 15px;
488
+ }
489
+ .gr-accordion .gr-button {
490
+ background: transparent;
491
+ color: #ff4da6;
492
+ font-weight: 600;
493
+ }
494
+ .label-wrap {
495
+ font-weight: 600;
496
+ color: #ff4da6;
497
+ margin-bottom: 8px;
498
+ }
499
+ .tooltip {
500
+ background: #ff4da6;
501
+ color: white;
502
+ }
503
+ """
504
+
505
+ default_template = "smart-voice"
506
+
507
+ """Create the Gradio UI."""
508
+ with gr.Blocks(theme=my_theme, css=custom_css, title="OnlyAgencies Audio Text-to-Speech") as demo:
509
+ gr.Markdown("# OnlyAgencies Audio Text-to-Speech")
510
+ gr.Markdown("Create irresistible audio messages that keep your fans coming back for more 😘")
511
+
512
+ # Main UI section
513
+ with gr.Row():
514
+ with gr.Column(scale=2):
515
+ # Template selection dropdown
516
+ template_dropdown = gr.Dropdown(
517
+ label="TTS Template",
518
+ choices=list(PREDEFINED_EXAMPLES.keys()),
519
+ value=default_template,
520
+ info="Select a predefined example for system and input messages.",
521
+ )
522
+
523
+ # Template description display
524
+ template_description = gr.HTML(
525
+ value=f'<p style="font-size: 0.85em; color: #ff4da6; margin: 0; padding: 0;">{PREDEFINED_EXAMPLES[default_template]["description"]}</p>',
526
+ visible=True,
527
+ )
528
+
529
+ system_prompt = gr.TextArea(
530
+ label="System Prompt",
531
+ placeholder="Enter system prompt to guide the model...",
532
+ value=PREDEFINED_EXAMPLES[default_template]["system_prompt"],
533
+ lines=2,
534
+ )
535
+
536
+ input_text = gr.TextArea(
537
+ label="Input Text",
538
+ placeholder="Type the text you want to convert to speech...",
539
+ value=PREDEFINED_EXAMPLES[default_template]["input_text"],
540
+ lines=5,
541
+ )
542
+
543
+ voice_preset = gr.Dropdown(
544
+ label="Voice Preset",
545
+ choices=list(VOICE_PRESETS.keys()),
546
+ value="EMPTY",
547
+ interactive=False, # Disabled by default since default template is not voice-clone
548
+ visible=False,
549
+ )
550
+
551
+ with gr.Accordion(
552
+ "Custom Reference (Optional)", open=False, visible=False
553
+ ) as custom_reference_accordion:
554
+ reference_audio = gr.Audio(label="Reference Audio", type="filepath")
555
+ reference_text = gr.TextArea(
556
+ label="Reference Text (transcript of the reference audio)",
557
+ placeholder="Enter the transcript of your reference audio...",
558
+ lines=3,
559
+ )
560
+
561
+ with gr.Accordion("Advanced Parameters", open=False):
562
+ max_completion_tokens = gr.Slider(
563
+ minimum=128,
564
+ maximum=4096,
565
+ value=1024,
566
+ step=10,
567
+ label="Max Completion Tokens",
568
+ )
569
+ temperature = gr.Slider(
570
+ minimum=0.0,
571
+ maximum=1.5,
572
+ value=1.0,
573
+ step=0.1,
574
+ label="Temperature",
575
+ )
576
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top P")
577
+ top_k = gr.Slider(minimum=-1, maximum=100, value=50, step=1, label="Top K")
578
+ ras_win_len = gr.Slider(
579
+ minimum=0,
580
+ maximum=10,
581
+ value=7,
582
+ step=1,
583
+ label="RAS Window Length",
584
+ info="Window length for repetition avoidance sampling",
585
+ )
586
+ ras_win_max_num_repeat = gr.Slider(
587
+ minimum=1,
588
+ maximum=10,
589
+ value=2,
590
+ step=1,
591
+ label="RAS Max Num Repeat",
592
+ info="Maximum number of repetitions allowed in the window",
593
+ )
594
+ # Add stop strings component
595
+ stop_strings = gr.Dataframe(
596
+ label="Stop Strings",
597
+ headers=["stops"],
598
+ datatype=["str"],
599
+ value=[[s] for s in DEFAULT_STOP_STRINGS],
600
+ interactive=True,
601
+ col_count=(1, "fixed"),
602
+ )
603
+
604
+ submit_btn = gr.Button("Generate Speech", variant="primary", scale=1)
605
+
606
+ with gr.Column(scale=2):
607
+ output_text = gr.TextArea(label="Model Response", lines=2)
608
+
609
+ # Audio output
610
+ output_audio = gr.Audio(label="Generated Audio", interactive=False, autoplay=True)
611
+
612
+ stop_btn = gr.Button("Stop Playback", variant="primary")
613
+
614
+ # Example voice
615
+ with gr.Row(visible=False) as voice_samples_section:
616
+ voice_samples_table = gr.Dataframe(
617
+ headers=["Voice Preset", "Sample Text"],
618
+ datatype=["str", "str"],
619
+ value=[[preset, text] for preset, text in VOICE_PRESETS.items() if preset != "EMPTY"],
620
+ interactive=False,
621
+ )
622
+ sample_audio = gr.Audio(label="Voice Sample")
623
+
624
+ # Function to play voice sample when clicking on a row
625
+ def play_voice_sample(evt: gr.SelectData):
626
+ """
627
+ Play a voice sample when a row is clicked in the voice samples table.
628
+
629
+ Args:
630
+ evt: The select event containing the clicked row index
631
+
632
+ Returns:
633
+ Path to the voice sample audio file, or None if not found
634
+ """
635
+ try:
636
+ # Get the preset name from the clicked row
637
+ preset_names = [preset for preset in VOICE_PRESETS.keys() if preset != "EMPTY"]
638
+ if evt.index[0] < len(preset_names):
639
+ preset = preset_names[evt.index[0]]
640
+ voice_path, _ = get_voice_preset(preset)
641
+ if voice_path and os.path.exists(voice_path):
642
+ return voice_path
643
+ else:
644
+ gr.Warning(f"Voice sample file not found for preset: {preset}")
645
+ return None
646
+ else:
647
+ gr.Warning("Invalid voice preset selection")
648
+ return None
649
+ except Exception as e:
650
+ logger.error(f"Error playing voice sample: {e}")
651
+ gr.Error(f"Error playing voice sample: {e}")
652
+ return None
653
+
654
+ voice_samples_table.select(fn=play_voice_sample, outputs=[sample_audio])
655
+
656
+ # Function to handle template selection
657
+ def apply_template(template_name):
658
+ """
659
+ Apply a predefined template to the UI components.
660
+
661
+ Args:
662
+ template_name: Name of the template to apply
663
+
664
+ Returns:
665
+ Tuple of updated values for system_prompt, input_text, template_description,
666
+ voice_preset, custom_reference_accordion, voice_samples_section, and ras_win_len
667
+ """
668
+ if template_name in PREDEFINED_EXAMPLES:
669
+ template = PREDEFINED_EXAMPLES[template_name]
670
+ # Enable voice preset and custom reference only for voice-clone template
671
+ is_voice_clone = template_name == "voice-clone"
672
+ voice_preset_value = "belinda" if is_voice_clone else "EMPTY"
673
+ # Set ras_win_len to 0 for single-speaker-bgm, 7 for others
674
+ ras_win_len_value = 0 if template_name == "single-speaker-bgm" else 7
675
+ description_text = f'<p style="font-size: 0.85em; color: #ff4da6; margin: 0; padding: 0;">{template["description"]}</p>'
676
+ return (
677
+ template["system_prompt"], # system_prompt
678
+ template["input_text"], # input_text
679
+ description_text, # template_description
680
+ gr.update(
681
+ value=voice_preset_value, interactive=is_voice_clone, visible=is_voice_clone
682
+ ), # voice_preset (value and interactivity)
683
+ gr.update(visible=is_voice_clone), # custom reference accordion visibility
684
+ gr.update(visible=is_voice_clone), # voice samples section visibility
685
+ ras_win_len_value, # ras_win_len
686
+ )
687
+ else:
688
+ return (
689
+ gr.update(),
690
+ gr.update(),
691
+ gr.update(),
692
+ gr.update(),
693
+ gr.update(),
694
+ gr.update(),
695
+ gr.update(),
696
+ ) # No change if template not found
697
+
698
+ # Set up event handlers
699
+
700
+ # Connect template dropdown to handler
701
+ template_dropdown.change(
702
+ fn=apply_template,
703
+ inputs=[template_dropdown],
704
+ outputs=[
705
+ system_prompt,
706
+ input_text,
707
+ template_description,
708
+ voice_preset,
709
+ custom_reference_accordion,
710
+ voice_samples_section,
711
+ ras_win_len,
712
+ ],
713
+ )
714
+
715
+ # Connect submit button to the TTS function
716
+ submit_btn.click(
717
+ fn=text_to_speech,
718
+ inputs=[
719
+ input_text,
720
+ voice_preset,
721
+ reference_audio,
722
+ reference_text,
723
+ max_completion_tokens,
724
+ temperature,
725
+ top_p,
726
+ top_k,
727
+ system_prompt,
728
+ stop_strings,
729
+ ras_win_len,
730
+ ras_win_max_num_repeat,
731
+ ],
732
+ outputs=[output_text, output_audio],
733
+ api_name="generate_speech",
734
+ )
735
+
736
+ # Stop button functionality
737
+ stop_btn.click(
738
+ fn=lambda: None,
739
+ inputs=[],
740
+ outputs=[output_audio],
741
+ js="() => {const audio = document.querySelector('audio'); if(audio) audio.pause(); return null;}",
742
+ )
743
+
744
+ return demo
745
+
746
+
747
+ def main():
748
+ """Main function to parse arguments and launch the UI."""
749
+ global DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH, VOICE_PRESETS
750
+
751
+ parser = argparse.ArgumentParser(description="Gradio UI for Text-to-Speech using HiggsAudioServeEngine")
752
+ parser.add_argument(
753
+ "--device",
754
+ type=str,
755
+ default="cuda",
756
+ choices=["cuda", "cpu"],
757
+ help="Device to run the model on.",
758
+ )
759
+ parser.add_argument("--host", type=str, default="0.0.0.0", help="Host for the Gradio interface.")
760
+ parser.add_argument("--port", type=int, default=7860, help="Port for the Gradio interface.")
761
+
762
+ args = parser.parse_args()
763
+
764
+ # Update default values if provided via command line
765
+ VOICE_PRESETS = load_voice_presets()
766
+
767
+ # Create and launch the UI
768
+ demo = create_ui()
769
+ demo.launch(server_name=args.host, server_port=args.port, share=True)
770
+
771
+
772
+ if __name__ == "__main__":
773
+ main()
higgs_audio/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import HiggsAudioConfig, HiggsAudioModel
higgs_audio/audio_processing/LICENSE ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Third-Party License Attribution for Audio Processing Module
2
+ ===========================================================
3
+
4
+ This directory contains code derived from multiple open-source projects.
5
+ The following sections detail the licenses and attributions for third-party code.
6
+
7
+ ## XCodec Repository
8
+ The code in this directory is derived from:
9
+ https://github.com/zhenye234/xcodec
10
+
11
+ ## Individual File Attributions
12
+
13
+ ### Quantization Module (quantization/)
14
+ - Several files contain code derived from Meta Platforms, Inc. and the vector-quantize-pytorch repository
15
+ - Individual files contain their own license headers where applicable
16
+ - The vector-quantize-pytorch portions are licensed under the MIT License
17
+
18
+ ## License Terms
19
+
20
+ ### MIT License (for applicable portions)
21
+ Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ of this software and associated documentation files (the "Software"), to deal
23
+ in the Software without restriction, including without limitation the rights
24
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ copies of the Software, and to permit persons to whom the Software is
26
+ furnished to do so, subject to the following conditions:
27
+
28
+ The above copyright notice and this permission notice shall be included in all
29
+ copies or substantial portions of the Software.
30
+
31
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ SOFTWARE.
38
+
39
+ ## Attribution Requirements
40
+ When using this code, please ensure proper attribution to:
41
+ 1. The original xcodec repository: https://github.com/zhenye234/xcodec
42
+ 2. Any other repositories mentioned in individual file headers
43
+ 3. This derivative work and its modifications
44
+
45
+ ## Disclaimer
46
+ This directory contains modified versions of the original code. Please refer to
47
+ the original repositories for the canonical implementations and their specific
48
+ license terms.
49
+
50
+ For any questions about licensing or attribution, please check the individual
51
+ file headers and the original source repositories.
higgs_audio/audio_processing/descriptaudiocodec/__init__.py ADDED
File without changes
higgs_audio/audio_processing/descriptaudiocodec/dac/model/base.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+ from audiotools import AudioSignal
10
+ from torch import nn
11
+
12
+ SUPPORTED_VERSIONS = ["1.0.0"]
13
+
14
+
15
+ @dataclass
16
+ class DACFile:
17
+ codes: torch.Tensor
18
+
19
+ # Metadata
20
+ chunk_length: int
21
+ original_length: int
22
+ input_db: float
23
+ channels: int
24
+ sample_rate: int
25
+ padding: bool
26
+ dac_version: str
27
+
28
+ def save(self, path):
29
+ artifacts = {
30
+ "codes": self.codes.numpy().astype(np.uint16),
31
+ "metadata": {
32
+ "input_db": self.input_db.numpy().astype(np.float32),
33
+ "original_length": self.original_length,
34
+ "sample_rate": self.sample_rate,
35
+ "chunk_length": self.chunk_length,
36
+ "channels": self.channels,
37
+ "padding": self.padding,
38
+ "dac_version": SUPPORTED_VERSIONS[-1],
39
+ },
40
+ }
41
+ path = Path(path).with_suffix(".dac")
42
+ with open(path, "wb") as f:
43
+ np.save(f, artifacts)
44
+ return path
45
+
46
+ @classmethod
47
+ def load(cls, path):
48
+ artifacts = np.load(path, allow_pickle=True)[()]
49
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
50
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
+ raise RuntimeError(f"Given file {path} can't be loaded with this version of descript-audio-codec.")
52
+ return cls(codes=codes, **artifacts["metadata"])
53
+
54
+
55
+ class CodecMixin:
56
+ @property
57
+ def padding(self):
58
+ if not hasattr(self, "_padding"):
59
+ self._padding = True
60
+ return self._padding
61
+
62
+ @padding.setter
63
+ def padding(self, value):
64
+ assert isinstance(value, bool)
65
+
66
+ layers = [l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))]
67
+
68
+ for layer in layers:
69
+ if value:
70
+ if hasattr(layer, "original_padding"):
71
+ layer.padding = layer.original_padding
72
+ else:
73
+ layer.original_padding = layer.padding
74
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
75
+
76
+ self._padding = value
77
+
78
+ def get_delay(self):
79
+ # Any number works here, delay is invariant to input length
80
+ l_out = self.get_output_length(0)
81
+ L = l_out
82
+
83
+ layers = []
84
+ for layer in self.modules():
85
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
86
+ layers.append(layer)
87
+
88
+ for layer in reversed(layers):
89
+ d = layer.dilation[0]
90
+ k = layer.kernel_size[0]
91
+ s = layer.stride[0]
92
+
93
+ if isinstance(layer, nn.ConvTranspose1d):
94
+ L = ((L - d * (k - 1) - 1) / s) + 1
95
+ elif isinstance(layer, nn.Conv1d):
96
+ L = (L - 1) * s + d * (k - 1) + 1
97
+
98
+ L = math.ceil(L)
99
+
100
+ l_in = L
101
+
102
+ return (l_in - l_out) // 2
103
+
104
+ def get_output_length(self, input_length):
105
+ L = input_length
106
+ # Calculate output length
107
+ for layer in self.modules():
108
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
109
+ d = layer.dilation[0]
110
+ k = layer.kernel_size[0]
111
+ s = layer.stride[0]
112
+
113
+ if isinstance(layer, nn.Conv1d):
114
+ L = ((L - d * (k - 1) - 1) / s) + 1
115
+ elif isinstance(layer, nn.ConvTranspose1d):
116
+ L = (L - 1) * s + d * (k - 1) + 1
117
+
118
+ L = math.floor(L)
119
+ return L
120
+
121
+ @torch.no_grad()
122
+ def compress(
123
+ self,
124
+ audio_path_or_signal: Union[str, Path, AudioSignal],
125
+ win_duration: float = 1.0,
126
+ verbose: bool = False,
127
+ normalize_db: float = -16,
128
+ n_quantizers: int = None,
129
+ ) -> DACFile:
130
+ """Processes an audio signal from a file or AudioSignal object into
131
+ discrete codes. This function processes the signal in short windows,
132
+ using constant GPU memory.
133
+
134
+ Parameters
135
+ ----------
136
+ audio_path_or_signal : Union[str, Path, AudioSignal]
137
+ audio signal to reconstruct
138
+ win_duration : float, optional
139
+ window duration in seconds, by default 5.0
140
+ verbose : bool, optional
141
+ by default False
142
+ normalize_db : float, optional
143
+ normalize db, by default -16
144
+
145
+ Returns
146
+ -------
147
+ DACFile
148
+ Object containing compressed codes and metadata
149
+ required for decompression
150
+ """
151
+ audio_signal = audio_path_or_signal
152
+ if isinstance(audio_signal, (str, Path)):
153
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
154
+
155
+ self.eval()
156
+ original_padding = self.padding
157
+ original_device = audio_signal.device
158
+
159
+ audio_signal = audio_signal.clone()
160
+ original_sr = audio_signal.sample_rate
161
+
162
+ resample_fn = audio_signal.resample
163
+ loudness_fn = audio_signal.loudness
164
+
165
+ # If audio is > 10 minutes long, use the ffmpeg versions
166
+ if audio_signal.signal_duration >= 10 * 60 * 60:
167
+ resample_fn = audio_signal.ffmpeg_resample
168
+ loudness_fn = audio_signal.ffmpeg_loudness
169
+
170
+ original_length = audio_signal.signal_length
171
+ resample_fn(self.sample_rate)
172
+ input_db = loudness_fn()
173
+
174
+ if normalize_db is not None:
175
+ audio_signal.normalize(normalize_db)
176
+ audio_signal.ensure_max_of_audio()
177
+
178
+ nb, nac, nt = audio_signal.audio_data.shape
179
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
180
+ win_duration = audio_signal.signal_duration if win_duration is None else win_duration
181
+
182
+ if audio_signal.signal_duration <= win_duration:
183
+ # Unchunked compression (used if signal length < win duration)
184
+ self.padding = True
185
+ n_samples = nt
186
+ hop = nt
187
+ else:
188
+ # Chunked inference
189
+ self.padding = False
190
+ # Zero-pad signal on either side by the delay
191
+ audio_signal.zero_pad(self.delay, self.delay)
192
+ n_samples = int(win_duration * self.sample_rate)
193
+ # Round n_samples to nearest hop length multiple
194
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
195
+ hop = self.get_output_length(n_samples)
196
+
197
+ codes = []
198
+ range_fn = range if not verbose else tqdm.trange
199
+
200
+ for i in range_fn(0, nt, hop):
201
+ x = audio_signal[..., i : i + n_samples]
202
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
203
+
204
+ audio_data = x.audio_data.to(self.device)
205
+ audio_data = self.preprocess(audio_data, self.sample_rate)
206
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
207
+ codes.append(c.to(original_device))
208
+ chunk_length = c.shape[-1]
209
+
210
+ codes = torch.cat(codes, dim=-1)
211
+
212
+ dac_file = DACFile(
213
+ codes=codes,
214
+ chunk_length=chunk_length,
215
+ original_length=original_length,
216
+ input_db=input_db,
217
+ channels=nac,
218
+ sample_rate=original_sr,
219
+ padding=self.padding,
220
+ dac_version=SUPPORTED_VERSIONS[-1],
221
+ )
222
+
223
+ if n_quantizers is not None:
224
+ codes = codes[:, :n_quantizers, :]
225
+
226
+ self.padding = original_padding
227
+ return dac_file
228
+
229
+ @torch.no_grad()
230
+ def decompress(
231
+ self,
232
+ obj: Union[str, Path, DACFile],
233
+ verbose: bool = False,
234
+ ) -> AudioSignal:
235
+ """Reconstruct audio from a given .dac file
236
+
237
+ Parameters
238
+ ----------
239
+ obj : Union[str, Path, DACFile]
240
+ .dac file location or corresponding DACFile object.
241
+ verbose : bool, optional
242
+ Prints progress if True, by default False
243
+
244
+ Returns
245
+ -------
246
+ AudioSignal
247
+ Object with the reconstructed audio
248
+ """
249
+ self.eval()
250
+ if isinstance(obj, (str, Path)):
251
+ obj = DACFile.load(obj)
252
+
253
+ original_padding = self.padding
254
+ self.padding = obj.padding
255
+
256
+ range_fn = range if not verbose else tqdm.trange
257
+ codes = obj.codes
258
+ original_device = codes.device
259
+ chunk_length = obj.chunk_length
260
+ recons = []
261
+
262
+ for i in range_fn(0, codes.shape[-1], chunk_length):
263
+ c = codes[..., i : i + chunk_length].to(self.device)
264
+ z = self.quantizer.from_codes(c)[0]
265
+ r = self.decode(z)
266
+ recons.append(r.to(original_device))
267
+
268
+ recons = torch.cat(recons, dim=-1)
269
+ recons = AudioSignal(recons, self.sample_rate)
270
+
271
+ resample_fn = recons.resample
272
+ loudness_fn = recons.loudness
273
+
274
+ # If audio is > 10 minutes long, use the ffmpeg versions
275
+ if recons.signal_duration >= 10 * 60 * 60:
276
+ resample_fn = recons.ffmpeg_resample
277
+ loudness_fn = recons.ffmpeg_loudness
278
+
279
+ recons.normalize(obj.input_db)
280
+ resample_fn(obj.sample_rate)
281
+ recons = recons[..., : obj.original_length]
282
+ loudness_fn()
283
+ recons.audio_data = recons.audio_data.reshape(-1, obj.channels, obj.original_length)
284
+
285
+ self.padding = original_padding
286
+ return recons
higgs_audio/audio_processing/descriptaudiocodec/dac/model/dac.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from audiotools.ml import BaseModel
9
+ from torch import nn
10
+
11
+ from .base import CodecMixin
12
+ from dac.nn.layers import Snake1d
13
+ from dac.nn.layers import WNConv1d
14
+ from dac.nn.layers import WNConvTranspose1d
15
+ from dac.nn.quantize import ResidualVectorQuantize
16
+
17
+
18
+ def init_weights(m):
19
+ if isinstance(m, nn.Conv1d):
20
+ nn.init.trunc_normal_(m.weight, std=0.02)
21
+ nn.init.constant_(m.bias, 0)
22
+
23
+
24
+ class ResidualUnit(nn.Module):
25
+ def __init__(self, dim: int = 16, dilation: int = 1):
26
+ super().__init__()
27
+ pad = ((7 - 1) * dilation) // 2
28
+ self.block = nn.Sequential(
29
+ Snake1d(dim),
30
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
31
+ Snake1d(dim),
32
+ WNConv1d(dim, dim, kernel_size=1),
33
+ )
34
+
35
+ def forward(self, x):
36
+ y = self.block(x)
37
+ pad = (x.shape[-1] - y.shape[-1]) // 2
38
+ if pad > 0:
39
+ x = x[..., pad:-pad]
40
+ return x + y
41
+
42
+
43
+ class EncoderBlock(nn.Module):
44
+ def __init__(self, dim: int = 16, stride: int = 1):
45
+ super().__init__()
46
+ self.block = nn.Sequential(
47
+ ResidualUnit(dim // 2, dilation=1),
48
+ ResidualUnit(dim // 2, dilation=3),
49
+ ResidualUnit(dim // 2, dilation=9),
50
+ Snake1d(dim // 2),
51
+ WNConv1d(
52
+ dim // 2,
53
+ dim,
54
+ kernel_size=2 * stride,
55
+ stride=stride,
56
+ padding=math.ceil(stride / 2),
57
+ ),
58
+ )
59
+
60
+ def forward(self, x):
61
+ return self.block(x)
62
+
63
+
64
+ class Encoder(nn.Module):
65
+ def __init__(
66
+ self,
67
+ d_model: int = 64,
68
+ strides: list = [2, 4, 8, 8],
69
+ d_latent: int = 256,
70
+ ):
71
+ super().__init__()
72
+ # Create first convolution
73
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
74
+
75
+ # Create EncoderBlocks that double channels as they downsample by `stride`
76
+ for stride in strides:
77
+ d_model *= 2
78
+ self.block += [EncoderBlock(d_model, stride=stride)]
79
+
80
+ # Create last convolution
81
+ self.block += [
82
+ Snake1d(d_model),
83
+ WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
84
+ ]
85
+
86
+ # Wrap black into nn.Sequential
87
+ self.block = nn.Sequential(*self.block)
88
+ self.enc_dim = d_model
89
+
90
+ def forward(self, x):
91
+ return self.block(x)
92
+
93
+
94
+ class DecoderBlock(nn.Module):
95
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, out_pad=0):
96
+ super().__init__()
97
+ self.block = nn.Sequential(
98
+ Snake1d(input_dim),
99
+ WNConvTranspose1d(
100
+ input_dim,
101
+ output_dim,
102
+ kernel_size=2 * stride,
103
+ stride=stride,
104
+ padding=math.ceil(stride / 2),
105
+ output_padding=stride % 2, # out_pad,
106
+ ),
107
+ ResidualUnit(output_dim, dilation=1),
108
+ ResidualUnit(output_dim, dilation=3),
109
+ ResidualUnit(output_dim, dilation=9),
110
+ )
111
+
112
+ def forward(self, x):
113
+ return self.block(x)
114
+
115
+
116
+ class Decoder(nn.Module):
117
+ def __init__(
118
+ self,
119
+ input_channel,
120
+ channels,
121
+ rates,
122
+ d_out: int = 1,
123
+ ):
124
+ super().__init__()
125
+
126
+ # Add first conv layer
127
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
128
+
129
+ # Add upsampling + MRF blocks
130
+ for i, stride in enumerate(rates):
131
+ input_dim = channels // 2**i
132
+ output_dim = channels // 2 ** (i + 1)
133
+ if i == 1:
134
+ out_pad = 1
135
+ else:
136
+ out_pad = 0
137
+ layers += [DecoderBlock(input_dim, output_dim, stride, out_pad)]
138
+
139
+ # Add final conv layer
140
+ layers += [
141
+ Snake1d(output_dim),
142
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
143
+ # nn.Tanh(),
144
+ ]
145
+
146
+ self.model = nn.Sequential(*layers)
147
+
148
+ def forward(self, x):
149
+ return self.model(x)
150
+
151
+
152
+ class DAC(BaseModel, CodecMixin):
153
+ def __init__(
154
+ self,
155
+ encoder_dim: int = 64,
156
+ encoder_rates: List[int] = [2, 4, 8, 8],
157
+ latent_dim: int = None,
158
+ decoder_dim: int = 1536,
159
+ decoder_rates: List[int] = [8, 8, 4, 2],
160
+ n_codebooks: int = 9,
161
+ codebook_size: int = 1024,
162
+ codebook_dim: Union[int, list] = 8,
163
+ quantizer_dropout: bool = False,
164
+ sample_rate: int = 44100,
165
+ ):
166
+ super().__init__()
167
+
168
+ self.encoder_dim = encoder_dim
169
+ self.encoder_rates = encoder_rates
170
+ self.decoder_dim = decoder_dim
171
+ self.decoder_rates = decoder_rates
172
+ self.sample_rate = sample_rate
173
+
174
+ if latent_dim is None:
175
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
176
+
177
+ self.latent_dim = latent_dim
178
+
179
+ self.hop_length = np.prod(encoder_rates)
180
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
181
+
182
+ self.n_codebooks = n_codebooks
183
+ self.codebook_size = codebook_size
184
+ self.codebook_dim = codebook_dim
185
+ self.quantizer = ResidualVectorQuantize(
186
+ input_dim=latent_dim,
187
+ n_codebooks=n_codebooks,
188
+ codebook_size=codebook_size,
189
+ codebook_dim=codebook_dim,
190
+ quantizer_dropout=quantizer_dropout,
191
+ )
192
+
193
+ self.decoder = Decoder(
194
+ latent_dim,
195
+ decoder_dim,
196
+ decoder_rates,
197
+ )
198
+ self.sample_rate = sample_rate
199
+ self.apply(init_weights)
200
+
201
+ self.delay = self.get_delay()
202
+
203
+ def preprocess(self, audio_data, sample_rate):
204
+ if sample_rate is None:
205
+ sample_rate = self.sample_rate
206
+ assert sample_rate == self.sample_rate
207
+
208
+ length = audio_data.shape[-1]
209
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
210
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
211
+
212
+ return audio_data
213
+
214
+ def encode(
215
+ self,
216
+ audio_data: torch.Tensor,
217
+ n_quantizers: int = None,
218
+ ):
219
+ """Encode given audio data and return quantized latent codes
220
+
221
+ Parameters
222
+ ----------
223
+ audio_data : Tensor[B x 1 x T]
224
+ Audio data to encode
225
+ n_quantizers : int, optional
226
+ Number of quantizers to use, by default None
227
+ If None, all quantizers are used.
228
+
229
+ Returns
230
+ -------
231
+ dict
232
+ A dictionary with the following keys:
233
+ "z" : Tensor[B x D x T]
234
+ Quantized continuous representation of input
235
+ "codes" : Tensor[B x N x T]
236
+ Codebook indices for each codebook
237
+ (quantized discrete representation of input)
238
+ "latents" : Tensor[B x N*D x T]
239
+ Projected latents (continuous representation of input before quantization)
240
+ "vq/commitment_loss" : Tensor[1]
241
+ Commitment loss to train encoder to predict vectors closer to codebook
242
+ entries
243
+ "vq/codebook_loss" : Tensor[1]
244
+ Codebook loss to update the codebook
245
+ "length" : int
246
+ Number of samples in input audio
247
+ """
248
+ z = self.encoder(audio_data)
249
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
250
+ return z, codes, latents, commitment_loss, codebook_loss
251
+
252
+ def decode(self, z: torch.Tensor):
253
+ """Decode given latent codes and return audio data
254
+
255
+ Parameters
256
+ ----------
257
+ z : Tensor[B x D x T]
258
+ Quantized continuous representation of input
259
+ length : int, optional
260
+ Number of samples in output audio, by default None
261
+
262
+ Returns
263
+ -------
264
+ dict
265
+ A dictionary with the following keys:
266
+ "audio" : Tensor[B x 1 x length]
267
+ Decoded audio data.
268
+ """
269
+ return self.decoder(z)
270
+
271
+ def forward(
272
+ self,
273
+ audio_data: torch.Tensor,
274
+ sample_rate: int = None,
275
+ n_quantizers: int = None,
276
+ ):
277
+ """Model forward pass
278
+
279
+ Parameters
280
+ ----------
281
+ audio_data : Tensor[B x 1 x T]
282
+ Audio data to encode
283
+ sample_rate : int, optional
284
+ Sample rate of audio data in Hz, by default None
285
+ If None, defaults to `self.sample_rate`
286
+ n_quantizers : int, optional
287
+ Number of quantizers to use, by default None.
288
+ If None, all quantizers are used.
289
+
290
+ Returns
291
+ -------
292
+ dict
293
+ A dictionary with the following keys:
294
+ "z" : Tensor[B x D x T]
295
+ Quantized continuous representation of input
296
+ "codes" : Tensor[B x N x T]
297
+ Codebook indices for each codebook
298
+ (quantized discrete representation of input)
299
+ "latents" : Tensor[B x N*D x T]
300
+ Projected latents (continuous representation of input before quantization)
301
+ "vq/commitment_loss" : Tensor[1]
302
+ Commitment loss to train encoder to predict vectors closer to codebook
303
+ entries
304
+ "vq/codebook_loss" : Tensor[1]
305
+ Codebook loss to update the codebook
306
+ "length" : int
307
+ Number of samples in input audio
308
+ "audio" : Tensor[B x 1 x length]
309
+ Decoded audio data.
310
+ """
311
+ length = audio_data.shape[-1]
312
+ audio_data = self.preprocess(audio_data, sample_rate)
313
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
314
+
315
+ x = self.decode(z)
316
+ return {
317
+ "audio": x[..., :length],
318
+ "z": z,
319
+ "codes": codes,
320
+ "latents": latents,
321
+ "vq/commitment_loss": commitment_loss,
322
+ "vq/codebook_loss": codebook_loss,
323
+ }
324
+
325
+
326
+ if __name__ == "__main__":
327
+ import numpy as np
328
+ from functools import partial
329
+
330
+ model = DAC().to("cpu")
331
+
332
+ for n, m in model.named_modules():
333
+ o = m.extra_repr()
334
+ p = sum([np.prod(p.size()) for p in m.parameters()])
335
+ fn = lambda o, p: o + f" {p / 1e6:<.3f}M params."
336
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
337
+ print(model)
338
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
339
+
340
+ length = 88200 * 2
341
+ x = torch.randn(1, 1, length).to(model.device)
342
+ x.requires_grad_(True)
343
+ x.retain_grad()
344
+
345
+ # Make a forward pass
346
+ out = model(x)["audio"]
347
+ print("Input shape:", x.shape)
348
+ print("Output shape:", out.shape)
349
+
350
+ # Create gradient variable
351
+ grad = torch.zeros_like(out)
352
+ grad[:, :, grad.shape[-1] // 2] = 1
353
+
354
+ # Make a backward pass
355
+ out.backward(grad)
356
+
357
+ # Check non-zero values
358
+ gradmap = x.grad.squeeze(0)
359
+ gradmap = (gradmap != 0).sum(0) # sum across features
360
+ rf = (gradmap != 0).sum()
361
+
362
+ print(f"Receptive field: {rf.item()}")
363
+
364
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
365
+ model.decompress(model.compress(x, verbose=True), verbose=True)
higgs_audio/audio_processing/descriptaudiocodec/dac/nn/layers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ # Scripting this brings model speed up 1.4x
18
+ @torch.jit.script
19
+ def snake(x, alpha):
20
+ shape = x.shape
21
+ x = x.reshape(shape[0], shape[1], -1)
22
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
+ x = x.reshape(shape)
24
+ return x
25
+
26
+
27
+ class Snake1d(nn.Module):
28
+ def __init__(self, channels):
29
+ super().__init__()
30
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
+
32
+ def forward(self, x):
33
+ return snake(x, self.alpha)
higgs_audio/audio_processing/descriptaudiocodec/dac/nn/quantize.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from dac.nn.layers import WNConv1d
11
+
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """
15
+ Implementation of VQ similar to Karpathy's repo:
16
+ https://github.com/karpathy/deep-vector-quantization
17
+ Additionally uses following tricks from Improved VQGAN
18
+ (https://arxiv.org/pdf/2110.04627.pdf):
19
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
+ for improved codebook usage
21
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
+ improves training stability
23
+ """
24
+
25
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
26
+ super().__init__()
27
+ self.codebook_size = codebook_size
28
+ self.codebook_dim = codebook_dim
29
+
30
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
+
34
+ def forward(self, z):
35
+ """Quantized the input tensor using a fixed codebook and returns
36
+ the corresponding codebook vectors
37
+
38
+ Parameters
39
+ ----------
40
+ z : Tensor[B x D x T]
41
+
42
+ Returns
43
+ -------
44
+ Tensor[B x D x T]
45
+ Quantized continuous representation of input
46
+ Tensor[1]
47
+ Commitment loss to train encoder to predict vectors closer to codebook
48
+ entries
49
+ Tensor[1]
50
+ Codebook loss to update the codebook
51
+ Tensor[B x T]
52
+ Codebook indices (quantized discrete representation of input)
53
+ Tensor[B x D x T]
54
+ Projected latents (continuous representation of input before quantization)
55
+ """
56
+
57
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
58
+ z_e = self.in_proj(z) # z_e : (B x D x T)
59
+ z_q, indices = self.decode_latents(z_e)
60
+
61
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
62
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
63
+
64
+ z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass
65
+
66
+ z_q = self.out_proj(z_q)
67
+
68
+ return z_q, commitment_loss, codebook_loss, indices, z_e
69
+
70
+ def embed_code(self, embed_id):
71
+ return F.embedding(embed_id, self.codebook.weight)
72
+
73
+ def decode_code(self, embed_id):
74
+ return self.embed_code(embed_id).transpose(1, 2)
75
+
76
+ def decode_latents(self, latents):
77
+ encodings = rearrange(latents, "b d t -> (b t) d")
78
+ codebook = self.codebook.weight # codebook: (N x D)
79
+
80
+ # L2 normalize encodings and codebook (ViT-VQGAN)
81
+ encodings = F.normalize(encodings)
82
+ codebook = F.normalize(codebook)
83
+
84
+ # Compute euclidean distance with codebook
85
+ dist = (
86
+ encodings.pow(2).sum(1, keepdim=True)
87
+ - 2 * encodings @ codebook.t()
88
+ + codebook.pow(2).sum(1, keepdim=True).t()
89
+ )
90
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
91
+ z_q = self.decode_code(indices)
92
+ return z_q, indices
93
+
94
+
95
+ class ResidualVectorQuantize(nn.Module):
96
+ """
97
+ Introduced in SoundStream: An end2end neural audio codec
98
+ https://arxiv.org/abs/2107.03312
99
+ """
100
+
101
+ def __init__(
102
+ self,
103
+ input_dim: int = 512,
104
+ n_codebooks: int = 9,
105
+ codebook_size: int = 1024,
106
+ codebook_dim: Union[int, list] = 8,
107
+ quantizer_dropout: float = 0.0,
108
+ ):
109
+ super().__init__()
110
+ if isinstance(codebook_dim, int):
111
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
112
+
113
+ self.n_codebooks = n_codebooks
114
+ self.codebook_dim = codebook_dim
115
+ self.codebook_size = codebook_size
116
+
117
+ self.quantizers = nn.ModuleList(
118
+ [VectorQuantize(input_dim, codebook_size, codebook_dim[i]) for i in range(n_codebooks)]
119
+ )
120
+ self.quantizer_dropout = quantizer_dropout
121
+
122
+ def forward(self, z, n_quantizers: int = None):
123
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
124
+ the corresponding codebook vectors
125
+ Parameters
126
+ ----------
127
+ z : Tensor[B x D x T]
128
+ n_quantizers : int, optional
129
+ No. of quantizers to use
130
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
131
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
132
+ when in training mode, and a random number of quantizers is used.
133
+ Returns
134
+ -------
135
+ dict
136
+ A dictionary with the following keys:
137
+
138
+ "z" : Tensor[B x D x T]
139
+ Quantized continuous representation of input
140
+ "codes" : Tensor[B x N x T]
141
+ Codebook indices for each codebook
142
+ (quantized discrete representation of input)
143
+ "latents" : Tensor[B x N*D x T]
144
+ Projected latents (continuous representation of input before quantization)
145
+ "vq/commitment_loss" : Tensor[1]
146
+ Commitment loss to train encoder to predict vectors closer to codebook
147
+ entries
148
+ "vq/codebook_loss" : Tensor[1]
149
+ Codebook loss to update the codebook
150
+ """
151
+ z_q = 0
152
+ residual = z
153
+ commitment_loss = 0
154
+ codebook_loss = 0
155
+
156
+ codebook_indices = []
157
+ latents = []
158
+
159
+ if n_quantizers is None:
160
+ n_quantizers = self.n_codebooks
161
+ if self.training:
162
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
163
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
164
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
165
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
166
+ n_quantizers = n_quantizers.to(z.device)
167
+
168
+ for i, quantizer in enumerate(self.quantizers):
169
+ if self.training is False and i >= n_quantizers:
170
+ break
171
+
172
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual)
173
+
174
+ # Create mask to apply quantizer dropout
175
+ mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
176
+ z_q = z_q + z_q_i * mask[:, None, None]
177
+ residual = residual - z_q_i
178
+
179
+ # Sum losses
180
+ commitment_loss += (commitment_loss_i * mask).mean()
181
+ codebook_loss += (codebook_loss_i * mask).mean()
182
+
183
+ codebook_indices.append(indices_i)
184
+ latents.append(z_e_i)
185
+
186
+ codes = torch.stack(codebook_indices, dim=1)
187
+ latents = torch.cat(latents, dim=1)
188
+
189
+ return z_q, codes, latents, commitment_loss, codebook_loss
190
+
191
+ def from_codes(self, codes: torch.Tensor):
192
+ """Given the quantized codes, reconstruct the continuous representation
193
+ Parameters
194
+ ----------
195
+ codes : Tensor[B x N x T]
196
+ Quantized discrete representation of input
197
+ Returns
198
+ -------
199
+ Tensor[B x D x T]
200
+ Quantized continuous representation of input
201
+ """
202
+ z_q = 0.0
203
+ z_p = []
204
+ n_codebooks = codes.shape[1]
205
+ for i in range(n_codebooks):
206
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
207
+ z_p.append(z_p_i)
208
+
209
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
210
+ z_q = z_q + z_q_i
211
+ return z_q, torch.cat(z_p, dim=1), codes
212
+
213
+ def from_latents(self, latents: torch.Tensor):
214
+ """Given the unquantized latents, reconstruct the
215
+ continuous representation after quantization.
216
+
217
+ Parameters
218
+ ----------
219
+ latents : Tensor[B x N x T]
220
+ Continuous representation of input after projection
221
+
222
+ Returns
223
+ -------
224
+ Tensor[B x D x T]
225
+ Quantized representation of full-projected space
226
+ Tensor[B x D x T]
227
+ Quantized representation of latent space
228
+ """
229
+ z_q = 0
230
+ z_p = []
231
+ codes = []
232
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
233
+
234
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
235
+ for i in range(n_codebooks):
236
+ j, k = dims[i], dims[i + 1]
237
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
238
+ z_p.append(z_p_i)
239
+ codes.append(codes_i)
240
+
241
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
242
+ z_q = z_q + z_q_i
243
+
244
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
245
+
246
+
247
+ if __name__ == "__main__":
248
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
249
+ x = torch.randn(16, 512, 80)
250
+ y = rvq(x)
251
+ print(y["latents"].shape)
higgs_audio/audio_processing/higgs_audio_tokenizer.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on code from: https://github.com/zhenye234/xcodec
2
+ # Licensed under MIT License
3
+ # Modifications by BosonAI
4
+
5
+ import math
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Optional, Union, Sequence
11
+ import numpy as np
12
+ from transformers import AutoModel
13
+ import torchaudio
14
+ import json
15
+ import librosa
16
+ from huggingface_hub import snapshot_download
17
+
18
+ from vector_quantize_pytorch import ResidualFSQ
19
+ from .descriptaudiocodec.dac.model import dac as dac2
20
+ from .quantization.vq import ResidualVectorQuantizer
21
+ from .semantic_module import Encoder, Decoder
22
+
23
+
24
+ class EncodedResult:
25
+ def __init__(self, audio_codes):
26
+ self.audio_codes = audio_codes
27
+
28
+
29
+ class HiggsAudioFeatureExtractor(nn.Module):
30
+ def __init__(self, sampling_rate=16000):
31
+ super().__init__()
32
+ self.sampling_rate = sampling_rate
33
+
34
+ def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"):
35
+ # Convert from librosa to torch
36
+ audio_signal = torch.tensor(raw_audio)
37
+ audio_signal = audio_signal.unsqueeze(0)
38
+ if len(audio_signal.shape) < 3:
39
+ audio_signal = audio_signal.unsqueeze(0)
40
+ return {"input_values": audio_signal}
41
+
42
+
43
+ class HiggsAudioTokenizer(nn.Module):
44
+ def __init__(
45
+ self,
46
+ n_filters: int = 32,
47
+ D: int = 128,
48
+ target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6],
49
+ ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320
50
+ sample_rate: int = 16000,
51
+ bins: int = 1024,
52
+ n_q: int = 8,
53
+ codebook_dim: int = None,
54
+ normalize: bool = False,
55
+ causal: bool = False,
56
+ semantic_techer: str = "hubert_base_general",
57
+ last_layer_semantic: bool = True,
58
+ merge_mode: str = "concat",
59
+ downsample_mode: str = "step_down",
60
+ semantic_mode: str = "classic",
61
+ vq_scale: int = 1,
62
+ semantic_sample_rate: int = None,
63
+ device: str = "cuda",
64
+ ):
65
+ super().__init__()
66
+ self.hop_length = np.prod(ratios)
67
+ self.semantic_techer = semantic_techer
68
+
69
+ self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz
70
+
71
+ self.target_bandwidths = target_bandwidths
72
+ self.n_q = n_q
73
+ self.sample_rate = sample_rate
74
+ self.encoder = dac2.Encoder(64, ratios, D)
75
+
76
+ self.decoder_2 = dac2.Decoder(D, 1024, ratios)
77
+ self.last_layer_semantic = last_layer_semantic
78
+ self.device = device
79
+ if semantic_techer == "hubert_base":
80
+ self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960")
81
+ self.semantic_sample_rate = 16000
82
+ self.semantic_dim = 768
83
+ self.encoder_semantic_dim = 768
84
+
85
+ elif semantic_techer == "wavlm_base_plus":
86
+ self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus")
87
+ self.semantic_sample_rate = 16000
88
+ self.semantic_dim = 768
89
+ self.encoder_semantic_dim = 768
90
+
91
+ elif semantic_techer == "hubert_base_general":
92
+ self.semantic_model = AutoModel.from_pretrained("ZhenYe234/hubert_base_general_audio")
93
+ self.semantic_sample_rate = 16000
94
+ self.semantic_dim = 768
95
+ self.encoder_semantic_dim = 768
96
+
97
+ # Overwrite semantic model sr to ensure semantic_downsample_factor is an integer
98
+ if semantic_sample_rate is not None:
99
+ self.semantic_sample_rate = semantic_sample_rate
100
+
101
+ self.semantic_model.eval()
102
+
103
+ # make the semantic model parameters do not need gradient
104
+ for param in self.semantic_model.parameters():
105
+ param.requires_grad = False
106
+
107
+ self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320)
108
+
109
+ self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale)
110
+ self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim)
111
+ self.decoder_semantic = Decoder(
112
+ code_dim=self.encoder_semantic_dim,
113
+ output_channels=self.semantic_dim,
114
+ decode_channels=self.semantic_dim,
115
+ )
116
+
117
+ # out_D=D+768
118
+ if isinstance(bins, int): # RVQ
119
+ self.quantizer = ResidualVectorQuantizer(
120
+ dimension=self.quantizer_dim,
121
+ codebook_dim=codebook_dim,
122
+ n_q=n_q,
123
+ bins=bins,
124
+ )
125
+ self.quantizer_type = "RVQ"
126
+ else: # RFSQ
127
+ self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q)
128
+ self.quantizer_type = "RFSQ"
129
+
130
+ self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim)
131
+ self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim)
132
+ self.fc_post2 = nn.Linear(self.quantizer_dim, D)
133
+
134
+ self.downsample_mode = downsample_mode
135
+ if downsample_mode == "avg":
136
+ self.semantic_pooling = nn.AvgPool1d(
137
+ kernel_size=self.semantic_downsample_factor,
138
+ stride=self.semantic_downsample_factor,
139
+ )
140
+
141
+ self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate)
142
+
143
+ @property
144
+ def tps(self):
145
+ return self.frame_rate
146
+
147
+ @property
148
+ def sampling_rate(self):
149
+ return self.sample_rate
150
+
151
+ @property
152
+ def num_codebooks(self):
153
+ return self.n_q
154
+
155
+ @property
156
+ def codebook_size(self):
157
+ return self.quantizer_dim
158
+
159
+ def get_last_layer(self):
160
+ return self.decoder.layers[-1].weight
161
+
162
+ def calculate_rec_loss(self, rec, target):
163
+ target = target / target.norm(dim=-1, keepdim=True)
164
+ rec = rec / rec.norm(dim=-1, keepdim=True)
165
+ rec_loss = (1 - (target * rec).sum(-1)).mean()
166
+
167
+ return rec_loss
168
+
169
+ @torch.no_grad()
170
+ def get_regress_target(self, x):
171
+ x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate)
172
+
173
+ if (
174
+ self.semantic_techer == "hubert_base"
175
+ or self.semantic_techer == "hubert_base_general"
176
+ or self.semantic_techer == "wavlm_base_plus"
177
+ ):
178
+ x = x[:, 0, :]
179
+ x = F.pad(x, (160, 160))
180
+ target = self.semantic_model(x, output_hidden_states=True).hidden_states
181
+ target = torch.stack(target, dim=1) # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2)
182
+
183
+ # average for all layers
184
+ target = target.mean(1)
185
+ # target = target[9]
186
+ # if self.hop_length > 320:
187
+ # target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
188
+
189
+ elif self.semantic_techer == "w2v_bert2":
190
+ target = self.semantic_model(x)
191
+
192
+ elif self.semantic_techer.startswith("whisper"):
193
+ if self.last_layer_semantic:
194
+ target = self.semantic_model(x, avg_layers=False)
195
+ else:
196
+ target = self.semantic_model(x, avg_layers=True)
197
+
198
+ elif self.semantic_techer.startswith("mert_music"):
199
+ if self.last_layer_semantic:
200
+ target = self.semantic_model(x, avg_layers=False)
201
+ else:
202
+ target = self.semantic_model(x, avg_layers=True)
203
+
204
+ elif self.semantic_techer.startswith("qwen_audio_omni"):
205
+ target = self.semantic_model(x)
206
+
207
+ if self.downsample_mode == "step_down":
208
+ if self.semantic_downsample_factor > 1:
209
+ target = target[:, :: self.semantic_downsample_factor, :]
210
+
211
+ elif self.downsample_mode == "avg":
212
+ target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
213
+ return target
214
+
215
+ def forward(self, x: torch.Tensor, bw: int):
216
+ e_semantic_input = self.get_regress_target(x).detach()
217
+
218
+ e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
219
+ e_acoustic = self.encoder(x)
220
+
221
+ e = torch.cat([e_acoustic, e_semantic], dim=1)
222
+
223
+ e = self.fc_prior(e.transpose(1, 2))
224
+
225
+ if self.quantizer_type == "RVQ":
226
+ e = e.transpose(1, 2)
227
+ quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
228
+ quantized = quantized.transpose(1, 2)
229
+ else:
230
+ quantized, codes = self.quantizer(e)
231
+ commit_loss = torch.tensor(0.0)
232
+
233
+ quantized_semantic = self.fc_post1(quantized).transpose(1, 2)
234
+ quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
235
+
236
+ o = self.decoder_2(quantized_acoustic)
237
+
238
+ o_semantic = self.decoder_semantic(quantized_semantic)
239
+ semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(), o_semantic)
240
+
241
+ return o, commit_loss, semantic_recon_loss, None
242
+
243
+ def encode(
244
+ self,
245
+ audio_path_or_wv,
246
+ sr=None,
247
+ loudness_normalize=False,
248
+ loudness_threshold=-23.0,
249
+ ):
250
+ if isinstance(audio_path_or_wv, str):
251
+ wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None)
252
+ else:
253
+ wv = audio_path_or_wv
254
+ assert sr is not None
255
+ if loudness_normalize:
256
+ import pyloudnorm as pyln
257
+
258
+ meter = pyln.Meter(sr)
259
+ l = meter.integrated_loudness(wv)
260
+ wv = pyln.normalize.loudness(wv, l, loudness_threshold)
261
+ if sr != self.sampling_rate:
262
+ wv = librosa.resample(wv, orig_sr=sr, target_sr=self.sampling_rate)
263
+ if self.audio_tokenizer_feature_extractor is not None:
264
+ inputs = self.audio_tokenizer_feature_extractor(
265
+ raw_audio=wv,
266
+ sampling_rate=self.audio_tokenizer_feature_extractor.sampling_rate,
267
+ return_tensors="pt",
268
+ )
269
+ input_values = inputs["input_values"].to(self.device)
270
+ else:
271
+ input_values = torch.from_numpy(wv).float().unsqueeze(0)
272
+ with torch.no_grad():
273
+ encoder_outputs = self._xcodec_encode(input_values)
274
+ vq_code = encoder_outputs.audio_codes[0]
275
+ return vq_code
276
+
277
+ def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor:
278
+ bw = target_bw
279
+
280
+ e_semantic_input = self.get_regress_target(x).detach()
281
+
282
+ e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
283
+ e_acoustic = self.encoder(x)
284
+
285
+ if e_acoustic.shape[2] != e_semantic.shape[2]:
286
+ pad_size = 160 * self.semantic_downsample_factor
287
+ e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0))
288
+
289
+ if e_acoustic.shape[2] != e_semantic.shape[2]:
290
+ if e_acoustic.shape[2] > e_semantic.shape[2]:
291
+ e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]]
292
+ else:
293
+ e_semantic = e_semantic[:, :, : e_acoustic.shape[2]]
294
+
295
+ e = torch.cat([e_acoustic, e_semantic], dim=1)
296
+
297
+ e = self.fc_prior(e.transpose(1, 2))
298
+
299
+ if self.quantizer_type == "RVQ":
300
+ e = e.transpose(1, 2)
301
+ quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
302
+ codes = codes.permute(1, 0, 2)
303
+ else:
304
+ quantized, codes = self.quantizer(e)
305
+ codes = codes.permute(0, 2, 1)
306
+
307
+ # return codes
308
+ return EncodedResult(codes)
309
+
310
+ def decode(self, vq_code: torch.Tensor) -> torch.Tensor:
311
+ if self.quantizer_type == "RVQ":
312
+ vq_code = vq_code.permute(1, 0, 2)
313
+ quantized = self.quantizer.decode(vq_code)
314
+ quantized = quantized.transpose(1, 2)
315
+ else:
316
+ vq_code = vq_code.permute(0, 2, 1)
317
+ quantized = self.quantizer.get_output_from_indices(vq_code)
318
+ quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
319
+
320
+ o = self.decoder_2(quantized_acoustic)
321
+ return o.cpu().numpy()
322
+
323
+
324
+ def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"):
325
+ is_local = os.path.exists(tokenizer_name_or_path)
326
+ if not is_local:
327
+ tokenizer_path = snapshot_download(tokenizer_name_or_path)
328
+ else:
329
+ tokenizer_path = tokenizer_name_or_path
330
+ config_path = os.path.join(tokenizer_path, "config.json")
331
+ model_path = os.path.join(tokenizer_path, "model.pth")
332
+ config = json.load(open(config_path))
333
+ model = HiggsAudioTokenizer(
334
+ **config,
335
+ device=device,
336
+ )
337
+ parameter_dict = torch.load(model_path, map_location=device)
338
+ model.load_state_dict(parameter_dict, strict=False)
339
+ model.to(device)
340
+ model.eval()
341
+ return model
higgs_audio/audio_processing/quantization/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # flake8: noqa
8
+ from .vq import QuantizedResult, ResidualVectorQuantizer
higgs_audio/audio_processing/quantization/ac.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Arithmetic coder."""
8
+
9
+ import io
10
+ import math
11
+ import random
12
+ import typing as tp
13
+ import torch
14
+
15
+ from ..binary import BitPacker, BitUnpacker
16
+
17
+
18
+ def build_stable_quantized_cdf(
19
+ pdf: torch.Tensor,
20
+ total_range_bits: int,
21
+ roundoff: float = 1e-8,
22
+ min_range: int = 2,
23
+ check: bool = True,
24
+ ) -> torch.Tensor:
25
+ """Turn the given PDF into a quantized CDF that splits
26
+ [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
27
+ to the PDF.
28
+
29
+ Args:
30
+ pdf (torch.Tensor): probability distribution, shape should be `[N]`.
31
+ total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
32
+ during the coding process is `[0, 2 ** total_range_bits - 1]`.
33
+ roundoff (float): will round the pdf up to that level to remove difference coming
34
+ from e.g. evaluating the Language Model on different architectures.
35
+ min_range (int): minimum range width. Should always be at least 2 for numerical
36
+ stability. Use this to avoid pathological behavior is a value
37
+ that is expected to be rare actually happens in real life.
38
+ check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
39
+ """
40
+ pdf = pdf.detach()
41
+ if roundoff:
42
+ pdf = (pdf / roundoff).floor() * roundoff
43
+ # interpolate with uniform distribution to achieve desired minimum probability.
44
+ total_range = 2**total_range_bits
45
+ cardinality = len(pdf)
46
+ alpha = min_range * cardinality / total_range
47
+ assert alpha <= 1, "you must reduce min_range"
48
+ ranges = (((1 - alpha) * total_range) * pdf).floor().long()
49
+ ranges += min_range
50
+ quantized_cdf = torch.cumsum(ranges, dim=-1)
51
+ if min_range < 2:
52
+ raise ValueError("min_range must be at least 2.")
53
+ if check:
54
+ assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
55
+ if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
56
+ raise ValueError("You must increase your total_range_bits.")
57
+ return quantized_cdf
58
+
59
+
60
+ class ArithmeticCoder:
61
+ """ArithmeticCoder,
62
+ Let us take a distribution `p` over `N` symbols, and assume we have a stream
63
+ of random variables `s_t` sampled from `p`. Let us assume that we have a budget
64
+ of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
65
+ corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
66
+ sequence `(s_t)` by doing the following:
67
+
68
+ 1) Initialize the current range to` [0 ** 2 B - 1]`.
69
+ 2) For each time step t, split the current range into contiguous chunks,
70
+ one for each possible outcome, with size roughly proportional to `p`.
71
+ For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
72
+ would be `{[0, 2], [3, 3]}`.
73
+ 3) Select the chunk corresponding to `s_t`, and replace the current range with this.
74
+ 4) When done encoding all the values, just select any value remaining in the range.
75
+
76
+ You will notice that this procedure can fail: for instance if at any point in time
77
+ the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
78
+ possible outcome. Intuitively, the more likely a value is, the less the range width
79
+ will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
80
+ coding scheme, likely outcomes would take less bits, and more of them can be coded
81
+ with a fixed budget.
82
+
83
+ In practice, we do not know `B` ahead of time, but we have a way to inject new bits
84
+ when the current range decreases below a given limit (given by `total_range_bits`), without
85
+ having to redo all the computations. If we encode mostly likely values, we will seldom
86
+ need to inject new bits, but a single rare value can deplete our stock of entropy!
87
+
88
+ In this explanation, we assumed that the distribution `p` was constant. In fact, the present
89
+ code works for any sequence `(p_t)` possibly different for each timestep.
90
+ We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
91
+ the KL between the true distribution and `p_t`, the most efficient the coding will be.
92
+
93
+ Args:
94
+ fo (IO[bytes]): file-like object to which the bytes will be written to.
95
+ total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
96
+ Any time the current range width fall under this limit, new bits will
97
+ be injected to rescale the initial range.
98
+ """
99
+
100
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
101
+ assert total_range_bits <= 30
102
+ self.total_range_bits = total_range_bits
103
+ self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time.
104
+ self.low: int = 0
105
+ self.high: int = 0
106
+ self.max_bit: int = -1
107
+ self._dbg: tp.List[tp.Any] = []
108
+ self._dbg2: tp.List[tp.Any] = []
109
+
110
+ @property
111
+ def delta(self) -> int:
112
+ """Return the current range width."""
113
+ return self.high - self.low + 1
114
+
115
+ def _flush_common_prefix(self):
116
+ # If self.low and self.high start with the sames bits,
117
+ # those won't change anymore as we always just increase the range
118
+ # by powers of 2, and we can flush them out to the bit stream.
119
+ assert self.high >= self.low, (self.low, self.high)
120
+ assert self.high < 2 ** (self.max_bit + 1)
121
+ while self.max_bit >= 0:
122
+ b1 = self.low >> self.max_bit
123
+ b2 = self.high >> self.max_bit
124
+ if b1 == b2:
125
+ self.low -= b1 << self.max_bit
126
+ self.high -= b1 << self.max_bit
127
+ assert self.high >= self.low, (self.high, self.low, self.max_bit)
128
+ assert self.low >= 0
129
+ self.max_bit -= 1
130
+ self.packer.push(b1)
131
+ else:
132
+ break
133
+
134
+ def push(self, symbol: int, quantized_cdf: torch.Tensor):
135
+ """Push the given symbol on the stream, flushing out bits
136
+ if possible.
137
+
138
+ Args:
139
+ symbol (int): symbol to encode with the AC.
140
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
141
+ to build this from your pdf estimate.
142
+ """
143
+ while self.delta < 2**self.total_range_bits:
144
+ self.low *= 2
145
+ self.high = self.high * 2 + 1
146
+ self.max_bit += 1
147
+
148
+ range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
149
+ range_high = quantized_cdf[symbol].item() - 1
150
+ effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
151
+ effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
152
+ assert self.low <= self.high
153
+ self.high = self.low + effective_high
154
+ self.low = self.low + effective_low
155
+ assert self.low <= self.high, (
156
+ effective_low,
157
+ effective_high,
158
+ range_low,
159
+ range_high,
160
+ )
161
+ self._dbg.append((self.low, self.high))
162
+ self._dbg2.append((self.low, self.high))
163
+ outs = self._flush_common_prefix()
164
+ assert self.low <= self.high
165
+ assert self.max_bit >= -1
166
+ assert self.max_bit <= 61, self.max_bit
167
+ return outs
168
+
169
+ def flush(self):
170
+ """Flush the remaining information to the stream."""
171
+ while self.max_bit >= 0:
172
+ b1 = (self.low >> self.max_bit) & 1
173
+ self.packer.push(b1)
174
+ self.max_bit -= 1
175
+ self.packer.flush()
176
+
177
+
178
+ class ArithmeticDecoder:
179
+ """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
180
+
181
+ Note that this must be called with **exactly** the same parameters and sequence
182
+ of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
183
+
184
+ If the AC encoder current range is [L, H], with `L` and `H` having the some common
185
+ prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
186
+ For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
187
+ `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
188
+ for a specific sequence of symbols and a binary-search allows us to decode those symbols.
189
+ At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
190
+ and we will need to read new bits from the stream and repeat the process.
191
+
192
+ """
193
+
194
+ def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
195
+ self.total_range_bits = total_range_bits
196
+ self.low: int = 0
197
+ self.high: int = 0
198
+ self.current: int = 0
199
+ self.max_bit: int = -1
200
+ self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time.
201
+ # Following is for debugging
202
+ self._dbg: tp.List[tp.Any] = []
203
+ self._dbg2: tp.List[tp.Any] = []
204
+ self._last: tp.Any = None
205
+
206
+ @property
207
+ def delta(self) -> int:
208
+ return self.high - self.low + 1
209
+
210
+ def _flush_common_prefix(self):
211
+ # Given the current range [L, H], if both have a common prefix,
212
+ # we know we can remove it from our representation to avoid handling large numbers.
213
+ while self.max_bit >= 0:
214
+ b1 = self.low >> self.max_bit
215
+ b2 = self.high >> self.max_bit
216
+ if b1 == b2:
217
+ self.low -= b1 << self.max_bit
218
+ self.high -= b1 << self.max_bit
219
+ self.current -= b1 << self.max_bit
220
+ assert self.high >= self.low
221
+ assert self.low >= 0
222
+ self.max_bit -= 1
223
+ else:
224
+ break
225
+
226
+ def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
227
+ """Pull a symbol, reading as many bits from the stream as required.
228
+ This returns `None` when the stream has been exhausted.
229
+
230
+ Args:
231
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
232
+ to build this from your pdf estimate. This must be **exatly**
233
+ the same cdf as the one used at encoding time.
234
+ """
235
+ while self.delta < 2**self.total_range_bits:
236
+ bit = self.unpacker.pull()
237
+ if bit is None:
238
+ return None
239
+ self.low *= 2
240
+ self.high = self.high * 2 + 1
241
+ self.current = self.current * 2 + bit
242
+ self.max_bit += 1
243
+
244
+ def bin_search(low_idx: int, high_idx: int):
245
+ # Binary search is not just for coding interviews :)
246
+ if high_idx < low_idx:
247
+ raise RuntimeError("Binary search failed")
248
+ mid = (low_idx + high_idx) // 2
249
+ range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
250
+ range_high = quantized_cdf[mid].item() - 1
251
+ effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
252
+ effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
253
+ low = effective_low + self.low
254
+ high = effective_high + self.low
255
+ if self.current >= low:
256
+ if self.current <= high:
257
+ return (mid, low, high, self.current)
258
+ else:
259
+ return bin_search(mid + 1, high_idx)
260
+ else:
261
+ return bin_search(low_idx, mid - 1)
262
+
263
+ self._last = (self.low, self.high, self.current, self.max_bit)
264
+ sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
265
+ self._dbg.append((self.low, self.high, self.current))
266
+ self._flush_common_prefix()
267
+ self._dbg2.append((self.low, self.high, self.current))
268
+
269
+ return sym
270
+
271
+
272
+ def test():
273
+ torch.manual_seed(1234)
274
+ random.seed(1234)
275
+ for _ in range(4):
276
+ pdfs = []
277
+ cardinality = random.randrange(4000)
278
+ steps = random.randrange(100, 500)
279
+ fo = io.BytesIO()
280
+ encoder = ArithmeticCoder(fo)
281
+ symbols = []
282
+ for step in range(steps):
283
+ pdf = torch.softmax(torch.randn(cardinality), dim=0)
284
+ pdfs.append(pdf)
285
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
286
+ symbol = torch.multinomial(pdf, 1).item()
287
+ symbols.append(symbol)
288
+ encoder.push(symbol, q_cdf)
289
+ encoder.flush()
290
+
291
+ fo.seek(0)
292
+ decoder = ArithmeticDecoder(fo)
293
+ for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
294
+ q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
295
+ decoded_symbol = decoder.pull(q_cdf)
296
+ assert decoded_symbol == symbol, idx
297
+ assert decoder.pull(torch.zeros(1)) is None
298
+
299
+
300
+ if __name__ == "__main__":
301
+ test()
higgs_audio/audio_processing/quantization/core_vq.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # This implementation is inspired from
8
+ # https://github.com/lucidrains/vector-quantize-pytorch
9
+ # which is released under MIT License. Hereafter, the original license:
10
+ # MIT License
11
+ #
12
+ # Copyright (c) 2020 Phil Wang
13
+ #
14
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ # of this software and associated documentation files (the "Software"), to deal
16
+ # in the Software without restriction, including without limitation the rights
17
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ # copies of the Software, and to permit persons to whom the Software is
19
+ # furnished to do so, subject to the following conditions:
20
+ #
21
+ # The above copyright notice and this permission notice shall be included in all
22
+ # copies or substantial portions of the Software.
23
+ #
24
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ # SOFTWARE.
31
+
32
+ """Core vector quantization implementation."""
33
+
34
+ import typing as tp
35
+
36
+ from einops import rearrange, repeat
37
+ import torch
38
+ from torch import nn
39
+ import torch.nn.functional as F
40
+
41
+ from xcodec.quantization.distrib import broadcast_tensors, rank
42
+
43
+
44
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
45
+ return val if val is not None else d
46
+
47
+
48
+ def ema_inplace(moving_avg, new, decay: float):
49
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
50
+
51
+
52
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
53
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
54
+
55
+
56
+ def uniform_init(*shape: int):
57
+ t = torch.empty(shape)
58
+ nn.init.kaiming_uniform_(t)
59
+ return t
60
+
61
+
62
+ def sample_vectors(samples, num: int):
63
+ num_samples, device = samples.shape[0], samples.device
64
+
65
+ if num_samples >= num:
66
+ indices = torch.randperm(num_samples, device=device)[:num]
67
+ else:
68
+ indices = torch.randint(0, num_samples, (num,), device=device)
69
+
70
+ return samples[indices]
71
+
72
+
73
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
74
+ dim, dtype = samples.shape[-1], samples.dtype
75
+
76
+ means = sample_vectors(samples, num_clusters)
77
+
78
+ for _ in range(num_iters):
79
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
80
+ dists = -(diffs**2).sum(dim=-1)
81
+
82
+ buckets = dists.max(dim=-1).indices
83
+ bins = torch.bincount(buckets, minlength=num_clusters)
84
+ zero_mask = bins == 0
85
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
86
+
87
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
88
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
89
+ new_means = new_means / bins_min_clamped[..., None]
90
+
91
+ means = torch.where(zero_mask[..., None], means, new_means)
92
+
93
+ return means, bins
94
+
95
+
96
+ class EuclideanCodebook(nn.Module):
97
+ """Codebook with Euclidean distance.
98
+ Args:
99
+ dim (int): Dimension.
100
+ codebook_size (int): Codebook size.
101
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
102
+ If set to true, run the k-means algorithm on the first training batch and use
103
+ the learned centroids as initialization.
104
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
105
+ decay (float): Decay for exponential moving average over the codebooks.
106
+ epsilon (float): Epsilon value for numerical stability.
107
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
108
+ that have an exponential moving average cluster size less than the specified threshold with
109
+ randomly selected vector from the current batch.
110
+ """
111
+
112
+ def __init__(
113
+ self,
114
+ dim: int,
115
+ codebook_size: int,
116
+ kmeans_init: int = False,
117
+ kmeans_iters: int = 10,
118
+ decay: float = 0.99,
119
+ epsilon: float = 1e-5,
120
+ threshold_ema_dead_code: int = 2,
121
+ ):
122
+ super().__init__()
123
+ self.decay = decay
124
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
125
+ embed = init_fn(codebook_size, dim)
126
+
127
+ self.codebook_size = codebook_size
128
+
129
+ self.kmeans_iters = kmeans_iters
130
+ self.epsilon = epsilon
131
+ self.threshold_ema_dead_code = threshold_ema_dead_code
132
+
133
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
134
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
135
+ self.register_buffer("embed", embed)
136
+ self.register_buffer("embed_avg", embed.clone())
137
+
138
+ @torch.jit.ignore
139
+ def init_embed_(self, data):
140
+ if self.inited:
141
+ return
142
+
143
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
144
+ self.embed.data.copy_(embed)
145
+ self.embed_avg.data.copy_(embed.clone())
146
+ self.cluster_size.data.copy_(cluster_size)
147
+ self.inited.data.copy_(torch.Tensor([True]))
148
+ # Make sure all buffers across workers are in sync after initialization
149
+ broadcast_tensors(self.buffers())
150
+
151
+ def replace_(self, samples, mask):
152
+ modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
153
+ self.embed.data.copy_(modified_codebook)
154
+
155
+ def expire_codes_(self, batch_samples):
156
+ if self.threshold_ema_dead_code == 0:
157
+ return
158
+
159
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
160
+ if not torch.any(expired_codes):
161
+ return
162
+
163
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
164
+ self.replace_(batch_samples, mask=expired_codes)
165
+ broadcast_tensors(self.buffers())
166
+
167
+ def preprocess(self, x):
168
+ x = rearrange(x, "... d -> (...) d")
169
+ return x
170
+
171
+ def quantize(self, x):
172
+ embed = self.embed.t()
173
+ dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
174
+ embed_ind = dist.max(dim=-1).indices
175
+ return embed_ind
176
+
177
+ def postprocess_emb(self, embed_ind, shape):
178
+ return embed_ind.view(*shape[:-1])
179
+
180
+ def dequantize(self, embed_ind):
181
+ quantize = F.embedding(embed_ind, self.embed) # get embedding based on index
182
+ return quantize
183
+
184
+ def encode(self, x):
185
+ shape = x.shape
186
+ # pre-process
187
+ x = self.preprocess(x)
188
+ # quantize
189
+ embed_ind = self.quantize(x) # get index based on Euclidean distance
190
+ # post-process
191
+ embed_ind = self.postprocess_emb(embed_ind, shape)
192
+ return embed_ind
193
+
194
+ def decode(self, embed_ind):
195
+ quantize = self.dequantize(embed_ind)
196
+ return quantize
197
+
198
+ def forward(self, x):
199
+ shape, dtype = x.shape, x.dtype
200
+ x = self.preprocess(x)
201
+
202
+ self.init_embed_(x)
203
+
204
+ embed_ind = self.quantize(x)
205
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
206
+ embed_ind = self.postprocess_emb(embed_ind, shape)
207
+ quantize = self.dequantize(embed_ind)
208
+
209
+ if self.training:
210
+ # We do the expiry of code at that point as buffers are in sync
211
+ # and all the workers will take the same decision.
212
+ self.expire_codes_(x)
213
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
214
+ embed_sum = x.t() @ embed_onehot
215
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
216
+ cluster_size = (
217
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
218
+ )
219
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
220
+ self.embed.data.copy_(embed_normalized)
221
+
222
+ return quantize, embed_ind
223
+
224
+
225
+ class VectorQuantization(nn.Module):
226
+ """Vector quantization implementation.
227
+ Currently supports only euclidean distance.
228
+ Args:
229
+ dim (int): Dimension
230
+ codebook_size (int): Codebook size
231
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
232
+ decay (float): Decay for exponential moving average over the codebooks.
233
+ epsilon (float): Epsilon value for numerical stability.
234
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
235
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
236
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
237
+ that have an exponential moving average cluster size less than the specified threshold with
238
+ randomly selected vector from the current batch.
239
+ commitment_weight (float): Weight for commitment loss.
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ dim: int,
245
+ codebook_size: int,
246
+ codebook_dim: tp.Optional[int] = None,
247
+ decay: float = 0.99,
248
+ epsilon: float = 1e-5,
249
+ kmeans_init: bool = True,
250
+ kmeans_iters: int = 50,
251
+ threshold_ema_dead_code: int = 2,
252
+ commitment_weight: float = 1.0,
253
+ ):
254
+ super().__init__()
255
+ _codebook_dim: int = default(codebook_dim, dim)
256
+
257
+ requires_projection = _codebook_dim != dim
258
+ self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
259
+ self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
260
+
261
+ self.epsilon = epsilon
262
+ self.commitment_weight = commitment_weight
263
+
264
+ self._codebook = EuclideanCodebook(
265
+ dim=_codebook_dim,
266
+ codebook_size=codebook_size,
267
+ kmeans_init=kmeans_init,
268
+ kmeans_iters=kmeans_iters,
269
+ decay=decay,
270
+ epsilon=epsilon,
271
+ threshold_ema_dead_code=threshold_ema_dead_code,
272
+ )
273
+ self.codebook_size = codebook_size
274
+
275
+ @property
276
+ def codebook(self):
277
+ return self._codebook.embed
278
+
279
+ def encode(self, x):
280
+ x = rearrange(x, "b d n -> b n d")
281
+ x = self.project_in(x)
282
+ embed_in = self._codebook.encode(x)
283
+ return embed_in
284
+
285
+ def decode(self, embed_ind):
286
+ quantize = self._codebook.decode(embed_ind)
287
+ quantize = self.project_out(quantize)
288
+ quantize = rearrange(quantize, "b n d -> b d n")
289
+ return quantize
290
+
291
+ def forward(self, x):
292
+ device = x.device
293
+ x = rearrange(x, "b d n -> b n d")
294
+ x = self.project_in(x)
295
+
296
+ quantize, embed_ind = self._codebook(x)
297
+
298
+ if self.training:
299
+ quantize = x + (quantize - x).detach()
300
+
301
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
302
+
303
+ if self.training:
304
+ if self.commitment_weight > 0:
305
+ commit_loss = F.mse_loss(quantize.detach(), x)
306
+ loss = loss + commit_loss * self.commitment_weight
307
+
308
+ quantize = self.project_out(quantize)
309
+ quantize = rearrange(quantize, "b n d -> b d n")
310
+ return quantize, embed_ind, loss
311
+
312
+
313
+ class ResidualVectorQuantization(nn.Module):
314
+ """Residual vector quantization implementation.
315
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
316
+ """
317
+
318
+ def __init__(self, *, num_quantizers, **kwargs):
319
+ super().__init__()
320
+ self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
321
+
322
+ def forward(self, x, n_q: tp.Optional[int] = None):
323
+ quantized_out = 0.0
324
+ residual = x
325
+
326
+ all_losses = []
327
+ all_indices = []
328
+
329
+ n_q = n_q or len(self.layers)
330
+
331
+ for layer in self.layers[:n_q]:
332
+ quantized, indices, loss = layer(residual)
333
+ residual = residual - quantized
334
+ quantized_out = quantized_out + quantized
335
+
336
+ all_indices.append(indices)
337
+ all_losses.append(loss)
338
+
339
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
340
+ return quantized_out, out_indices, out_losses
341
+
342
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
343
+ residual = x
344
+ all_indices = []
345
+ n_q = n_q or len(self.layers)
346
+ for layer in self.layers[:n_q]:
347
+ indices = layer.encode(residual)
348
+ quantized = layer.decode(indices)
349
+ residual = residual - quantized
350
+ all_indices.append(indices)
351
+ out_indices = torch.stack(all_indices)
352
+ return out_indices
353
+
354
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
355
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
356
+ for i, indices in enumerate(q_indices):
357
+ layer = self.layers[i]
358
+ quantized = layer.decode(indices)
359
+ quantized_out = quantized_out + quantized
360
+ return quantized_out
higgs_audio/audio_processing/quantization/core_vq_lsx_version.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c)
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # This implementation is inspired from
6
+ # https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py and
7
+ # https://github.com/clementchadebec/benchmark_VAE/blob/dfa0dcf6c79172df5d27769c09c860c42008baaa/src/pythae/models/vq_vae/vq_vae_utils.py#L81
8
+ #
9
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
10
+ # All rights reserved.
11
+ #
12
+ # This source code is licensed under the license found in the
13
+ # LICENSE file in the root directory of this source tree.
14
+ #
15
+ # This implementation is inspired from
16
+ # https://github.com/lucidrains/vector-quantize-pytorch
17
+ # which is released under MIT License. Hereafter, the original license:
18
+ # MIT License
19
+ #
20
+ # Copyright (c) 2020 Phil Wang
21
+ #
22
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
23
+ # of this software and associated documentation files (the "Software"), to deal
24
+ # in the Software without restriction, including without limitation the rights
25
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
26
+ # copies of the Software, and to permit persons to whom the Software is
27
+ # furnished to do so, subject to the following conditions:
28
+ #
29
+ # The above copyright notice and this permission notice shall be included in all
30
+ # copies or substantial portions of the Software.
31
+ #
32
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
33
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
34
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
35
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
36
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
37
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
38
+ # SOFTWARE.
39
+
40
+ """Core vector quantization implementation."""
41
+
42
+ import typing as tp
43
+
44
+ from einops import rearrange
45
+ import torch
46
+ from torch import nn
47
+ import torch.nn.functional as F
48
+ import torch.distributed as dist
49
+
50
+ from .distrib import broadcast_tensors, is_distributed
51
+ from .ddp_utils import SyncFunction
52
+
53
+
54
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
55
+ return val if val is not None else d
56
+
57
+
58
+ def ema_inplace(moving_avg, new, decay: float):
59
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
60
+
61
+
62
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
63
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
64
+
65
+
66
+ def uniform_init(*shape: int):
67
+ t = torch.empty(shape)
68
+ nn.init.kaiming_uniform_(t)
69
+ return t
70
+
71
+
72
+ def sample_vectors(samples, num: int):
73
+ num_samples, device = samples.shape[0], samples.device
74
+
75
+ if num_samples >= num:
76
+ indices = torch.randperm(num_samples, device=device)[:num]
77
+ else:
78
+ indices = torch.randint(0, num_samples, (num,), device=device)
79
+
80
+ return samples[indices]
81
+
82
+
83
+ def kmeans(
84
+ samples,
85
+ num_clusters: int,
86
+ num_iters: int = 10,
87
+ frames_to_use: int = 10_000,
88
+ batch_size: int = 64,
89
+ ):
90
+ """
91
+ Memory-efficient K-means clustering.
92
+ Args:
93
+ samples (tensor): shape [N, D]
94
+ num_clusters (int): number of centroids.
95
+ num_iters (int): number of iterations.
96
+ frames_to_use (int): subsample size from total samples.
97
+ batch_size (int): batch size used in distance computation.
98
+ Returns:
99
+ means: [num_clusters, D]
100
+ bins: [num_clusters] (number of points per cluster)
101
+ """
102
+ N, D = samples.shape
103
+ dtype, device = samples.dtype, samples.device
104
+
105
+ if frames_to_use < N:
106
+ indices = torch.randperm(N, device=device)[:frames_to_use]
107
+ samples = samples[indices]
108
+
109
+ means = sample_vectors(samples, num_clusters)
110
+
111
+ for _ in range(num_iters):
112
+ # Store cluster assignments
113
+ all_assignments = []
114
+
115
+ for i in range(0, samples.shape[0], batch_size):
116
+ batch = samples[i : i + batch_size] # [B, D]
117
+ dists = torch.cdist(batch, means, p=2) # [B, C]
118
+ assignments = dists.argmin(dim=1) # [B]
119
+ all_assignments.append(assignments)
120
+
121
+ buckets = torch.cat(all_assignments, dim=0) # [N]
122
+ bins = torch.bincount(buckets, minlength=num_clusters)
123
+ zero_mask = bins == 0
124
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
125
+
126
+ # Compute new means
127
+ new_means = torch.zeros_like(means)
128
+ for i in range(num_clusters):
129
+ mask = buckets == i
130
+ if mask.any():
131
+ new_means[i] = samples[mask].mean(dim=0)
132
+
133
+ means = torch.where(zero_mask[:, None], means, new_means)
134
+
135
+ return means, bins
136
+
137
+
138
+ class EuclideanCodebook(nn.Module):
139
+ """Codebook with Euclidean distance.
140
+ Args:
141
+ dim (int): Dimension.
142
+ codebook_size (int): Codebook size.
143
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
144
+ If set to true, run the k-means algorithm on the first training batch and use
145
+ the learned centroids as initialization.
146
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
147
+ decay (float): Decay for exponential moving average over the codebooks.
148
+ epsilon (float): Epsilon value for numerical stability.
149
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
150
+ that have an exponential moving average cluster size less than the specified threshold with
151
+ randomly selected vector from the current batch.
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ dim: int,
157
+ codebook_size: int,
158
+ kmeans_init: int = False,
159
+ kmeans_iters: int = 10,
160
+ decay: float = 0.99,
161
+ epsilon: float = 1e-5,
162
+ threshold_ema_dead_code: int = 2,
163
+ ):
164
+ super().__init__()
165
+ self.decay = decay
166
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
167
+ embed = init_fn(codebook_size, dim)
168
+
169
+ self.codebook_size = codebook_size
170
+
171
+ self.kmeans_iters = kmeans_iters
172
+ self.epsilon = epsilon
173
+ self.threshold_ema_dead_code = threshold_ema_dead_code
174
+
175
+ # Flag variable to indicate whether the codebook is initialized
176
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
177
+ # Runing EMA cluster size/count: N_i^t in eq. (6) in vqvae paper
178
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
179
+ # Codebook
180
+ self.register_buffer("embed", embed)
181
+ # EMA codebook: eq. (7) in vqvae paper
182
+ self.register_buffer("embed_avg", embed.clone())
183
+
184
+ @torch.jit.ignore
185
+ def init_embed_(self, data):
186
+ """Initialize codebook.
187
+ Args:
188
+ data (tensor): [B * T, D].
189
+ """
190
+ if self.inited:
191
+ return
192
+
193
+ ## NOTE (snippet added by Songxiang Liu): gather data from all gpus
194
+ if dist.is_available() and dist.is_initialized():
195
+ # [B * T * world_size, D]
196
+ data = SyncFunction.apply(data)
197
+
198
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
199
+ self.embed.data.copy_(embed)
200
+ self.embed_avg.data.copy_(embed.clone())
201
+ self.cluster_size.data.copy_(cluster_size)
202
+ self.inited.data.copy_(torch.Tensor([True]))
203
+ # Make sure all buffers across workers are in sync after initialization
204
+ broadcast_tensors(self.buffers())
205
+
206
+ def replace_(self, samples, mask):
207
+ modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
208
+ self.embed.data.copy_(modified_codebook)
209
+
210
+ def expire_codes_(self, batch_samples):
211
+ if self.threshold_ema_dead_code == 0:
212
+ return
213
+
214
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
215
+ if not torch.any(expired_codes):
216
+ return
217
+
218
+ ## NOTE (snippet added by Songxiang Liu): gather data from all gpus
219
+ if is_distributed():
220
+ # [B * T * world_size, D]
221
+ batch_samples = SyncFunction.apply(batch_samples)
222
+
223
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
224
+ self.replace_(batch_samples, mask=expired_codes)
225
+ broadcast_tensors(self.buffers())
226
+
227
+ def preprocess(self, x):
228
+ x = rearrange(x, "... d -> (...) d")
229
+ return x
230
+
231
+ def quantize(self, x):
232
+ embed = self.embed.t()
233
+ dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
234
+ embed_ind = dist.max(dim=-1).indices
235
+ return embed_ind
236
+
237
+ def postprocess_emb(self, embed_ind, shape):
238
+ return embed_ind.view(*shape[:-1])
239
+
240
+ def dequantize(self, embed_ind):
241
+ quantize = F.embedding(embed_ind, self.embed)
242
+ return quantize
243
+
244
+ def encode(self, x):
245
+ shape = x.shape
246
+ # pre-process
247
+ x = self.preprocess(x) # [B, T, D] -> [B*T, D]
248
+ # quantize
249
+ embed_ind = self.quantize(x)
250
+ # post-process
251
+ embed_ind = self.postprocess_emb(embed_ind, shape)
252
+ return embed_ind
253
+
254
+ def decode(self, embed_ind):
255
+ quantize = self.dequantize(embed_ind)
256
+ return quantize
257
+
258
+ def forward(self, x):
259
+ # shape: [B, T, D]
260
+ shape, dtype = x.shape, x.dtype
261
+ x = self.preprocess(x) # [B, T, D] -> [B*T, D]
262
+
263
+ # Initialize codebook
264
+ self.init_embed_(x)
265
+
266
+ embed_ind = self.quantize(x) # [B*T,]
267
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) # [B*T, cb-size]
268
+ embed_ind = self.postprocess_emb(embed_ind, shape) # [B, T]
269
+ quantize = self.dequantize(embed_ind) # [B, T, D]
270
+
271
+ if self.training:
272
+ ### Update codebook by EMA
273
+ embed_onehot_sum = embed_onehot.sum(0) # [cb-size,]
274
+ embed_sum = x.t() @ embed_onehot # [D, cb-size]
275
+ if is_distributed():
276
+ dist.all_reduce(embed_onehot_sum)
277
+ dist.all_reduce(embed_sum)
278
+ # Update ema cluster count N_i^t, eq. (6) in vqvae paper
279
+ self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
280
+ # Update ema embed: eq. (7) in vqvae paper
281
+ self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay)
282
+ # apply laplace smoothing
283
+ n = self.cluster_size.sum()
284
+ cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n
285
+ # Update ema embed: eq. (8) in vqvae paper
286
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
287
+ self.embed.data.copy_(embed_normalized)
288
+
289
+ # We do the expiry of code at that point as buffers are in sync
290
+ # and all the workers will take the same decision.
291
+ self.expire_codes_(x)
292
+
293
+ return quantize, embed_ind
294
+
295
+
296
+ class VectorQuantization(nn.Module):
297
+ """Vector quantization implementation.
298
+ Currently supports only euclidean distance.
299
+ Args:
300
+ dim (int): Dimension
301
+ codebook_size (int): Codebook size
302
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
303
+ decay (float): Decay for exponential moving average over the codebooks.
304
+ epsilon (float): Epsilon value for numerical stability.
305
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
306
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
307
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
308
+ that have an exponential moving average cluster size less than the specified threshold with
309
+ randomly selected vector from the current batch.
310
+ commitment_weight (float): Weight for commitment loss.
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ dim: int,
316
+ codebook_size: int,
317
+ codebook_dim: tp.Optional[int] = None,
318
+ decay: float = 0.99,
319
+ epsilon: float = 1e-5,
320
+ kmeans_init: bool = True,
321
+ kmeans_iters: int = 50,
322
+ threshold_ema_dead_code: int = 2,
323
+ commitment_weight: float = 1.0,
324
+ ):
325
+ super().__init__()
326
+ _codebook_dim: int = default(codebook_dim, dim)
327
+
328
+ requires_projection = _codebook_dim != dim
329
+ self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
330
+ self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
331
+
332
+ self.epsilon = epsilon
333
+ self.commitment_weight = commitment_weight
334
+
335
+ self._codebook = EuclideanCodebook(
336
+ dim=_codebook_dim,
337
+ codebook_size=codebook_size,
338
+ kmeans_init=kmeans_init,
339
+ kmeans_iters=kmeans_iters,
340
+ decay=decay,
341
+ epsilon=epsilon,
342
+ threshold_ema_dead_code=threshold_ema_dead_code,
343
+ )
344
+ self.codebook_size = codebook_size
345
+
346
+ @property
347
+ def codebook(self):
348
+ return self._codebook.embed
349
+
350
+ def encode(self, x):
351
+ x = rearrange(x, "b d n -> b n d")
352
+ x = self.project_in(x)
353
+ embed_in = self._codebook.encode(x)
354
+ return embed_in
355
+
356
+ def decode(self, embed_ind):
357
+ quantize = self._codebook.decode(embed_ind)
358
+ quantize = self.project_out(quantize)
359
+ quantize = rearrange(quantize, "b n d -> b d n")
360
+ return quantize
361
+
362
+ def forward(self, x):
363
+ device = x.device
364
+ x = x.transpose(1, 2).contiguous() # [b d n] -> [b n d]
365
+ x = self.project_in(x)
366
+
367
+ quantize, embed_ind = self._codebook(x)
368
+
369
+ if self.training:
370
+ quantize = x + (quantize - x).detach()
371
+
372
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
373
+
374
+ if self.training:
375
+ if self.commitment_weight > 0:
376
+ commit_loss = F.mse_loss(quantize.detach(), x)
377
+ loss = loss + commit_loss * self.commitment_weight
378
+
379
+ quantize = self.project_out(quantize)
380
+ quantize = quantize.transpose(1, 2).contiguous() # [b n d] -> [b d n]
381
+ return quantize, embed_ind, loss
382
+
383
+
384
+ class ResidualVectorQuantization(nn.Module):
385
+ """Residual vector quantization implementation.
386
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
387
+ """
388
+
389
+ def __init__(self, *, num_quantizers, **kwargs):
390
+ super().__init__()
391
+ self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
392
+
393
+ def forward(self, x, n_q: tp.Optional[int] = None):
394
+ quantized_out = 0.0
395
+ residual = x
396
+
397
+ all_losses = []
398
+ all_indices = []
399
+
400
+ n_q = n_q or len(self.layers)
401
+
402
+ for layer in self.layers[:n_q]:
403
+ quantized, indices, loss = layer(residual)
404
+ residual = residual - quantized
405
+ quantized_out = quantized_out + quantized
406
+
407
+ all_indices.append(indices)
408
+ all_losses.append(loss)
409
+
410
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
411
+ return quantized_out, out_indices, out_losses
412
+
413
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
414
+ residual = x
415
+ all_indices = []
416
+ n_q = n_q or len(self.layers)
417
+ for layer in self.layers[:n_q]:
418
+ indices = layer.encode(residual)
419
+ quantized = layer.decode(indices)
420
+ residual = residual - quantized
421
+ all_indices.append(indices)
422
+ out_indices = torch.stack(all_indices)
423
+ return out_indices
424
+
425
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
426
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
427
+ for i, indices in enumerate(q_indices):
428
+ layer = self.layers[i]
429
+ quantized = layer.decode(indices)
430
+ quantized_out = quantized_out + quantized
431
+ return quantized_out
higgs_audio/audio_processing/quantization/ddp_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ import subprocess
4
+ from datetime import datetime
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.distributed as dist
9
+ from torch.nn.parallel import DistributedDataParallel
10
+ from torch.nn.parallel.distributed import _find_tensors
11
+ import torch.optim
12
+ import torch.utils.data
13
+ from packaging import version
14
+ from omegaconf import OmegaConf
15
+
16
+
17
+ def set_random_seed(seed):
18
+ random.seed(seed)
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed_all(seed)
22
+
23
+
24
+ def is_logging_process():
25
+ return not dist.is_initialized() or dist.get_rank() == 0
26
+
27
+
28
+ def get_logger(cfg, name=None):
29
+ # log_file_path is used when unit testing
30
+ if is_logging_process():
31
+ logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_config, resolve=True))
32
+ return logging.getLogger(name)
33
+
34
+
35
+ # from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
36
+ class SyncFunction(torch.autograd.Function):
37
+ @staticmethod
38
+ # @torch.no_grad()
39
+ def forward(ctx, tensor):
40
+ ctx.batch_size = tensor.shape[0]
41
+
42
+ gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
43
+
44
+ torch.distributed.all_gather(gathered_tensor, tensor)
45
+ gathered_tensor = torch.cat(gathered_tensor, 0)
46
+
47
+ return gathered_tensor
48
+
49
+ @staticmethod
50
+ def backward(ctx, grad_output):
51
+ grad_input = grad_output.clone()
52
+ torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
53
+
54
+ idx_from = torch.distributed.get_rank() * ctx.batch_size
55
+ idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
56
+ return grad_input[idx_from:idx_to]
57
+
58
+
59
+ def get_timestamp():
60
+ return datetime.now().strftime("%y%m%d-%H%M%S")
61
+
62
+
63
+ def get_commit_hash():
64
+ message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
65
+ return message.strip().decode("utf-8")
66
+
67
+
68
+ class DDP(DistributedDataParallel):
69
+ """
70
+ Override the forward call in lightning so it goes to training and validation step respectively
71
+ """
72
+
73
+ def forward(self, *inputs, **kwargs): # pragma: no cover
74
+ if version.parse(torch.__version__[:6]) < version.parse("1.11"):
75
+ self._sync_params()
76
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
77
+ assert len(self.device_ids) == 1
78
+ if self.module.training:
79
+ output = self.module.training_step(*inputs[0], **kwargs[0])
80
+ elif self.module.testing:
81
+ output = self.module.test_step(*inputs[0], **kwargs[0])
82
+ else:
83
+ output = self.module.validation_step(*inputs[0], **kwargs[0])
84
+ if torch.is_grad_enabled():
85
+ # We'll return the output object verbatim since it is a freeform
86
+ # object. We need to find any tensors in this object, though,
87
+ # because we need to figure out which parameters were used during
88
+ # this forward pass, to ensure we short circuit reduction for any
89
+ # unused parameters. Only if `find_unused_parameters` is set.
90
+ if self.find_unused_parameters:
91
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
92
+ else:
93
+ self.reducer.prepare_for_backward([])
94
+ else:
95
+ from torch.nn.parallel.distributed import (
96
+ logging,
97
+ Join,
98
+ _DDPSink,
99
+ _tree_flatten_with_rref,
100
+ _tree_unflatten_with_rref,
101
+ )
102
+
103
+ with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
104
+ if torch.is_grad_enabled() and self.require_backward_grad_sync:
105
+ self.logger.set_runtime_stats_and_log()
106
+ self.num_iterations += 1
107
+ self.reducer.prepare_for_forward()
108
+
109
+ # Notify the join context that this process has not joined, if
110
+ # needed
111
+ work = Join.notify_join_context(self)
112
+ if work:
113
+ self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size)
114
+
115
+ # Calling _rebuild_buckets before forward compuation,
116
+ # It may allocate new buckets before deallocating old buckets
117
+ # inside _rebuild_buckets. To save peak memory usage,
118
+ # call _rebuild_buckets before the peak memory usage increases
119
+ # during forward computation.
120
+ # This should be called only once during whole training period.
121
+ if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
122
+ logging.info("Reducer buckets have been rebuilt in this iteration.")
123
+ self._has_rebuilt_buckets = True
124
+
125
+ # sync params according to location (before/after forward) user
126
+ # specified as part of hook, if hook was specified.
127
+ buffer_hook_registered = hasattr(self, "buffer_hook")
128
+ if self._check_sync_bufs_pre_fwd():
129
+ self._sync_buffers()
130
+
131
+ if self._join_config.enable:
132
+ # Notify joined ranks whether they should sync in backwards pass or not.
133
+ self._check_global_requires_backward_grad_sync(is_joined_rank=False)
134
+
135
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
136
+ if self.module.training:
137
+ output = self.module.training_step(*inputs[0], **kwargs[0])
138
+ elif self.module.testing:
139
+ output = self.module.test_step(*inputs[0], **kwargs[0])
140
+ else:
141
+ output = self.module.validation_step(*inputs[0], **kwargs[0])
142
+
143
+ # sync params according to location (before/after forward) user
144
+ # specified as part of hook, if hook was specified.
145
+ if self._check_sync_bufs_post_fwd():
146
+ self._sync_buffers()
147
+
148
+ if torch.is_grad_enabled() and self.require_backward_grad_sync:
149
+ self.require_forward_param_sync = True
150
+ # We'll return the output object verbatim since it is a freeform
151
+ # object. We need to find any tensors in this object, though,
152
+ # because we need to figure out which parameters were used during
153
+ # this forward pass, to ensure we short circuit reduction for any
154
+ # unused parameters. Only if `find_unused_parameters` is set.
155
+ if self.find_unused_parameters and not self.static_graph:
156
+ # Do not need to populate this for static graph.
157
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
158
+ else:
159
+ self.reducer.prepare_for_backward([])
160
+ else:
161
+ self.require_forward_param_sync = False
162
+
163
+ # TODO: DDPSink is currently enabled for unused parameter detection and
164
+ # static graph training for first iteration.
165
+ if (self.find_unused_parameters and not self.static_graph) or (
166
+ self.static_graph and self.num_iterations == 1
167
+ ):
168
+ state_dict = {
169
+ "static_graph": self.static_graph,
170
+ "num_iterations": self.num_iterations,
171
+ }
172
+
173
+ output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output)
174
+ output_placeholders = [None for _ in range(len(output_tensor_list))]
175
+ # Do not touch tensors that have no grad_fn, which can cause issues
176
+ # such as https://github.com/pytorch/pytorch/issues/60733
177
+ for i, output in enumerate(output_tensor_list):
178
+ if torch.is_tensor(output) and output.grad_fn is None:
179
+ output_placeholders[i] = output
180
+
181
+ # When find_unused_parameters=True, makes tensors which require grad
182
+ # run through the DDPSink backward pass. When not all outputs are
183
+ # used in loss, this makes those corresponding tensors receive
184
+ # undefined gradient which the reducer then handles to ensure
185
+ # param.grad field is not touched and we don't error out.
186
+ passthrough_tensor_list = _DDPSink.apply(
187
+ self.reducer,
188
+ state_dict,
189
+ *output_tensor_list,
190
+ )
191
+ for i in range(len(output_placeholders)):
192
+ if output_placeholders[i] is None:
193
+ output_placeholders[i] = passthrough_tensor_list[i]
194
+
195
+ # Reconstruct output data structure.
196
+ output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref)
197
+ return output
higgs_audio/audio_processing/quantization/distrib.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Torch distributed utilities."""
8
+
9
+ import typing as tp
10
+
11
+ import torch
12
+
13
+
14
+ def rank():
15
+ if torch.distributed.is_initialized():
16
+ return torch.distributed.get_rank()
17
+ else:
18
+ return 0
19
+
20
+
21
+ def world_size():
22
+ if torch.distributed.is_initialized():
23
+ return torch.distributed.get_world_size()
24
+ else:
25
+ return 1
26
+
27
+
28
+ def is_distributed():
29
+ return world_size() > 1
30
+
31
+
32
+ def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
33
+ if is_distributed():
34
+ return torch.distributed.all_reduce(tensor, op)
35
+
36
+
37
+ def _is_complex_or_float(tensor):
38
+ return torch.is_floating_point(tensor) or torch.is_complex(tensor)
39
+
40
+
41
+ def _check_number_of_params(params: tp.List[torch.Tensor]):
42
+ # utility function to check that the number of params in all workers is the same,
43
+ # and thus avoid a deadlock with distributed all reduce.
44
+ if not is_distributed() or not params:
45
+ return
46
+ # print('params[0].device ', params[0].device)
47
+ tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
48
+ all_reduce(tensor)
49
+ if tensor.item() != len(params) * world_size():
50
+ # If not all the workers have the same number, for at least one of them,
51
+ # this inequality will be verified.
52
+ raise RuntimeError(
53
+ f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one."
54
+ )
55
+
56
+
57
+ def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
58
+ """Broadcast the tensors from the given parameters to all workers.
59
+ This can be used to ensure that all workers have the same model to start with.
60
+ """
61
+ if not is_distributed():
62
+ return
63
+ tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
64
+ _check_number_of_params(tensors)
65
+ handles = []
66
+ for tensor in tensors:
67
+ handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
68
+ handles.append(handle)
69
+ for handle in handles:
70
+ handle.wait()
71
+
72
+
73
+ def sync_buffer(buffers, average=True):
74
+ """
75
+ Sync grad for buffers. If average is False, broadcast instead of averaging.
76
+ """
77
+ if not is_distributed():
78
+ return
79
+ handles = []
80
+ for buffer in buffers:
81
+ if torch.is_floating_point(buffer.data):
82
+ if average:
83
+ handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
84
+ else:
85
+ handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
86
+ handles.append((buffer, handle))
87
+ for buffer, handle in handles:
88
+ handle.wait()
89
+ if average:
90
+ buffer.data /= world_size
91
+
92
+
93
+ def sync_grad(params):
94
+ """
95
+ Simpler alternative to DistributedDataParallel, that doesn't rely
96
+ on any black magic. For simple models it can also be as fast.
97
+ Just call this on your model parameters after the call to backward!
98
+ """
99
+ if not is_distributed():
100
+ return
101
+ handles = []
102
+ for p in params:
103
+ if p.grad is not None:
104
+ handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
105
+ handles.append((p, handle))
106
+ for p, handle in handles:
107
+ handle.wait()
108
+ p.grad.data /= world_size()
109
+
110
+
111
+ def average_metrics(metrics: tp.Dict[str, float], count=1.0):
112
+ """Average a dictionary of metrics across all workers, using the optional
113
+ `count` as unormalized weight.
114
+ """
115
+ if not is_distributed():
116
+ return metrics
117
+ keys, values = zip(*metrics.items())
118
+ device = "cuda" if torch.cuda.is_available() else "cpu"
119
+ tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
120
+ tensor *= count
121
+ all_reduce(tensor)
122
+ averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
123
+ return dict(zip(keys, averaged))
higgs_audio/audio_processing/quantization/vq.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Residual vector quantizer implementation."""
8
+
9
+ from dataclasses import dataclass, field
10
+ import math
11
+ import typing as tp
12
+
13
+ import torch
14
+ from torch import nn
15
+
16
+ # from .core_vq import ResidualVectorQuantization
17
+ from .core_vq_lsx_version import ResidualVectorQuantization
18
+
19
+
20
+ @dataclass
21
+ class QuantizedResult:
22
+ quantized: torch.Tensor
23
+ codes: torch.Tensor
24
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
25
+ penalty: tp.Optional[torch.Tensor] = None
26
+ metrics: dict = field(default_factory=dict)
27
+
28
+
29
+ class ResidualVectorQuantizer(nn.Module):
30
+ """Residual Vector Quantizer.
31
+ Args:
32
+ dimension (int): Dimension of the codebooks.
33
+ n_q (int): Number of residual vector quantizers used.
34
+ bins (int): Codebook size.
35
+ decay (float): Decay for exponential moving average over the codebooks.
36
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
37
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
38
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
39
+ that have an exponential moving average cluster size less than the specified threshold with
40
+ randomly selected vector from the current batch.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dimension: int = 256,
46
+ codebook_dim: int = None,
47
+ n_q: int = 8,
48
+ bins: int = 1024,
49
+ decay: float = 0.99,
50
+ kmeans_init: bool = True,
51
+ kmeans_iters: int = 50,
52
+ threshold_ema_dead_code: int = 2,
53
+ ):
54
+ super().__init__()
55
+ self.n_q = n_q
56
+ self.dimension = dimension
57
+ self.codebook_dim = codebook_dim
58
+ self.bins = bins
59
+ self.decay = decay
60
+ self.kmeans_init = kmeans_init
61
+ self.kmeans_iters = kmeans_iters
62
+ self.threshold_ema_dead_code = threshold_ema_dead_code
63
+ self.vq = ResidualVectorQuantization(
64
+ dim=self.dimension,
65
+ codebook_dim=self.codebook_dim,
66
+ codebook_size=self.bins,
67
+ num_quantizers=self.n_q,
68
+ decay=self.decay,
69
+ kmeans_init=self.kmeans_init,
70
+ kmeans_iters=self.kmeans_iters,
71
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
72
+ )
73
+
74
+ def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None): # -> QuantizedResult:
75
+ """Residual vector quantization on the given input tensor.
76
+ Args:
77
+ x (torch.Tensor): Input tensor.
78
+ sample_rate (int): Sample rate of the input tensor.
79
+ bandwidth (float): Target bandwidth.
80
+ Returns:
81
+ QuantizedResult:
82
+ The quantized (or approximately quantized) representation with
83
+ the associated bandwidth and any penalty term for the loss.
84
+ """
85
+ bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
86
+ n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
87
+ quantized, codes, commit_loss = self.vq(x, n_q=n_q)
88
+ bw = torch.tensor(n_q * bw_per_q).to(x)
89
+ return quantized, codes, bw, torch.mean(commit_loss)
90
+ # return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
91
+
92
+ def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int:
93
+ """Return n_q based on specified target bandwidth."""
94
+ bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
95
+ n_q = self.n_q
96
+ if bandwidth and bandwidth > 0.0:
97
+ n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
98
+ return n_q
99
+
100
+ def get_bandwidth_per_quantizer(self, sample_rate: int):
101
+ """Return bandwidth per quantizer for a given input sample rate."""
102
+ return math.log2(self.bins) * sample_rate / 1000
103
+
104
+ def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
105
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
106
+ The RVQ encode method sets the appropriate number of quantizer to use
107
+ and returns indices for each quantizer.
108
+ """
109
+ n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
110
+ codes = self.vq.encode(x, n_q=n_q)
111
+ return codes
112
+
113
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
114
+ """Decode the given codes to the quantized representation."""
115
+ quantized = self.vq.decode(codes)
116
+ return quantized
higgs_audio/audio_processing/semantic_module.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on code from: https://github.com/zhenye234/xcodec
2
+ # Licensed under MIT License
3
+ # Modifications by BosonAI
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class Conv1d1x1(nn.Conv1d):
10
+ """1x1 Conv1d."""
11
+
12
+ def __init__(self, in_channels, out_channels, bias=True):
13
+ super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias)
14
+
15
+
16
+ class Conv1d(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_channels: int,
20
+ out_channels: int,
21
+ kernel_size: int,
22
+ stride: int = 1,
23
+ padding: int = -1,
24
+ dilation: int = 1,
25
+ groups: int = 1,
26
+ bias: bool = True,
27
+ ):
28
+ super().__init__()
29
+ self.in_channels = in_channels
30
+ self.out_channels = out_channels
31
+ self.kernel_size = kernel_size
32
+ if padding < 0:
33
+ padding = (kernel_size - 1) // 2 * dilation
34
+ self.dilation = dilation
35
+ self.conv = nn.Conv1d(
36
+ in_channels=in_channels,
37
+ out_channels=out_channels,
38
+ kernel_size=kernel_size,
39
+ stride=stride,
40
+ padding=padding,
41
+ dilation=dilation,
42
+ groups=groups,
43
+ bias=bias,
44
+ )
45
+
46
+ def forward(self, x):
47
+ """
48
+ Args:
49
+ x (Tensor): Float tensor variable with the shape (B, C, T).
50
+ Returns:
51
+ Tensor: Float tensor variable with the shape (B, C, T).
52
+ """
53
+ x = self.conv(x)
54
+ return x
55
+
56
+
57
+ class ResidualUnit(nn.Module):
58
+ def __init__(
59
+ self,
60
+ in_channels: int,
61
+ out_channels: int,
62
+ kernel_size=3,
63
+ dilation=1,
64
+ bias=False,
65
+ nonlinear_activation="ELU",
66
+ nonlinear_activation_params={},
67
+ ):
68
+ super().__init__()
69
+ self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
70
+ self.conv1 = Conv1d(
71
+ in_channels=in_channels,
72
+ out_channels=out_channels,
73
+ kernel_size=kernel_size,
74
+ stride=1,
75
+ dilation=dilation,
76
+ bias=bias,
77
+ )
78
+ self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
79
+
80
+ def forward(self, x):
81
+ y = self.conv1(self.activation(x))
82
+ y = self.conv2(self.activation(y))
83
+ return x + y
84
+
85
+
86
+ class ConvTranspose1d(nn.Module):
87
+ def __init__(
88
+ self,
89
+ in_channels: int,
90
+ out_channels: int,
91
+ kernel_size: int,
92
+ stride: int,
93
+ padding=-1,
94
+ output_padding=-1,
95
+ groups=1,
96
+ bias=True,
97
+ ):
98
+ super().__init__()
99
+ if padding < 0:
100
+ padding = (stride + 1) // 2
101
+ if output_padding < 0:
102
+ output_padding = 1 if stride % 2 else 0
103
+ self.deconv = nn.ConvTranspose1d(
104
+ in_channels=in_channels,
105
+ out_channels=out_channels,
106
+ kernel_size=kernel_size,
107
+ stride=stride,
108
+ padding=padding,
109
+ output_padding=output_padding,
110
+ groups=groups,
111
+ bias=bias,
112
+ )
113
+
114
+ def forward(self, x):
115
+ """
116
+ Args:
117
+ x (Tensor): Float tensor variable with the shape (B, C, T).
118
+ Returns:
119
+ Tensor: Float tensor variable with the shape (B, C', T').
120
+ """
121
+ x = self.deconv(x)
122
+ return x
123
+
124
+
125
+ class EncoderBlock(nn.Module):
126
+ def __init__(
127
+ self,
128
+ in_channels: int,
129
+ out_channels: int,
130
+ stride: int,
131
+ dilations=(1, 1),
132
+ unit_kernel_size=3,
133
+ bias=True,
134
+ ):
135
+ super().__init__()
136
+ self.res_units = torch.nn.ModuleList()
137
+ for dilation in dilations:
138
+ self.res_units += [
139
+ ResidualUnit(
140
+ in_channels,
141
+ in_channels,
142
+ kernel_size=unit_kernel_size,
143
+ dilation=dilation,
144
+ )
145
+ ]
146
+ self.num_res = len(self.res_units)
147
+
148
+ self.conv = Conv1d(
149
+ in_channels=in_channels,
150
+ out_channels=out_channels,
151
+ kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2
152
+ stride=stride,
153
+ bias=bias,
154
+ )
155
+
156
+ def forward(self, x):
157
+ for idx in range(self.num_res):
158
+ x = self.res_units[idx](x)
159
+ x = self.conv(x)
160
+ return x
161
+
162
+
163
+ class Encoder(nn.Module):
164
+ def __init__(
165
+ self,
166
+ input_channels: int,
167
+ encode_channels: int,
168
+ channel_ratios=(1, 1),
169
+ strides=(1, 1),
170
+ kernel_size=3,
171
+ bias=True,
172
+ block_dilations=(1, 1),
173
+ unit_kernel_size=3,
174
+ ):
175
+ super().__init__()
176
+ assert len(channel_ratios) == len(strides)
177
+
178
+ self.conv = Conv1d(
179
+ in_channels=input_channels,
180
+ out_channels=encode_channels,
181
+ kernel_size=kernel_size,
182
+ stride=1,
183
+ bias=False,
184
+ )
185
+ self.conv_blocks = torch.nn.ModuleList()
186
+ in_channels = encode_channels
187
+ for idx, stride in enumerate(strides):
188
+ out_channels = int(encode_channels * channel_ratios[idx]) # could be float
189
+ self.conv_blocks += [
190
+ EncoderBlock(
191
+ in_channels,
192
+ out_channels,
193
+ stride,
194
+ dilations=block_dilations,
195
+ unit_kernel_size=unit_kernel_size,
196
+ bias=bias,
197
+ )
198
+ ]
199
+ in_channels = out_channels
200
+ self.num_blocks = len(self.conv_blocks)
201
+ self.out_channels = out_channels
202
+
203
+ def forward(self, x):
204
+ x = self.conv(x)
205
+ for i in range(self.num_blocks):
206
+ x = self.conv_blocks[i](x)
207
+ return x
208
+
209
+
210
+ class DecoderBlock(nn.Module):
211
+ """Decoder block (no up-sampling)"""
212
+
213
+ def __init__(
214
+ self,
215
+ in_channels: int,
216
+ out_channels: int,
217
+ stride: int,
218
+ dilations=(1, 1),
219
+ unit_kernel_size=3,
220
+ bias=True,
221
+ ):
222
+ super().__init__()
223
+
224
+ if stride == 1:
225
+ self.conv = Conv1d(
226
+ in_channels=in_channels,
227
+ out_channels=out_channels,
228
+ kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape
229
+ stride=stride,
230
+ bias=bias,
231
+ )
232
+ else:
233
+ self.conv = ConvTranspose1d(
234
+ in_channels=in_channels,
235
+ out_channels=out_channels,
236
+ kernel_size=(2 * stride),
237
+ stride=stride,
238
+ bias=bias,
239
+ )
240
+
241
+ self.res_units = torch.nn.ModuleList()
242
+ for idx, dilation in enumerate(dilations):
243
+ self.res_units += [
244
+ ResidualUnit(
245
+ out_channels,
246
+ out_channels,
247
+ kernel_size=unit_kernel_size,
248
+ dilation=dilation,
249
+ )
250
+ ]
251
+ self.num_res = len(self.res_units)
252
+
253
+ def forward(self, x):
254
+ x = self.conv(x)
255
+ for idx in range(self.num_res):
256
+ x = self.res_units[idx](x)
257
+ return x
258
+
259
+
260
+ class Decoder(nn.Module):
261
+ def __init__(
262
+ self,
263
+ code_dim: int,
264
+ output_channels: int,
265
+ decode_channels: int,
266
+ channel_ratios=(1, 1),
267
+ strides=(1, 1),
268
+ kernel_size=3,
269
+ bias=True,
270
+ block_dilations=(1, 1),
271
+ unit_kernel_size=3,
272
+ ):
273
+ super().__init__()
274
+ assert len(channel_ratios) == len(strides)
275
+
276
+ self.conv1 = Conv1d(
277
+ in_channels=code_dim,
278
+ out_channels=int(decode_channels * channel_ratios[0]),
279
+ kernel_size=kernel_size,
280
+ stride=1,
281
+ bias=False,
282
+ )
283
+
284
+ self.conv_blocks = torch.nn.ModuleList()
285
+ for idx, stride in enumerate(strides):
286
+ in_channels = int(decode_channels * channel_ratios[idx])
287
+ if idx < (len(channel_ratios) - 1):
288
+ out_channels = int(decode_channels * channel_ratios[idx + 1])
289
+ else:
290
+ out_channels = decode_channels
291
+ self.conv_blocks += [
292
+ DecoderBlock(
293
+ in_channels,
294
+ out_channels,
295
+ stride,
296
+ dilations=block_dilations,
297
+ unit_kernel_size=unit_kernel_size,
298
+ bias=bias,
299
+ )
300
+ ]
301
+ self.num_blocks = len(self.conv_blocks)
302
+
303
+ self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
304
+
305
+ def forward(self, z):
306
+ x = self.conv1(z)
307
+ for i in range(self.num_blocks):
308
+ x = self.conv_blocks[i](x)
309
+ x = self.conv2(x)
310
+ return x
higgs_audio/constants.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ AUDIO_IN_TOKEN = "<|AUDIO|>"
2
+ AUDIO_OUT_TOKEN = "<|AUDIO_OUT|>"
3
+ EOS_TOKEN = "<|end_of_text|>"
higgs_audio/data_collator/__init__.py ADDED
File without changes
higgs_audio/data_collator/higgs_audio_collator.py ADDED
@@ -0,0 +1,583 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import math
5
+ import numpy as np
6
+ from typing import List, Tuple, Dict
7
+
8
+ from dataclasses import dataclass
9
+ from typing import List, Optional
10
+ from transformers.models.whisper.processing_whisper import WhisperProcessor
11
+
12
+ from ..dataset.chatml_dataset import ChatMLDatasetSample, RankedChatMLDatasetSampleTuple
13
+ from ..model.utils import build_delay_pattern_mask
14
+
15
+
16
+ def _ceil_to_nearest(n, round_to):
17
+ return (n + round_to - 1) // round_to * round_to
18
+
19
+
20
+ @dataclass
21
+ class HiggsAudioBatchInput:
22
+ input_ids: torch.LongTensor # shape (bsz, seq_len).
23
+ attention_mask: torch.Tensor # shape (bsz, seq_len).
24
+ audio_features: Optional[torch.Tensor] # shape (num_audio_in, feature_dim, max_mel_seq_len).
25
+ audio_feature_attention_mask: Optional[torch.Tensor] # shape (num_audio_in, max_mel_seq_len).
26
+ audio_out_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
27
+ audio_out_ids_start: Optional[torch.LongTensor] # shape (num_audio_out,)
28
+ # The audio_out_ids_start_group_loc has the same length as audio_out_ids_start. It is used to recover group location in a batch for an audio segment
29
+ # Currently, we concatenante audio segments along dim 0 to handle variadic audio segment length. However, in the alignment stage, we need the location information
30
+ # For example,
31
+ # audio_out_ids_start = [0, 2, 4, 8]; and the first two audio segments come from the same sample in a batch, and other two come from different samples.
32
+ # This is a batch of 3 samples, then we will have the group location as:
33
+ # audio_out_ids_start_group_loc = [0, 0, 1, 2]
34
+ audio_out_ids_start_group_loc: Optional[
35
+ torch.LongTensor
36
+ ] # shape (num_audio_out,), specify which a sample's group location in the batch
37
+ audio_in_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_in_total_length)
38
+ audio_in_ids_start: Optional[torch.LongTensor] # shape (num_audio_in,)
39
+ label_ids: Optional[torch.LongTensor] # shape (bsz, seq_len)
40
+ label_audio_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length)
41
+ reward: Optional[float] = None
42
+
43
+
44
+ class HiggsAudioSampleCollator:
45
+ """Sample collator for Higgs-Audio model.
46
+
47
+ Args:
48
+ whisper_processor (WhisperProcessor): The whisper processor.
49
+ audio_in_token_id (int): The token id for audio-in.
50
+ audio_out_token_id (int): The token id for audio-out.
51
+ pad_token_id (int): The token id for padding.
52
+ audio_stream_bos_id (int): The token id for audio-stream beginning of sentence.
53
+ audio_stream_eos_id (int): The token id for audio-stream end of sentence.
54
+ round_to (int): The round-to value.
55
+ pad_left (bool): Whether to pad left.
56
+ return_audio_in_tokens (bool): Whether to return audio-in tokens.
57
+ use_delay_pattern (bool): Whether to use delay pattern.
58
+ disable_audio_codes_transform (bool): Whether to add bos and eos tokens to audio codes.
59
+ chunk_size_seconds (int): The chunk size in seconds.
60
+ add_new_bos_eos_for_long_chunk (bool): Whether to add new bos and eos tokens for long chunks.
61
+ mask_audio_out_token_label (bool): Whether to always mask the label associated with <|AUDIO_OUT|> token. Since we will always have `<|AUDIO_OUT|>` after `<|audio_bos|>`, we can safely mask <|AUDIO_OUT|>.
62
+
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ whisper_processor: WhisperProcessor,
68
+ audio_in_token_id,
69
+ audio_out_token_id,
70
+ pad_token_id,
71
+ audio_stream_bos_id,
72
+ audio_stream_eos_id,
73
+ round_to=8,
74
+ pad_left=False,
75
+ encode_whisper_embed=True,
76
+ return_audio_in_tokens=True,
77
+ audio_num_codebooks=None,
78
+ use_delay_pattern=False,
79
+ disable_audio_codes_transform=False,
80
+ chunk_size_seconds=30, # Maximum duration for each chunk
81
+ add_new_bos_eos_for_long_chunk=True,
82
+ mask_audio_out_token_label=True,
83
+ ):
84
+ self.whisper_processor = whisper_processor
85
+ self.round_to = round_to
86
+ self.pad_left = pad_left
87
+ self.audio_in_token_id = audio_in_token_id
88
+ self.audio_out_token_id = audio_out_token_id
89
+ self.audio_stream_bos_id = audio_stream_bos_id
90
+ self.audio_stream_eos_id = audio_stream_eos_id
91
+ self.pad_token_id = pad_token_id
92
+ self.encode_whisper_embed = encode_whisper_embed
93
+ self.return_audio_in_tokens = return_audio_in_tokens
94
+ self.audio_num_codebooks = audio_num_codebooks
95
+ self.use_delay_pattern = use_delay_pattern
96
+ if encode_whisper_embed:
97
+ self.chunk_size_seconds = chunk_size_seconds
98
+ self.chunk_size_samples = int(chunk_size_seconds * whisper_processor.feature_extractor.sampling_rate)
99
+ else:
100
+ self.chunk_size_seconds = None
101
+ self.chunk_size_samples = None
102
+ self.disable_audio_codes_transform = disable_audio_codes_transform
103
+ self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk
104
+ self.mask_audio_out_token_label = mask_audio_out_token_label
105
+
106
+ def _process_and_duplicate_audio_tokens(
107
+ self,
108
+ input_ids: torch.Tensor,
109
+ audio_idx: int,
110
+ wv: torch.Tensor,
111
+ sr: int,
112
+ labels: Optional[torch.Tensor] = None,
113
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
114
+ """Process long audio and duplicate corresponding audio tokens.
115
+
116
+ Args:
117
+ input_ids: Input token ids
118
+ audio_idx: Index of the audio token in the sequence
119
+ wv: Audio waveform
120
+ sr: Sample rate
121
+ labels: Optional label ids to be duplicated alongside input ids
122
+
123
+ Returns:
124
+ Tuple of:
125
+ - New input ids with duplicated audio tokens
126
+ - New label ids (if labels were provided) or None
127
+ - Number of chunks created
128
+ """
129
+ # Calculate number of chunks needed
130
+ total_samples = len(wv)
131
+ num_chunks = math.ceil(total_samples / self.chunk_size_samples)
132
+
133
+ if num_chunks <= 1:
134
+ return input_ids, labels, 1
135
+
136
+ # Get the three tokens: <|audio_bos|><|AUDIO|><|audio_eos|>
137
+ audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2]
138
+ # Duplicate sequence for each chunk
139
+ duplicated_sequence = audio_token_seq.repeat(num_chunks)
140
+
141
+ # Create new input_ids with duplicated tokens
142
+ new_input_ids = torch.cat(
143
+ [
144
+ input_ids[: audio_idx - 1],
145
+ duplicated_sequence,
146
+ input_ids[audio_idx + 2 :],
147
+ ]
148
+ )
149
+
150
+ # If labels are provided, duplicate them as well
151
+ new_labels = None
152
+ if labels is not None:
153
+ label_seq = labels[audio_idx - 1 : audio_idx + 2]
154
+ duplicated_labels = label_seq.repeat(num_chunks)
155
+ new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]])
156
+
157
+ return new_input_ids, new_labels, num_chunks
158
+
159
+ def __call__(self, batch: List[ChatMLDatasetSample]):
160
+ """Collate the input data with support for long audio processing."""
161
+
162
+ label_ids = None
163
+ label_audio_ids = None
164
+ if all([ele.label_ids is None for ele in batch]):
165
+ return_labels = False
166
+ else:
167
+ return_labels = True
168
+
169
+ if self.encode_whisper_embed:
170
+ # Process each sample in the batch to handle long audio
171
+ # TODO(?) The implementation here can be optimized.
172
+ processed_batch = []
173
+ for i in range(len(batch)):
174
+ sample = batch[i]
175
+ audio_in_mask = sample.input_ids == self.audio_in_token_id
176
+ audio_in_indices = torch.where(audio_in_mask)[0]
177
+ audio_out_mask = sample.input_ids == self.audio_out_token_id
178
+
179
+ # Process each audio token and duplicate if needed
180
+ modified_input_ids = sample.input_ids
181
+ modified_labels = sample.label_ids if return_labels else None
182
+ modified_waveforms_concat = []
183
+ modified_waveforms_start = []
184
+ modified_sample_rate = []
185
+ offset = 0 # Track position changes from duplicating tokens
186
+ curr_wv_offset = 0
187
+
188
+ # Process input audio tokens
189
+ for idx, audio_idx in enumerate(audio_in_indices):
190
+ # Get the audio for this token
191
+ wv, sr = sample.get_wv(idx) # Use idx since we want the original audio index
192
+ if sr != self.whisper_processor.feature_extractor.sampling_rate:
193
+ resampled_wv = librosa.resample(
194
+ wv.cpu().numpy(),
195
+ orig_sr=sr,
196
+ target_sr=self.whisper_processor.feature_extractor.sampling_rate,
197
+ )
198
+ else:
199
+ resampled_wv = wv.cpu().numpy()
200
+ wv = torch.tensor(resampled_wv, device=wv.device)
201
+ sr = self.whisper_processor.feature_extractor.sampling_rate
202
+
203
+ # Process and duplicate tokens if necessary
204
+ token_pos = audio_idx + offset
205
+ modified_input_ids, modified_labels, num_chunks = self._process_and_duplicate_audio_tokens(
206
+ modified_input_ids, token_pos, wv, sr, modified_labels
207
+ )
208
+
209
+ # Update audio data
210
+ for chunk_idx in range(num_chunks):
211
+ chunk_start = chunk_idx * self.chunk_size_samples
212
+ chunk_end = min((chunk_idx + 1) * self.chunk_size_samples, len(wv))
213
+ chunk_wv = wv[chunk_start:chunk_end]
214
+ modified_waveforms_concat.append(chunk_wv)
215
+ modified_waveforms_start.append(curr_wv_offset)
216
+ curr_wv_offset += len(chunk_wv)
217
+ modified_sample_rate.append(sr)
218
+
219
+ # Update offset for next iteration
220
+ offset += (num_chunks - 1) * 3 # Each new chunk adds 3 more tokens
221
+
222
+ # Create new sample with modified tokens and audio data
223
+ processed_sample = ChatMLDatasetSample(
224
+ input_ids=modified_input_ids,
225
+ label_ids=modified_labels if return_labels else sample.label_ids,
226
+ audio_ids_concat=sample.audio_ids_concat,
227
+ audio_ids_start=sample.audio_ids_start,
228
+ audio_waveforms_concat=torch.cat(modified_waveforms_concat)
229
+ if modified_waveforms_concat
230
+ else sample.audio_waveforms_concat,
231
+ audio_waveforms_start=torch.tensor(modified_waveforms_start, dtype=torch.long)
232
+ if modified_waveforms_start
233
+ else sample.audio_waveforms_start,
234
+ audio_sample_rate=torch.tensor(modified_sample_rate)
235
+ if modified_sample_rate
236
+ else sample.audio_sample_rate,
237
+ audio_speaker_indices=torch.tensor([]),
238
+ # FIXME(sxjscience): The logic here is not correct for audio_label_ids_concat.
239
+ audio_label_ids_concat=sample.audio_label_ids_concat,
240
+ )
241
+ # audio_in_chunk_len = len(torch.where(modified_input_ids == self.audio_in_token_id)[0])
242
+ # assert audio_in_chunk_len == processed_sample.num_audios(), f"Mismatch: audio_in_chunk_len={audio_in_chunk_len}, processed_sample.num_audios()={processed_sample.num_audios()}"
243
+ processed_batch.append(processed_sample)
244
+ else:
245
+ processed_batch = batch
246
+
247
+ # Get the max sequence length based on processed batch
248
+ max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to)
249
+
250
+ # Get the ids for audio-in and audio-out for each batch
251
+ audio_in_wv_l = []
252
+ audio_in_ids_l = []
253
+ audio_out_ids_l = []
254
+ audio_out_ids_group_loc_l = []
255
+ audio_in_label_ids_l = None
256
+ audio_out_label_ids_l = None
257
+ reward_l = []
258
+
259
+ if return_labels:
260
+ audio_out_no_train_flag = [] # Whether the audio-out data should be trained on or not.
261
+
262
+ # Process the audio inputs and outputs
263
+ for i in range(len(processed_batch)):
264
+ audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id
265
+ audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id
266
+ audio_ids = torch.ones_like(processed_batch[i].input_ids)
267
+ audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1
268
+ audio_in_ids = audio_ids[audio_in_mask]
269
+ audio_out_ids = audio_ids[audio_out_mask]
270
+
271
+ if return_labels:
272
+ audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0)
273
+ if self.mask_audio_out_token_label:
274
+ processed_batch[i].label_ids[audio_out_mask] = -100
275
+
276
+ # Process audio inputs
277
+ if self.return_audio_in_tokens:
278
+ audio_in_ids_l.extend(
279
+ [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids]
280
+ )
281
+ if processed_batch[i].audio_label_ids_concat is not None:
282
+ if audio_in_label_ids_l is None:
283
+ audio_in_label_ids_l = []
284
+ audio_in_label_ids_l.extend(
285
+ [
286
+ processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
287
+ for idx in audio_in_ids
288
+ ]
289
+ )
290
+
291
+ audio_out_ids_l.extend(
292
+ [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids]
293
+ )
294
+ audio_out_ids_group_loc_l.append(i)
295
+ if processed_batch[i].reward is not None:
296
+ reward_l.append(processed_batch[i].reward)
297
+
298
+ if processed_batch[i].audio_label_ids_concat is not None:
299
+ if audio_out_label_ids_l is None:
300
+ audio_out_label_ids_l = []
301
+ audio_out_label_ids_l.extend(
302
+ [
303
+ processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
304
+ for idx in audio_out_ids
305
+ ]
306
+ )
307
+
308
+ if self.encode_whisper_embed:
309
+ for idx in audio_in_ids:
310
+ wv, sr = processed_batch[i].get_wv(idx)
311
+ resampled_wv = wv.cpu().numpy()
312
+ # Split long audio into chunks
313
+ total_samples = len(resampled_wv)
314
+ for chunk_start in range(0, total_samples, self.chunk_size_samples):
315
+ chunk_end = min(chunk_start + self.chunk_size_samples, total_samples)
316
+ chunk = resampled_wv[chunk_start:chunk_end]
317
+ audio_in_wv_l.append(chunk)
318
+ # assert len(audio_in_wv_l) == processed_batch[i].num_audios(), \
319
+ # f"Assertion failed: Mismatch in number of audios. " \
320
+ # f"Expected {processed_batch[i].num_audios()}, but got {len(audio_in_wv_l)} at index {i}."
321
+
322
+ if return_labels:
323
+ audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
324
+
325
+ # Process all audio features
326
+ if len(audio_in_wv_l) > 0:
327
+ feature_ret = self.whisper_processor.feature_extractor(
328
+ audio_in_wv_l,
329
+ sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
330
+ return_attention_mask=True,
331
+ padding="max_length",
332
+ )
333
+ audio_features = torch.from_numpy(feature_ret["input_features"])
334
+ audio_feature_attention_mask = torch.from_numpy(feature_ret["attention_mask"])
335
+ else:
336
+ if self.encode_whisper_embed:
337
+ audio_features = torch.zeros(
338
+ (
339
+ 0,
340
+ self.whisper_processor.feature_extractor.feature_size,
341
+ self.whisper_processor.feature_extractor.nb_max_frames,
342
+ ),
343
+ dtype=torch.float32,
344
+ )
345
+ audio_feature_attention_mask = torch.zeros(
346
+ (0, self.whisper_processor.feature_extractor.nb_max_frames),
347
+ dtype=torch.int32,
348
+ )
349
+ else:
350
+ audio_features = None
351
+ audio_feature_attention_mask = None
352
+
353
+ # Process audio input tokens
354
+ if len(audio_in_ids_l) > 0:
355
+ # Append audio-stream-bos and eos tokens
356
+ new_audio_in_ids_l = []
357
+ for ele in audio_in_ids_l:
358
+ if self.disable_audio_codes_transform:
359
+ # Do not add audio-stream-bos or eos tokens.
360
+ # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
361
+ audio_codes = ele
362
+ else:
363
+ audio_codes = torch.cat(
364
+ [
365
+ torch.full(
366
+ (ele.shape[0], 1),
367
+ self.audio_stream_bos_id,
368
+ dtype=torch.long,
369
+ ),
370
+ ele,
371
+ torch.full(
372
+ (ele.shape[0], 1),
373
+ self.audio_stream_eos_id,
374
+ dtype=torch.long,
375
+ ),
376
+ ],
377
+ dim=1,
378
+ )
379
+ if self.use_delay_pattern:
380
+ audio_codes = build_delay_pattern_mask(
381
+ audio_codes.unsqueeze(0),
382
+ bos_token_id=self.audio_stream_bos_id,
383
+ pad_token_id=self.audio_stream_eos_id,
384
+ )[0].squeeze(0)
385
+ new_audio_in_ids_l.append(audio_codes)
386
+ audio_in_ids = torch.cat(new_audio_in_ids_l, dim=1).long()
387
+ audio_in_ids_start = torch.cumsum(
388
+ torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_in_ids_l[:-1]]),
389
+ dim=0,
390
+ )
391
+ else:
392
+ audio_in_ids = torch.zeros((0, 0), dtype=torch.long)
393
+ audio_in_ids_start = torch.zeros(0, dtype=torch.long)
394
+
395
+ # Process audio output tokens
396
+ audio_out_ids_start_group_loc = None
397
+ if len(audio_out_ids_l) > 0:
398
+ new_audio_out_ids_l = []
399
+ label_audio_ids_l = []
400
+ for idx, ele in enumerate(audio_out_ids_l):
401
+ if self.disable_audio_codes_transform:
402
+ # Do not add audio-stream-bos or eos tokens.
403
+ # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
404
+ audio_codes = ele
405
+ if return_labels:
406
+ label_audio_ids = audio_out_label_ids_l[idx]
407
+ else:
408
+ audio_codes = torch.cat(
409
+ [
410
+ torch.full(
411
+ (ele.shape[0], 1),
412
+ self.audio_stream_bos_id,
413
+ dtype=torch.long,
414
+ ),
415
+ ele,
416
+ torch.full(
417
+ (ele.shape[0], 1),
418
+ self.audio_stream_eos_id,
419
+ dtype=torch.long,
420
+ ),
421
+ ],
422
+ dim=1,
423
+ )
424
+ if return_labels:
425
+ label_audio_ids = torch.cat(
426
+ [
427
+ torch.full((ele.shape[0], 1), -100, dtype=torch.long),
428
+ ele,
429
+ torch.full(
430
+ (ele.shape[0], 1),
431
+ self.audio_stream_eos_id,
432
+ dtype=torch.long,
433
+ ),
434
+ ],
435
+ dim=1,
436
+ )
437
+ if self.use_delay_pattern:
438
+ audio_codes = build_delay_pattern_mask(
439
+ audio_codes.unsqueeze(0),
440
+ bos_token_id=self.audio_stream_bos_id,
441
+ pad_token_id=self.audio_stream_eos_id,
442
+ )[0].squeeze(0)
443
+ if return_labels:
444
+ label_audio_ids = build_delay_pattern_mask(
445
+ label_audio_ids.unsqueeze(0),
446
+ bos_token_id=-100,
447
+ pad_token_id=-100,
448
+ )[0].squeeze(0)
449
+ new_audio_out_ids_l.append(audio_codes)
450
+
451
+ if return_labels:
452
+ if audio_out_no_train_flag[idx]:
453
+ label_audio_ids[:] = -100
454
+ label_audio_ids_l.append(label_audio_ids)
455
+
456
+ audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long()
457
+ if return_labels:
458
+ label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long()
459
+ audio_out_ids_start = torch.cumsum(
460
+ torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]),
461
+ dim=0,
462
+ )
463
+ audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long)
464
+ else:
465
+ audio_out_ids = torch.zeros((0, 0), dtype=torch.long)
466
+ audio_out_ids_start = torch.zeros(0, dtype=torch.long)
467
+ if return_labels:
468
+ label_audio_ids = torch.zeros((0, 0), dtype=torch.long)
469
+
470
+ reward = torch.tensor(reward_l, dtype=torch.float32)
471
+
472
+ # Handle padding for input ids and attention mask
473
+ if self.pad_left:
474
+ input_ids = torch.stack(
475
+ [
476
+ F.pad(
477
+ ele.input_ids,
478
+ (max_seq_length - len(ele.input_ids), 0),
479
+ value=self.pad_token_id,
480
+ )
481
+ for ele in processed_batch
482
+ ]
483
+ )
484
+ if return_labels:
485
+ label_ids = torch.stack(
486
+ [
487
+ F.pad(
488
+ ele.label_ids,
489
+ (max_seq_length - len(ele.label_ids), 0),
490
+ value=-100,
491
+ )
492
+ for ele in processed_batch
493
+ ]
494
+ )
495
+ attention_mask = torch.stack(
496
+ [
497
+ F.pad(
498
+ torch.ones_like(ele.input_ids),
499
+ (max_seq_length - len(ele.input_ids), 0),
500
+ value=0,
501
+ )
502
+ for ele in processed_batch
503
+ ]
504
+ )
505
+ else:
506
+ input_ids = torch.stack(
507
+ [
508
+ F.pad(
509
+ ele.input_ids,
510
+ (0, max_seq_length - len(ele.input_ids)),
511
+ value=self.pad_token_id,
512
+ )
513
+ for ele in processed_batch
514
+ ]
515
+ )
516
+ if return_labels:
517
+ label_ids = torch.stack(
518
+ [
519
+ F.pad(
520
+ ele.label_ids,
521
+ (0, max_seq_length - len(ele.label_ids)),
522
+ value=-100,
523
+ )
524
+ for ele in processed_batch
525
+ ]
526
+ )
527
+ attention_mask = torch.stack(
528
+ [
529
+ F.pad(
530
+ torch.ones_like(ele.input_ids),
531
+ (0, max_seq_length - len(ele.input_ids)),
532
+ value=0,
533
+ )
534
+ for ele in processed_batch
535
+ ]
536
+ )
537
+
538
+ if not self.return_audio_in_tokens:
539
+ audio_in_ids = None
540
+ audio_in_ids_start = None
541
+
542
+ # Apply audio_num_codebooks limit if specified
543
+ if self.audio_num_codebooks is not None:
544
+ if audio_in_ids is not None:
545
+ audio_in_ids = audio_in_ids[: self.audio_num_codebooks]
546
+ if audio_out_ids is not None:
547
+ audio_out_ids = audio_out_ids[: self.audio_num_codebooks]
548
+ if label_audio_ids is not None:
549
+ label_audio_ids = label_audio_ids[: self.audio_num_codebooks]
550
+
551
+ return HiggsAudioBatchInput(
552
+ input_ids=input_ids,
553
+ attention_mask=attention_mask,
554
+ audio_features=audio_features,
555
+ audio_feature_attention_mask=audio_feature_attention_mask,
556
+ audio_out_ids=audio_out_ids,
557
+ audio_out_ids_start=audio_out_ids_start,
558
+ audio_out_ids_start_group_loc=audio_out_ids_start_group_loc,
559
+ audio_in_ids=audio_in_ids,
560
+ audio_in_ids_start=audio_in_ids_start,
561
+ label_ids=label_ids,
562
+ label_audio_ids=label_audio_ids,
563
+ reward=reward,
564
+ )
565
+
566
+
567
+ class HiggsAudioDPOSamplesCollator(HiggsAudioSampleCollator):
568
+ def __init__(self, *args, **kwargs):
569
+ super().__init__(*args, **kwargs)
570
+
571
+ def __call__(self, batch: List[RankedChatMLDatasetSampleTuple]) -> HiggsAudioBatchInput:
572
+ # flatten ranked chatml samples
573
+ chosen = []
574
+ rejected = []
575
+
576
+ for sample in batch:
577
+ chosen.append(sample.max_score_sample())
578
+ rejected.append(sample.min_score_sample())
579
+
580
+ merged = chosen
581
+ merged.extend(rejected)
582
+
583
+ return super().__call__(batch=merged)
higgs_audio/data_types.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Basic data types for multimodal ChatML format."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Optional, Union
5
+
6
+
7
+ @dataclass
8
+ class AudioContent:
9
+ audio_url: str
10
+ # Base64 encoded audio bytes
11
+ raw_audio: Optional[str] = None
12
+ offset: Optional[float] = None
13
+ duration: Optional[float] = None
14
+ row_id: Optional[int] = None
15
+ type: str = "audio"
16
+
17
+
18
+ @dataclass
19
+ class TextContent:
20
+ text: str
21
+ type: str = "text"
22
+
23
+
24
+ @dataclass
25
+ class Message:
26
+ role: str
27
+ content: Union[str, AudioContent, TextContent, List[Union[str, AudioContent, TextContent]]]
28
+ recipient: Optional[str] = None
29
+
30
+
31
+ @dataclass
32
+ class ChatMLSample:
33
+ """Dataclass to hold multimodal ChatML data."""
34
+
35
+ messages: List[Message]
36
+ start_index: Optional[int] = None # We will mask the messages[:start_index] when finetuning the LLM.
37
+ misc: Optional[Dict] = None
38
+ speaker: Optional[str] = None
higgs_audio/dataset/__init__.py ADDED
File without changes
higgs_audio/dataset/chatml_dataset.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dacite
2
+ import pandas as pd
3
+ import torch
4
+ import json
5
+
6
+ import numpy as np
7
+ import multiprocessing as mp
8
+
9
+ from dataclasses import dataclass, fields
10
+ from abc import ABC, abstractmethod
11
+ from typing import Union, List, Dict, Optional
12
+
13
+ from ..data_types import ChatMLSample, TextContent, AudioContent
14
+ from ..constants import AUDIO_IN_TOKEN, AUDIO_OUT_TOKEN
15
+
16
+ from loguru import logger
17
+
18
+ # Whisper processor, 30 sec -> 3000 features
19
+ # Then we divide 4 in the audio towker, we decrease 3000 features to 750, which gives 25 Hz
20
+ WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC = 25
21
+
22
+
23
+ @dataclass
24
+ class ChatMLDatasetSample:
25
+ input_ids: torch.LongTensor # Shape (seq_len,): The input text tokens.
26
+ label_ids: torch.LongTensor # Shape (seq_len,): The label ids.
27
+ audio_ids_concat: torch.LongTensor # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
28
+ # Here `audio_seq_len` is the length of the concatenated audio tokens.`
29
+ audio_ids_start: (
30
+ torch.LongTensor
31
+ ) # Shape (num_audios,): The start index of each audio token in the concatenated audio tokens.
32
+ audio_waveforms_concat: (
33
+ torch.Tensor
34
+ ) # Shape (total_wv_length,): The concatenated audio waveforms for audio-in features.
35
+ audio_waveforms_start: (
36
+ torch.LongTensor
37
+ ) # Shape (num_audios,): The start index of each audio waveform in the concatenated audio waveforms.
38
+ audio_sample_rate: torch.Tensor # Shape (num_audios,): The sampling rate of the audio waveforms.
39
+ audio_speaker_indices: (
40
+ torch.LongTensor
41
+ ) # Shape (num_audios,) -1 means unknown speaker: The speaker indices for each audio.
42
+ audio_label_ids_concat: Optional[torch.LongTensor] = (
43
+ None # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
44
+ )
45
+ # Here `audio_seq_len` is the length of the concatenated audio tokens.`
46
+ reward: Optional[float] = None
47
+
48
+ def num_audios(self):
49
+ return max(len(self.audio_waveforms_start), len(self.audio_ids_start))
50
+
51
+ def get_audio_codes(self, idx):
52
+ code_start = self.audio_ids_start[idx]
53
+ if idx < len(self.audio_ids_start) - 1:
54
+ code_end = self.audio_ids_start[idx + 1]
55
+ else:
56
+ code_end = self.audio_ids_concat.shape[-1]
57
+
58
+ return self.audio_ids_concat[:, code_start:code_end]
59
+
60
+ def get_audio_codes_labels(self, idx):
61
+ if self.audio_label_ids_concat is None:
62
+ return None
63
+ code_start = self.audio_ids_start[idx]
64
+ if idx < len(self.audio_ids_start) - 1:
65
+ code_end = self.audio_ids_start[idx + 1]
66
+ else:
67
+ code_end = self.audio_ids_concat.shape[-1]
68
+
69
+ return self.audio_label_ids_concat[:, code_start:code_end]
70
+
71
+ def get_wv(self, idx):
72
+ wv_start = self.audio_waveforms_start[idx]
73
+ sr = self.audio_sample_rate[idx]
74
+ if idx < len(self.audio_waveforms_start) - 1:
75
+ wv_end = self.audio_waveforms_start[idx + 1]
76
+ else:
77
+ wv_end = self.audio_waveforms_concat.shape[-1]
78
+ return self.audio_waveforms_concat[wv_start:wv_end], sr
79
+
80
+ def cal_num_tokens(
81
+ self,
82
+ encode_whisper_embed: bool = True,
83
+ encode_audio_in_tokens: bool = False,
84
+ encode_audio_out_tokens: bool = True,
85
+ audio_in_token_id: int = 128015,
86
+ audio_out_token_id: int = 128016,
87
+ ) -> int:
88
+ # we firstly exclude <|AUDIO|> and <|AUDIO_OUT|> because we do late merging and replace those position with actual audio features and audio token ids
89
+ # It's assumed that we always have audio_ids when audio_waveforms are there (but not vice-versa)
90
+ num_tokens = len(self.input_ids) - len(self.audio_ids_start)
91
+
92
+ if encode_whisper_embed and len(self.audio_waveforms_concat) > 0:
93
+ audio_lengths = torch.diff(self.audio_waveforms_start)
94
+ if len(audio_lengths):
95
+ # Sum before calling .item()
96
+ num_tokens += (
97
+ (
98
+ np.ceil(WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC * audio_lengths / self.audio_sample_rate[:-1])
99
+ ).sum()
100
+ ).item()
101
+ # add the last audio's token estimation
102
+ num_tokens += (
103
+ np.ceil(
104
+ WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC
105
+ * (self.audio_waveforms_concat.shape[0] - self.audio_waveforms_start[-1])
106
+ / self.audio_sample_rate[-1]
107
+ )
108
+ ).item()
109
+
110
+ if self.audio_ids_concat.size(1) > 0:
111
+ audio_io_ids = self.input_ids[
112
+ (self.input_ids == audio_in_token_id) | (self.input_ids == audio_out_token_id)
113
+ ]
114
+ audio_io_id_lengths = torch.concat(
115
+ [
116
+ torch.diff(self.audio_ids_start),
117
+ torch.tensor([self.audio_ids_concat.shape[-1] - self.audio_ids_start[-1]]),
118
+ ]
119
+ )
120
+ if encode_audio_in_tokens:
121
+ num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_in_token_id]).item()
122
+
123
+ if encode_audio_out_tokens:
124
+ num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_out_token_id]).item()
125
+
126
+ return int(num_tokens)
127
+
128
+ @classmethod
129
+ def merge(
130
+ cls,
131
+ samples: List["ChatMLDatasetSample"],
132
+ eos_token_id: int,
133
+ ignore_index: int,
134
+ padding_size: Optional[int] = None,
135
+ ) -> "ChatMLDatasetSample":
136
+ """Merges a list of ChatMLDatasetSample instances, inserting eos_token_id and ignore_index between them, and adjusting offsets for audio_ids_start and audio_waveforms_start.
137
+
138
+ Args:
139
+ samples (List[ChatMLDatasetSample]): List of samples to merge.
140
+ eos_token_id (int): Tokens to be inserted into input_ids between samples.
141
+ ignore_index (int): Default label for padding.
142
+ padding_size (Optional[int]): If provided, pad the sequence to with this length.
143
+
144
+ Returns:
145
+ ChatMLDatasetSample: Merged and potentially padded sample.
146
+ """
147
+ if not samples:
148
+ logger.fatal("The samples list is empty and cannot be merged.")
149
+ raise ValueError("The samples list is empty and cannot be merged.")
150
+
151
+ # Initialize empty lists for concatenation
152
+ input_ids_list = []
153
+ label_ids_list = []
154
+ audio_ids_concat_list = []
155
+ audio_ids_start_list = []
156
+ audio_waveforms_concat_list = []
157
+ audio_waveforms_start_list = []
158
+ audio_sample_rate_list = []
159
+ audio_speaker_indices_list = []
160
+
161
+ # Track offsets
162
+ audio_ids_offset = 0
163
+ audio_waveforms_offset = 0
164
+
165
+ for sample in samples:
166
+ # Add input_ids and label_ids with padding
167
+ if input_ids_list:
168
+ input_ids_list.append(torch.tensor([eos_token_id], dtype=torch.long))
169
+ label_ids_list.append(torch.tensor([ignore_index], dtype=torch.long))
170
+ input_ids_list.append(sample.input_ids)
171
+ label_ids_list.append(sample.label_ids)
172
+
173
+ # Add audio_ids_concat and handle empty audio ids
174
+ if sample.audio_ids_concat.size(1) > 0:
175
+ audio_ids_concat_list.append(sample.audio_ids_concat)
176
+
177
+ # Offset and add audio_ids_start
178
+ audio_ids_start_list.append(sample.audio_ids_start + audio_ids_offset)
179
+ audio_ids_offset += sample.audio_ids_concat.size(
180
+ 1
181
+ ) # (num_codebooks, seq_len): Update offset by audio_seq_len
182
+
183
+ # Add audio_waveforms_concat
184
+ if sample.audio_waveforms_concat.size(0) > 0:
185
+ # Check dimensions of the audio waveform to ensure consistency
186
+ if (
187
+ audio_waveforms_concat_list
188
+ and sample.audio_waveforms_concat.dim() != audio_waveforms_concat_list[0].dim()
189
+ ):
190
+ logger.warning(
191
+ f"Skipping audio waveform with inconsistent dimensions: expected {audio_waveforms_concat_list[0].dim()}D, got {sample.audio_waveforms_concat.dim()}D"
192
+ )
193
+ continue
194
+
195
+ audio_waveforms_concat_list.append(sample.audio_waveforms_concat)
196
+ audio_waveforms_start_list.append(sample.audio_waveforms_start + audio_waveforms_offset)
197
+ audio_waveforms_offset += sample.audio_waveforms_concat.size(0)
198
+
199
+ # Add audio_sample_rate and audio_speaker_indices
200
+ audio_sample_rate_list.append(sample.audio_sample_rate)
201
+
202
+ audio_speaker_indices_list.append(sample.audio_speaker_indices)
203
+
204
+ # Concatenate all tensors
205
+ input_ids = torch.cat(input_ids_list, dim=0)
206
+ label_ids = torch.cat(label_ids_list, dim=0)
207
+
208
+ # Apply padding if padding_size is specified
209
+ if padding_size is not None and padding_size > 0:
210
+ input_ids = torch.cat(
211
+ [
212
+ input_ids,
213
+ torch.full((padding_size,), eos_token_id, dtype=torch.long),
214
+ ],
215
+ dim=0,
216
+ )
217
+ label_ids = torch.cat(
218
+ [
219
+ label_ids,
220
+ torch.full((padding_size,), ignore_index, dtype=torch.long),
221
+ ],
222
+ dim=0,
223
+ )
224
+
225
+ # Safely concatenate audio tensors with proper error handling
226
+ try:
227
+ audio_ids_concat = torch.cat(audio_ids_concat_list, dim=1) if audio_ids_concat_list else torch.tensor([[]])
228
+ audio_ids_start = torch.cat(audio_ids_start_list, dim=0) if audio_ids_start_list else torch.tensor([])
229
+
230
+ # Check for dimensional consistency in audio waveforms
231
+ if audio_waveforms_concat_list:
232
+ dims = [t.dim() for t in audio_waveforms_concat_list]
233
+ if not all(d == dims[0] for d in dims):
234
+ # If dimensions don't match, log warning and filter out the problematic tensors
235
+ logger.warning(
236
+ f"Inconsistent dimensions in audio waveforms: {dims}. Filtering to keep only consistent ones."
237
+ )
238
+ expected_dim = max(set(dims), key=dims.count) # Most common dimension
239
+ audio_waveforms_concat_list = [t for t in audio_waveforms_concat_list if t.dim() == expected_dim]
240
+
241
+ # Recalculate audio_waveforms_start with the filtered list
242
+ if audio_waveforms_concat_list:
243
+ audio_waveforms_offset = 0
244
+ audio_waveforms_start_list = []
245
+ for waveform in audio_waveforms_concat_list:
246
+ audio_waveforms_start_list.append(torch.tensor([audio_waveforms_offset]))
247
+ audio_waveforms_offset += waveform.size(0)
248
+
249
+ audio_waveforms_concat = (
250
+ torch.cat(audio_waveforms_concat_list, dim=0) if audio_waveforms_concat_list else torch.tensor([])
251
+ )
252
+ audio_waveforms_start = (
253
+ torch.cat(audio_waveforms_start_list, dim=0) if audio_waveforms_start_list else torch.tensor([])
254
+ )
255
+ audio_sample_rate = (
256
+ torch.cat(audio_sample_rate_list, dim=0) if audio_sample_rate_list else torch.tensor([])
257
+ )
258
+ audio_speaker_indices = (
259
+ torch.cat(audio_speaker_indices_list, dim=0) if audio_speaker_indices_list else torch.tensor([])
260
+ )
261
+
262
+ except RuntimeError as e:
263
+ logger.error(f"Error during tensor concatenation: {str(e)}")
264
+ logger.warning("Falling back to empty audio tensors")
265
+ # Fall back to empty tensors
266
+ audio_ids_concat = torch.tensor([[]])
267
+ audio_ids_start = torch.tensor([])
268
+ audio_waveforms_concat = torch.tensor([])
269
+ audio_waveforms_start = torch.tensor([])
270
+ audio_sample_rate = torch.tensor([])
271
+ audio_speaker_indices = torch.tensor([])
272
+
273
+ # Create the merged sample
274
+ merged_sample = cls(
275
+ input_ids=input_ids,
276
+ label_ids=label_ids,
277
+ audio_ids_concat=audio_ids_concat,
278
+ audio_ids_start=audio_ids_start,
279
+ audio_waveforms_concat=audio_waveforms_concat,
280
+ audio_waveforms_start=audio_waveforms_start,
281
+ audio_sample_rate=audio_sample_rate,
282
+ audio_speaker_indices=audio_speaker_indices,
283
+ )
284
+
285
+ return merged_sample
286
+
287
+
288
+ @dataclass
289
+ class RankedChatMLDatasetSampleTuple:
290
+ samples: List[ChatMLDatasetSample]
291
+ scores: List[float]
292
+
293
+ def max_score_sample(self) -> ChatMLDatasetSample:
294
+ idx = self.scores.index(max(self.scores))
295
+ self.samples[idx].reward = self.scores[idx]
296
+ return self.samples[idx]
297
+
298
+ def min_score_sample(self) -> ChatMLDatasetSample:
299
+ idx = self.scores.index(min(self.scores))
300
+ self.samples[idx].reward = self.scores[idx]
301
+ return self.samples[idx]
302
+
303
+
304
+ @dataclass
305
+ class ChatMLDatasetStorageSample:
306
+ input_tokens: torch.LongTensor
307
+ label_tokens: torch.LongTensor
308
+ audio_bytes_cache_dir_index: int
309
+ audio_codes_cache_dir_index: int
310
+ audio_bytes_indices: torch.LongTensor
311
+ audio_codes_indices: torch.LongTensor
312
+ speaker_indices: torch.LongTensor
313
+ file_index: int
314
+ original_sample_index: int
315
+
316
+
317
+ # TODO(sxjscience): We need to revist the logic about parsing speaker ids.
318
+ # Currently, we assume that the speaker id is stored at the "misc" field in ChatMLSample.
319
+ def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
320
+ """Preprocess the ChatML sample to get the tokens for the text part.
321
+
322
+ Args:
323
+ sample (ChatMLSample): The ChatML sample to preprocess.
324
+ tokenizer: The tokenizer to use for encoding the text.
325
+
326
+ """
327
+
328
+ try:
329
+ if not isinstance(sample, ChatMLSample):
330
+ # Handle all fields that could be NaN
331
+ if "speaker" in sample and pd.isna(sample["speaker"]):
332
+ sample["speaker"] = None
333
+ if "start_index" in sample and pd.isna(sample["start_index"]):
334
+ sample["start_index"] = None
335
+ if "content" in sample and pd.isna(sample["content"]):
336
+ sample["content"] = ""
337
+
338
+ # Convert any other potential NaN values in nested structures
339
+ def convert_nan_to_none(obj):
340
+ import numpy as np
341
+
342
+ if isinstance(obj, (pd.Series, np.ndarray)):
343
+ return obj.tolist()
344
+ elif pd.api.types.is_scalar(obj) and pd.isna(obj):
345
+ return None
346
+ elif isinstance(obj, dict):
347
+ return {k: convert_nan_to_none(v) for k, v in obj.items()}
348
+ elif isinstance(obj, (list, tuple)): # Fixed: Handle both list and tuple
349
+ return [convert_nan_to_none(item) for item in obj]
350
+ return obj
351
+
352
+ # Clean the sample data
353
+ clean_sample = convert_nan_to_none(sample)
354
+
355
+ val_keys = []
356
+ for field in fields(ChatMLSample):
357
+ if field.name in clean_sample:
358
+ val_keys.append(field.name)
359
+ clean_sample = {k: clean_sample[k] for k in val_keys}
360
+
361
+ try:
362
+ sample = dacite.from_dict(
363
+ data_class=ChatMLSample,
364
+ data=clean_sample,
365
+ config=dacite.Config(strict=True, check_types=True),
366
+ )
367
+ except Exception as e:
368
+ print(f"Failed to convert to ChatMLSample: {e}")
369
+ print(f"Clean sample: {json.dumps(clean_sample, indent=2)}")
370
+ return None, None, None, None
371
+
372
+ input_tokens = []
373
+ label_tokens = []
374
+ audio_contents = []
375
+ speaker_id = None
376
+ if sample.speaker is not None:
377
+ speaker_id = sample.speaker
378
+ elif sample.misc is not None:
379
+ if "speaker" in sample.misc:
380
+ speaker_id = sample.misc["speaker"]
381
+
382
+ total_m = len(sample.messages)
383
+ for turn_id, message in enumerate(sample.messages):
384
+ role = message.role
385
+ recipient = message.recipient
386
+ content = message.content
387
+ content_l = []
388
+
389
+ if isinstance(content, str):
390
+ content_l.append(TextContent(text=content))
391
+ elif isinstance(content, TextContent):
392
+ content_l.append(content)
393
+ elif isinstance(content, AudioContent):
394
+ content_l.append(content)
395
+ elif isinstance(content, list):
396
+ for ele in content:
397
+ if isinstance(ele, str):
398
+ content_l.append(TextContent(text=ele))
399
+ else:
400
+ content_l.append(ele)
401
+ if turn_id == 0:
402
+ prefix = f"<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n"
403
+ else:
404
+ prefix = f"<|start_header_id|>{role}<|end_header_id|>\n\n"
405
+ eot_postfix = "<|eot_id|>"
406
+ eom_postfix = "<|eom_id|>"
407
+
408
+ prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
409
+ input_tokens.extend(prefix_tokens)
410
+ label_tokens.extend([-100 for _ in prefix_tokens])
411
+
412
+ if recipient:
413
+ assert role == "assistant", "Recipient is only available for assistant role."
414
+ recipient_tokens = tokenizer.encode(f"{recipient}<|recipient|>", add_special_tokens=False)
415
+ input_tokens.extend(recipient_tokens)
416
+ label_tokens.extend(recipient_tokens)
417
+
418
+ for content in content_l:
419
+ if content.type == "text":
420
+ text_tokens = tokenizer.encode(content.text, add_special_tokens=False)
421
+ input_tokens.extend(text_tokens)
422
+ if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
423
+ label_tokens.extend(text_tokens)
424
+ else:
425
+ label_tokens.extend([-100 for _ in text_tokens])
426
+
427
+ elif content.type == "audio":
428
+ # Generate the text-part of the audio tokens
429
+ audio_contents.append(content)
430
+ if role == "user" or role == "system":
431
+ # Add the text tokens
432
+ text_tokens = tokenizer.encode(
433
+ f"<|audio_bos|><|AUDIO|><|audio_eos|>",
434
+ add_special_tokens=False,
435
+ )
436
+ input_tokens.extend(text_tokens)
437
+ label_tokens.extend([-100 for _ in text_tokens])
438
+ elif role == "assistant":
439
+ # Add the text tokens for audio-out part.
440
+ text_tokens = tokenizer.encode(
441
+ f"<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>",
442
+ add_special_tokens=False,
443
+ )
444
+ input_tokens.extend(text_tokens)
445
+ if sample.start_index is None or turn_id >= sample.start_index:
446
+ label_tokens.extend(text_tokens)
447
+ else:
448
+ label_tokens.extend([-100 for _ in text_tokens])
449
+ next_id = turn_id + 1
450
+ if role == "assistant" and next_id != total_m and sample.messages[next_id].role == "assistant":
451
+ postfix_tokens = tokenizer.encode(eom_postfix, add_special_tokens=False)
452
+ input_tokens.extend(postfix_tokens)
453
+ else:
454
+ postfix_tokens = tokenizer.encode(eot_postfix, add_special_tokens=False)
455
+ input_tokens.extend(postfix_tokens)
456
+ if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
457
+ label_tokens.extend(postfix_tokens)
458
+ else:
459
+ label_tokens.extend([-100 for _ in postfix_tokens])
460
+
461
+ return input_tokens, label_tokens, audio_contents, speaker_id
462
+
463
+ except Exception as e:
464
+ print(f"Error in prepare_chatml_sample: {str(e)}")
465
+ print(f"Sample data: {json.dumps(sample, indent=2)}")
466
+ return None, None, None, None
467
+
468
+
469
+ def extract_generation_prompt_from_input_tokens(input_tokens, tokenizer):
470
+ """Extract the generation prompt and reference answer from the input tokens.
471
+
472
+ For example:
473
+
474
+ Input Text = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
475
+ What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
476
+ <|start_header_id|>assistant<|end_header_id|>\n\nAt first they went by quick, too quick to even get.<|eot_id|>'
477
+
478
+ -->
479
+
480
+ Prompt = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
481
+ What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
482
+ <|start_header_id|>assistant<|end_header_id|>\n\n',
483
+ Reference = 'At first they went by quick, too quick to even get.'
484
+
485
+ Args:
486
+ input_tokens: The input tokens.
487
+ audio_contents: The audio contents.
488
+ tokenizer: The tokenizer to use for decoding the text.
489
+
490
+ Returns:
491
+ prompt_tokens: The tokens for the prompt.
492
+ reference_answer: The reference answer.
493
+ num_audios_in_reference: The number of audios in the reference answer.
494
+
495
+ """
496
+ input_text = tokenizer.decode(input_tokens)
497
+ generation_prefix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
498
+ postfix = "<|eot_id|>"
499
+ assert generation_prefix in input_text
500
+ generation_prompt_end_loc = input_text.rfind(generation_prefix) + len(generation_prefix)
501
+ generation_prompt = input_text[:generation_prompt_end_loc]
502
+ reference_answer = input_text[generation_prompt_end_loc : input_text.find(postfix, generation_prompt_end_loc)]
503
+ num_audios_in_reference = reference_answer.count(AUDIO_IN_TOKEN) + reference_answer.count(AUDIO_OUT_TOKEN)
504
+ return (
505
+ tokenizer.encode(generation_prompt, add_special_tokens=False),
506
+ reference_answer,
507
+ num_audios_in_reference,
508
+ )
509
+
510
+
511
+ def prepare_chatml_dataframe_single_process(df, tokenizer):
512
+ """Prepare the ChatML DataFrame."""
513
+ ret = []
514
+ for _, row in df.iterrows():
515
+ input_tokens, label_tokens, audio_contents, speaker_id = prepare_chatml_sample(row.to_dict(), tokenizer)
516
+ ret.append((input_tokens, label_tokens, audio_contents, speaker_id))
517
+ return ret
518
+
519
+
520
+ def prepare_chatml_dataframe(df, tokenizer, num_process=16):
521
+ if num_process is None:
522
+ return prepare_chatml_dataframe_single_process(df, tokenizer)
523
+ else:
524
+ num_process = max(min(len(df) // 1000, num_process), 1)
525
+ workloads = np.array_split(df, num_process)
526
+ with mp.Pool(num_process) as pool:
527
+ ret = pool.starmap(
528
+ prepare_chatml_dataframe_single_process,
529
+ [(workload, tokenizer) for workload in workloads],
530
+ )
531
+ return sum(ret, [])
532
+
533
+
534
+ class DatasetInterface(ABC):
535
+ @abstractmethod
536
+ def __getitem__(self, idx) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
537
+ """Retrieve a dataset sample by index."""
538
+ raise NotImplementedError
539
+
540
+
541
+ class IterableDatasetInterface(ABC):
542
+ @abstractmethod
543
+ def __iter__(
544
+ self,
545
+ ) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
546
+ """Retrieve a sample by iterating through the dataset."""
547
+ raise NotImplementedError
548
+
549
+
550
+ @dataclass
551
+ class DatasetInfo:
552
+ dataset_type: str
553
+ group_type: Optional[str] = None
554
+ mask_text: Optional[bool] = None # Whether to mask the text tokens for pretraining samples.
higgs_audio/model/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModel
2
+
3
+ from .configuration_higgs_audio import HiggsAudioConfig, HiggsAudioEncoderConfig
4
+ from .modeling_higgs_audio import HiggsAudioModel
5
+
6
+
7
+ AutoConfig.register("higgs_audio_encoder", HiggsAudioEncoderConfig)
8
+ AutoConfig.register("higgs_audio", HiggsAudioConfig)
9
+ AutoModel.register(HiggsAudioConfig, HiggsAudioModel)
higgs_audio/model/audio_head.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Projector that maps hidden states from the LLM component to multimodal logits."""
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional, Tuple
8
+
9
+ from .common import HiggsAudioPreTrainedModel
10
+ from .configuration_higgs_audio import HiggsAudioConfig
11
+
12
+
13
+ @dataclass
14
+ class HiggsAudioDecoderLayerOutput:
15
+ logits: torch.FloatTensor
16
+ audio_logits: torch.FloatTensor
17
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
18
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
19
+
20
+
21
+ class HiggsAudioDecoderProjector(HiggsAudioPreTrainedModel):
22
+ """Projection layers that map hidden states from the LLM component to audio / text logits.
23
+
24
+ We support two type of audio head:
25
+ - Basic Audio Head:
26
+ Directly map the hidden states to audio logits for all the codebooks.
27
+ """
28
+
29
+ def __init__(self, config: HiggsAudioConfig, layer_idx: Optional[int] = None):
30
+ super().__init__(config)
31
+ self.text_lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
32
+ self.audio_lm_head = nn.Linear(
33
+ config.text_config.hidden_size,
34
+ config.audio_num_codebooks * (config.audio_codebook_size + 2),
35
+ bias=False,
36
+ )
37
+
38
+ # Initialize weights and apply final processing
39
+ self.post_init()
40
+
41
+ def forward(
42
+ self,
43
+ hidden_states,
44
+ audio_out_mask,
45
+ label_audio_ids=None,
46
+ attention_mask=None,
47
+ position_ids=None,
48
+ past_key_values=None,
49
+ use_cache=None,
50
+ output_attentions=None,
51
+ output_hidden_states=None,
52
+ output_audio_hidden_states=False,
53
+ cache_position=None,
54
+ ):
55
+ """
56
+ Args:
57
+ hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
58
+ Hidden states from the LLM component
59
+ audio_out_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
60
+ Mask for identifying the audio out tokens.
61
+ label_audio_ids (`torch.Tensor` of shape `(num_codebooks, num_audio_out_tokens)`):
62
+ Label tokens for the audio-out part. This is used for calculating the logits if RQ-Transformer is used.
63
+ attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
64
+ Mask to avoid performing attention on padding token indices
65
+ position_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
66
+ Position ids for the input tokens
67
+
68
+ Returns:
69
+ logits (`torch.Tensor` of shape `(batch_size, seq_len, vocab_size)`):
70
+ Logits for text tokens
71
+ audio_logits (`torch.Tensor` of shape `(num_audio_out_tokens, audio_num_codebooks * audio_codebook_size)`):
72
+ Logits for audio tokens. We ensure `num_text_tokens + num_audio_tokens == batch_size * seq_len`
73
+ """
74
+ logits = self.text_lm_head(hidden_states)
75
+
76
+ all_hidden_states = () if output_hidden_states else None
77
+ all_self_attns = () if output_attentions else None
78
+ next_decoder_cache = None
79
+
80
+ # TODO(sxjscience) Need to check if DeepSpeed Zero3 supports zero-shape input.
81
+ if self.config.audio_decoder_proj_num_layers > 0:
82
+ # create position embeddings to be shared across the decoder layers
83
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
84
+ for decoder_layer in self.transformer_layers:
85
+ if output_hidden_states:
86
+ all_hidden_states += (hidden_states,)
87
+
88
+ if self.gradient_checkpointing and self.training:
89
+ layer_outputs = self._gradient_checkpointing_func(
90
+ decoder_layer.__call__,
91
+ hidden_states,
92
+ attention_mask,
93
+ position_ids,
94
+ past_key_values,
95
+ output_attentions,
96
+ use_cache,
97
+ cache_position,
98
+ position_embeddings,
99
+ )
100
+ else:
101
+ layer_outputs = decoder_layer(
102
+ hidden_states,
103
+ attention_mask=attention_mask,
104
+ position_ids=position_ids,
105
+ past_key_value=past_key_values,
106
+ output_attentions=output_attentions,
107
+ use_cache=use_cache,
108
+ cache_position=cache_position,
109
+ position_embeddings=position_embeddings,
110
+ )
111
+ hidden_states = layer_outputs[0]
112
+ hidden_states = self.norm(hidden_states)
113
+
114
+ if output_hidden_states:
115
+ all_hidden_states += (hidden_states,)
116
+
117
+ if output_attentions:
118
+ all_self_attns += (layer_outputs[1],)
119
+
120
+ if use_cache:
121
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
122
+
123
+ next_cache = next_decoder_cache if use_cache else None
124
+
125
+ audio_logits = self.audio_lm_head(hidden_states[audio_out_mask])
126
+
127
+ if output_audio_hidden_states:
128
+ audio_hidden_states = hidden_states[audio_out_mask]
129
+ else:
130
+ audio_hidden_states = None
131
+
132
+ return (
133
+ logits,
134
+ audio_logits,
135
+ all_self_attns,
136
+ all_hidden_states,
137
+ audio_hidden_states,
138
+ next_cache,
139
+ )
higgs_audio/model/common.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from transformers.modeling_utils import PreTrainedModel
4
+
5
+ from .configuration_higgs_audio import HiggsAudioConfig
6
+
7
+
8
+ class HiggsAudioPreTrainedModel(PreTrainedModel):
9
+ config_class = HiggsAudioConfig
10
+ base_model_prefix = "model"
11
+ supports_gradient_checkpointing = True
12
+ _no_split_modules = []
13
+ _skip_keys_device_placement = "past_key_values"
14
+ _supports_flash_attn_2 = True
15
+ _supports_sdpa = True
16
+
17
+ def _init_weights(self, module):
18
+ std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_encoder_config.init_std
19
+
20
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
21
+ module.weight.data.normal_(mean=0.0, std=std)
22
+ if module.bias is not None:
23
+ module.bias.data.zero_()
24
+ elif isinstance(module, nn.Embedding):
25
+ module.weight.data.normal_(mean=0.0, std=std)
26
+ if module.padding_idx is not None:
27
+ module.weight.data[module.padding_idx].zero_()
higgs_audio/model/configuration_higgs_audio.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.models.auto import CONFIG_MAPPING
3
+
4
+
5
+ class HiggsAudioEncoderConfig(PretrainedConfig):
6
+ """Configuration of the Audio encoder in Higgs-Audio."""
7
+
8
+ model_type = "higgs_audio_encoder"
9
+
10
+ def __init__(
11
+ self,
12
+ num_mel_bins=128,
13
+ encoder_layers=32,
14
+ encoder_attention_heads=20,
15
+ encoder_ffn_dim=5120,
16
+ encoder_layerdrop=0.0,
17
+ d_model=1280,
18
+ dropout=0.0,
19
+ attention_dropout=0.0,
20
+ activation_function="gelu",
21
+ activation_dropout=0.0,
22
+ scale_embedding=False,
23
+ init_std=0.02,
24
+ max_source_positions=1500,
25
+ pad_token_id=128001,
26
+ **kwargs,
27
+ ):
28
+ super().__init__(**kwargs)
29
+
30
+ self.num_mel_bins = num_mel_bins
31
+ self.d_model = d_model
32
+ self.encoder_layers = encoder_layers
33
+ self.encoder_attention_heads = encoder_attention_heads
34
+ self.encoder_ffn_dim = encoder_ffn_dim
35
+ self.dropout = dropout
36
+ self.attention_dropout = attention_dropout
37
+ self.activation_function = activation_function
38
+ self.activation_dropout = activation_dropout
39
+ self.encoder_layerdrop = encoder_layerdrop
40
+ self.num_hidden_layers = encoder_layers
41
+ self.init_std = init_std
42
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
43
+ self.max_source_positions = max_source_positions
44
+ self.pad_token_id = pad_token_id
45
+
46
+
47
+ class HiggsAudioConfig(PretrainedConfig):
48
+ r"""
49
+ This is the configuration class for the HiggsAudioModel.
50
+
51
+ Args:
52
+ text_config (`Union[AutoConfig, dict]`):
53
+ The config object or dictionary of the text backbone.
54
+ audio_encoder_config (`Union[AutoConfig, dict]`):
55
+ The config object or dictionary of the whisper encoder.
56
+ The audio encoder will be bidirectional and will be only available for audio understanding.
57
+ audio_tokenizer_config
58
+ The config object or dictionary of the audio tokenizer.
59
+ audio_adapter_type
60
+ The type of audio adapter to use. We support two types of adapter:
61
+ - stack:
62
+ We stack additional Transformer layers after the main LLM backbone for audio generation.
63
+ - dual_ffn:
64
+ For selected part of the LLM backbone, we replace the text FFN with a dual FFN architecture
65
+ that contains an additional audio FFN. The audio FFN will be triggered when the location is marked for audio tokens.
66
+ - dual_ffn_fast_forward:
67
+ We pick a few layers in the LLM backbone to plug-in the audio FFN. For the remaining layers,
68
+ the audio hidden states will be directly fast-forward to the next layer.
69
+ This reduces the computational cost for audio generation.
70
+ audio_embed_avg (`bool`, *optional*, defaults to False):
71
+ Whether to average the audio embeddings before sending them to the text attention layer.
72
+ audio_ffn_hidden_size
73
+ The hidden size of the audio feedforward network in dual-path FFN
74
+ audio_ffn_intermediate_size
75
+ The intermediate size of the audio feedforward network in dual-path FFN
76
+ audio_dual_ffn_layers
77
+ The layers in the LLM backbone to plug-in the dual FFN layer (mixture of audio FFN and text FFN).
78
+ audio_decoder_proj_num_attention (`int`, *optional*, defaults to 0):
79
+ The number of attention heads in the audio decoder projection layer.
80
+ use_delay_pattern (`bool`, *optional*, defaults to False):
81
+ Whether to use delay pattern in the audio decoder.
82
+ skip_audio_tower (`bool`, *optional*, defaults to False):
83
+ Whether to skip the audio tower in the audio encoder.
84
+ use_audio_out_embed_projector (`bool`, *optional*, defaults to False):
85
+ Whether to use an embedding projector to map audio out embeddings.
86
+ use_audio_out_self_attention (`bool`, *optional*, defaults to False):
87
+ Whether to use self-attention to aggregate information from audio-tokens before sending to the text attention layer.
88
+ audio_num_codebooks (`int`, *optional*, defaults to 12):
89
+ The number of codebooks in RVQGAN.
90
+ audio_codebook_size (`int`, *optional*, defaults to 1024):
91
+ The size of each codebook in RVQGAN.
92
+ audio_stream_bos_id
93
+ The id of the bos in the audio stream
94
+ audio_stream_eos_id
95
+ The id of the eos in the audio stream
96
+ audio_bos_token (`str`, *optional*, defaults to "<|audio_bos|>"):
97
+ The special `<|audio_bos|>` token. In Higgs-Audio, it is mapped to 128011,
98
+ which is the index of `<|reserved_special_token_3|>` in Llama-3.1-8B-Instruct's tokenizer.
99
+ audio_eos_token (`str`, *optional*, defaults to "<|audio_eos|>"):
100
+ The special `<|audio_eos|>` token. We use 128012 as the default value,
101
+ which is the index of `<|reserved_special_token_4|>` in Llama-3.1-8B-Instruct's tokenizer.
102
+ audio_out_bos_token (`str`, *optional*, defaults to "<|audio_out_bos|>"):
103
+ The special `<|audio_out_bos|>` token. We use 128013 as the default value,
104
+ which is the index of `<|reserved_special_token_5|>` in Llama-3.1-8B-Instruct's tokenizer.
105
+ audio_token (`str`, *optional*, defaults to "<|AUDIO|>"):
106
+ The special `<|AUDIO|>` token. We use 128015 as the default value,
107
+ which is the index of `<|reserved_special_token_7|>` in Llama-3.1-8B-Instruct's tokenizer.
108
+ This token indicates that the location should be filled in with whisper features.
109
+ audio_out_token (`str`, *optional*, defaults to "<|AUDIO_OUT|>"):
110
+ The special `<|AUDIO_OUT|>` token. We use 128016 as the default value,
111
+ which is the index of `<|reserved_special_token_8|>` in Llama-3.1-8B-Instruct's tokenizer.
112
+ This token indicates that the location should be filled in with audio tokens extracted via audio tokenizer.
113
+ """
114
+
115
+ model_type = "higgs_audio"
116
+ is_composition = True
117
+
118
+ def __init__(
119
+ self,
120
+ text_config=None,
121
+ audio_encoder_config=None,
122
+ audio_tokenizer_config=None,
123
+ audio_adapter_type="stack",
124
+ audio_embed_avg=False,
125
+ audio_ffn_hidden_size=4096,
126
+ audio_ffn_intermediate_size=14336,
127
+ audio_dual_ffn_layers=None,
128
+ audio_decoder_proj_num_layers=0,
129
+ encode_whisper_embed=True,
130
+ encode_audio_in_tokens=False,
131
+ use_delay_pattern=False,
132
+ skip_audio_tower=False,
133
+ use_audio_out_embed_projector=False,
134
+ use_audio_out_self_attention=False,
135
+ use_rq_transformer=False,
136
+ rq_transformer_hidden_size=None,
137
+ rq_transformer_intermediate_size=None,
138
+ rq_transformer_num_attention_heads=None,
139
+ rq_transformer_num_key_value_heads=None,
140
+ rq_transformer_num_hidden_layers=3,
141
+ audio_num_codebooks=12,
142
+ audio_codebook_size=1024,
143
+ audio_stream_bos_id=1024,
144
+ audio_stream_eos_id=1025,
145
+ audio_bos_token="<|audio_bos|>",
146
+ audio_eos_token="<|audio_eos|>",
147
+ audio_out_bos_token="<|audio_out_bos|>",
148
+ audio_in_token="<|AUDIO|>",
149
+ audio_out_token="<|AUDIO_OUT|>",
150
+ audio_in_token_idx=128015,
151
+ audio_out_token_idx=128016,
152
+ pad_token_id=128001,
153
+ audio_out_bos_token_id=128013,
154
+ audio_eos_token_id=128012,
155
+ **kwargs,
156
+ ):
157
+ if isinstance(audio_encoder_config, dict):
158
+ audio_encoder_config["model_type"] = (
159
+ audio_encoder_config["model_type"] if "model_type" in audio_encoder_config else "higgs_audio_encoder"
160
+ )
161
+ audio_encoder_config = CONFIG_MAPPING[audio_encoder_config["model_type"]](**audio_encoder_config)
162
+ elif audio_encoder_config is None:
163
+ audio_encoder_config = HiggsAudioEncoderConfig()
164
+
165
+ if isinstance(text_config, dict):
166
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
167
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
168
+ elif text_config is None:
169
+ text_config = CONFIG_MAPPING["llama"]()
170
+
171
+ assert audio_adapter_type in [
172
+ "stack",
173
+ "dual_ffn",
174
+ "dual_ffn_fast_forward",
175
+ ], f"Invalid audio adapter type: {audio_adapter_type}"
176
+ if audio_adapter_type.startswith("dual_ffn"):
177
+ assert audio_dual_ffn_layers is not None, (
178
+ "audio_dual_ffn_layers must be specified when using dual_ffn adapter."
179
+ )
180
+ self.text_config = text_config
181
+ self.audio_encoder_config = audio_encoder_config
182
+ self.audio_tokenizer_config = audio_tokenizer_config
183
+ self.audio_adapter_type = audio_adapter_type
184
+ self.audio_embed_avg = audio_embed_avg
185
+ self.audio_ffn_hidden_size = audio_ffn_hidden_size
186
+ self.audio_ffn_intermediate_size = audio_ffn_intermediate_size
187
+ self.audio_dual_ffn_layers = audio_dual_ffn_layers
188
+ self.audio_decoder_proj_num_layers = audio_decoder_proj_num_layers
189
+ self.encode_whisper_embed = encode_whisper_embed
190
+ self.encode_audio_in_tokens = encode_audio_in_tokens
191
+ self.use_delay_pattern = use_delay_pattern
192
+ self.skip_audio_tower = skip_audio_tower
193
+ self.use_audio_out_embed_projector = use_audio_out_embed_projector
194
+ self.use_audio_out_self_attention = use_audio_out_self_attention
195
+
196
+ self.use_rq_transformer = use_rq_transformer
197
+
198
+ if self.use_rq_transformer:
199
+ assert not self.use_delay_pattern, "Delay pattern is not supported if you turned on RQ-Transformer!"
200
+ self.rq_transformer_hidden_size = rq_transformer_hidden_size
201
+ self.rq_transformer_intermediate_size = rq_transformer_intermediate_size
202
+ self.rq_transformer_num_attention_heads = rq_transformer_num_attention_heads
203
+ self.rq_transformer_num_key_value_heads = rq_transformer_num_key_value_heads
204
+ self.rq_transformer_num_hidden_layers = rq_transformer_num_hidden_layers
205
+
206
+ if use_rq_transformer:
207
+ # For RQ-Transformer, we set the hidden_size to the same as the text model's hidden size if it is not specified.
208
+ if self.rq_transformer_hidden_size is None:
209
+ self.rq_transformer_hidden_size = text_config.hidden_size
210
+ assert self.rq_transformer_hidden_size % 128 == 0
211
+ if self.rq_transformer_intermediate_size is None:
212
+ self.rq_transformer_intermediate_size = text_config.intermediate_size
213
+ if self.rq_transformer_num_attention_heads is None:
214
+ self.rq_transformer_num_attention_heads = self.rq_transformer_hidden_size // 128
215
+ if self.rq_transformer_num_key_value_heads is None:
216
+ self.rq_transformer_num_key_value_heads = self.rq_transformer_hidden_size // 128 // 4
217
+ assert self.rq_transformer_hidden_size % self.rq_transformer_num_attention_heads == 0
218
+ assert self.rq_transformer_hidden_size % self.rq_transformer_num_key_value_heads == 0
219
+
220
+ self.audio_num_codebooks = audio_num_codebooks
221
+ self.audio_codebook_size = audio_codebook_size
222
+ self.audio_bos_token = audio_bos_token
223
+ self.audio_eos_token = audio_eos_token
224
+ self.audio_out_bos_token = audio_out_bos_token
225
+ self.audio_in_token = audio_in_token
226
+ self.audio_out_token = audio_out_token
227
+ self.audio_in_token_idx = audio_in_token_idx
228
+ self.audio_out_token_idx = audio_out_token_idx
229
+ self.audio_stream_bos_id = audio_stream_bos_id
230
+ self.audio_stream_eos_id = audio_stream_eos_id
231
+ self.audio_out_bos_token_id = audio_out_bos_token_id
232
+ self.audio_eos_token_id = audio_eos_token_id
233
+
234
+ super().__init__(**kwargs)
235
+ self.pad_token_id = pad_token_id
higgs_audio/model/cuda_graph_runner.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, List, Dict, Tuple, Union
4
+ import gc
5
+
6
+ from transformers.cache_utils import Cache
7
+
8
+
9
+ _NUM_WARMUP_ITERS = 2
10
+
11
+
12
+ class CUDAGraphRunner(nn.Module):
13
+ def __init__(self, model):
14
+ super().__init__()
15
+ self.model = model
16
+
17
+ self.input_buffers: Dict[str, torch.Tensor] = {}
18
+ self.output_buffers: Dict[str, torch.Tensor] = {}
19
+
20
+ self._graph: Optional[torch.cuda.CUDAGraph] = None
21
+
22
+ @property
23
+ def graph(self):
24
+ assert self._graph is not None
25
+ return self._graph
26
+
27
+ def capture(
28
+ self,
29
+ hidden_states: torch.Tensor,
30
+ causal_mask: torch.Tensor,
31
+ position_ids: torch.Tensor,
32
+ audio_discrete_codes_mask: torch.Tensor,
33
+ cache_position: torch.Tensor,
34
+ past_key_values: Union[Cache, List[torch.FloatTensor]],
35
+ use_cache: bool,
36
+ audio_attention_mask: torch.Tensor,
37
+ fast_forward_attention_mask: torch.Tensor,
38
+ output_attentions: bool,
39
+ output_hidden_states: bool,
40
+ is_decoding_audio_token: Optional[bool] = None,
41
+ is_using_cuda_graph: Optional[bool] = False,
42
+ stream: torch.cuda.Stream = None,
43
+ memory_pool: Optional[Tuple[int, int]] = None,
44
+ ):
45
+ assert self._graph is None
46
+ # Run warmup iterations
47
+ for _ in range(_NUM_WARMUP_ITERS):
48
+ self.model(
49
+ hidden_states=hidden_states,
50
+ causal_mask=causal_mask,
51
+ position_ids=position_ids,
52
+ audio_discrete_codes_mask=audio_discrete_codes_mask,
53
+ cache_position=cache_position,
54
+ past_key_values=past_key_values,
55
+ use_cache=use_cache,
56
+ audio_attention_mask=audio_attention_mask,
57
+ fast_forward_attention_mask=fast_forward_attention_mask,
58
+ output_attentions=output_attentions,
59
+ output_hidden_states=output_hidden_states,
60
+ is_decoding_audio_token=is_decoding_audio_token,
61
+ is_using_cuda_graph=is_using_cuda_graph,
62
+ )
63
+
64
+ torch.cuda.synchronize()
65
+
66
+ # Capture the graph
67
+ self._graph = torch.cuda.CUDAGraph()
68
+ with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
69
+ out_hidden_states, all_hidden_states, all_self_attns = self.model(
70
+ hidden_states=hidden_states,
71
+ causal_mask=causal_mask,
72
+ position_ids=position_ids,
73
+ audio_discrete_codes_mask=audio_discrete_codes_mask,
74
+ cache_position=cache_position,
75
+ past_key_values=past_key_values,
76
+ use_cache=use_cache,
77
+ audio_attention_mask=audio_attention_mask,
78
+ fast_forward_attention_mask=fast_forward_attention_mask,
79
+ output_attentions=output_attentions,
80
+ output_hidden_states=output_hidden_states,
81
+ is_decoding_audio_token=is_decoding_audio_token,
82
+ is_using_cuda_graph=is_using_cuda_graph,
83
+ )
84
+ # hidden_states_out = torch.ops._C.weak_ref_tensor(outputs[0])
85
+ # del outputs
86
+ gc.collect()
87
+ torch.cuda.synchronize()
88
+
89
+ # Save input and output buffers
90
+ self.input_buffers = {
91
+ "hidden_states": hidden_states,
92
+ "causal_mask": causal_mask,
93
+ "position_ids": position_ids,
94
+ "audio_discrete_codes_mask": audio_discrete_codes_mask,
95
+ "cache_position": cache_position,
96
+ "past_key_values": past_key_values,
97
+ "audio_attention_mask": audio_attention_mask,
98
+ "fast_forward_attention_mask": fast_forward_attention_mask,
99
+ }
100
+ self.output_buffers = {
101
+ "hidden_states": out_hidden_states,
102
+ "all_hidden_states": all_hidden_states,
103
+ "all_self_attns": all_self_attns,
104
+ }
105
+
106
+ def forward(
107
+ self,
108
+ hidden_states: torch.Tensor,
109
+ causal_mask: torch.Tensor,
110
+ position_ids: torch.Tensor,
111
+ audio_discrete_codes_mask: torch.Tensor,
112
+ cache_position: torch.Tensor,
113
+ audio_attention_mask: torch.Tensor,
114
+ fast_forward_attention_mask: torch.Tensor,
115
+ **kwargs,
116
+ ) -> torch.Tensor:
117
+ # Copy input tensors to buffers
118
+ self.input_buffers["hidden_states"].copy_(hidden_states, non_blocking=True)
119
+ self.input_buffers["causal_mask"].copy_(causal_mask, non_blocking=True)
120
+ self.input_buffers["position_ids"].copy_(position_ids, non_blocking=True)
121
+ self.input_buffers["audio_discrete_codes_mask"].copy_(audio_discrete_codes_mask, non_blocking=True)
122
+ self.input_buffers["cache_position"].copy_(cache_position, non_blocking=True)
123
+ self.input_buffers["audio_attention_mask"].copy_(audio_attention_mask, non_blocking=True)
124
+ self.input_buffers["fast_forward_attention_mask"].copy_(fast_forward_attention_mask, non_blocking=True)
125
+
126
+ # Run the captured graph
127
+ self.graph.replay()
128
+
129
+ return self.output_buffers["hidden_states"], None, None
higgs_audio/model/custom_modules.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class PartiallyFrozenEmbedding(nn.Module):
6
+ """Split an existing `nn.Embedding` module that splits the embedding into:
7
+
8
+ - A frozen embedding for indices [0..freeze_until_idx].
9
+ - A trainable embedding for indices [freeze_until_idx+1..vocab_size-1].
10
+
11
+ This should work with both Zero-2 and Zero-3 seamlessly
12
+ """
13
+
14
+ def __init__(self, original_embedding: nn.Embedding, freeze_until_idx: int):
15
+ """
16
+ :param original_embedding: An instance of nn.Embedding (the original embedding layer).
17
+ :param freeze_until_idx: The index up to which the embedding is frozen (excluding). The freeze_until_idx is not frozen.
18
+ """
19
+ super().__init__()
20
+ self.freeze_until_idx = freeze_until_idx
21
+ self.original_vocab_size = original_embedding.num_embeddings
22
+ self.embedding_dim = original_embedding.embedding_dim
23
+
24
+ # Split the original embedding into frozen and trainable parts
25
+ self.embedding_frozen = nn.Embedding(
26
+ freeze_until_idx,
27
+ self.embedding_dim,
28
+ dtype=original_embedding.weight.dtype,
29
+ device=original_embedding.weight.device,
30
+ )
31
+ self.embedding_trainable = nn.Embedding(
32
+ self.original_vocab_size - freeze_until_idx,
33
+ self.embedding_dim,
34
+ dtype=original_embedding.weight.dtype,
35
+ device=original_embedding.weight.device,
36
+ )
37
+
38
+ # Copy weights from the original embedding into the frozen and trainable parts
39
+ with torch.no_grad():
40
+ self.embedding_frozen.weight.copy_(original_embedding.weight[:freeze_until_idx])
41
+ self.embedding_trainable.weight.copy_(original_embedding.weight[freeze_until_idx:])
42
+
43
+ # Freeze the frozen embedding
44
+ self.embedding_frozen.weight.requires_grad = False
45
+
46
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ Forward pass for the split embedding wrapper.
49
+ :param input_ids: Tensor of shape [batch_size, seq_len] with indices in [0..original_vocab_size-1].
50
+ """
51
+ # Masks to separate frozen and trainable indices
52
+ # (bsz, seq_len)
53
+ mask_frozen = input_ids < self.freeze_until_idx
54
+ mask_trainable = ~mask_frozen
55
+
56
+ # Output tensor for embedding results
57
+ batch_size, seq_len = input_ids.shape
58
+ embeddings = torch.zeros(
59
+ batch_size,
60
+ seq_len,
61
+ self.embedding_dim,
62
+ device=input_ids.device,
63
+ dtype=self.embedding_frozen.weight.dtype,
64
+ )
65
+
66
+ # Handle frozen embedding
67
+ if mask_frozen.any():
68
+ frozen_ids = input_ids[mask_frozen]
69
+ frozen_emb = self.embedding_frozen(frozen_ids)
70
+ embeddings[mask_frozen] = frozen_emb
71
+
72
+ # Handle trainable embedding
73
+ if mask_trainable.any():
74
+ # Adjust trainable IDs to the local index space of the trainable embedding
75
+ trainable_ids = input_ids[mask_trainable] - (self.freeze_until_idx)
76
+ trainable_emb = self.embedding_trainable(trainable_ids)
77
+ embeddings[mask_trainable] = trainable_emb
78
+
79
+ return embeddings
80
+
81
+ def to_unsplit(self) -> nn.Embedding:
82
+ unsplit_embedding = nn.Embedding(
83
+ self.original_vocab_size,
84
+ self.embedding_dim,
85
+ dtype=self.embedding_frozen.weight.dtype,
86
+ device=self.embedding_frozen.weight.device,
87
+ )
88
+
89
+ with torch.no_grad():
90
+ unsplit_embedding.weight[: self.freeze_until_idx].copy_(self.embedding_frozen.weight)
91
+ unsplit_embedding.weight[self.freeze_until_idx :].copy_(self.embedding_trainable.weight)
92
+
93
+ return unsplit_embedding
94
+
95
+
96
+ class PartiallyFrozenLinear(nn.Module):
97
+ """A wrapper around nn.Linear to partially freeze part of the weight matrix."""
98
+
99
+ def __init__(self, original_linear: nn.Linear, freeze_until_idx: int):
100
+ """
101
+ :param original_linear: The original nn.Linear layer.
102
+ :param freeze_until_idx: The index up to which the rows of the weight matrix are frozen.
103
+ """
104
+ super().__init__()
105
+ assert original_linear.bias is None, "Currently only support linear module without bias"
106
+
107
+ self.freeze_until_idx = freeze_until_idx
108
+ self.input_dim = original_linear.in_features
109
+ self.output_dim = original_linear.out_features
110
+
111
+ # Create frozen and trainable linear layers
112
+ self.linear_frozen = nn.Linear(
113
+ self.input_dim,
114
+ freeze_until_idx,
115
+ bias=False,
116
+ dtype=original_linear.weight.dtype,
117
+ device=original_linear.weight.device,
118
+ )
119
+ self.linear_trainable = nn.Linear(
120
+ self.input_dim,
121
+ self.output_dim - freeze_until_idx,
122
+ bias=False,
123
+ dtype=original_linear.weight.dtype,
124
+ device=original_linear.weight.device,
125
+ )
126
+
127
+ # Copy weights from the original linear layer
128
+ with torch.no_grad():
129
+ self.linear_frozen.weight.copy_(original_linear.weight[:freeze_until_idx])
130
+ self.linear_trainable.weight.copy_(original_linear.weight[freeze_until_idx:])
131
+
132
+ # Freeze the frozen linear layer
133
+ self.linear_frozen.weight.requires_grad = False
134
+
135
+ def forward(self, input_tensor):
136
+ # input_tensor: (bsz, seq_len, hidden_state_dim)
137
+ frozen_output = self.linear_frozen(input_tensor)
138
+ trainable_output = self.linear_trainable(input_tensor)
139
+ return torch.cat((frozen_output, trainable_output), dim=-1)
140
+
141
+ def to_unsplit(self) -> nn.Linear:
142
+ unsplit_linear = nn.Linear(
143
+ self.input_dim,
144
+ self.output_dim,
145
+ bias=False,
146
+ dtype=self.linear_frozen.weight.dtype,
147
+ device=self.linear_frozen.weight.device,
148
+ )
149
+
150
+ # Copy weights from the frozen and trainable layers into the unsplit linear layer
151
+ with torch.no_grad():
152
+ unsplit_linear.weight[: self.freeze_until_idx].copy_(self.linear_frozen.weight)
153
+ unsplit_linear.weight[self.freeze_until_idx :].copy_(self.linear_trainable.weight)
154
+
155
+ return unsplit_linear
higgs_audio/model/modeling_higgs_audio.py ADDED
The diff for this file is too large to render. See raw diff
 
higgs_audio/model/utils.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ from contextlib import contextmanager
3
+ from functools import wraps
4
+ import torch
5
+ from transformers.integrations import is_deepspeed_available
6
+
7
+ if is_deepspeed_available():
8
+ from deepspeed.utils import groups as deepspeed_groups
9
+ from deepspeed.sequence.layer import _SeqAllToAll
10
+ else:
11
+ deepspeed_groups = None
12
+ _SeqAllToAll = None
13
+
14
+
15
+ def _ceil_to_nearest(n, round_to):
16
+ return (n + round_to - 1) // round_to * round_to
17
+
18
+
19
+ def count_parameters(model, trainable_only=True):
20
+ if trainable_only:
21
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
22
+ else:
23
+ return sum(p.numel() for p in model.parameters())
24
+
25
+
26
+ # TODO(sxjscience) Consider to move the function to audio_processing/utils.py
27
+ def build_delay_pattern_mask(
28
+ input_ids: torch.LongTensor,
29
+ bos_token_id: int,
30
+ pad_token_id: int,
31
+ ):
32
+ """Implement the delay pattern proposed in "Simple and Controllable Music Generation", https://arxiv.org/pdf/2306.05284
33
+
34
+ In the delay pattern, each codebook is offset by the previous codebook by
35
+ one. We insert a special delay token at the start of the sequence if its delayed, and append pad token once the sequence finishes.
36
+
37
+ Take the example where there are 4 codebooks and audio sequence length=5. After shifting, the output should have length seq_len + num_codebooks - 1
38
+
39
+ - [ *, *, *, *, *, P, P, P]
40
+ - [ B, *, *, *, *, *, P, P]
41
+ - [ B, B, *, *, *, *, *, P]
42
+ - [ B, B, B, *, *, *, *, *]
43
+
44
+ where B indicates the delay token id, P is the special padding token id and `*` indicates that the original audio token.
45
+
46
+ Now let's consider the case where we have a sequence of audio tokens to condition on.
47
+ The audio tokens were originally in the following non-delayed form:
48
+
49
+ - [a, b]
50
+ - [c, d]
51
+ - [e, f]
52
+ - [g, h]
53
+
54
+ After conversion, we get the following delayed form:
55
+ - [a, b, -1, -1, -1]
56
+ - [B, c, d, -1, -1]
57
+ - [B, B, e, f, -1]
58
+ - [B, B, B, g, h]
59
+
60
+ Note that we have a special token `-1` that indicates it should be replaced by a new token we see in the generation phase.
61
+ In that case, we should override the `-1` tokens in auto-regressive generation.
62
+
63
+ Args:
64
+ input_ids (:obj:`torch.LongTensor`):
65
+ The input ids of the prompt. It will have shape (bsz, num_codebooks, seq_len).
66
+ bos_token_id (:obj:`int`):
67
+ The id of the special delay token
68
+ pad_token_id (:obj:`int`):
69
+ The id of the padding token. Should be the same as eos_token_id.
70
+
71
+ Returns:
72
+ input_ids (:obj:`torch.LongTensor`):
73
+ The transformed input ids with delay pattern applied. It will have shape (bsz, num_codebooks, seq_len + num_codebooks - 1).
74
+ input_ids_with_gen_mask (:obj:`torch.LongTensor`):
75
+ The transformed input ids with delay pattern applied. The -1 in the output indicates new tokens that should be generated.
76
+
77
+ """
78
+ bsz, num_codebooks, seq_len = input_ids.shape
79
+
80
+ new_seq_len = seq_len + num_codebooks - 1
81
+ input_ids_with_gen_mask = torch.ones((bsz, num_codebooks, new_seq_len), dtype=torch.long, device=input_ids.device)
82
+ bos_mask = torch.tril(input_ids_with_gen_mask, -1) > 0
83
+ eos_mask = torch.triu(input_ids_with_gen_mask, seq_len) > 0
84
+ input_ids_with_gen_mask[bos_mask] = bos_token_id
85
+ input_ids_with_gen_mask[(~bos_mask) & (~eos_mask)] = input_ids.reshape(-1)
86
+ input_ids = input_ids_with_gen_mask.clone()
87
+ input_ids[eos_mask] = pad_token_id
88
+ input_ids_with_gen_mask[eos_mask] = -1
89
+ return input_ids, input_ids_with_gen_mask
90
+
91
+
92
+ def revert_delay_pattern(data):
93
+ """Convert samples encoded with delay pattern back to the original form.
94
+
95
+ Args:
96
+ data (:obj:`torch.Tensor`):
97
+ The data with delay pattern applied. It will have shape (num_codebooks, seq_len + num_codebooks - 1).
98
+
99
+ Returns:
100
+ ret (:obj:`torch.Tensor`):
101
+ Recovered data with delay pattern removed. It will have shape (num_codebooks, seq_len).
102
+ """
103
+ assert len(data.shape) == 2
104
+ out_l = []
105
+ num_codebooks = data.shape[0]
106
+ for i in range(num_codebooks):
107
+ out_l.append(data[i : (i + 1), i : (data.shape[1] - num_codebooks + 1 + i)])
108
+ return torch.cat(out_l, dim=0)
109
+
110
+
111
+ def merge_input_ids_with_audio_features(
112
+ audio_features_embed,
113
+ audio_features_length,
114
+ audio_in_embed,
115
+ audio_in_ids_start,
116
+ audio_out_embed,
117
+ audio_out_ids_start,
118
+ audio_in_token_idx,
119
+ audio_out_token_idx,
120
+ inputs_embeds,
121
+ input_ids,
122
+ attention_mask,
123
+ label_ids,
124
+ pad_token_id,
125
+ ignore_index=-100,
126
+ round_to=8,
127
+ left_padding=True,
128
+ ):
129
+ """
130
+ Merge input_ids with audio features into final embeddings.
131
+
132
+ Args:
133
+ audio_features_embed (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`):
134
+ Encoded vectors of all audios in the batch (obtained from the semantic encoder)
135
+ audio_features_length (`torch.LongTensor` of shape `(num_audios,)`):
136
+ The length of audio embeddings of each audio as stacked in `audio_features_embed`
137
+ audio_in_embed (`torch.Tensor` of shape `(total_num_audio_in_tokens, embed_dim)`):
138
+ The embeddings of audio-in tokens
139
+ audio_in_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
140
+ The start index of the audio-in tokens for each audio
141
+ audio_out_embed (`torch.Tensor` of shape `(total_num_audio_out_tokens, embed_dim)`):
142
+ The embeddings of audio-out tokens
143
+ audio_out_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
144
+ The start index of the audio-out tokens for each audio
145
+ audio_in_token_idx
146
+ The index of the audio-in token in the vocabulary
147
+ audio_out_token_idx
148
+ The index of the audio-out token in the vocabulary
149
+ inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
150
+ Token embeddings before merging with audio embeddings
151
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
152
+ Input_ids of tokens, possibly filled with audio token
153
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
154
+ Mask to avoid performing attention on padding token indices.
155
+ label_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
156
+ labels need to be recalculated to support training (if provided)
157
+ pad_token_id (`int`):
158
+ The index of the pad token in the vocabulary
159
+ ignore_index
160
+ The index to ignore in the loss calculation
161
+ round_to
162
+ The number to round to for padding
163
+ left_padding
164
+ Whether to apply left padding
165
+
166
+ Returns:
167
+ final_embedding
168
+ The final embeddings after merging audio embeddings with text embeddings.
169
+ final_attention_mask
170
+ The final attention mask after merging audio embeddings with text embeddings.
171
+ final_labels
172
+ The labels for the text stream
173
+ position_ids
174
+ Positional ids for the merged data
175
+ final_input_ids
176
+ The final input_ids after merging audio embeddings with text embeddings.
177
+ final_audio_in_mask
178
+ Mask for audio-in embeddings
179
+ final_audio_in_discrete_codes_mask
180
+ Mask for audio-in discrete tokens
181
+ final_audio_out_mask
182
+ Mask for audio-out embeddings
183
+
184
+ Explanation:
185
+ each audio has variable length embeddings, with length specified by
186
+ - audio_features_length
187
+ - audio_in_ids_start
188
+ - audio_out_ids_start
189
+
190
+ Task:
191
+ - fill each <|AUDIO|> with audio embeddings (it can be the combination of embeddings extracted by WhisperEncoder and embeddings from audio codebooks)
192
+ - fill each <|AUDIO_OUT|> with the audio-out embeddings
193
+
194
+ Example:
195
+ <|AUDIO_OUT|>: X (5 tokens), Y (3 tokens)
196
+ <|AUDIO|>: Z (8 tokens)
197
+
198
+ X, Y are in the same sequence (in-context voice-clone). Z is in a different sequence (audio understanding).
199
+ if right padding
200
+ input_ids: [
201
+ a b c d e f X g h i j k Y l m
202
+ o p q r Z s t u v _ _ _ _ _ _
203
+ ]
204
+ input_ids should be: [
205
+ a b c d e f X X X X X g h i j k Y Y Y l m
206
+ o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
207
+ ]
208
+ labels should be: [
209
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
210
+ o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
211
+ ]
212
+ elif left padding
213
+ input_ids: [
214
+ a b c d e f X g h i j k Y l m
215
+ _ _ _ _ _ _ o p q r Z s t u v
216
+ ]
217
+ input_ids should be: [
218
+ a b c d e f X X X X X g h i j k Y Y Y l m
219
+ _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
220
+ ]
221
+ labels should be: [
222
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
223
+ _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
224
+ ]
225
+
226
+ """
227
+ if label_ids is None:
228
+ skip_labels = True
229
+ else:
230
+ skip_labels = False
231
+ if audio_features_embed is not None and audio_features_embed.shape[0] == 0:
232
+ audio_features_embed = None
233
+ if audio_in_embed is not None and audio_in_embed.shape[0] == 0:
234
+ audio_in_embed = None
235
+ if audio_out_embed is not None and audio_out_embed.shape[0] == 0:
236
+ audio_out_embed = None
237
+
238
+ batch_size, sequence_length, embed_dim = inputs_embeds.shape
239
+
240
+ target_device = inputs_embeds.device
241
+ if left_padding is None:
242
+ left_padding = torch.any(attention_mask[:, 0] == 0)
243
+
244
+ audio_in_token_mask = input_ids == audio_in_token_idx
245
+ audio_out_token_mask = input_ids == audio_out_token_idx
246
+ text_token_mask = (input_ids != audio_in_token_idx) & (input_ids != audio_out_token_idx)
247
+
248
+ # 1. Calculate the number of tokens for each placeholder (like [<|AUDIO|>, <|AUDIO_OUT|>]).
249
+ token_placeholder_num = torch.ones_like(input_ids)
250
+
251
+ if audio_features_embed is not None:
252
+ num_audios, max_audio_tokens, _ = audio_features_embed.shape
253
+ audio_in_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to(
254
+ audio_features_length.device
255
+ ) < audio_features_length.unsqueeze(1)
256
+ masked_audio_in_features = audio_features_embed[audio_in_features_mask].view(-1, embed_dim)
257
+ token_placeholder_num[audio_in_token_mask] = audio_features_length.long()
258
+
259
+ if audio_in_embed is not None:
260
+ audio_in_codes_length = torch.concat(
261
+ [
262
+ audio_in_ids_start[1:] - audio_in_ids_start[:-1],
263
+ torch.tensor(
264
+ [audio_in_embed.shape[0] - audio_in_ids_start[-1]],
265
+ device=audio_in_ids_start.device,
266
+ dtype=torch.long,
267
+ ),
268
+ ],
269
+ dim=0,
270
+ )
271
+ if audio_features_embed is not None:
272
+ token_placeholder_num[audio_in_token_mask] += audio_in_codes_length.long()
273
+ else:
274
+ token_placeholder_num[audio_in_token_mask] = audio_in_codes_length.long()
275
+
276
+ if audio_out_embed is not None:
277
+ audio_out_codes_length = torch.concat(
278
+ [
279
+ audio_out_ids_start[1:] - audio_out_ids_start[:-1],
280
+ torch.tensor(
281
+ [audio_out_embed.shape[0] - audio_out_ids_start[-1]],
282
+ device=audio_out_ids_start.device,
283
+ dtype=torch.long,
284
+ ),
285
+ ],
286
+ dim=0,
287
+ )
288
+ token_placeholder_num[audio_out_token_mask] = audio_out_codes_length.long()
289
+
290
+ new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
291
+ max_token_num = _ceil_to_nearest(token_placeholder_num.sum(-1).max(), round_to)
292
+ nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1]
293
+
294
+ if left_padding:
295
+ new_token_positions += nb_audio_pad[:, None] # offset for left padding
296
+
297
+ # 2. Create the full embedding, already padded to the maximum position
298
+ final_embedding = torch.zeros(
299
+ (batch_size, max_token_num, embed_dim),
300
+ dtype=inputs_embeds.dtype,
301
+ device=inputs_embeds.device,
302
+ )
303
+ final_attention_mask = torch.zeros(
304
+ (batch_size, max_token_num),
305
+ dtype=attention_mask.dtype,
306
+ device=inputs_embeds.device,
307
+ )
308
+ final_input_ids = torch.full(
309
+ (batch_size, max_token_num),
310
+ pad_token_id,
311
+ dtype=input_ids.dtype,
312
+ device=inputs_embeds.device,
313
+ )
314
+ if skip_labels:
315
+ final_labels = None
316
+ else:
317
+ final_labels = torch.full(
318
+ (batch_size, max_token_num),
319
+ ignore_index,
320
+ dtype=label_ids.dtype,
321
+ device=inputs_embeds.device,
322
+ )
323
+
324
+ final_audio_in_mask = torch.full(
325
+ (batch_size, max_token_num),
326
+ False,
327
+ dtype=torch.bool,
328
+ device=inputs_embeds.device,
329
+ )
330
+ final_audio_in_discrete_codes_mask = torch.full(
331
+ (batch_size, max_token_num),
332
+ False,
333
+ dtype=torch.bool,
334
+ device=inputs_embeds.device,
335
+ )
336
+ final_audio_out_mask = torch.full(
337
+ (batch_size, max_token_num),
338
+ False,
339
+ dtype=torch.bool,
340
+ device=inputs_embeds.device,
341
+ )
342
+ # 3. Get the audio-in token positions and audio-out token positions
343
+ batch_id = torch.arange(batch_size, device=target_device).unsqueeze(1).expand(batch_size, sequence_length)
344
+ audio_in_batch_id = batch_id[audio_in_token_mask] # Shape (num_audio_in,)
345
+ audio_out_batch_id = batch_id[audio_out_token_mask] # Shape (num_audio_out,)
346
+ audio_features_token_ends = new_token_positions[audio_in_token_mask] # Shape (num_audio_in,)
347
+ audio_out_embed_ends = new_token_positions[audio_out_token_mask] # Shape (num_audio_out,)
348
+
349
+ if audio_in_embed is not None:
350
+ # Fill in the audio-in embeddings
351
+ seq_indices = (
352
+ torch.arange(max_token_num, device=target_device)
353
+ .unsqueeze(0)
354
+ .expand(audio_in_ids_start.shape[0], max_token_num)
355
+ )
356
+ audio_in_embed_token_starts = audio_features_token_ends - audio_in_codes_length + 1
357
+ batch_indices, col_indices = torch.where(
358
+ (seq_indices >= audio_in_embed_token_starts.unsqueeze(1))
359
+ & (seq_indices <= audio_features_token_ends.unsqueeze(1))
360
+ )
361
+ batch_indices = audio_in_batch_id[batch_indices]
362
+ final_embedding[batch_indices, col_indices] = audio_in_embed
363
+ final_input_ids[batch_indices, col_indices] = audio_in_token_idx
364
+ if not skip_labels:
365
+ final_labels[batch_indices, col_indices] = ignore_index
366
+ final_audio_in_mask[batch_indices, col_indices] = True
367
+ final_audio_in_discrete_codes_mask[batch_indices, col_indices] = True
368
+ audio_features_token_ends = audio_features_token_ends - audio_in_codes_length
369
+
370
+ if audio_features_embed is not None:
371
+ # Fill in the audio features
372
+ seq_indices = (
373
+ torch.arange(max_token_num, device=target_device)
374
+ .unsqueeze(0)
375
+ .expand(audio_features_embed.shape[0], max_token_num)
376
+ )
377
+ audio_features_token_starts = audio_features_token_ends - audio_features_length + 1
378
+ batch_indices, col_indices = torch.where(
379
+ (seq_indices >= audio_features_token_starts.unsqueeze(1))
380
+ & (seq_indices <= audio_features_token_ends.unsqueeze(1))
381
+ )
382
+ batch_indices = audio_in_batch_id[batch_indices]
383
+ final_embedding[batch_indices, col_indices] = masked_audio_in_features
384
+ final_input_ids[batch_indices, col_indices] = audio_in_token_idx
385
+ if not skip_labels:
386
+ final_labels[batch_indices, col_indices] = ignore_index
387
+ final_audio_in_mask[batch_indices, col_indices] = True
388
+
389
+ if audio_out_embed is not None:
390
+ # Fill in the audio-out embeddings
391
+ seq_indices = (
392
+ torch.arange(max_token_num, device=target_device)
393
+ .unsqueeze(0)
394
+ .expand(audio_out_ids_start.shape[0], max_token_num)
395
+ )
396
+ audio_out_embed_token_starts = audio_out_embed_ends - audio_out_codes_length + 1
397
+ batch_indices, col_indices = torch.where(
398
+ (seq_indices >= audio_out_embed_token_starts.unsqueeze(1))
399
+ & (seq_indices <= audio_out_embed_ends.unsqueeze(1))
400
+ )
401
+ batch_indices = audio_out_batch_id[batch_indices]
402
+ final_embedding[batch_indices, col_indices] = audio_out_embed
403
+ final_input_ids[batch_indices, col_indices] = audio_out_token_idx
404
+ if not skip_labels:
405
+ final_labels[batch_indices, col_indices] = ignore_index
406
+ final_audio_out_mask[batch_indices, col_indices] = True
407
+
408
+ # Fill in the original text embeddings and labels
409
+ batch_indices, non_audio_indices = torch.where(text_token_mask)
410
+ text_to_overwrite = new_token_positions[batch_indices, non_audio_indices]
411
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_audio_indices]
412
+ if not skip_labels:
413
+ final_labels[batch_indices, text_to_overwrite] = label_ids[batch_indices, non_audio_indices]
414
+ final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_audio_indices]
415
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_audio_indices]
416
+ final_attention_mask = final_attention_mask | final_audio_in_mask | final_audio_out_mask
417
+
418
+ # Trim the tensor if there are redundant padding tokens
419
+ if left_padding:
420
+ first_non_zero_loc = final_attention_mask.sum(0).nonzero()[0]
421
+ first_non_zero_loc = (first_non_zero_loc // round_to) * round_to
422
+ if first_non_zero_loc > 0:
423
+ final_attention_mask = final_attention_mask[:, first_non_zero_loc:]
424
+ final_embedding = final_embedding[:, first_non_zero_loc:]
425
+ if not skip_labels:
426
+ final_labels = final_labels[:, first_non_zero_loc:]
427
+ final_input_ids = final_input_ids[:, first_non_zero_loc:]
428
+ final_audio_in_mask = final_audio_in_mask[:, first_non_zero_loc:]
429
+ final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, first_non_zero_loc:]
430
+ final_audio_out_mask = final_audio_out_mask[:, first_non_zero_loc:]
431
+ else:
432
+ # We have done right padding, so we need to trim the mask
433
+ last_non_zero_loc = final_attention_mask.sum(0).nonzero()[-1] + 1
434
+ last_non_zero_loc = ((last_non_zero_loc + round_to - 1) // round_to) * round_to
435
+ if last_non_zero_loc < max_token_num:
436
+ final_attention_mask = final_attention_mask[:, :last_non_zero_loc]
437
+ final_embedding = final_embedding[:, :last_non_zero_loc]
438
+ if not skip_labels:
439
+ final_labels = final_labels[:, :last_non_zero_loc]
440
+ final_input_ids = final_input_ids[:, :last_non_zero_loc]
441
+ final_audio_in_mask = final_audio_in_mask[:, :last_non_zero_loc]
442
+ final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, :last_non_zero_loc]
443
+ final_audio_out_mask = final_audio_out_mask[:, :last_non_zero_loc]
444
+
445
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
446
+ return (
447
+ final_embedding,
448
+ final_attention_mask,
449
+ final_labels,
450
+ position_ids,
451
+ final_input_ids,
452
+ final_audio_in_mask,
453
+ final_audio_in_discrete_codes_mask,
454
+ final_audio_out_mask,
455
+ )
456
+
457
+
458
+ def is_deepspeed_ulysses_enabled():
459
+ if deepspeed_groups is None:
460
+ return False
461
+
462
+ """Check if sequence parallelism is enabled."""
463
+ return deepspeed_groups._get_sequence_parallel_world_size() > 1
464
+
465
+
466
+ def support_deepspeed_ulysses(module):
467
+ """A decorator around Pytorch module. It is needed for the module that needs access to sequence parallel info."""
468
+ module._sp_size = None
469
+ module._sp_rank = None
470
+ module._sp_group = None
471
+
472
+ @property
473
+ def sp_size(self):
474
+ if self._sp_size is None:
475
+ self._sp_size = 1
476
+ if is_deepspeed_ulysses_enabled():
477
+ self._sp_size = deepspeed_groups._get_sequence_parallel_group().size()
478
+ return self._sp_size
479
+
480
+ @property
481
+ def sp_rank(self):
482
+ if self._sp_rank is None:
483
+ self._sp_rank = 0
484
+ if is_deepspeed_ulysses_enabled():
485
+ self._sp_rank = deepspeed_groups._get_sequence_parallel_rank()
486
+ return self._sp_rank
487
+
488
+ @property
489
+ def sp_group(self):
490
+ if self._sp_group is None and is_deepspeed_ulysses_enabled():
491
+ self._sp_group = deepspeed_groups._get_sequence_parallel_group()
492
+ return self._sp_group
493
+
494
+ module.sp_size = sp_size
495
+ module.sp_rank = sp_rank
496
+ module.sp_group = sp_group
497
+
498
+ return module
499
+
500
+
501
+ def deepspeed_ulysses_attention(seq_dim=1, head_dim=2):
502
+ """Perform all-to-all before and after the attention function."""
503
+
504
+ def attention_decorator(attn_func=None):
505
+ def wrapped(*args, **kwargs):
506
+ if is_deepspeed_ulysses_enabled():
507
+ sp_group = deepspeed_groups._get_sequence_parallel_group()
508
+ scatter_idx = head_dim # Scatter on num_heads dimension
509
+ gather_idx = seq_dim # Gather on seq_len dimension
510
+ batch_dim_idx = 0
511
+ args = list(args)
512
+ args[0] = _SeqAllToAll.apply(sp_group, args[0], scatter_idx, gather_idx, batch_dim_idx)
513
+ args[1] = _SeqAllToAll.apply(sp_group, args[1], scatter_idx, gather_idx, batch_dim_idx)
514
+ args[2] = _SeqAllToAll.apply(sp_group, args[2], scatter_idx, gather_idx, batch_dim_idx)
515
+ args = tuple(args)
516
+
517
+ attn_output = attn_func(*args, **kwargs)
518
+
519
+ if is_deepspeed_ulysses_enabled():
520
+ scatter_idx = seq_dim # Scatter back on seq_len dimension
521
+ gather_idx = head_dim # Gather on num_heads dimension
522
+ batch_dim_idx = 0
523
+ attn_output = _SeqAllToAll.apply(sp_group, attn_output, scatter_idx, gather_idx, batch_dim_idx)
524
+
525
+ return attn_output
526
+
527
+ return wrapped
528
+
529
+ return attention_decorator
530
+
531
+
532
+ def deepspeed_ulysses_rope(state_seq_dim=2, trig_seq_dim=1):
533
+ """Slice the corresponding cos and sin chunks for rope."""
534
+
535
+ def rope_decorator(rope_func=None):
536
+ def wrapped(*args, **kwargs):
537
+ if is_deepspeed_ulysses_enabled():
538
+ sp_rank = deepspeed_groups._get_sequence_parallel_rank()
539
+ args = list(args)
540
+ seq_chunk_size = args[0].size(state_seq_dim)
541
+ args[2] = torch.narrow(args[2], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
542
+ args[3] = torch.narrow(args[3], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
543
+ args = tuple(args)
544
+
545
+ return rope_func(*args, **kwargs)
546
+
547
+ return wrapped
548
+
549
+ return rope_decorator
550
+
551
+
552
+ def _gather_tensors(input_, group=None):
553
+ """Gather tensors and concatenate them along a dimension."""
554
+ input_ = input_.contiguous()
555
+ world_size = torch.distributed.get_world_size(group)
556
+ if world_size == 1:
557
+ return input_
558
+ tensor_shapes = [
559
+ torch.empty(len(input_.size()), dtype=torch.int64, device=input_.device) for _ in range(world_size)
560
+ ]
561
+ input_size = torch.tensor(input_.size(), dtype=torch.int64, device=input_.device)
562
+ torch.distributed.all_gather(tensor_shapes, input_size, group=group)
563
+ gathered_buffers = [
564
+ torch.empty(tensor_shapes[i].tolist(), dtype=input_.dtype, device=input_.device) for i in range(world_size)
565
+ ]
566
+ torch.distributed.all_gather(gathered_buffers, input_, group=group)
567
+ return gathered_buffers
568
+
569
+
570
+ def _scatter_tensors(input_, group=None):
571
+ """Scatter tensors."""
572
+ world_size = torch.distributed.get_world_size(group)
573
+ if world_size == 1:
574
+ return input_
575
+ rank = torch.distributed.get_rank(group)
576
+ return input_[rank]
577
+
578
+
579
+ class _GatherTensors(torch.autograd.Function):
580
+ """All gather tensors among the ranks."""
581
+
582
+ @staticmethod
583
+ def symbolic(graph, input_, group):
584
+ return _gather_tensors(input_, group)
585
+
586
+ @staticmethod
587
+ def forward(ctx, input_, group):
588
+ ctx.group = group
589
+ return torch.nested.as_nested_tensor(_gather_tensors(input_, group), layout=torch.jagged)
590
+
591
+ @staticmethod
592
+ def backward(ctx, grad_output):
593
+ return _scatter_tensors(grad_output, ctx.group), None
594
+
595
+
596
+ def all_gather_tensors(input_, size=None, dim=0, group=None):
597
+ if torch.distributed.get_world_size(group) == 1:
598
+ # no sequence parallelism
599
+ return input_
600
+ gathered_tensors = _GatherTensors.apply(input_, group)
601
+
602
+ if size:
603
+ split_gathered_tensors = []
604
+ for s, gathered_tensor in zip(size, gathered_tensors):
605
+ split_gathered_tensor = torch.split(gathered_tensor, s.tolist())
606
+ split_gathered_tensors.append(split_gathered_tensor)
607
+
608
+ gathered_tensors = [y for x in zip(*split_gathered_tensors) for y in x]
609
+
610
+ return torch.cat(gathered_tensors, dim).contiguous()
611
+
612
+
613
+ def get_sequence_data_parallel_world_size():
614
+ return torch.distributed.get_world_size()
615
+
616
+
617
+ def get_sequence_data_parallel_rank():
618
+ return torch.distributed.get_rank()
619
+
620
+
621
+ def get_sequence_data_parallel_group():
622
+ return torch.distributed.group.WORLD
623
+
624
+
625
+ if is_deepspeed_available():
626
+ deepspeed_groups._get_sequence_data_parallel_world_size = get_sequence_data_parallel_world_size
627
+ deepspeed_groups._get_sequence_data_parallel_rank = get_sequence_data_parallel_rank
628
+ deepspeed_groups._get_sequence_data_parallel_group = get_sequence_data_parallel_group
629
+
630
+
631
+ def _gather_tokens(input_, dim=0, group=None):
632
+ """Gather tensors and concatenate them along a dimension"""
633
+ input_ = input_.contiguous()
634
+ world_size = torch.distributed.get_world_size(group)
635
+ if world_size == 1:
636
+ return input_
637
+
638
+ gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device)
639
+ torch.distributed.all_gather_into_tensor(gather_buffer, input_, group=group)
640
+ if dim == 0:
641
+ shape = list(input_.size())
642
+ shape[0] = shape[0] * world_size
643
+ output = gather_buffer.view(shape)
644
+ else:
645
+ tensor_list = [
646
+ gather_buffer.narrow(0, input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size)
647
+ ]
648
+ # Note: torch.cat already creates a contiguous tensor.
649
+ output = torch.cat(tensor_list, dim=dim).contiguous()
650
+
651
+ return output
652
+
653
+
654
+ def _drop_tokens(input_, dim=0, group=None):
655
+ """Divide a tensor among the sequence parallel ranks"""
656
+ world_size = torch.distributed.get_world_size(group)
657
+ if world_size == 1:
658
+ return input_
659
+ this_rank = torch.distributed.get_rank(group)
660
+ assert input_.shape[dim] % world_size == 0, (
661
+ f"input dimension {dim} ({input_.shape[dim]}) is not divisible by sequence parallel world size ({world_size})"
662
+ )
663
+ chunk_size = input_.shape[dim] // world_size
664
+
665
+ return torch.narrow(input_, dim, this_rank * chunk_size, chunk_size)
666
+
667
+
668
+ class _DropTokens(torch.autograd.Function):
669
+ "Divide tokens equally among the sequence parallel ranks"
670
+
671
+ @staticmethod
672
+ def symbolic(graph, input_, dim, group, grad_scale):
673
+ return _drop_tokens(input_, dim, group)
674
+
675
+ @staticmethod
676
+ def forward(ctx, input_, dim, group, grad_scale):
677
+ ctx.dim = dim
678
+ ctx.group = group
679
+ ctx.grad_scale = grad_scale
680
+ return _drop_tokens(input_, dim, group)
681
+
682
+ @staticmethod
683
+ def backward(ctx, grad_output):
684
+ grad_input = _gather_tokens(grad_output, ctx.dim, ctx.group)
685
+ if ctx.grad_scale != 1:
686
+ grad_input /= ctx.grad_scale
687
+ return grad_input, None, None, None
688
+
689
+
690
+ class _GatherTokens(torch.autograd.Function):
691
+ "Gather tokens among the sequence parallel ranks"
692
+
693
+ @staticmethod
694
+ def symbolic(graph, input_, dim, group, grad_scale):
695
+ return _gather_tokens(input_, dim, group)
696
+
697
+ @staticmethod
698
+ def forward(ctx, input_, dim, group, grad_scale):
699
+ ctx.dim = dim
700
+ ctx.group = group
701
+ ctx.grad_scale = grad_scale
702
+ return _gather_tokens(input_, dim, group)
703
+
704
+ @staticmethod
705
+ def backward(ctx, grad_output):
706
+ grad_input = _drop_tokens(grad_output, ctx.dim, ctx.group)
707
+ if ctx.grad_scale != 1:
708
+ grad_input *= ctx.grad_scale
709
+ return grad_input, None, None, None
710
+
711
+
712
+ def drop_tokens(input_, dim=0, group=None, grad_scale=1):
713
+ if torch.distributed.get_world_size(group) == 1:
714
+ # no sequence parallelism
715
+ return input_
716
+ return _DropTokens.apply(input_, dim, group, grad_scale)
717
+
718
+
719
+ def gather_tokens(input_, dim=0, group=None, grad_scale=1):
720
+ if torch.distributed.get_world_size(group) == 1:
721
+ # no sequence parallelism
722
+ return input_
723
+ return _GatherTokens.apply(input_, dim, group, grad_scale)
724
+
725
+
726
+ def sequence_chunking_per_rank(sp_size, sp_rank, *args, dim=1):
727
+ """
728
+ Slice the inputs to create chuncks per the sequence parallel rank. This is used for the context parallel training.
729
+
730
+ Args:
731
+ sp_size (`int`):
732
+ Sequence parallel size.
733
+ sp_rank (`int`):
734
+ Sequence parallel rank for the current process.
735
+ dim (`int`):
736
+ The dimension to slice
737
+ """
738
+ if sp_size == 1:
739
+ return args[0] if len(args) == 1 else args
740
+
741
+ seq_length = args[0].size(dim)
742
+ for arg in args[1:]:
743
+ assert arg.size(dim) == seq_length, (
744
+ f"arg={arg} ({arg.shape[dim]}) does not have the same size as args[0] ({seq_length}) in dimension {dim}"
745
+ )
746
+ assert seq_length % sp_size == 0, (
747
+ f"dimension {dim} ({args[0].shape[dim]}) is not divisible by sequence parallel world size ({sp_size})"
748
+ )
749
+
750
+ sub_seq_length = seq_length // sp_size
751
+ sub_seq_start = sp_rank * sub_seq_length
752
+
753
+ output = []
754
+ for ind in args:
755
+ ind = torch.narrow(ind, dim, sub_seq_start, sub_seq_length)
756
+ output.append(ind)
757
+
758
+ return tuple(output) if len(output) > 1 else output[0]
759
+
760
+
761
+ @contextmanager
762
+ def disable_deepspeed_ulysses():
763
+ """Disable deepspeed ulysses (sequence parallelism) if it is enabled"""
764
+ if is_deepspeed_ulysses_enabled():
765
+ _old_get_sequence_parallel_world_size = deepspeed_groups._get_sequence_parallel_world_size
766
+
767
+ def _get_sequence_parallel_world_size():
768
+ return 1
769
+
770
+ deepspeed_groups._get_sequence_parallel_world_size = _get_sequence_parallel_world_size
771
+ try:
772
+ yield
773
+ finally:
774
+ deepspeed_groups._get_sequence_parallel_world_size = _old_get_sequence_parallel_world_size
775
+ else:
776
+ context = contextlib.nullcontext
777
+ with context():
778
+ yield
higgs_audio/serve/serve_engine.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import torch
4
+ import numpy as np
5
+ from io import BytesIO
6
+ from dataclasses import dataclass, field
7
+ from typing import List, Optional, Union
8
+ from copy import deepcopy
9
+ from transformers import AutoTokenizer, AutoProcessor
10
+ from transformers.cache_utils import StaticCache
11
+ from transformers.generation.streamers import BaseStreamer
12
+ from transformers.generation.stopping_criteria import StoppingCriteria
13
+ from dataclasses import asdict
14
+ from loguru import logger
15
+ import threading
16
+ import librosa
17
+
18
+
19
+ from ..dataset.chatml_dataset import (
20
+ ChatMLSample,
21
+ ChatMLDatasetSample,
22
+ prepare_chatml_sample,
23
+ )
24
+ from ..model import HiggsAudioModel
25
+ from ..model.utils import revert_delay_pattern
26
+ from ..data_collator.higgs_audio_collator import HiggsAudioSampleCollator
27
+ from ..audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer
28
+
29
+
30
+ def normalize_chinese_punctuation(text):
31
+ """
32
+ Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
33
+ """
34
+ # Mapping of Chinese punctuation to English punctuation
35
+ chinese_to_english_punct = {
36
+ ",": ",", # comma
37
+ "。": ".", # period
38
+ ":": ":", # colon
39
+ ";": ";", # semicolon
40
+ "?": "?", # question mark
41
+ "!": "!", # exclamation mark
42
+ "(": "(", # left parenthesis
43
+ ")": ")", # right parenthesis
44
+ "【": "[", # left square bracket
45
+ "】": "]", # right square bracket
46
+ "《": "<", # left angle quote
47
+ "》": ">", # right angle quote
48
+ "“": '"', # left double quotation
49
+ "”": '"', # right double quotation
50
+ "‘": "'", # left single quotation
51
+ "’": "'", # right single quotation
52
+ "、": ",", # enumeration comma
53
+ "—": "-", # em dash
54
+ "…": "...", # ellipsis
55
+ "·": ".", # middle dot
56
+ "「": '"', # left corner bracket
57
+ "」": '"', # right corner bracket
58
+ "『": '"', # left double corner bracket
59
+ "』": '"', # right double corner bracket
60
+ }
61
+
62
+ # Replace each Chinese punctuation with its English counterpart
63
+ for zh_punct, en_punct in chinese_to_english_punct.items():
64
+ text = text.replace(zh_punct, en_punct)
65
+
66
+ return text
67
+
68
+
69
+ @dataclass
70
+ class HiggsAudioStreamerDelta:
71
+ """Represents a chunk of generated content, either text or audio tokens."""
72
+
73
+ text: Optional[str] = None
74
+ text_tokens: Optional[torch.Tensor] = None
75
+ audio_tokens: Optional[torch.Tensor] = None
76
+ finish_reason: Optional[str] = None
77
+
78
+
79
+ class AsyncHiggsAudioStreamer(BaseStreamer):
80
+ """
81
+ Async streamer that handles both text and audio token generation from Higgs-Audio model.
82
+ Stores chunks in a queue to be consumed by downstream applications.
83
+
84
+ Parameters:
85
+ tokenizer (`AutoTokenizer`):
86
+ The tokenizer used to decode text tokens.
87
+ skip_prompt (`bool`, *optional*, defaults to `False`):
88
+ Whether to skip the prompt tokens in generation.
89
+ timeout (`float`, *optional*):
90
+ The timeout for the queue. If `None`, the queue will block indefinitely.
91
+ decode_kwargs (`dict`, *optional*):
92
+ Additional keyword arguments to pass to the tokenizer's `decode` method.
93
+
94
+ Examples:
95
+ ```python
96
+ >>> from transformers import AutoTokenizer
97
+ >>> from threading import Thread
98
+ >>> import asyncio
99
+
100
+ >>> tokenizer = AutoTokenizer.from_pretrained("path/to/higgs/tokenizer")
101
+ >>> model = HiggsAudioModel.from_pretrained("path/to/higgs/model")
102
+ >>> inputs = tokenizer(["Generate some text and audio:"], return_tensors="pt")
103
+
104
+ >>> async def main():
105
+ ... streamer = AsyncHiggsAudioStreamer(tokenizer)
106
+ ... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
107
+ ... thread = Thread(target=model.generate, kwargs=generation_kwargs)
108
+ ... thread.start()
109
+ ...
110
+ ... async for delta in streamer:
111
+ ... if delta.text is not None:
112
+ ... print("Text:", delta.text)
113
+ ... if delta.audio_tokens is not None:
114
+ ... print("Audio tokens shape:", delta.audio_tokens.shape)
115
+ >>> asyncio.run(main())
116
+ ```
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ tokenizer: "AutoTokenizer",
122
+ skip_prompt: bool = False,
123
+ timeout: Optional[float] = None,
124
+ audio_num_codebooks: int = 1,
125
+ **decode_kwargs,
126
+ ):
127
+ self.tokenizer = tokenizer
128
+ self.skip_prompt = skip_prompt
129
+ self.timeout = timeout
130
+ self.decode_kwargs = decode_kwargs
131
+ self.audio_num_codebooks = audio_num_codebooks
132
+
133
+ # Queue to store generated chunks
134
+ self.queue = asyncio.Queue()
135
+ self.stop_signal = None
136
+
137
+ # Get running event loop
138
+ self.loop = asyncio.get_running_loop()
139
+ self.has_asyncio_timeout = hasattr(asyncio, "timeout")
140
+
141
+ # State tracking
142
+ self.next_tokens_are_prompt = True
143
+
144
+ def put(self, value: torch.Tensor):
145
+ """
146
+ Receives tokens and processes them as either text or audio tokens.
147
+ For text tokens, decodes and caches them until complete words are formed.
148
+ For audio tokens, directly queues them.
149
+ """
150
+ if value.shape[0] > 1 and not self.next_tokens_are_prompt:
151
+ # This is likely audio tokens (shape: [audio_num_codebooks])
152
+ assert value.shape[0] == self.audio_num_codebooks, "Number of codebooks mismatch"
153
+ delta = HiggsAudioStreamerDelta(audio_tokens=value)
154
+ self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
155
+ return
156
+
157
+ # Skip prompt tokens if configured
158
+ if self.skip_prompt and self.next_tokens_are_prompt:
159
+ self.next_tokens_are_prompt = False
160
+ return
161
+
162
+ # Process as text tokens
163
+ if len(value.shape) > 1:
164
+ value = value[0]
165
+
166
+ text = self.tokenizer.decode(value, **self.decode_kwargs)
167
+ delta = HiggsAudioStreamerDelta(text=text, text_tokens=value)
168
+ self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
169
+
170
+ def end(self):
171
+ """Flushes any remaining text tokens and signals the end of generation."""
172
+ self.next_tokens_are_prompt = True
173
+ self.loop.call_soon_threadsafe(self.queue.put_nowait, self.stop_signal)
174
+
175
+ def __aiter__(self):
176
+ return self
177
+
178
+ async def __anext__(self):
179
+ try:
180
+ if self.has_asyncio_timeout:
181
+ async with asyncio.timeout(self.timeout):
182
+ value = await self.queue.get()
183
+ else:
184
+ value = await asyncio.wait_for(self.queue.get(), timeout=self.timeout)
185
+ except asyncio.TimeoutError:
186
+ raise TimeoutError()
187
+ else:
188
+ if value == self.stop_signal:
189
+ raise StopAsyncIteration()
190
+ else:
191
+ return value
192
+
193
+
194
+ class AsyncStoppingCriteria(StoppingCriteria):
195
+ """
196
+ Stopping criteria that checks for stop signal from a threading event.
197
+
198
+ Args:
199
+ stop_signal (threading.Event): Event that will receive stop signals
200
+ """
201
+
202
+ def __init__(self, stop_signal: threading.Event):
203
+ self.stop_signal = stop_signal
204
+
205
+ def __call__(self, input_ids, scores, **kwargs) -> bool:
206
+ if self.stop_signal.is_set():
207
+ logger.info(f"Stop signal received. Can be caused by client disconnection.")
208
+ return True
209
+ return False
210
+
211
+
212
+ @dataclass
213
+ class HiggsAudioResponse:
214
+ audio: Optional[np.ndarray] = None
215
+ generated_audio_tokens: Optional[np.ndarray] = None
216
+ sampling_rate: Optional[int] = None
217
+ generated_text: str = ""
218
+ generated_text_tokens: np.ndarray = field(default_factory=np.ndarray)
219
+ usage: Optional[dict] = None
220
+
221
+
222
+ class HiggsAudioServeEngine:
223
+ def __init__(
224
+ self,
225
+ model_name_or_path: str,
226
+ audio_tokenizer_name_or_path: str,
227
+ tokenizer_name_or_path: Optional[str] = None,
228
+ device: str = "cuda",
229
+ torch_dtype: Union[torch.dtype, str] = "auto",
230
+ kv_cache_lengths: List[int] = [1024, 4096, 8192], # Multiple KV cache sizes
231
+ ):
232
+ """
233
+ Initialize the HiggsAudioServeEngine, a serving wrapper for the HiggsAudioModel.
234
+ The model, tokenizer, and audio tokenizer will be downloaded from the Hugging Face Hub if they are not local.
235
+
236
+ Args:
237
+ model_name_or_path (str):
238
+ The name or path of the model to load.
239
+ audio_tokenizer_name_or_path (str):
240
+ The name or path of the audio tokenizer to load.
241
+ tokenizer_name_or_path (str):
242
+ The name or path of the tokenizer to load.
243
+ device (str):
244
+ The device to use for the model.
245
+ kv_cache_lengths (List[int]):
246
+ The lengths of the KV caches to use for the model. Used for cuda graph capture when device is cuda.
247
+ torch_dtype (Union[torch.dtype, str]):
248
+ The dtype to use for the model.
249
+ """
250
+ self.device = device
251
+ self.model_name_or_path = model_name_or_path
252
+ self.torch_dtype = torch_dtype
253
+
254
+ # Initialize model and tokenizer
255
+ self.model = HiggsAudioModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).to(device)
256
+ logger.info(f"Loaded model from {model_name_or_path}, dtype: {self.model.dtype}")
257
+
258
+ if tokenizer_name_or_path is None:
259
+ tokenizer_name_or_path = model_name_or_path
260
+ logger.info(f"Loading tokenizer from {tokenizer_name_or_path}")
261
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
262
+
263
+ logger.info(f"Initializing Higgs Audio Tokenizer")
264
+ self.audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer_name_or_path, device=device)
265
+
266
+ self.audio_num_codebooks = self.model.config.audio_num_codebooks
267
+ self.audio_codebook_size = self.model.config.audio_codebook_size
268
+ self.audio_tokenizer_tps = self.audio_tokenizer.tps
269
+ self.samples_per_token = int(self.audio_tokenizer.sampling_rate // self.audio_tokenizer_tps)
270
+ self.hamming_window_len = 2 * self.audio_num_codebooks * self.samples_per_token
271
+ # Set the audio special tokens
272
+ self.model.set_audio_special_tokens(self.tokenizer)
273
+
274
+ # Prepare KV caches for different lengths
275
+ cache_config = deepcopy(self.model.config.text_config)
276
+ cache_config.num_hidden_layers = self.model.config.text_config.num_hidden_layers
277
+ if self.model.config.audio_dual_ffn_layers:
278
+ cache_config.num_hidden_layers += len(self.model.config.audio_dual_ffn_layers)
279
+ # A list of KV caches for different lengths
280
+ self.kv_caches = {
281
+ length: StaticCache(
282
+ config=cache_config,
283
+ max_batch_size=1,
284
+ max_cache_len=length,
285
+ device=self.model.device,
286
+ dtype=self.model.dtype,
287
+ )
288
+ for length in sorted(kv_cache_lengths)
289
+ }
290
+
291
+ if self.model.config.encode_whisper_embed:
292
+ logger.info(f"Loading whisper processor")
293
+ whisper_processor = AutoProcessor.from_pretrained(
294
+ "openai/whisper-large-v3-turbo",
295
+ trust_remote=True,
296
+ device=self.device,
297
+ )
298
+ else:
299
+ whisper_processor = None
300
+
301
+ # Reuse collator to prepare inference samples
302
+ self.collator = HiggsAudioSampleCollator(
303
+ whisper_processor=whisper_processor,
304
+ encode_whisper_embed=self.model.config.encode_whisper_embed,
305
+ audio_in_token_id=self.model.config.audio_in_token_idx,
306
+ audio_out_token_id=self.model.config.audio_out_token_idx,
307
+ audio_stream_bos_id=self.model.config.audio_stream_bos_id,
308
+ audio_stream_eos_id=self.model.config.audio_stream_eos_id,
309
+ pad_token_id=self.model.config.pad_token_id,
310
+ return_audio_in_tokens=False,
311
+ use_delay_pattern=self.model.config.use_delay_pattern,
312
+ audio_num_codebooks=self.model.config.audio_num_codebooks,
313
+ round_to=1,
314
+ )
315
+
316
+ # Lock to prevent multiple generations from happening at the same time
317
+ self.generate_lock = threading.Lock()
318
+
319
+ # Capture CUDA graphs for each KV cache length
320
+ if device == "cuda":
321
+ logger.info(f"Capturing CUDA graphs for each KV cache length")
322
+ self.model.capture_model(self.kv_caches.values())
323
+
324
+ def _prepare_inputs(self, chat_ml_sample: ChatMLSample, force_audio_gen: bool = False):
325
+ input_tokens, _, audio_contents, _ = prepare_chatml_sample(
326
+ chat_ml_sample,
327
+ self.tokenizer,
328
+ )
329
+
330
+ postfix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
331
+ if force_audio_gen:
332
+ postfix += "<|audio_out_bos|>"
333
+ postfix = self.tokenizer.encode(postfix, add_special_tokens=False)
334
+ input_tokens.extend(postfix)
335
+
336
+ # Configure the audio inputs
337
+ audio_ids_l = []
338
+ for audio_content in audio_contents:
339
+ if audio_content.audio_url not in ["placeholder", ""]:
340
+ raw_audio, _ = librosa.load(audio_content.audio_url, sr=self.audio_tokenizer.sampling_rate)
341
+ elif audio_content.raw_audio is not None:
342
+ raw_audio, _ = librosa.load(
343
+ BytesIO(base64.b64decode(audio_content.raw_audio)),
344
+ sr=self.audio_tokenizer.sampling_rate,
345
+ )
346
+ else:
347
+ raw_audio = None
348
+
349
+ if raw_audio is not None:
350
+ audio_ids = self.audio_tokenizer.encode(raw_audio, self.audio_tokenizer.sampling_rate)
351
+ audio_ids_l.append(audio_ids.squeeze(0).cpu())
352
+
353
+ if len(audio_ids_l) > 0:
354
+ audio_ids_start = torch.tensor(
355
+ np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
356
+ dtype=torch.long,
357
+ device=self.device,
358
+ )[0:-1]
359
+ audio_ids_concat = torch.cat(audio_ids_l, dim=1)
360
+ else:
361
+ audio_ids_start = None
362
+ audio_ids_concat = None
363
+
364
+ sample = ChatMLDatasetSample(
365
+ input_ids=torch.LongTensor(input_tokens),
366
+ label_ids=None,
367
+ audio_ids_concat=audio_ids_concat,
368
+ audio_ids_start=audio_ids_start,
369
+ audio_waveforms_concat=None,
370
+ audio_waveforms_start=None,
371
+ audio_sample_rate=None,
372
+ audio_speaker_indices=None,
373
+ )
374
+ data = self.collator([sample])
375
+ inputs = asdict(data)
376
+ for k, v in inputs.items():
377
+ if isinstance(v, torch.Tensor):
378
+ inputs[k] = v.to(self.model.device)
379
+
380
+ return inputs
381
+
382
+ def _prepare_kv_caches(self):
383
+ for kv_cache in self.kv_caches.values():
384
+ kv_cache.reset()
385
+
386
+ def generate(
387
+ self,
388
+ chat_ml_sample: ChatMLSample,
389
+ max_new_tokens: int,
390
+ temperature: float = 0.7,
391
+ top_k: Optional[int] = None,
392
+ top_p: float = 0.95,
393
+ stop_strings: Optional[List[str]] = None,
394
+ force_audio_gen: bool = False,
395
+ ras_win_len: Optional[int] = None,
396
+ ras_win_max_num_repeat: int = 2,
397
+ ):
398
+ """
399
+ Generate audio from a chatml sample.
400
+ Args:
401
+ chat_ml_sample: A chatml sample.
402
+ max_new_tokens: The maximum number of new tokens to generate.
403
+ temperature: The temperature to use for the generation.
404
+ top_p: The top p to use for the generation.
405
+ Returns:
406
+ A dictionary with the following keys:
407
+ audio: The generated audio.
408
+ sampling_rate: The sampling rate of the generated audio.
409
+ """
410
+ # Default stop strings
411
+ if stop_strings is None:
412
+ stop_strings = ["<|end_of_text|>", "<|eot_id|>"]
413
+
414
+ with torch.no_grad(), self.generate_lock:
415
+ inputs = self._prepare_inputs(chat_ml_sample, force_audio_gen=force_audio_gen)
416
+ prompt_token_ids = inputs["input_ids"][0].cpu().numpy()
417
+
418
+ self._prepare_kv_caches()
419
+
420
+ outputs = self.model.generate(
421
+ **inputs,
422
+ max_new_tokens=max_new_tokens,
423
+ use_cache=True,
424
+ stop_strings=stop_strings,
425
+ tokenizer=self.tokenizer,
426
+ do_sample=False if temperature == 0.0 else True,
427
+ temperature=temperature,
428
+ top_k=top_k,
429
+ top_p=top_p,
430
+ past_key_values_buckets=self.kv_caches,
431
+ ras_win_len=ras_win_len,
432
+ ras_win_max_num_repeat=ras_win_max_num_repeat,
433
+ )
434
+
435
+ if len(outputs[1]) > 0:
436
+ wv_list = []
437
+ for output_audio in outputs[1]:
438
+ vq_code = revert_delay_pattern(output_audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
439
+ wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
440
+ wv_list.append(wv_numpy)
441
+ wv_numpy = np.concatenate(wv_list)
442
+ else:
443
+ wv_numpy = None
444
+
445
+ # We only support one request at a time now
446
+ generated_text_tokens = outputs[0][0].cpu().numpy()[len(prompt_token_ids) :]
447
+ generated_text = self.tokenizer.decode(generated_text_tokens)
448
+ generated_audio_tokens = outputs[1][0].cpu().numpy()
449
+ return HiggsAudioResponse(
450
+ audio=wv_numpy,
451
+ generated_audio_tokens=generated_audio_tokens,
452
+ sampling_rate=self.audio_tokenizer.sampling_rate,
453
+ generated_text=generated_text,
454
+ generated_text_tokens=generated_text_tokens,
455
+ usage={
456
+ "prompt_tokens": prompt_token_ids.shape[0],
457
+ "completion_tokens": generated_text_tokens.shape[0] + generated_audio_tokens.shape[1],
458
+ "total_tokens": (
459
+ prompt_token_ids.shape[0] + generated_text_tokens.shape[0] + generated_audio_tokens.shape[1]
460
+ ),
461
+ "cached_tokens": 0,
462
+ },
463
+ )
464
+
465
+ def text_normalize(self, text: str) -> str:
466
+ """
467
+ Normalize the text.
468
+ """
469
+ # Perform some basic normalization
470
+ text = normalize_chinese_punctuation(text)
471
+ # Handle parentheses
472
+ text = text.replace("(", " ")
473
+ text = text.replace(")", " ")
474
+ return text
higgs_audio/serve/utils.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import base64
3
+ import re
4
+ import regex
5
+ from typing import AsyncGenerator, Union
6
+ import io
7
+ from pydub import AudioSegment
8
+ import torch
9
+ import numpy as np
10
+ from functools import lru_cache
11
+
12
+ from ..audio_processing.higgs_audio_tokenizer import HiggsAudioTokenizer
13
+
14
+
15
+ def random_uuid() -> str:
16
+ return str(uuid.uuid4().hex)
17
+
18
+
19
+ async def async_generator_wrap(first_element, gen: AsyncGenerator):
20
+ """Wrap an async generator with the first element."""
21
+ yield first_element
22
+ async for item in gen:
23
+ yield item
24
+
25
+
26
+ @lru_cache(maxsize=50)
27
+ def encode_base64_content_from_file(file_path: str) -> str:
28
+ """Encode a content from a local file to base64 format."""
29
+ # Read the MP3 file as binary and encode it directly to Base64
30
+ with open(file_path, "rb") as audio_file:
31
+ audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8")
32
+ return audio_base64
33
+
34
+
35
+ def pcm16_to_target_format(
36
+ np_audio: np.ndarray,
37
+ sample_rate: int,
38
+ bit_depth: int,
39
+ channels: int,
40
+ format: str,
41
+ target_rate: int,
42
+ ):
43
+ wav_audio = AudioSegment(
44
+ np_audio.tobytes(),
45
+ frame_rate=sample_rate,
46
+ sample_width=bit_depth // 8,
47
+ channels=channels,
48
+ )
49
+ if target_rate is not None and target_rate != sample_rate:
50
+ wav_audio = wav_audio.set_frame_rate(target_rate)
51
+
52
+ # Convert WAV to MP3
53
+ target_io = io.BytesIO()
54
+ wav_audio.export(target_io, format=format)
55
+ target_io.seek(0)
56
+
57
+ return target_io
58
+
59
+
60
+ chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
61
+
62
+
63
+ def contains_chinese(text: str):
64
+ return bool(chinese_char_pattern.search(text))
65
+
66
+
67
+ # remove blank between chinese character
68
+ def replace_blank(text: str):
69
+ out_str = []
70
+ for i, c in enumerate(text):
71
+ if c == " ":
72
+ if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
73
+ out_str.append(c)
74
+ else:
75
+ out_str.append(c)
76
+ return "".join(out_str)
77
+
78
+
79
+ def replace_corner_mark(text: str):
80
+ text = text.replace("²", "平方")
81
+ text = text.replace("³", "立方")
82
+ return text
83
+
84
+
85
+ # remove meaningless symbol
86
+ def remove_bracket(text: str):
87
+ text = text.replace("(", "").replace(")", "")
88
+ text = text.replace("【", "").replace("】", "")
89
+ text = text.replace("`", "").replace("`", "")
90
+ text = text.replace("——", " ")
91
+ return text
92
+
93
+
94
+ # split paragrah logic:
95
+ # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
96
+ # 2. cal sentence len according to lang
97
+ # 3. split sentence according to puncatation
98
+ def split_paragraph(
99
+ text: str,
100
+ tokenize,
101
+ lang="zh",
102
+ token_max_n=80,
103
+ token_min_n=60,
104
+ merge_len=20,
105
+ comma_split=False,
106
+ ):
107
+ def calc_utt_length(_text: str):
108
+ if lang == "zh":
109
+ return len(_text)
110
+ else:
111
+ return len(tokenize(_text))
112
+
113
+ def should_merge(_text: str):
114
+ if lang == "zh":
115
+ return len(_text) < merge_len
116
+ else:
117
+ return len(tokenize(_text)) < merge_len
118
+
119
+ if lang == "zh":
120
+ pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
121
+ else:
122
+ pounc = [".", "?", "!", ";", ":"]
123
+ if comma_split:
124
+ pounc.extend([",", ","])
125
+
126
+ if text[-1] not in pounc:
127
+ if lang == "zh":
128
+ text += "。"
129
+ else:
130
+ text += "."
131
+
132
+ st = 0
133
+ utts = []
134
+ for i, c in enumerate(text):
135
+ if c in pounc:
136
+ if len(text[st:i]) > 0:
137
+ utts.append(text[st:i] + c)
138
+ if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
139
+ tmp = utts.pop(-1)
140
+ utts.append(tmp + text[i + 1])
141
+ st = i + 2
142
+ else:
143
+ st = i + 1
144
+
145
+ final_utts = []
146
+ cur_utt = ""
147
+ for utt in utts:
148
+ if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
149
+ final_utts.append(cur_utt)
150
+ cur_utt = ""
151
+ cur_utt = cur_utt + utt
152
+ if len(cur_utt) > 0:
153
+ if should_merge(cur_utt) and len(final_utts) != 0:
154
+ final_utts[-1] = final_utts[-1] + cur_utt
155
+ else:
156
+ final_utts.append(cur_utt)
157
+
158
+ return final_utts
159
+
160
+
161
+ def is_only_punctuation(text: str):
162
+ # Regular expression: Match strings that consist only of punctuation marks or are empty.
163
+ punctuation_pattern = r"^[\p{P}\p{S}]*$"
164
+ return bool(regex.fullmatch(punctuation_pattern, text))
165
+
166
+
167
+ # spell Arabic numerals
168
+ def spell_out_number(text: str, inflect_parser):
169
+ new_text = []
170
+ st = None
171
+ for i, c in enumerate(text):
172
+ if not c.isdigit():
173
+ if st is not None:
174
+ num_str = inflect_parser.number_to_words(text[st:i])
175
+ new_text.append(num_str)
176
+ st = None
177
+ new_text.append(c)
178
+ else:
179
+ if st is None:
180
+ st = i
181
+ if st is not None and st < len(text):
182
+ num_str = inflect_parser.number_to_words(text[st:])
183
+ new_text.append(num_str)
184
+ return "".join(new_text)
185
+
186
+
187
+ def remove_emoji(text: str):
188
+ # Pattern to match emojis and their modifiers
189
+ # - Standard emoji range
190
+ # - Zero-width joiners (U+200D)
191
+ # - Variation selectors (U+FE0F, U+FE0E)
192
+ # - Skin tone modifiers (U+1F3FB to U+1F3FF)
193
+ emoji_pattern = re.compile(
194
+ r"["
195
+ r"\U00010000-\U0010FFFF" # Standard emoji range
196
+ r"\u200D" # Zero-width joiner
197
+ r"\uFE0F\uFE0E" # Variation selectors
198
+ r"\U0001F3FB-\U0001F3FF" # Skin tone modifiers
199
+ r"]+",
200
+ flags=re.UNICODE,
201
+ )
202
+ return emoji_pattern.sub(r"", text)
203
+
204
+
205
+ def remove_repeated_punctuations(text, punctuations):
206
+ if len(punctuations) == 0:
207
+ return text
208
+ pattern = f"[{re.escape(''.join(punctuations))}]" # Create regex pattern for given punctuations
209
+ return re.sub(rf"({pattern})\1+", r"\1", text)
210
+
211
+
212
+ def full_to_half_width(text: str) -> str:
213
+ """Convert full-width punctuation to half-width in a given string."""
214
+ full_width = "!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
215
+ half_width = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
216
+ trans_table = str.maketrans(full_width, half_width)
217
+ return text.translate(trans_table)
218
+
219
+
220
+ def split_interleaved_delayed_audios(
221
+ audio_data: Union[list[list[int]], torch.Tensor],
222
+ audio_tokenizer: HiggsAudioTokenizer,
223
+ audio_stream_eos_id: int,
224
+ ) -> list[tuple[list[list[int]], torch.Tensor]]:
225
+ separator = [audio_stream_eos_id] * audio_tokenizer.num_codebooks
226
+
227
+ # Convert separator to numpy array if audio_data is numpy array
228
+ if isinstance(audio_data, torch.Tensor):
229
+ audio_data = audio_data.transpose(1, 0)
230
+ separator = torch.tensor(separator)
231
+ # Find the indices where the rows equal the separator
232
+ split_indices = torch.where(torch.all(audio_data == separator, dim=1))[0]
233
+ start = 0
234
+ groups = []
235
+ for idx in split_indices:
236
+ groups.append(audio_data[start:idx].transpose(1, 0))
237
+ start = idx + 1
238
+ if start < len(audio_data):
239
+ groups.append(audio_data[start:].transpose(1, 0))
240
+ else:
241
+ groups = []
242
+ current = []
243
+ for row in audio_data:
244
+ current.append(row)
245
+
246
+ if row == separator:
247
+ groups.append(current)
248
+ current = []
249
+
250
+ # Don't forget the last group if there's no trailing separator
251
+ if current:
252
+ groups.append(current)
253
+
254
+ return groups
pyproject.toml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [tool.ruff]
6
+ line-length = 119
7
+ target-version = "py310"
8
+ indent-width = 4
9
+ exclude = [
10
+ ".bzr",
11
+ ".direnv",
12
+ ".eggs",
13
+ ".git",
14
+ ".git-rewrite",
15
+ ".hg",
16
+ ".ipynb_checkpoints",
17
+ ".mypy_cache",
18
+ ".nox",
19
+ ".pants.d",
20
+ ".pyenv",
21
+ ".pytest_cache",
22
+ ".pytype",
23
+ ".ruff_cache",
24
+ ".svn",
25
+ ".tox",
26
+ ".venv",
27
+ ".vscode",
28
+ "__pypackages__",
29
+ "_build",
30
+ "buck-out",
31
+ "build",
32
+ "dist",
33
+ "node_modules",
34
+ "site-packages",
35
+ "venv",
36
+ "external",
37
+ "third_party",
38
+ ]
39
+
40
+ [tool.ruff.lint]
41
+ preview = true
42
+ ignore-init-module-imports = true
43
+ extend-select = [
44
+ "B009", # static getattr
45
+ "B010", # static setattr
46
+ "CPY", # Copyright
47
+ "E", # PEP8 errors
48
+ "F", # PEP8 formatting
49
+ "I", # Import sorting
50
+ "TID251", # Banned API
51
+ "UP", # Pyupgrade
52
+ "W", # PEP8 warnings
53
+ ]
54
+ ignore = [
55
+ "E501", # Line length (handled by ruff-format)
56
+ "E741", # Ambiguous variable name
57
+ "W605", # Invalid escape sequence
58
+ "UP007", # X | Y type annotations
59
+ ]
60
+
61
+ [tool.ruff.lint.per-file-ignores]
62
+ "__init__.py" = [
63
+ "F401", # Ignore seemingly unused imports (they're meant for re-export)
64
+ ]
65
+
66
+ [tool.ruff.lint.isort]
67
+ lines-after-imports = 2
68
+ known-first-party = ["character_tuning"]
69
+
70
+ [tool.ruff.format]
71
+ # Like Black, use double quotes for strings.
72
+ quote-style = "double"
73
+
74
+ # Like Black, indent with spaces, rather than tabs.
75
+ indent-style = "space"
76
+
77
+ # Like Black, respect magic trailing commas.
78
+ skip-magic-trailing-comma = false
79
+
80
+ # Like Black, automatically detect the appropriate line ending.
81
+ line-ending = "auto"
82
+
83
+ # Enable auto-formatting of code examples in docstrings. Markdown,
84
+ # reStructuredText code/literal blocks and doctests are all supported.
85
+ #
86
+ # This is currently disabled by default, but it is planned for this
87
+ # to be opt-out in the future.
88
+ docstring-code-format = false
89
+
90
+ # Set the line length limit used when formatting code snippets in
91
+ # docstrings.
92
+ #
93
+ # This only has an effect when the `docstring-code-format` setting is
94
+ # enabled.
95
+ docstring-code-line-length = "dynamic"
96
+
97
+ [tool.ruff.lint.flake8-tidy-imports.banned-api]
98
+ "os.getenv".msg = "Use os.environ instead"
99
+ "os.putenv".msg = "Use os.environ instead"
100
+ "os.unsetenv".msg = "Use os.environ instead"
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ descript-audio-codec
2
+ torch==2.5.1
3
+ torchaudio==2.5.1
4
+ transformers>=4.45.1,<4.47.0
5
+ librosa
6
+ dacite
7
+ boto3==1.35.36
8
+ s3fs
9
+ json_repair
10
+ pandas
11
+ pydantic
12
+ vector_quantize_pytorch
13
+ loguru
14
+ pydub
15
+ ruff==0.12.2
16
+ omegaconf
17
+ click
theme.json ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "theme": {
3
+ "_font": [
4
+ {"__gradio_font__": true, "name": "Arial", "class": "font"},
5
+ {"__gradio_font__": true, "name": "sans-serif", "class": "font"}
6
+ ],
7
+ "_font_mono": [
8
+ {"__gradio_font__": true, "name": "Courier New", "class": "font"},
9
+ {"__gradio_font__": true, "name": "monospace", "class": "font"}
10
+ ],
11
+ "_stylesheets": [],
12
+ "background_fill_primary": "black",
13
+ "background_fill_primary_dark": "black",
14
+ "background_fill_secondary": "#1a1a1a",
15
+ "background_fill_secondary_dark": "#1a1a1a",
16
+ "block_background_fill": "#1a1a1a",
17
+ "block_background_fill_dark": "#1a1a1a",
18
+ "block_border_color": "#333333",
19
+ "block_border_color_dark": "#333333",
20
+ "block_border_width": "1px",
21
+ "block_info_text_color": "#cccccc",
22
+ "block_info_text_color_dark": "#cccccc",
23
+ "block_info_text_size": "*text_sm",
24
+ "block_info_text_weight": "600",
25
+ "block_label_background_fill": "#333333",
26
+ "block_label_background_fill_dark": "#333333",
27
+ "block_label_border_color": "#444444",
28
+ "block_label_border_color_dark": "#444444",
29
+ "block_label_border_width": "1px",
30
+ "block_label_margin": "*spacing_md",
31
+ "block_label_padding": "*spacing_sm *spacing_md",
32
+ "block_label_radius": "*radius_md",
33
+ "block_label_right_radius": "0 calc(*radius_lg - 1px) 0 calc(*radius_lg - 1px)",
34
+ "block_label_text_color": "yellow",
35
+ "block_label_text_color_dark": "yellow",
36
+ "block_label_text_size": "*text_md",
37
+ "block_label_text_weight": "700",
38
+ "block_padding": "*spacing_xl calc(*spacing_xl + 2px)",
39
+ "block_radius": "*radius_lg",
40
+ "block_shadow": "none",
41
+ "block_title_background_fill": "#333333",
42
+ "block_title_border_color": "none",
43
+ "block_title_border_width": "0px",
44
+ "block_title_padding": "*block_label_padding",
45
+ "block_title_radius": "*block_label_radius",
46
+ "block_title_text_color": "yellow",
47
+ "block_title_text_color_dark": "yellow",
48
+ "block_title_text_size": "*text_md",
49
+ "block_title_text_weight": "700",
50
+ "body_background_fill": "black",
51
+ "body_background_fill_dark": "black",
52
+ "body_text_color": "white",
53
+ "body_text_color_dark": "white",
54
+ "body_text_color_subdued": "#aaaaaa",
55
+ "body_text_color_subdued_dark": "#aaaaaa",
56
+ "body_text_size": "*text_md",
57
+ "body_text_weight": "600",
58
+ "border_color_accent": "yellow",
59
+ "border_color_accent_dark": "yellow",
60
+ "border_color_primary": "#333333",
61
+ "border_color_primary_dark": "#333333",
62
+ "button_border_width": "*input_border_width",
63
+ "button_border_width_dark": "*input_border_width",
64
+ "button_cancel_background_fill": "#333333",
65
+ "button_cancel_background_fill_dark": "#333333",
66
+ "button_cancel_background_fill_hover": "#444444",
67
+ "button_cancel_background_fill_hover_dark": "#444444",
68
+ "button_cancel_border_color": "#555555",
69
+ "button_cancel_border_color_dark": "#555555",
70
+ "button_cancel_border_color_hover": "#666666",
71
+ "button_cancel_border_color_hover_dark": "#666666",
72
+ "button_cancel_text_color": "white",
73
+ "button_cancel_text_color_dark": "white",
74
+ "button_cancel_text_color_hover": "white",
75
+ "button_cancel_text_color_hover_dark": "white",
76
+ "button_large_padding": "*spacing_lg calc(2 * *spacing_lg)",
77
+ "button_large_radius": "*radius_lg",
78
+ "button_large_text_size": "*text_lg",
79
+ "button_large_text_weight": "700",
80
+ "button_primary_background_fill": "yellow",
81
+ "button_primary_background_fill_dark": "yellow",
82
+ "button_primary_background_fill_hover": "#ffff80",
83
+ "button_primary_background_fill_hover_dark": "#ffff80",
84
+ "button_primary_border_color": "yellow",
85
+ "button_primary_border_color_dark": "yellow",
86
+ "button_primary_border_color_hover": "#ffff80",
87
+ "button_primary_border_color_hover_dark": "#ffff80",
88
+ "button_primary_text_color": "black",
89
+ "button_primary_text_color_dark": "black",
90
+ "button_primary_text_color_hover": "black",
91
+ "button_primary_text_color_hover_dark": "black",
92
+ "button_secondary_background_fill": "#333333",
93
+ "button_secondary_background_fill_dark": "#333333",
94
+ "button_secondary_background_fill_hover": "#444444",
95
+ "button_secondary_background_fill_hover_dark": "#444444",
96
+ "button_secondary_border_color": "#555555",
97
+ "button_secondary_border_color_dark": "#555555",
98
+ "button_secondary_border_color_hover": "#666666",
99
+ "button_secondary_border_color_hover_dark": "#666666",
100
+ "button_secondary_text_color": "white",
101
+ "button_secondary_text_color_dark": "white",
102
+ "button_secondary_text_color_hover": "white",
103
+ "button_secondary_text_color_hover_dark": "white",
104
+ "button_shadow": "*shadow_drop_lg",
105
+ "button_shadow_active": "*shadow_inset",
106
+ "button_shadow_hover": "*shadow_drop_lg",
107
+ "button_small_padding": "*spacing_sm calc(2 * *spacing_sm)",
108
+ "button_small_radius": "*radius_lg",
109
+ "button_small_text_size": "*text_md",
110
+ "button_small_text_weight": "700",
111
+ "button_transition": "background-color 0.2s ease",
112
+ "checkbox_background_color": "#1a1a1a",
113
+ "checkbox_background_color_dark": "#1a1a1a",
114
+ "checkbox_background_color_focus": "#1a1a1a",
115
+ "checkbox_background_color_focus_dark": "#1a1a1a",
116
+ "checkbox_background_color_hover": "#1a1a1a",
117
+ "checkbox_background_color_hover_dark": "#1a1a1a",
118
+ "checkbox_background_color_selected": "yellow",
119
+ "checkbox_background_color_selected_dark": "yellow",
120
+ "checkbox_border_color": "#333333",
121
+ "checkbox_border_color_dark": "#333333",
122
+ "checkbox_border_color_focus": "yellow",
123
+ "checkbox_border_color_focus_dark": "yellow",
124
+ "checkbox_border_color_hover": "#444444",
125
+ "checkbox_border_color_hover_dark": "#444444",
126
+ "checkbox_border_color_selected": "yellow",
127
+ "checkbox_border_color_selected_dark": "yellow",
128
+ "checkbox_border_radius": "*radius_sm",
129
+ "checkbox_border_width": "1px",
130
+ "checkbox_border_width_dark": "*input_border_width",
131
+ "checkbox_check": "url(\"data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='black' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e\")",
132
+ "checkbox_label_background_fill": "#333333",
133
+ "checkbox_label_background_fill_dark": "#333333",
134
+ "checkbox_label_background_fill_hover": "#444444",
135
+ "checkbox_label_background_fill_hover_dark": "#444444",
136
+ "checkbox_label_background_fill_selected": "yellow",
137
+ "checkbox_label_background_fill_selected_dark": "yellow",
138
+ "checkbox_label_border_color": "#555555",
139
+ "checkbox_label_border_color_dark": "#555555",
140
+ "checkbox_label_border_color_hover": "#666666",
141
+ "checkbox_label_border_color_hover_dark": "#666666",
142
+ "checkbox_label_border_width": "*input_border_width",
143
+ "checkbox_label_border_width_dark": "*input_border_width",
144
+ "checkbox_label_gap": "*spacing_lg",
145
+ "checkbox_label_padding": "*spacing_md calc(2 * *spacing_md)",
146
+ "checkbox_label_shadow": "*shadow_drop_lg",
147
+ "checkbox_label_text_color": "white",
148
+ "checkbox_label_text_color_dark": "white",
149
+ "checkbox_label_text_color_selected": "black",
150
+ "checkbox_label_text_color_selected_dark": "black",
151
+ "checkbox_label_text_size": "*text_md",
152
+ "checkbox_label_text_weight": "700",
153
+ "checkbox_shadow": "none",
154
+ "color_accent": "yellow",
155
+ "color_accent_soft": "#333300",
156
+ "color_accent_soft_dark": "#333300",
157
+ "container_radius": "*radius_lg",
158
+ "embed_radius": "*radius_lg",
159
+ "error_background_fill": "#330000",
160
+ "error_background_fill_dark": "#330000",
161
+ "error_border_color": "#660000",
162
+ "error_border_color_dark": "#660000",
163
+ "error_border_width": "1px",
164
+ "error_text_color": "#ff6666",
165
+ "error_text_color_dark": "#ff6666",
166
+ "font": "'Arial', 'sans-serif'",
167
+ "font_mono": "'Courier New', 'monospace'",
168
+ "form_gap_width": "0px",
169
+ "input_background_fill": "#1a1a1a",
170
+ "input_background_fill_dark": "#1a1a1a",
171
+ "input_background_fill_focus": "#333333",
172
+ "input_background_fill_focus_dark": "#333333",
173
+ "input_background_fill_hover": "#1a1a1a",
174
+ "input_background_fill_hover_dark": "#1a1a1a",
175
+ "input_border_color": "#333333",
176
+ "input_border_color_dark": "#333333",
177
+ "input_border_color_focus": "yellow",
178
+ "input_border_color_focus_dark": "yellow",
179
+ "input_border_color_hover": "#444444",
180
+ "input_border_color_hover_dark": "#444444",
181
+ "input_border_width": "1px",
182
+ "input_padding": "*spacing_xl",
183
+ "input_placeholder_color": "#666666",
184
+ "input_placeholder_color_dark": "#666666",
185
+ "input_radius": "*radius_lg",
186
+ "input_shadow": "none",
187
+ "input_shadow_focus": "0 0 0 2px rgba(255,255,0,0.2)",
188
+ "input_text_size": "*text_md",
189
+ "input_text_weight": "600",
190
+ "layout_gap": "*spacing_xxl",
191
+ "link_text_color": "yellow",
192
+ "link_text_color_active": "#ffff80",
193
+ "link_text_color_active_dark": "#ffff80",
194
+ "link_text_color_dark": "yellow",
195
+ "link_text_color_hover": "#ffff80",
196
+ "link_text_color_hover_dark": "#ffff80",
197
+ "link_text_color_visited": "yellow",
198
+ "link_text_color_visited_dark": "yellow",
199
+ "loader_color": "yellow",
200
+ "neutral_100": "#1a1a1a",
201
+ "neutral_200": "#333333",
202
+ "neutral_300": "#444444",
203
+ "neutral_400": "#666666",
204
+ "neutral_50": "#0d0d0d",
205
+ "neutral_500": "#808080",
206
+ "neutral_600": "#999999",
207
+ "neutral_700": "#b3b3b3",
208
+ "neutral_800": "#cccccc",
209
+ "neutral_900": "#e6e6e6",
210
+ "neutral_950": "#f2f2f2",
211
+ "panel_background_fill": "#1a1a1a",
212
+ "panel_background_fill_dark": "#1a1a1a",
213
+ "panel_border_color": "#333333",
214
+ "panel_border_color_dark": "#333333",
215
+ "panel_border_width": "1px",
216
+ "primary_100": "#333300",
217
+ "primary_200": "#666600",
218
+ "primary_300": "#999900",
219
+ "primary_400": "#cccc00",
220
+ "primary_50": "#1a1a00",
221
+ "primary_500": "yellow",
222
+ "primary_600": "#ffff33",
223
+ "primary_700": "#ffff66",
224
+ "primary_800": "#ffff99",
225
+ "primary_900": "#ffffcc",
226
+ "primary_950": "#ffffe6",
227
+ "prose_header_text_weight": "700",
228
+ "prose_text_size": "*text_md",
229
+ "prose_text_weight": "600",
230
+ "radio_circle": "url(\"data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='black' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='8' cy='8' r='3'/%3e%3c/svg%3e\")",
231
+ "radius_lg": "6px",
232
+ "radius_md": "4px",
233
+ "radius_sm": "2px",
234
+ "radius_xl": "8px",
235
+ "radius_xs": "1px",
236
+ "radius_xxl": "12px",
237
+ "radius_xxs": "1px",
238
+ "secondary_100": "#333333",
239
+ "secondary_200": "#444444",
240
+ "secondary_300": "#555555",
241
+ "secondary_400": "#666666",
242
+ "secondary_50": "#1a1a1a",
243
+ "secondary_500": "#777777",
244
+ "secondary_600": "#888888",
245
+ "secondary_700": "#999999",
246
+ "secondary_800": "#aaaaaa",
247
+ "secondary_900": "#bbbbbb",
248
+ "secondary_950": "#cccccc",
249
+ "section_header_text_size": "*text_md",
250
+ "section_header_text_weight": "700",
251
+ "shadow_drop": "0 1px 4px 0 rgba(255,255,0,0.1)",
252
+ "shadow_drop_lg": "0 2px 5px 0 rgba(255,255,0,0.2)",
253
+ "shadow_inset": "rgba(255,255,0,0.1) 0px 2px 4px 0px inset",
254
+ "shadow_spread": "6px",
255
+ "shadow_spread_dark": "1px",
256
+ "slider_color": "yellow",
257
+ "slider_color_dark": "yellow",
258
+ "spacing_lg": "8px",
259
+ "spacing_md": "6px",
260
+ "spacing_sm": "4px",
261
+ "spacing_xl": "10px",
262
+ "spacing_xs": "2px",
263
+ "spacing_xxl": "16px",
264
+ "spacing_xxs": "1px",
265
+ "stat_background_fill": "#333300",
266
+ "stat_background_fill_dark": "#333300",
267
+ "table_border_color": "#333333",
268
+ "table_border_color_dark": "#333333",
269
+ "table_even_background_fill": "#1a1a1a",
270
+ "table_even_background_fill_dark": "#1a1a1a",
271
+ "table_odd_background_fill": "#0d0d0d",
272
+ "table_odd_background_fill_dark": "#0d0d0d",
273
+ "table_radius": "*radius_lg",
274
+ "table_row_focus": "#333300",
275
+ "table_row_focus_dark": "#333300",
276
+ "text_lg": "16px",
277
+ "text_md": "14px",
278
+ "text_sm": "12px",
279
+ "text_xl": "22px",
280
+ "text_xs": "10px",
281
+ "text_xxl": "26px",
282
+ "text_xxs": "9px"
283
+ },
284
+ "version": "1.0.0"
285
+ }
voice_examples/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "belinda": {
3
+ "transcript": "Twas the night before my birthday. Hooray! It's almost here! It may not be a holiday, but it's the best day of the year.",
4
+ "audio_file": "belinda.wav"
5
+ },
6
+ "broom_salesman": {
7
+ "transcript": "I would imagine so. A wand with a dragon heartstring core is capable of dazzling magic. And the bond between you and your wand should only grow stronger. Do not be surprised at your new wand's ability to perceive your intentions - particularly in a moment of need.",
8
+ "audio_file": "broom_salesman.wav"
9
+ },
10
+ "chadwick": {
11
+ "transcript": "Oh dear, who left all this junk lying around? Whoops, there it goes! Mind your pointed little pink head, starfish man.",
12
+ "audio_file": "chadwick.wav"
13
+ },
14
+ "en_man": {
15
+ "transcript": "Maintaining your ability to learn translates into increased marketability, improved career options and higher salaries.",
16
+ "audio_file": "en_man.wav"
17
+ },
18
+ "en_woman": {
19
+ "transcript": "The device would work during the day as well, if you took steps to either block direct sunlight or point it away from the sun.",
20
+ "audio_file": "en_woman.wav"
21
+ },
22
+ "mabel": {
23
+ "transcript": "You do talk an awful lot about weather, did you know that? Sometimes I wonder if you're actually content to be a wizard or if you're secretly harbouring a desire to become a seer of the clouds.",
24
+ "audio_file": "mabel.wav"
25
+ },
26
+ "vex": {
27
+ "transcript": "Uhh, this is going to take forever. Why is everything so far?",
28
+ "audio_file": "vex.wav"
29
+ },
30
+ }