YoussefKejue commited on
Commit
195cb4d
·
verified ·
1 Parent(s): a4f5d1f

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
added_tokens.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|audio|>": 151669,
9
+ "<|box_end|>": 151649,
10
+ "<|box_start|>": 151648,
11
+ "<|endoftext|>": 151643,
12
+ "<|file_sep|>": 151664,
13
+ "<|fim_middle|>": 151660,
14
+ "<|fim_pad|>": 151662,
15
+ "<|fim_prefix|>": 151659,
16
+ "<|fim_suffix|>": 151661,
17
+ "<|im_end|>": 151645,
18
+ "<|im_start|>": 151644,
19
+ "<|image_pad|>": 151655,
20
+ "<|object_ref_end|>": 151647,
21
+ "<|object_ref_start|>": 151646,
22
+ "<|quad_end|>": 151651,
23
+ "<|quad_start|>": 151650,
24
+ "<|repo_name|>": 151663,
25
+ "<|video_pad|>": 151656,
26
+ "<|vision_end|>": 151653,
27
+ "<|vision_pad|>": 151654,
28
+ "<|vision_start|>": 151652
29
+ }
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "UltravoxModel"
4
+ ],
5
+ "audio_latency_block_size": null,
6
+ "audio_model_id": "openai/whisper-large-v3-turbo",
7
+ "audio_model_lora_config": {
8
+ "lora_alpha": 8,
9
+ "r": 0,
10
+ "target_modules": [
11
+ "k_proj",
12
+ "q_proj",
13
+ "linear_k",
14
+ "linear_q"
15
+ ]
16
+ },
17
+ "audio_token_index": 151669,
18
+ "auto_map": {
19
+ "AutoConfig": "ultravox_config.UltravoxConfig",
20
+ "AutoModel": "ultravox_model.UltravoxModel"
21
+ },
22
+ "hidden_size": 4096,
23
+ "ignore_index": -100,
24
+ "initializer_range": 0.02,
25
+ "llm_only_training": false,
26
+ "model_type": "ultravox",
27
+ "norm_init": 0.4,
28
+ "pad_token_id": 151645,
29
+ "projector_act": "swiglu",
30
+ "projector_ln_mid": true,
31
+ "stack_factor": 8,
32
+ "text_model_id": "Qwen/Qwen3-32B",
33
+ "text_model_lora_config": {
34
+ "lora_alpha": 8,
35
+ "r": 0,
36
+ "target_modules": [
37
+ "k_proj",
38
+ "q_proj",
39
+ "linear_k",
40
+ "linear_q"
41
+ ]
42
+ },
43
+ "torch_dtype": "bfloat16",
44
+ "transformers_version": "4.51.3",
45
+ "vocab_size": 151936
46
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 151643,
4
+ "eos_token_id": 151645,
5
+ "pad_token_id": 151645,
6
+ "transformers_version": "4.51.3"
7
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e8c23828793ed1ed3d7f19be86b2a8b0aaa9349bc037aaab5fa6cbc49b1b023
3
+ size 1378876648
preprocessor_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "ultravox_processing.UltravoxProcessor"
4
+ },
5
+ "chunk_length": 30,
6
+ "dither": 0.0,
7
+ "feature_extractor_type": "WhisperFeatureExtractor",
8
+ "feature_size": 128,
9
+ "hop_length": 160,
10
+ "n_fft": 400,
11
+ "n_samples": 480000,
12
+ "nb_max_frames": 3000,
13
+ "padding_side": "right",
14
+ "padding_value": 0.0,
15
+ "processor_class": "UltravoxProcessor",
16
+ "return_attention_mask": false,
17
+ "sampling_rate": 16000
18
+ }
processor_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "audio_context_size": 3000,
3
+ "audio_padding": "longest",
4
+ "audio_placeholder": "<|audio|>",
5
+ "auto_map": {
6
+ "AutoProcessor": "ultravox_processing.UltravoxProcessor"
7
+ },
8
+ "encoder_ds_factor": 2,
9
+ "processor_class": "UltravoxProcessor",
10
+ "stack_factor": 8
11
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<|audio|>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ }
10
+ ],
11
+ "eos_token": {
12
+ "content": "<|im_end|>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "pad_token": "<|im_end|>"
19
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0422f49acb9b2b8861f8a8c9fc5883758d7be3e805fdc014ade10478ea8fb1c9
3
+ size 11422840
tokenizer_config.json ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ },
213
+ "151669": {
214
+ "content": "<|audio|>",
215
+ "lstrip": false,
216
+ "normalized": false,
217
+ "rstrip": false,
218
+ "single_word": false,
219
+ "special": true
220
+ }
221
+ },
222
+ "additional_special_tokens": [
223
+ "<|audio|>"
224
+ ],
225
+ "auto_map": {
226
+ "AutoProcessor": "ultravox_processing.UltravoxProcessor"
227
+ },
228
+ "bos_token": null,
229
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|im_end|>",
236
+ "processor_class": "UltravoxProcessor",
237
+ "split_special_tokens": false,
238
+ "tokenizer_class": "Qwen2Tokenizer",
239
+ "unk_token": null
240
+ }
ultravox_config.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import Enum
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import transformers
6
+
7
+
8
+ @dataclasses.dataclass
9
+ class LoraConfigSimplified:
10
+ """
11
+ Low Rank Approximation (LoRA) configuration.
12
+
13
+ Used for language and audio models separately.
14
+ """
15
+
16
+ # The rank of the approximation
17
+ r: int = 0
18
+ lora_alpha: float = 8
19
+ target_modules: Optional[List[str]] = dataclasses.field(
20
+ default_factory=lambda: ["k_proj", "q_proj", "linear_k", "linear_q"]
21
+ )
22
+ # A list of module names regex patterns to unfreeze. Only used if r == 0.
23
+ unfreeze_layers: Optional[List[str]] = None
24
+
25
+
26
+ class LossMaskType(str, Enum):
27
+ """Type of loss mask to use."""
28
+
29
+ LAST_ASSISTANT = "last_assistant"
30
+ """This applies the loss mask up until the last assistant token"""
31
+ ALL = "all" # This does not work with KL loss
32
+ """No loss mask, all inputs are used for loss"""
33
+ AFTER_AUDIO = "after_audio"
34
+ """Applies the loss mask up until the audio token"""
35
+
36
+
37
+ class LossFunction(str, Enum):
38
+ CrossEntropy = "ce"
39
+ KL_Divergence = "kl"
40
+
41
+
42
+ @dataclasses.dataclass
43
+ class LossConfig:
44
+ loss_function: LossFunction = LossFunction.CrossEntropy
45
+ kl_temperature: float = 2.0
46
+ # Number of tokens to ignore from the beginning of the sequence. Only used in LSM
47
+ initial_tokens_to_ignore: int = 0
48
+ # Weight for the EOT token KL loss
49
+ eot_loss_weight: float = 1.0
50
+
51
+ @property
52
+ def requires_alt_fields(self):
53
+ return self.loss_function == LossFunction.KL_Divergence
54
+
55
+
56
+ class UltravoxConfig(transformers.PretrainedConfig):
57
+ r"""
58
+ This is the configuration class to store the configuration of a [`UltravoxForConditionalGeneration`]. It is used to instantiate an
59
+ Ultravox model according to the specified arguments, defining the model architecture.
60
+
61
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
62
+ documentation from [`PretrainedConfig`] for more information.
63
+
64
+ Args:
65
+ audio_config (`WhisperConfig`, *optional*):
66
+ Custom audio config or dict
67
+ text_config (`Union[AutoConfig, dict]`, *optional*):
68
+ The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
69
+ ignore_index (`int`, *optional*, defaults to -100):
70
+ The ignore index for the loss function.
71
+ audio_token_index (`int`, *optional*, defaults to 32000):
72
+ The audio token index to encode the audio prompt.
73
+ stack_factor (`int`, *optional*, defaults to 8):
74
+ Audio downsampling factor for the multimodal projector.
75
+ norm_init (`float`, *optional*, defaults to 0.4):
76
+ The initialization value for the layer normalization.
77
+ projector_act (`str`, *optional*, defaults to `"swiglu"`):
78
+ The activation function used by the multimodal projector.
79
+ text_model_lora_config (`LoraConfigSimplified`, *optional*):
80
+ The LoRA configuration for finetuning the text model.
81
+ audio_model_lora_config (`LoraConfigSimplified`, *optional*):
82
+ The LoRA configuration for finetuning the audio model.
83
+ audio_latency_block_size (`int`, *optional*, defaults to `None`):
84
+ The latency block size for simulating audio streaming.
85
+
86
+
87
+ Example:
88
+
89
+ ```python
90
+ >>> from transformers import UltravoxModel, WhisperConfig, UltravoxConfig, LlamaConfig
91
+
92
+ >>> # Initializing an audio encoder config
93
+ >>> audio_config = WhisperConfig()
94
+
95
+ >>> # Initializing a Llama config
96
+ >>> text_config = LlamaConfig()
97
+
98
+ >>> # Initializing a default configuration
99
+ >>> configuration = UltravoxConfig(audio_config, text_config)
100
+
101
+ >>> # Initializing a completely untrained model from the configuration
102
+ >>> model = UltravoxModel(configuration)
103
+
104
+ >>> # Accessing the model configuration
105
+ >>> configuration = model.config
106
+
107
+ >>> # Initialize a model from pretrained checkpoints and random projector weights
108
+ >>> config = UltravoxConfig(audio_model_id="openai/whisper-tiny", text_model_id="meta-llama/Llama-2-7b-chat-hf")
109
+ ```"""
110
+
111
+ model_type = "ultravox"
112
+ is_composition = False
113
+
114
+ def __init__(
115
+ self,
116
+ audio_config: dict[str, Any] | transformers.PretrainedConfig | None = None,
117
+ text_config: dict[str, Any] | transformers.PretrainedConfig | None = None,
118
+ audio_model_id: str | None = None,
119
+ text_model_id: str | None = None,
120
+ llm_only_training: bool = False,
121
+ ignore_index: int = -100,
122
+ audio_token_index: int | None = None,
123
+ hidden_size: int = 4096,
124
+ stack_factor: int = 8,
125
+ norm_init: float = 0.4,
126
+ projector_act: str = "swiglu",
127
+ projector_ln_mid: bool = False, # defaults to False for compatibility with v0.4.1 and below
128
+ text_model_lora_config: LoraConfigSimplified | None = None,
129
+ audio_model_lora_config: LoraConfigSimplified | None = None,
130
+ audio_latency_block_size: int | None = None,
131
+ **kwargs,
132
+ ):
133
+ self.ignore_index = ignore_index
134
+
135
+ self.audio_model_id = audio_model_id
136
+ self.text_model_id = text_model_id
137
+
138
+ self.audio_token_index = audio_token_index
139
+
140
+ self.hidden_size = hidden_size
141
+ self.stack_factor = stack_factor
142
+ self.norm_init = norm_init
143
+ self.projector_act = projector_act
144
+ self.projector_ln_mid = projector_ln_mid
145
+ if text_model_id is not None:
146
+ text_config = transformers.AutoConfig.from_pretrained(text_model_id)
147
+ else:
148
+ text_config = text_config or {}
149
+ if isinstance(text_config, dict):
150
+ text_config = transformers.CONFIG_MAPPING[
151
+ text_config.get("model_type", "llama")
152
+ ](**text_config)
153
+
154
+ if audio_model_id is not None:
155
+ audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
156
+ else:
157
+ audio_config = audio_config or {}
158
+ if isinstance(audio_config, dict):
159
+ audio_config = transformers.CONFIG_MAPPING[
160
+ audio_config.get("model_type", "whisper")
161
+ ](**audio_config)
162
+
163
+ self.text_config = text_config
164
+ self.audio_config = audio_config
165
+
166
+ self.llm_only_training = llm_only_training
167
+ self.text_model_lora_config = (
168
+ text_model_lora_config
169
+ if isinstance(text_model_lora_config, dict)
170
+ else dataclasses.asdict(text_model_lora_config or LoraConfigSimplified())
171
+ )
172
+ self.audio_model_lora_config = (
173
+ audio_model_lora_config
174
+ if isinstance(audio_model_lora_config, dict)
175
+ else dataclasses.asdict(audio_model_lora_config or LoraConfigSimplified())
176
+ )
177
+ self.audio_latency_block_size = audio_latency_block_size
178
+
179
+ if hasattr(text_config, "text_config"):
180
+ text_config.vocab_size = text_config.text_config.vocab_size
181
+ text_config.hidden_size = text_config.text_config.hidden_size
182
+
183
+ self.vocab_size = text_config.vocab_size
184
+
185
+ self.initializer_range = getattr(text_config, "initializer_range", 0.02)
186
+
187
+ super().__init__(**kwargs)
188
+
189
+ def to_diff_dict(self) -> Dict[str, Any]:
190
+ diff_dict = super().to_diff_dict()
191
+
192
+ # remove text_config and audio_config if text_model_id and audio_model_id are present
193
+ if self.text_model_id is not None:
194
+ diff_dict.pop("text_config", None)
195
+ elif "text_config" in diff_dict:
196
+ diff_dict["text_config"].pop("_attn_implementation_autoset", None)
197
+
198
+ if self.audio_model_id is not None:
199
+ diff_dict.pop("audio_config", None)
200
+ elif "audio_config" in diff_dict:
201
+ diff_dict["audio_config"].pop("_attn_implementation_autoset", None)
202
+
203
+ return diff_dict
ultravox_model.py ADDED
@@ -0,0 +1,1003 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ from typing import Any, Dict, Generator, Optional, Set, Tuple, TypeVar, Union
4
+
5
+ import accelerate
6
+ import peft
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import transformers
11
+ import transformers.activations
12
+ import transformers.modeling_outputs
13
+ import transformers.models
14
+ from transformers.generation.utils import GenerationMixin
15
+ from transformers.models.whisper import modeling_whisper as whisper
16
+
17
+ # We must use relative import in this directory to allow uploading to HF Hub
18
+ # Even "from . import X" pattern doesn't work (undocumented and unclear why)
19
+ from .ultravox_config import LossConfig
20
+ from .ultravox_config import LossFunction
21
+ from .ultravox_config import UltravoxConfig
22
+
23
+ FROM_PRETRAINED_KWARGS = {}
24
+ SHARED_PRETRAINED_KWARGS = [
25
+ "tp_plan",
26
+ "device_map",
27
+ "torch_dtype",
28
+ "attn_implementation",
29
+ "use_flash_attention_2",
30
+ ]
31
+
32
+
33
+ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
34
+ """
35
+ The Ultravox model which consists of an audio encoder and a language model.
36
+
37
+ Audio input is processed by the audio encoder, then every `stack_factor` frames are stacked together and
38
+ projected to the language model's embedding space using a few linear layers.
39
+ The text is embedded by the language model as usual and then the audio and text embeddings are merged together.
40
+
41
+ A special token `<|audio|>` is used to indicate the start of the audio embeddings in the merged embeddings.
42
+
43
+ Parameters:
44
+ config: Model configuration class with all the parameters of the model.
45
+ """
46
+
47
+ config_class = UltravoxConfig
48
+ config: UltravoxConfig # for type hinting
49
+ # Usually we load encoder and LLM weights from a pretrained model separately, so they are allowed to be missing
50
+ _keys_to_ignore_on_load_missing = ["audio_tower.*", "language_model.*"]
51
+ # Since we have kwargs in forward, we need to set this to False, otherwise grad_accum_steps will cause incorrect train loss to be reported
52
+ # see https://github.com/huggingface/transformers/issues/35856 and https://github.com/huggingface/trl/pull/2615/files
53
+ accepts_loss_kwargs = False
54
+
55
+ def __init__(self, config: UltravoxConfig):
56
+ super().__init__(config)
57
+ self._register_load_state_dict_pre_hook(self._pre_load_state_dict_hook)
58
+
59
+ self.keep_params: Set[str] = set()
60
+ self.vocab_size = config.vocab_size
61
+
62
+ if not config.llm_only_training:
63
+ self.audio_tower = self._create_audio_tower(config)
64
+ self.multi_modal_projector = self._create_multi_modal_projector(config)
65
+ self.audio_tower_context_length = self.audio_tower.max_context_length
66
+
67
+ self.language_model = self._create_language_model(config)
68
+
69
+ if self.language_model._tied_weights_keys is not None:
70
+ self._tied_weights_keys = [
71
+ f"language_model.{k}" for k in self.language_model._tied_weights_keys
72
+ ]
73
+
74
+ # Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
75
+ # This would be something like ["LlamaDecoderLayer"] as we don't split audio encoder layers.
76
+ # FSDP throws an error if some of the layer types are not found in the model, and they need to be filted out.
77
+ # 1. Get the names the language model *wants* to keep intact
78
+ candidate_names = set(
79
+ getattr(self.language_model, "_no_split_modules", []) or []
80
+ )
81
+ # 2. Names that actually exist in the current model
82
+ present_names = {m.__class__.__name__ for m in self.modules()}
83
+ # 3. Keep only those that are both requested and present
84
+ self._no_split_modules = list(candidate_names & present_names)
85
+
86
+ self.loss_config = LossConfig()
87
+ self.post_init()
88
+
89
+ def _init_weights(self, module):
90
+ if module is self:
91
+ if self.config.text_model_id is not None:
92
+ self.language_model = self._create_language_model(self.config)
93
+ if self.config.audio_model_id is not None:
94
+ self.audio_tower = self._create_audio_tower(self.config)
95
+ elif module in self.language_model.modules():
96
+ pass
97
+ elif module in self.audio_tower.modules():
98
+ pass
99
+ else:
100
+ super()._init_weights(module)
101
+
102
+ @classmethod
103
+ def from_pretrained(cls, *args, **kwargs):
104
+ global FROM_PRETRAINED_KWARGS
105
+ FROM_PRETRAINED_KWARGS = {
106
+ k: v for k, v in kwargs.items() if k in SHARED_PRETRAINED_KWARGS
107
+ }
108
+ model = super().from_pretrained(*args, **kwargs)
109
+ FROM_PRETRAINED_KWARGS = {}
110
+ return model
111
+
112
+ def get_input_embeddings(self):
113
+ return self.language_model.get_input_embeddings()
114
+
115
+ def set_input_embeddings(self, value):
116
+ self.language_model.set_input_embeddings(value)
117
+
118
+ def get_output_embeddings(self):
119
+ return self.language_model.get_output_embeddings()
120
+
121
+ def set_output_embeddings(self, new_embeddings):
122
+ self.language_model.set_output_embeddings(new_embeddings)
123
+
124
+ def set_decoder(self, decoder):
125
+ self.language_model.set_decoder(decoder)
126
+
127
+ def get_decoder(self):
128
+ return self.language_model.get_decoder()
129
+
130
+ def tie_weights(self, **kwargs):
131
+ return self.language_model.tie_weights(**kwargs)
132
+
133
+ def set_loss_config(self, loss_config: LossConfig):
134
+ self.loss_config = loss_config
135
+
136
+ def _setup_cache(
137
+ self, cache_cls, max_batch_size: int, max_cache_len: Optional[int] = None
138
+ ):
139
+ self.language_model._setup_cache(cache_cls, max_batch_size, max_cache_len)
140
+
141
+ def _reorder_cache(self, past_key_values, beam_idx):
142
+ return self.language_model._reorder_cache(past_key_values, beam_idx)
143
+
144
+ def resize_token_embeddings(
145
+ self,
146
+ new_num_tokens: Optional[int] = None,
147
+ pad_to_multiple_of: Optional[int] = None,
148
+ ) -> nn.Embedding:
149
+ model_embeds = self.language_model.resize_token_embeddings(
150
+ new_num_tokens, pad_to_multiple_of
151
+ )
152
+ # update vocab size
153
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
154
+ self.config.vocab_size = model_embeds.num_embeddings
155
+ self.vocab_size = model_embeds.num_embeddings
156
+ return model_embeds
157
+
158
+ def _get_prediction_mask(
159
+ self, labels: Optional[torch.Tensor]
160
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
161
+ """Get boolean masks for positions where we want to compute KL divergence.
162
+
163
+ For each label position, we want the position before it since that's where
164
+ the model makes the prediction for that label.
165
+
166
+ Additionally, we want to identify the position right before the EOT token
167
+ (the last token with label != -100).
168
+
169
+ Args:
170
+ labels: Tensor of shape (B, T) where B is batch size and T is sequence length,
171
+ with -100 for masked positions and token ids for label positions
172
+
173
+ Returns:
174
+ Tuple containing:
175
+ - pred_mask: Boolean tensor of shape (B, T) that's True for positions where we want to compute KL divergence
176
+ - eot_mask: Boolean tensor of shape (B, T) that's True only for the last prediction position in each sequence
177
+ """
178
+ if labels is None:
179
+ raise ValueError("labels must be provided")
180
+
181
+ # Shift the label mask right by 1 along the sequence dimension
182
+ # This gives us positions where we make predictions for the next token
183
+ label_mask = labels != -100
184
+ pred_mask = torch.zeros_like(label_mask)
185
+ pred_mask[:, :-1] = label_mask[
186
+ :, 1:
187
+ ] # shift right by 1 along sequence dimension
188
+
189
+ # Create EOT mask - identify only the last prediction position in each sequence
190
+ eot_mask = torch.zeros_like(pred_mask)
191
+ batch_size = labels.shape[0]
192
+
193
+ for i in range(batch_size):
194
+ # Find positions where we make predictions
195
+ pred_positions = torch.where(pred_mask[i])[0]
196
+ if len(pred_positions) > 0:
197
+ # Only mark the last prediction position
198
+ eot_mask[i, pred_positions[-1]] = True
199
+
200
+ return pred_mask, eot_mask
201
+
202
+ def _compute_kl_loss(
203
+ self,
204
+ lm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
205
+ labels: Optional[torch.Tensor] = None,
206
+ past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
207
+ alt_input_ids: Optional[torch.Tensor] = None,
208
+ alt_attention_mask: Optional[torch.Tensor] = None,
209
+ alt_labels: Optional[torch.Tensor] = None,
210
+ **kwargs,
211
+ ):
212
+ # disable gradient computation for the teacher model
213
+ with torch.no_grad():
214
+ # compute the teacher (text-only) model's distribution
215
+ alt_inputs_embeds = self.get_input_embeddings().forward(alt_input_ids)
216
+ alt_lm_output = self.language_model.forward(
217
+ inputs_embeds=alt_inputs_embeds,
218
+ labels=alt_labels,
219
+ attention_mask=alt_attention_mask,
220
+ past_key_values=past_key_values,
221
+ **kwargs,
222
+ )
223
+
224
+ # Get prediction masks for regular tokens and EOT tokens
225
+ pred_mask, eot_mask = self._get_prediction_mask(labels)
226
+ alt_pred_mask, alt_eot_mask = self._get_prediction_mask(alt_labels)
227
+
228
+ # compute the KL divergence loss between the two models for regular tokens
229
+ kl_loss = F.kl_div(
230
+ F.log_softmax(
231
+ lm_output.logits[pred_mask] / self.loss_config.kl_temperature,
232
+ dim=-1,
233
+ ),
234
+ F.softmax(
235
+ alt_lm_output.logits[alt_pred_mask] / self.loss_config.kl_temperature,
236
+ dim=-1,
237
+ ),
238
+ reduction="batchmean",
239
+ )
240
+
241
+ # Compute the KL divergence loss for EOT token positions if any exist
242
+ if self.loss_config.eot_loss_weight > 0:
243
+ eot_loss = F.kl_div(
244
+ F.log_softmax(
245
+ lm_output.logits[eot_mask] / self.loss_config.kl_temperature,
246
+ dim=-1,
247
+ ),
248
+ F.softmax(
249
+ alt_lm_output.logits[alt_eot_mask]
250
+ / self.loss_config.kl_temperature,
251
+ dim=-1,
252
+ ),
253
+ reduction="batchmean",
254
+ )
255
+ kl_loss += self.loss_config.eot_loss_weight * eot_loss
256
+
257
+ return kl_loss
258
+
259
+ def _audio_iter(
260
+ self, audio_batch_size: torch.Tensor
261
+ ) -> Generator[Tuple[int, int], None, None]:
262
+ """
263
+ Iterate over the audio batch size and yield the batch index and audio index of each audio item.
264
+
265
+ Args:
266
+ audio_batch_size: A tensor of shape (B,) where B is the batch size.
267
+
268
+ Returns:
269
+ A generator that yields a tuple of (start index, length) for each audio item.
270
+ """
271
+ audio_index = 0
272
+ for i_b, batch_count in enumerate(audio_batch_size):
273
+ for _ in range(batch_count):
274
+ yield i_b, audio_index
275
+ audio_index += 1
276
+
277
+ def forward(
278
+ self,
279
+ input_ids: torch.Tensor,
280
+ audio_values: Optional[torch.Tensor] = None,
281
+ inputs_embeds: Optional[torch.Tensor] = None,
282
+ labels: Optional[torch.Tensor] = None,
283
+ attention_mask: Optional[torch.Tensor] = None,
284
+ audio_token_start_idx: Optional[torch.Tensor] = None,
285
+ audio_lens: Optional[torch.Tensor] = None,
286
+ audio_token_len: Optional[torch.Tensor] = None,
287
+ audio_batch_size: Optional[torch.Tensor] = None,
288
+ past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
289
+ # the alt_* fields are needed for KL divergence loss
290
+ alt_input_ids: Optional[torch.Tensor] = None,
291
+ alt_attention_mask: Optional[torch.Tensor] = None,
292
+ alt_labels: Optional[torch.Tensor] = None,
293
+ **kwargs,
294
+ ) -> transformers.modeling_outputs.CausalLMOutputWithPast:
295
+ """
296
+ Forward pass for the Ultravox model.
297
+
298
+ `input_ids` are the tokenized text input. They are embedded by the language model as usual.
299
+ `audio_values` are processed by the audio encoder and then every `stack_factor` frames are stacked together and
300
+ projected to the language model's embedding space using a few linear layers.
301
+ The audio and text embeddings are merged together. A special token `<|audio|>` is used to indicate the start
302
+ of the audio embeddings in the merged embeddings.
303
+
304
+ Args:
305
+ input_ids: The tokenized text input.
306
+ audio_values: The processed audio values.
307
+ inputs_embeds: The embeddings for the input tokens.
308
+ labels: The tokenized text labels.
309
+ attention_mask: The attention mask for the input.
310
+ position_ids: The position ids for the input.
311
+ past_key_values: The past key value cache for the language model attention layers.
312
+ **kwargs: Additional keyword arguments. Passed directly to the language model.
313
+ """
314
+ if inputs_embeds is None:
315
+ # B x T -> B x T x D
316
+ inputs_embeds = self.get_input_embeddings().forward(input_ids)
317
+
318
+ if audio_values is not None and len(audio_values) > 0:
319
+ inputs_embeds = self._prepare_audio_embeds(
320
+ inputs_embeds=inputs_embeds,
321
+ audio_values=audio_values,
322
+ audio_token_start_idx=audio_token_start_idx,
323
+ audio_lens=audio_lens,
324
+ audio_token_len=audio_token_len,
325
+ audio_batch_size=audio_batch_size,
326
+ )
327
+
328
+ lm_output = self.language_model.forward(
329
+ inputs_embeds=inputs_embeds,
330
+ labels=labels,
331
+ attention_mask=attention_mask,
332
+ past_key_values=past_key_values,
333
+ **kwargs,
334
+ )
335
+ if self.training:
336
+ if self.loss_config.loss_function == LossFunction.CrossEntropy:
337
+ pass
338
+ elif self.loss_config.loss_function == LossFunction.KL_Divergence:
339
+ lm_output.loss = self._compute_kl_loss(
340
+ lm_output=lm_output,
341
+ labels=labels,
342
+ past_key_values=past_key_values,
343
+ alt_input_ids=alt_input_ids,
344
+ alt_attention_mask=alt_attention_mask,
345
+ alt_labels=alt_labels,
346
+ **kwargs,
347
+ )
348
+ else:
349
+ raise ValueError(
350
+ f"Unsupported loss function: {self.loss_config.loss_function}"
351
+ )
352
+ return lm_output
353
+
354
+ def _prepare_audio_embeds(
355
+ self,
356
+ inputs_embeds: Optional[torch.Tensor] = None,
357
+ audio_values: Optional[torch.Tensor] = None,
358
+ audio_token_start_idx: Optional[torch.Tensor] = None,
359
+ audio_lens: Optional[torch.Tensor] = None,
360
+ audio_token_len: Optional[torch.Tensor] = None,
361
+ audio_batch_size: Optional[torch.Tensor] = None,
362
+ ) -> torch.Tensor:
363
+ assert (
364
+ inputs_embeds is not None
365
+ and audio_values is not None
366
+ and audio_token_start_idx is not None
367
+ and audio_token_len is not None
368
+ and audio_lens is not None
369
+ and audio_batch_size is not None
370
+ ), "inputs_embeds/audio_values/audio_token_start_idx/audio_token_len/audio_lens/audio_batch_size must be provided."
371
+ assert (
372
+ len(audio_token_start_idx)
373
+ == len(audio_token_len)
374
+ == len(audio_lens)
375
+ == len(audio_values)
376
+ ), "audio_token_start_idx/audio_token_len/audio_lens/audio_values must have the same batch size."
377
+ assert len(audio_batch_size) == len(
378
+ inputs_embeds
379
+ ), "audio_batch_size and inputs_embeds must have the same batch size."
380
+
381
+ # B x A/3200 x (D=max-audio-length-in-batch)
382
+ audio_tower_output = self.audio_tower.forward(
383
+ audio_values.to(self.audio_tower.dtype),
384
+ audio_len=audio_lens,
385
+ ).last_hidden_state
386
+ audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
387
+ audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
388
+
389
+ # combine audio and text embeddings
390
+ for i_b, i_a in self._audio_iter(audio_batch_size):
391
+ start_idx = audio_token_start_idx[i_a]
392
+ token_len = audio_token_len[i_a]
393
+ item_embedding = audio_embeds[i_a][:token_len]
394
+ inputs_embeds[i_b][start_idx : start_idx + token_len] = item_embedding
395
+
396
+ return inputs_embeds
397
+
398
+ def generate(
399
+ self,
400
+ input_ids: torch.Tensor,
401
+ audio_values: Optional[torch.Tensor] = None,
402
+ inputs_embeds: Optional[torch.Tensor] = None,
403
+ audio_token_start_idx: Optional[torch.Tensor] = None,
404
+ audio_lens: Optional[torch.Tensor] = None,
405
+ audio_token_len: Optional[torch.Tensor] = None,
406
+ audio_batch_size: Optional[torch.Tensor] = None,
407
+ **kwargs,
408
+ ) -> torch.Tensor:
409
+ if inputs_embeds is None:
410
+ inputs_embeds = self.get_input_embeddings().forward(input_ids)
411
+
412
+ if audio_values is not None and len(audio_values) > 0:
413
+ inputs_embeds = self._prepare_audio_embeds(
414
+ inputs_embeds=inputs_embeds,
415
+ audio_values=audio_values,
416
+ audio_token_start_idx=audio_token_start_idx,
417
+ audio_lens=audio_lens,
418
+ audio_token_len=audio_token_len,
419
+ audio_batch_size=audio_batch_size,
420
+ )
421
+
422
+ return self.language_model.generate(
423
+ input_ids=input_ids,
424
+ inputs_embeds=inputs_embeds,
425
+ **kwargs,
426
+ )
427
+
428
+ @classmethod
429
+ def _create_multi_modal_projector(
430
+ cls, config: UltravoxConfig
431
+ ) -> "UltravoxProjector":
432
+ projector = UltravoxProjector(config)
433
+ dtype = config.torch_dtype
434
+ if isinstance(dtype, str):
435
+ dtype = getattr(torch, dtype)
436
+ projector.to(dtype)
437
+ return projector
438
+
439
+ @classmethod
440
+ def _create_audio_tower(
441
+ cls, config: UltravoxConfig
442
+ ) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
443
+ # We probably don't want to pass tp_plan or device_map to the audio tower
444
+ # But potentially other kwargs can be passed in. TODO
445
+ kwargs = {"torch_dtype": config.torch_dtype}
446
+ if (
447
+ getattr(transformers.modeling_utils, "_init_weights", True)
448
+ and config.audio_model_id is not None
449
+ ):
450
+ if "whisper" in config.audio_model_id.lower():
451
+ audio_tower = ModifiedWhisperEncoder.from_pretrained(
452
+ config.audio_model_id, **kwargs
453
+ )
454
+ audio_tower.init_latency_mask(
455
+ config.audio_latency_block_size, dtype=config.torch_dtype
456
+ )
457
+ audio_tower.init_latency_mask(
458
+ config.audio_latency_block_size, dtype=config.torch_dtype
459
+ )
460
+ else:
461
+ assert config.audio_latency_block_size in (
462
+ None,
463
+ 0,
464
+ ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
465
+ audio_tower = transformers.AutoModel.from_pretrained(
466
+ config.audio_model_id, **kwargs
467
+ )
468
+ else:
469
+ with accelerate.init_empty_weights():
470
+ if "whisper" in config.audio_config._name_or_path.lower():
471
+ audio_tower = ModifiedWhisperEncoder(config.audio_config)
472
+ audio_tower.init_latency_mask(
473
+ config.audio_latency_block_size,
474
+ dtype=config.torch_dtype,
475
+ )
476
+ else:
477
+ assert config.audio_latency_block_size in (
478
+ None,
479
+ 0,
480
+ ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
481
+ # we only ever use from_config if the weights are retrained, hence initializing is not
482
+ # required. This makes the model quite creation faster since init on CPU is quite slow.
483
+ audio_tower = transformers.AutoModel.from_config(
484
+ config.audio_config, **kwargs
485
+ )
486
+
487
+ if isinstance(
488
+ audio_tower,
489
+ (transformers.Wav2Vec2BertModel, transformers.WhisperModel),
490
+ ):
491
+ # For these models we only need the encoder part
492
+ # Wav2Vec2BertModel -> Wav2Vec2BertEncoder
493
+ # WhisperModel -> WhisperEncoder
494
+ audio_tower = audio_tower.encoder
495
+
496
+ audio_tower = apply_lora(audio_tower, config.audio_model_lora_config)
497
+ return audio_tower
498
+
499
+ @classmethod
500
+ def _create_language_model(
501
+ cls, config: UltravoxConfig
502
+ ) -> transformers.LlamaForCausalLM:
503
+ if (
504
+ getattr(transformers.modeling_utils, "_init_weights", True)
505
+ and config.text_model_id is not None
506
+ ):
507
+ language_model = transformers.AutoModelForCausalLM.from_pretrained(
508
+ config.text_model_id,
509
+ **{
510
+ "attn_implementation": config.text_config._attn_implementation,
511
+ "torch_dtype": config.torch_dtype,
512
+ **FROM_PRETRAINED_KWARGS,
513
+ },
514
+ )
515
+ else:
516
+ with accelerate.init_empty_weights():
517
+ # we only ever use from_config if the weights are retrained, hence initializing is not
518
+ # required. This makes the model quite creation faster since init on CPU is quite slow.
519
+ language_model = transformers.AutoModelForCausalLM.from_config(
520
+ config.text_config,
521
+ attn_implementation=config.text_config._attn_implementation,
522
+ torch_dtype=config.torch_dtype,
523
+ )
524
+
525
+ language_model = apply_lora(language_model, config.text_model_lora_config)
526
+ return language_model
527
+
528
+ def merge_and_unload(self):
529
+ if isinstance(self.language_model, peft.PeftModel):
530
+ self.language_model = self.language_model.merge_and_unload()
531
+ # no need to download base language model weights anymore, so we can remove the id
532
+ self.config.text_model_id = None
533
+ self.keep_params.update(
534
+ set(
535
+ [
536
+ f"language_model.{name}"
537
+ for name, _ in self.language_model.named_parameters()
538
+ ]
539
+ )
540
+ )
541
+
542
+ if hasattr(self, "audio_tower") and isinstance(
543
+ self.audio_tower, peft.PeftModel
544
+ ):
545
+ self.audio_tower = self.audio_tower.merge_and_unload()
546
+ # no need to download base audio model weights anymore, so we can remove the id
547
+ self.config.audio_model_id = None
548
+ self.keep_params.update(
549
+ set(
550
+ [
551
+ f"audio_tower.{name}"
552
+ for name, _ in self.audio_tower.named_parameters()
553
+ ]
554
+ )
555
+ )
556
+
557
+ for param in ["text_model_lora_config", "audio_model_lora_config"]:
558
+ if hasattr(self.config, param):
559
+ delattr(self.config, param)
560
+
561
+ def push_to_hub(self, *args, **kwargs):
562
+ self.merge_and_unload()
563
+ return super().push_to_hub(*args, **kwargs)
564
+
565
+ def diff_state_dict(
566
+ self, state_dict: Optional[Dict[str, Any]] = None
567
+ ) -> Dict[str, Any]:
568
+ if state_dict is None:
569
+ state_dict = super().state_dict()
570
+
571
+ trainable_params = {k for k, v in self.named_parameters() if v.requires_grad}
572
+ # normalize the keys to match the original model
573
+ # Example: audio_tower.base_model.model.layers.0._fsdp_wrapped_module.self_attn.k_proj.lora_B.default.weight
574
+ trainable_params = {
575
+ k.replace("_fsdp_wrapped_module.", "") for k in trainable_params
576
+ }
577
+
578
+ state_dict = {
579
+ k: v
580
+ for k, v in state_dict.items()
581
+ if k in self.keep_params or k in trainable_params
582
+ }
583
+
584
+ return state_dict
585
+
586
+ def save_pretrained(
587
+ self, *args, state_dict: Optional[Dict[str, Any]] = None, **kwargs
588
+ ):
589
+ state_dict = self.diff_state_dict(state_dict)
590
+
591
+ super().save_pretrained(*args, state_dict=state_dict, **kwargs)
592
+
593
+ def _pre_load_state_dict_hook(self, state_dict: Dict[str, Any], *args, **kwargs):
594
+ self.keep_params.update(set(state_dict.keys()))
595
+
596
+ def print_trainable_parameters(self):
597
+ """
598
+ Prints the number of trainable parameters in the model (reuses Peft model's method)
599
+ """
600
+ count_params = peft.peft_model.PeftModel.get_nb_trainable_parameters
601
+
602
+ trainable_params, all_param = count_params(self)
603
+
604
+ logging.info(
605
+ f"trainable params: {trainable_params:,d} || all params: {all_param:,d}"
606
+ f" || trainable%: {100 * trainable_params / all_param:.1f}%"
607
+ )
608
+
609
+ lm_trainable_params, lm_all_params = count_params(self.language_model)
610
+ if hasattr(self, "audio_tower") and self.audio_tower is not None:
611
+ audio_trainable_params, audio_all_params = count_params(self.audio_tower)
612
+ else:
613
+ audio_trainable_params, audio_all_params = 0, 0
614
+
615
+ projector_trainable_params = (
616
+ trainable_params - lm_trainable_params - audio_trainable_params
617
+ )
618
+ projector_all_params = all_param - lm_all_params - audio_all_params
619
+
620
+ # Calculate percentages only if the total parameters are non-zero
621
+ audio_percent = (
622
+ 0.0
623
+ if audio_all_params == 0
624
+ else 100 * audio_trainable_params / audio_all_params
625
+ )
626
+ projector_percent = (
627
+ 0.0
628
+ if projector_all_params == 0
629
+ else 100 * projector_trainable_params / projector_all_params
630
+ )
631
+
632
+ logging.info(
633
+ f"Trainable%: "
634
+ f" LLM: {100 * lm_trainable_params / lm_all_params:.1f}%"
635
+ f" || Audio Encoder: {audio_percent:.1f}%"
636
+ f" || Projector: {projector_percent:.1f}%"
637
+ )
638
+
639
+
640
+ def get_checkpoint_files(
641
+ model_id: str,
642
+ ) -> tuple[list[str], dict | None, list[str]]:
643
+ resolved_archive_file = transformers.utils.cached_file(
644
+ model_id,
645
+ transformers.utils.SAFE_WEIGHTS_NAME,
646
+ _raise_exceptions_for_missing_entries=False,
647
+ )
648
+
649
+ if resolved_archive_file is not None:
650
+ # not sharded
651
+ sharded_metadata = None
652
+ state_dict = transformers.modeling_utils.load_state_dict(resolved_archive_file)
653
+ loaded_state_dict_keys = list(state_dict.keys())
654
+ else:
655
+ # sharded
656
+ resolved_archive_file = transformers.utils.cached_file(
657
+ model_id, transformers.utils.SAFE_WEIGHTS_INDEX_NAME
658
+ )
659
+ resolved_archive_file, sharded_metadata = (
660
+ transformers.modeling_utils.get_checkpoint_shard_files(
661
+ model_id,
662
+ resolved_archive_file,
663
+ )
664
+ )
665
+ loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
666
+
667
+ if isinstance(resolved_archive_file, str):
668
+ resolved_archive_file = [resolved_archive_file]
669
+
670
+ return resolved_archive_file, sharded_metadata, loaded_state_dict_keys
671
+
672
+
673
+ # TODO: refactor common parts to a shared module
674
+ def is_cache_empty(
675
+ past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]],
676
+ ) -> bool:
677
+ """
678
+ Check if the cache is empty.
679
+ """
680
+ if past_key_values is None:
681
+ return True
682
+ if isinstance(past_key_values, tuple):
683
+ return all(len(c) == 0 for c in past_key_values)
684
+ return past_key_values.get_seq_length() == 0
685
+
686
+
687
+ T = TypeVar("T", bound=torch.nn.Module)
688
+
689
+
690
+ def apply_lora(model: T, lora_config: dict) -> T:
691
+ """
692
+ Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
693
+ """
694
+ unfreeze_layers = lora_config.pop("unfreeze_layers", None)
695
+ lora_config = peft.LoraConfig(**lora_config or {})
696
+
697
+ if lora_config.r == 0:
698
+ # freeze the model entirely, except for the specified layers
699
+ for name, param in model.named_parameters():
700
+ if not unfreeze_layers or not any(
701
+ re.match(layer, name) for layer in unfreeze_layers
702
+ ):
703
+ param.requires_grad = False
704
+ else:
705
+ logging.info(f"Unfreezing layer: {name} with #{param.numel()} params")
706
+ else:
707
+ model = peft.get_peft_model(model, lora_config)
708
+
709
+ return model
710
+
711
+
712
+ class StackAudioFrames(nn.Module):
713
+ """
714
+ Stack the audio embedding frames to reduce the sequence length by a factor
715
+ of `stack_factor`.
716
+ """
717
+
718
+ def __init__(self, stack_factor: int = 8):
719
+ super().__init__()
720
+ self.stack_factor = stack_factor
721
+
722
+ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
723
+ B, T, C = audio_embeds.shape
724
+ T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
725
+ audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
726
+ B, T, C = audio_embeds.shape
727
+ audio_embeds = audio_embeds.view(
728
+ B, T // self.stack_factor, C * self.stack_factor
729
+ )
730
+ return audio_embeds
731
+
732
+
733
+ class RMSNorm(transformers.models.llama.modeling_llama.LlamaRMSNorm):
734
+ def __init__(self, hidden_size: int, init: float = 1, eps: float = 1e-6):
735
+ super().__init__(hidden_size=hidden_size, eps=eps)
736
+ self.weight.data.fill_(init)
737
+
738
+
739
+ class SwiGLU(nn.Module):
740
+ def forward(self, x):
741
+ x, gate = x.chunk(2, dim=-1)
742
+ return F.silu(gate) * x
743
+
744
+
745
+ class UltravoxProjector(nn.Module):
746
+ def __init__(self, config: UltravoxConfig):
747
+ super().__init__()
748
+ self.hidden_dim = config.hidden_size
749
+ self._pad_and_stack = StackAudioFrames(config.stack_factor)
750
+ dim_in = config.audio_config.hidden_size * config.stack_factor
751
+ self.ln_pre = RMSNorm(dim_in, init=config.norm_init)
752
+ self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
753
+ dim_mid = self.hidden_dim
754
+ self.act = transformers.activations.get_activation(config.projector_act)
755
+ dim_mid = dim_mid // 2 if config.projector_act == "swiglu" else dim_mid
756
+ dim_out = config.text_config.hidden_size
757
+ self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
758
+
759
+ # Ultravox v0.4.1 and below uses layer_norm after the second linear layer,
760
+ # while v0.5.0 and above uses layer_norm after the first linear layer.
761
+ if config.projector_ln_mid:
762
+ self.ln_mid: nn.Module = RMSNorm(dim_mid, init=config.norm_init)
763
+ self.ln_post: nn.Module = nn.Identity()
764
+ else:
765
+ self.ln_mid = nn.Identity()
766
+ self.ln_post = RMSNorm(dim_out, init=config.norm_init)
767
+
768
+ def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
769
+ """
770
+ Takes in audio features from the audio tower and projects them to the text model's embedding space.
771
+ It reduces the number of frames by a factor of `stack_factor` and increases the number of channels by the same factor.
772
+ If the number of audio frames are not a multiple of the stack factor, the last few frames will be padded with zeros.
773
+
774
+ Input shape:
775
+ audio_features: B, T*S, C
776
+ Output shape:
777
+ hidden_states: B, T, D
778
+ Where:
779
+ B: batch size
780
+ F: number of frames in the audio tower
781
+ T: number of output embeddings
782
+ T = ceil(F / S)
783
+ S: stack factor
784
+ C: number of channels out of the encoder (aka audio tower)
785
+ H: hidden size of the projector (config.hidden_size)
786
+ D: dimension of the text model (config.text_config.hidden_size)
787
+
788
+ """
789
+ # B, F, C -> B, T, C*S
790
+ audio_features = self._pad_and_stack(audio_features)
791
+ audio_features = self.ln_pre(audio_features)
792
+ # B, T, C*S -> B, T, H
793
+ hidden_states = self.linear_1(audio_features)
794
+ # B, T, H -> B, T, H/2 (assuming swiglu)
795
+ hidden_states = self.act(hidden_states)
796
+ hidden_states = self.ln_mid(hidden_states)
797
+ # B, T, H/2 -> B, T, D
798
+ hidden_states = self.linear_2(hidden_states)
799
+ hidden_states = self.ln_post(hidden_states)
800
+ return hidden_states
801
+
802
+
803
+ class ModifiedWhisperEncoder(
804
+ whisper.WhisperEncoder, transformers.modeling_utils.ModuleUtilsMixin
805
+ ):
806
+ """
807
+ Encoder portion of OpenAI's Whisper model.
808
+
809
+ This implementation is a slightly modified version of HF Transformers' Whisper Encoder, with only a few fixes:
810
+ 1. base_model_prefix updated to allow for doing `.from_pretrained` directly on the encoder
811
+ 2. allow less than 30 second of audio padding to be passed in:
812
+ - relaxed ValueError check for `input_features` length to be less than or equal to `expected_seq_length` instead of strictly equal
813
+ - embed_pos is now sliced to match the length of `inputs_embeds`
814
+
815
+ Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
816
+ """
817
+
818
+ base_model_prefix = "model.encoder"
819
+ _no_split_modules = ["WhisperEncoderLayer"]
820
+ _keys_to_ignore_on_load_unexpected = ["model.decoder.*"]
821
+
822
+ def __init__(self, config: transformers.WhisperConfig):
823
+ super().__init__(config)
824
+ self.config.is_decoder = False
825
+
826
+ @property
827
+ def max_context_length(self):
828
+ return (
829
+ self.config.max_source_positions
830
+ * self.conv1.stride[0]
831
+ * self.conv2.stride[0]
832
+ )
833
+
834
+ def init_latency_mask(
835
+ self, audio_latency_block_size: int | None, dtype: torch.dtype
836
+ ):
837
+ if audio_latency_block_size is None:
838
+ self.audio_streaming_mask = None
839
+ return
840
+
841
+ # Use max_context_length directly in the calculation
842
+ max_seqlen = self.max_context_length
843
+ assert (
844
+ max_seqlen > 0
845
+ ), f"maximum sequence length must be positive, got {max_seqlen}"
846
+ assert (
847
+ max_seqlen % audio_latency_block_size == 0
848
+ ), f"audio_latency_block_size {audio_latency_block_size} must divide {max_seqlen} evenly."
849
+ # Given the block size, we calculate number of blocks.
850
+ audio_latency_nblocks = max_seqlen // audio_latency_block_size
851
+ audio_streaming_mask = (
852
+ torch.tril(
853
+ torch.ones(audio_latency_nblocks, audio_latency_nblocks),
854
+ diagonal=0,
855
+ )
856
+ .repeat_interleave(audio_latency_block_size, dim=0)
857
+ .repeat_interleave(audio_latency_block_size, dim=1)
858
+ )
859
+ audio_streaming_mask = (1.0 - audio_streaming_mask) * torch.finfo(dtype).min
860
+ audio_streaming_mask = audio_streaming_mask[None, None, :, :]
861
+ self.register_buffer(
862
+ "audio_streaming_mask", audio_streaming_mask, persistent=False
863
+ )
864
+
865
+ def forward(
866
+ self,
867
+ input_features,
868
+ audio_len=None,
869
+ head_mask=None,
870
+ output_attentions=None,
871
+ output_hidden_states=None,
872
+ return_dict=None,
873
+ ):
874
+ expected_seq_length = self.max_context_length
875
+ if input_features.shape[-1] > expected_seq_length:
876
+ raise ValueError(
877
+ f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
878
+ )
879
+
880
+ output_attentions = (
881
+ output_attentions
882
+ if output_attentions is not None
883
+ else self.config.output_attentions
884
+ )
885
+ output_hidden_states = (
886
+ output_hidden_states
887
+ if output_hidden_states is not None
888
+ else self.config.output_hidden_states
889
+ )
890
+ return_dict = (
891
+ return_dict if return_dict is not None else self.config.use_return_dict
892
+ )
893
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
894
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
895
+
896
+ inputs_embeds = inputs_embeds.permute(0, 2, 1)
897
+ embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)]
898
+
899
+ hidden_states = inputs_embeds + embed_pos
900
+ hidden_states = nn.functional.dropout(
901
+ hidden_states, p=self.dropout, training=self.training
902
+ )
903
+
904
+ encoder_states = () if output_hidden_states else None
905
+ all_attentions = () if output_attentions else None
906
+
907
+ # Create attention mask based on audio lengths to mask out padding tokens
908
+ # For each sample in batch:
909
+ # - Convert raw audio length to feature length after convolutions
910
+ # - Create boolean mask that is True for valid positions and False for padding
911
+ # - Convert to extended attention mask format expected by transformer layers
912
+ # (1.0 for positions to attend to, large negative for positions to ignore)
913
+ # This masking ensures consistent behavior between training and inference
914
+ # by preventing the model from attending to padding tokens in both cases
915
+ attention_mask = None
916
+ if audio_len is not None:
917
+ audio_feature_len = self._get_feat_extract_output_lengths(audio_len)
918
+ max_seq_len = hidden_states.shape[1]
919
+ attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[
920
+ None, :
921
+ ].lt(audio_feature_len.view(-1, 1))
922
+ attention_mask = self.get_extended_attention_mask(
923
+ attention_mask,
924
+ None,
925
+ dtype=hidden_states.dtype,
926
+ )
927
+
928
+ if self.audio_streaming_mask is not None:
929
+ seqlen = hidden_states.size(-2)
930
+ if attention_mask is not None:
931
+ attention_mask = torch.minimum(
932
+ self.audio_streaming_mask[:, :, :seqlen, :seqlen], attention_mask
933
+ ) # merge
934
+ else:
935
+ attention_mask = self.audio_streaming_mask[:, :, :seqlen, :seqlen]
936
+ attention_mask = attention_mask.to(hidden_states.dtype)
937
+
938
+ # check if head_mask has a correct number of layers specified if desired
939
+ if head_mask is not None:
940
+ assert head_mask.size()[0] == (
941
+ len(self.layers)
942
+ ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
943
+
944
+ for idx, encoder_layer in enumerate(self.layers):
945
+ if output_hidden_states:
946
+ encoder_states = encoder_states + (hidden_states,)
947
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
948
+ to_drop = False
949
+ if self.training:
950
+ dropout_probability = torch.rand([])
951
+ if dropout_probability < self.layerdrop: # skip the layer
952
+ to_drop = True
953
+
954
+ if to_drop:
955
+ layer_outputs = (None, None)
956
+ else:
957
+ if self.gradient_checkpointing and self.training:
958
+ layer_outputs = self._gradient_checkpointing_func(
959
+ encoder_layer.__call__,
960
+ hidden_states,
961
+ attention_mask,
962
+ (head_mask[idx] if head_mask is not None else None),
963
+ output_attentions,
964
+ )
965
+ else:
966
+ layer_outputs = encoder_layer(
967
+ hidden_states,
968
+ attention_mask,
969
+ layer_head_mask=(
970
+ head_mask[idx] if head_mask is not None else None
971
+ ),
972
+ output_attentions=output_attentions,
973
+ )
974
+
975
+ hidden_states = layer_outputs[0]
976
+
977
+ if output_attentions:
978
+ all_attentions = all_attentions + (layer_outputs[1],)
979
+
980
+ hidden_states = self.layer_norm(hidden_states)
981
+ if output_hidden_states:
982
+ encoder_states = encoder_states + (hidden_states,)
983
+
984
+ if not return_dict:
985
+ return tuple(
986
+ v
987
+ for v in [hidden_states, encoder_states, all_attentions]
988
+ if v is not None
989
+ )
990
+ return transformers.modeling_outputs.BaseModelOutput(
991
+ last_hidden_state=hidden_states,
992
+ hidden_states=encoder_states,
993
+ attentions=all_attentions,
994
+ )
995
+
996
+
997
+ UltravoxConfig.register_for_auto_class()
998
+ UltravoxModel.register_for_auto_class()
999
+
1000
+ transformers.AutoConfig.register("ultravox", UltravoxConfig)
1001
+ transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
1002
+
1003
+ transformers.activations.ACT2FN["swiglu"] = SwiGLU
ultravox_processing.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import transformers
8
+
9
+ from .ultravox_config import UltravoxConfig
10
+
11
+
12
+ @dataclasses.dataclass
13
+ class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq):
14
+ # when enabled, the alt_input_ids, alt_attention_mask, and alt_labels fields are used for computing the KL loss in UltravoxModel
15
+ include_alt_fields: bool = False
16
+
17
+ def __call__(self, features, *args, **kwargs):
18
+ audio_values = [x for f in features for x in f.pop("audio_values", [])]
19
+ audio_lens = [x for f in features for x in f.pop("audio_lens", [])]
20
+ audio_token_len = [x for f in features for x in f.pop("audio_token_len", [])]
21
+ audio_token_start_idx = [
22
+ x for f in features for x in f.pop("audio_token_start_idx", [])
23
+ ]
24
+
25
+ if self.include_alt_fields:
26
+ # these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method
27
+ alt_features = [
28
+ {
29
+ "input_ids": f.pop("alt_input_ids"),
30
+ "attention_mask": f.pop("alt_attention_mask"),
31
+ "labels": f.pop("alt_labels"),
32
+ }
33
+ for f in features
34
+ ]
35
+
36
+ batch = super().__call__(features, *args, **kwargs)
37
+ if self.include_alt_fields:
38
+ alt_batch = super().__call__(alt_features, *args, **kwargs)
39
+ batch["alt_input_ids"] = alt_batch["input_ids"]
40
+ batch["alt_attention_mask"] = alt_batch["attention_mask"]
41
+ batch["alt_labels"] = alt_batch["labels"]
42
+
43
+ # Only process audio fields if we have non-empty audio values
44
+ if audio_values and len(audio_values) > 0 and len(audio_values[0]) > 0:
45
+ batch["audio_token_start_idx"] = torch.stack(audio_token_start_idx)
46
+ batch["audio_lens"] = torch.stack(audio_lens)
47
+ batch["audio_token_len"] = torch.stack(audio_token_len)
48
+ # Pad the last dimension of all audio_values to the same length, with 0s on the right.
49
+ max_len = max([x.shape[-1] for x in audio_values])
50
+ batch["audio_values"] = torch.stack(
51
+ [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
52
+ )
53
+ if self.tokenizer.padding_side == "left":
54
+ input_ids_lens = torch.LongTensor(
55
+ [f["input_ids"].shape[-1] for f in features]
56
+ )
57
+ displacement = batch["input_ids"].shape[-1] - input_ids_lens
58
+ displacement = displacement.repeat_interleave(
59
+ batch["audio_batch_size"].squeeze(-1)
60
+ )
61
+ batch["audio_token_start_idx"] += displacement.to(
62
+ batch["audio_token_start_idx"].device
63
+ )
64
+ return batch
65
+
66
+
67
+ class UltravoxProcessor(transformers.ProcessorMixin):
68
+ """
69
+ Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor.
70
+
71
+ Args:
72
+ audio_processor: The audio processor for the audio encoder.
73
+ tokenizer: The tokenizer for the language model.
74
+ """
75
+
76
+ attributes = ["audio_processor", "tokenizer"]
77
+ audio_processor_class = ("WhisperFeatureExtractor",)
78
+ tokenizer_class = (
79
+ "PreTrainedTokenizer",
80
+ "PreTrainedTokenizerFast",
81
+ )
82
+
83
+ tokenizer: transformers.PreTrainedTokenizerBase
84
+ audio_processor: transformers.ProcessorMixin
85
+
86
+ def __init__(
87
+ self,
88
+ audio_processor=None,
89
+ tokenizer=None,
90
+ audio_padding: str = "longest",
91
+ encoder_ds_factor: int = 2,
92
+ stack_factor: int = 8,
93
+ audio_placeholder: str = "<|audio|>",
94
+ # Defaults to whisper encoder context size
95
+ audio_context_size: Optional[int] = 3000,
96
+ ):
97
+ """
98
+ Args:
99
+ audio_processor: The audio processor for the audio encoder.
100
+ tokenizer: The tokenizer for the language model.
101
+ audio_padding: The padding strategy for the audio encoder.
102
+ stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
103
+ encoder_ds_factor: The downsampling factor of the audio encoder.
104
+ audio_placeholder: The placeholder for the audio in the text.
105
+ audio_context_size: The maximum number of frames that the audio encoder can handle.
106
+ """
107
+ self.audio_padding = audio_padding
108
+ self.encoder_ds_factor = encoder_ds_factor
109
+ self.stack_factor = stack_factor
110
+ self.audio_placeholder = audio_placeholder
111
+ self.audio_context_size = audio_context_size
112
+ assert (
113
+ tokenizer.eos_token is not None
114
+ ), "The tokenizer has no EOS token. Cannot recover."
115
+ self.vocab = tokenizer.get_vocab()
116
+ # VLLM currently relies on updating audio_token_replacement, hence to be safe
117
+ # we should not update it. This dependency should be removed in the future.
118
+ self.audio_token_replacement = tokenizer.eos_token
119
+ if tokenizer.pad_token_id is None:
120
+ tokenizer.pad_token_id = tokenizer.eos_token_id
121
+
122
+ # Use a dummy audio processor to satisfy the base class for text-only training
123
+ if audio_processor is None:
124
+ audio_processor = transformers.AutoProcessor.from_pretrained(
125
+ "openai/whisper-tiny"
126
+ )
127
+
128
+ # Extract feature extractor if a full processor was passed,
129
+ # as transformers 5.x expects a FeatureExtractionMixin for this attribute.
130
+ if hasattr(audio_processor, "feature_extractor"):
131
+ audio_processor = audio_processor.feature_extractor
132
+
133
+ super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
134
+
135
+ @classmethod
136
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
137
+ config: UltravoxConfig = transformers.AutoConfig.from_pretrained(
138
+ pretrained_model_name_or_path, **kwargs
139
+ )
140
+ audio_processor = transformers.AutoProcessor.from_pretrained(
141
+ config.audio_model_id
142
+ or config.audio_config._name_or_path
143
+ or "openai/whisper-tiny"
144
+ )
145
+
146
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
147
+ pretrained_model_name_or_path, **kwargs
148
+ )
149
+ tokenizer.padding_side = "left"
150
+ tokenizer.pad_token = tokenizer.eos_token
151
+
152
+ return cls(
153
+ audio_processor=audio_processor,
154
+ tokenizer=tokenizer,
155
+ stack_factor=config.stack_factor,
156
+ )
157
+
158
+ def _chunk_and_pad_audio(
159
+ self,
160
+ audio_values: torch.Tensor,
161
+ audio_lens: torch.Tensor,
162
+ include_audio_num_chunks: bool = False,
163
+ ) -> Dict[str, Any]:
164
+ """
165
+ Processes the audio batch by chunking any items in the batch according to the audio_context_size,
166
+ padding the last chunk if needed, and returns a dictionary with updated audio data.
167
+
168
+ Args:
169
+ audio_values (torch.Tensor): A tensor of audio values (e.g., in B, D, T format).
170
+ audio_lens (torch.Tensor): A tensor of audio lengths.
171
+
172
+ Returns:
173
+ Dict[str, Any]: Dictionary with the following keys:
174
+ - "audio_values": The concatenated audio tensor after chunking and padding.
175
+ - "audio_lens": Tensor of lengths for each chunk.
176
+ - "audio_is_continuation": Tensor of booleans indicating if the chunk is a continuation of the previous chunk.
177
+ - "audio_batch_size": A Tensor with one integer representing the number of chunks.
178
+
179
+ """
180
+ chunked_audio_values: List[torch.Tensor] = []
181
+ chunked_audio_lens: List[int] = []
182
+ is_continuation_list: List[bool] = []
183
+ num_chunks: List[int] = []
184
+ context_size = self.audio_context_size or audio_values.shape[-1]
185
+
186
+ for i in range(audio_values.shape[0]): # iterate over the batch
187
+ num_chunks.append(int(np.ceil(audio_lens[i] / context_size)))
188
+ for offset in range(0, audio_lens[i], context_size):
189
+ is_continuation = offset > 0
190
+ chunk = audio_values[i, :, offset : offset + context_size]
191
+ if is_continuation and chunk.shape[-1] < context_size:
192
+ # N.B. We only need to pad continuation chunks. If none of the samples require chunking, the
193
+ # batch might not (need to) be padded all the way to the audio_context_size, in which case
194
+ # we've already included the padding above. On the other hand, if we have any continuation
195
+ # chunks we know that the batch needs to be padded to audio_context_size because that's what
196
+ # we're slicing to.
197
+ chunk = F.pad(chunk, (0, context_size - chunk.shape[-1]))
198
+ chunked_audio_values.append(chunk)
199
+ chunked_audio_lens.append(
200
+ min(int(audio_lens[i].item()) - offset, context_size)
201
+ )
202
+ is_continuation_list.append(is_continuation)
203
+
204
+ data = {
205
+ "audio_values": torch.stack(chunked_audio_values, dim=0),
206
+ "audio_lens": torch.tensor(
207
+ chunked_audio_lens, dtype=torch.int64, device=audio_values.device
208
+ ),
209
+ "audio_is_continuation": torch.tensor(
210
+ is_continuation_list, dtype=torch.bool, device=audio_values.device
211
+ ),
212
+ "audio_batch_size": torch.tensor(
213
+ [len(chunked_audio_values)], device=audio_values.device
214
+ ),
215
+ }
216
+ if include_audio_num_chunks:
217
+ data["audio_num_chunks"] = torch.tensor(
218
+ num_chunks, dtype=torch.int64, device=audio_values.device
219
+ )
220
+ return data
221
+
222
+ def __call__(
223
+ self,
224
+ text: Optional[str] = None,
225
+ audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
226
+ audios: Optional[
227
+ Union[
228
+ List[Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor]
229
+ ]
230
+ ] = None,
231
+ sampling_rate: Optional[int] = None,
232
+ return_tensors: Optional[
233
+ Union[str, transformers.TensorType]
234
+ ] = transformers.TensorType.PYTORCH,
235
+ include_audio_num_chunks: bool = False,
236
+ **kwargs,
237
+ ) -> transformers.BatchFeature:
238
+ """
239
+ Main method to prepare for the model one text sequence and audio. This method forwards the `text`
240
+ and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
241
+ the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
242
+ audio processor's [`~WhisperProcessor.__call__`] if `audio` is not `None`. Please refer to the docstring
243
+ of the above two methods for more information.
244
+
245
+ Args:
246
+ text (`str`, `List[str]`):
247
+ The sequence to be encoded. Sequence can be a string or (pretokenized string).
248
+ audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
249
+ The audio to be prepared. Audio can be a single-channel (1-dimensional) NumPy array or PyTorch tensor.
250
+ audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
251
+ A list or two dimensional array of audio to be prepared.
252
+ sampling_rate (`int`, *optional*, defaults to 16000):
253
+ Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
254
+ you are doing.
255
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
256
+ If set, will return tensors of a particular framework. Acceptable values are:
257
+
258
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
259
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
260
+ - `'np'`: Return NumPy `np.ndarray` objects.
261
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
262
+
263
+ Returns:
264
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
265
+
266
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
267
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
268
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
269
+ `None`).
270
+ - **audio_values** -- Processed audio values to be fed to a model. Returned when `audio` is not `None`.
271
+ - **audio_token_len** -- Predicted number of audio frames: this value is guaranteed to be a close upper bound.
272
+ Returned when `audio` is not `None`.
273
+ - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
274
+ """
275
+ # TODO: Add support for multiple text inputs.
276
+ if audio is not None and audios is not None:
277
+ raise ValueError("Only one of `audio` or `audios` should be provided.")
278
+ elif audio is not None:
279
+ audios = audio if isinstance(audio, list) or audio.ndim == 2 else [audio]
280
+ elif audios is None:
281
+ audios = []
282
+
283
+ data = {}
284
+ audio_is_continuation = []
285
+ if len(audios) > 0:
286
+ audios = [x.numpy() if isinstance(x, torch.Tensor) else x for x in audios]
287
+
288
+ # Pad out each audio to at least 2 hops (the minimum required by the processor).
289
+ fe = getattr(self.audio_processor, "feature_extractor", self.audio_processor)
290
+ hop_length = fe.hop_length
291
+ audios = [
292
+ (
293
+ np.pad(x, (0, 2 * hop_length - len(x)), mode="constant")
294
+ if len(x) < 2 * hop_length
295
+ else x
296
+ )
297
+ for x in audios
298
+ ]
299
+
300
+ # Main audio processing. The processor is model-specific.
301
+ x: transformers.BatchFeature = self.audio_processor(
302
+ audios,
303
+ sampling_rate=sampling_rate,
304
+ padding="longest",
305
+ pad_to_multiple_of=hop_length, # The attention mask effectively gets padded to the hop length, so pad the audio to be consistent.
306
+ truncation=False,
307
+ return_attention_mask=True,
308
+ **kwargs,
309
+ )
310
+
311
+ data.update(
312
+ self._chunk_and_pad_audio(
313
+ audio_values=torch.as_tensor(
314
+ x.input_features if "input_features" in x else x.input_values
315
+ ),
316
+ audio_lens=torch.as_tensor(x.attention_mask).sum(-1),
317
+ include_audio_num_chunks=include_audio_num_chunks,
318
+ )
319
+ )
320
+
321
+ audio_is_continuation = data.pop("audio_is_continuation")
322
+ data["audio_token_len"] = torch.ceil(
323
+ data["audio_lens"] / (self.encoder_ds_factor * self.stack_factor)
324
+ ).to(dtype=torch.int)
325
+
326
+ if text is not None:
327
+ if not isinstance(text, str):
328
+ raise ValueError("Text must be a string. Batch mode not supported yet.")
329
+
330
+ # Special tokens like BOS should already have been added by the caller.
331
+ tokenized_parts = self.tokenizer(
332
+ text.split(
333
+ "<|audio|>" # The placeholder isn't part of the vocabulary, so split the text around it.
334
+ ),
335
+ add_special_tokens=False,
336
+ **kwargs,
337
+ )
338
+
339
+ audio_token_start_idx = []
340
+ placeholder_index = -1
341
+ split_input_ids = tokenized_parts["input_ids"]
342
+ input_ids: List[int] = []
343
+
344
+ audio_replacement_token_id = self.vocab[self.audio_token_replacement]
345
+
346
+ for i, token_len in enumerate(data.get("audio_token_len", [])):
347
+ if not audio_is_continuation[i]:
348
+ placeholder_index += 1
349
+ if placeholder_index >= len(split_input_ids):
350
+ raise ValueError(
351
+ f"Text contains too few audio placeholders. (Expected {len(audios)} placeholders)"
352
+ )
353
+
354
+ input_ids.extend(split_input_ids[placeholder_index])
355
+
356
+ audio_token_start_idx.append(len(input_ids))
357
+
358
+ input_ids.extend([audio_replacement_token_id] * token_len)
359
+
360
+ # Include any tokens after the last audio.
361
+ placeholder_index += 1
362
+ if placeholder_index != len(split_input_ids) - 1:
363
+ raise ValueError(
364
+ f"Text contains too many audio placeholders. (Expected {len(audios)} placeholders)"
365
+ )
366
+ input_ids.extend(split_input_ids[placeholder_index])
367
+
368
+ if "audio_token_len" in data:
369
+ data["audio_token_start_idx"] = torch.as_tensor(audio_token_start_idx)
370
+
371
+ data["input_ids"] = [input_ids]
372
+ data["attention_mask"] = [[1] * len(input_ids)]
373
+
374
+ # Ensure that there are no audio placeholders after the last audio.
375
+
376
+ return transformers.BatchFeature(data=data, tensor_type=return_tensors)
377
+
378
+ def batch_decode(self, *args, **kwargs):
379
+ return self.tokenizer.batch_decode(*args, **kwargs)
380
+
381
+ def decode(self, *args, **kwargs):
382
+ return self.tokenizer.decode(*args, **kwargs)
383
+
384
+ @property
385
+ def model_input_names(self):
386
+ tokenizer_input_names = self.tokenizer.model_input_names
387
+ audio_processor_input_names = self.audio_processor.model_input_names
388
+ return list(set(tokenizer_input_names + audio_processor_input_names))
389
+
390
+
391
+ UltravoxProcessor.register_for_auto_class()
392
+
393
+ transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)
vocab.json ADDED
The diff for this file is too large to render. See raw diff