littlebird13 commited on
Commit
9554dc7
·
verified ·
1 Parent(s): 7949de4

Upload folder using huggingface_hub

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
README.md CHANGED
@@ -1,14 +1,12 @@
1
  ---
2
- title: Qwen3 TTS
3
- emoji: 🏃
4
- colorFrom: green
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.3.0
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
 
12
  ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Qwen3-TTS Demo
3
+ emoji: 🎙️
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.33.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ suggested_hardware: zero-a10g
12
  ---
 
 
app.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Qwen3-TTS Gradio Demo for HuggingFace Spaces with Zero GPU
3
+ # Supports: Voice Design, Voice Clone (Base), TTS (CustomVoice)
4
+
5
+ import os
6
+ import spaces
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ from huggingface_hub import snapshot_download
11
+
12
+ from huggingface_hub import login
13
+ HF_TOKEN = os.environ.get('HF_TOKEN')
14
+ login(token=HF_TOKEN)
15
+
16
+ # Global model holders - keyed by (model_type, model_size)
17
+ loaded_models = {}
18
+
19
+ # Model size options
20
+ MODEL_SIZES = ["0.6B", "1.7B"]
21
+
22
+
23
+ def get_model_path(model_type: str, model_size: str) -> str:
24
+ """Get model path based on type and size."""
25
+ return snapshot_download(f"Qwen/Qwen3-TTS-12Hz-{model_size}-{model_type}")
26
+
27
+
28
+ def get_model(model_type: str, model_size: str):
29
+ """Get or load a model by type and size."""
30
+ global loaded_models
31
+ key = (model_type, model_size)
32
+ if key not in loaded_models:
33
+ from qwen_tts import Qwen3TTSModel
34
+ model_path = get_model_path(model_type, model_size)
35
+ loaded_models[key] = Qwen3TTSModel.from_pretrained(
36
+ model_path,
37
+ device_map="cuda",
38
+ dtype=torch.bfloat16,
39
+ token=HF_TOKEN,
40
+ )
41
+ return loaded_models[key]
42
+
43
+
44
+ def _normalize_audio(wav, eps=1e-12, clip=True):
45
+ """Normalize audio to float32 in [-1, 1] range."""
46
+ x = np.asarray(wav)
47
+
48
+ if np.issubdtype(x.dtype, np.integer):
49
+ info = np.iinfo(x.dtype)
50
+ if info.min < 0:
51
+ y = x.astype(np.float32) / max(abs(info.min), info.max)
52
+ else:
53
+ mid = (info.max + 1) / 2.0
54
+ y = (x.astype(np.float32) - mid) / mid
55
+ elif np.issubdtype(x.dtype, np.floating):
56
+ y = x.astype(np.float32)
57
+ m = np.max(np.abs(y)) if y.size else 0.0
58
+ if m > 1.0 + 1e-6:
59
+ y = y / (m + eps)
60
+ else:
61
+ raise TypeError(f"Unsupported dtype: {x.dtype}")
62
+
63
+ if clip:
64
+ y = np.clip(y, -1.0, 1.0)
65
+
66
+ if y.ndim > 1:
67
+ y = np.mean(y, axis=-1).astype(np.float32)
68
+
69
+ return y
70
+
71
+
72
+ def _audio_to_tuple(audio):
73
+ """Convert Gradio audio input to (wav, sr) tuple."""
74
+ if audio is None:
75
+ return None
76
+
77
+ if isinstance(audio, tuple) and len(audio) == 2 and isinstance(audio[0], int):
78
+ sr, wav = audio
79
+ wav = _normalize_audio(wav)
80
+ return wav, int(sr)
81
+
82
+ if isinstance(audio, dict) and "sampling_rate" in audio and "data" in audio:
83
+ sr = int(audio["sampling_rate"])
84
+ wav = _normalize_audio(audio["data"])
85
+ return wav, sr
86
+
87
+ return None
88
+
89
+
90
+ # Speaker and language choices for CustomVoice model
91
+ SPEAKERS = [
92
+ "Aiden", "Dylan", "Eric", "Ono_anna", "Ryan", "Serena", "Sohee", "Uncle_fu", "Vivian"
93
+ ]
94
+ LANGUAGES = ["Auto", "Chinese", "English", "Japanese", "Korean", "French", "German", "Spanish", "Portuguese", "Russian"]
95
+
96
+
97
+ @spaces.GPU(duration=120)
98
+ def generate_voice_design(text, language, voice_description):
99
+ """Generate speech using Voice Design model (1.7B only)."""
100
+ if not text or not text.strip():
101
+ return None, "Error: Text is required."
102
+ if not voice_description or not voice_description.strip():
103
+ return None, "Error: Voice description is required."
104
+
105
+ try:
106
+ tts = get_model("VoiceDesign", "1.7B")
107
+ wavs, sr = tts.generate_voice_design(
108
+ text=text.strip(),
109
+ language=language,
110
+ instruct=voice_description.strip(),
111
+ max_new_tokens=2048,
112
+ )
113
+ return (sr, wavs[0]), "Voice design generation completed successfully!"
114
+ except Exception as e:
115
+ return None, f"Error: {type(e).__name__}: {e}"
116
+
117
+
118
+ @spaces.GPU(duration=180)
119
+ def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector_only, model_size):
120
+ """Generate speech using Base (Voice Clone) model."""
121
+ if not target_text or not target_text.strip():
122
+ return None, "Error: Target text is required."
123
+
124
+ audio_tuple = _audio_to_tuple(ref_audio)
125
+ if audio_tuple is None:
126
+ return None, "Error: Reference audio is required."
127
+
128
+ if not use_xvector_only and (not ref_text or not ref_text.strip()):
129
+ return None, "Error: Reference text is required when 'Use x-vector only' is not enabled."
130
+
131
+ try:
132
+ tts = get_model("Base", model_size)
133
+ wavs, sr = tts.generate_voice_clone(
134
+ text=target_text.strip(),
135
+ language=language,
136
+ ref_audio=audio_tuple,
137
+ ref_text=ref_text.strip() if ref_text else None,
138
+ x_vector_only_mode=use_xvector_only,
139
+ max_new_tokens=2048,
140
+ )
141
+ return (sr, wavs[0]), "Voice clone generation completed successfully!"
142
+ except Exception as e:
143
+ return None, f"Error: {type(e).__name__}: {e}"
144
+
145
+
146
+ @spaces.GPU(duration=120)
147
+ def generate_custom_voice(text, language, speaker, instruct, model_size):
148
+ """Generate speech using CustomVoice model."""
149
+ if not text or not text.strip():
150
+ return None, "Error: Text is required."
151
+ if not speaker:
152
+ return None, "Error: Speaker is required."
153
+
154
+ try:
155
+ tts = get_model("CustomVoice", model_size)
156
+ wavs, sr = tts.generate_custom_voice(
157
+ text=text.strip(),
158
+ language=language,
159
+ speaker=speaker.lower().replace(" ", "_"),
160
+ instruct=instruct.strip() if instruct else None,
161
+ max_new_tokens=2048,
162
+ )
163
+ return (sr, wavs[0]), "Generation completed successfully!"
164
+ except Exception as e:
165
+ return None, f"Error: {type(e).__name__}: {e}"
166
+
167
+
168
+ # Build Gradio UI
169
+ def build_ui():
170
+ theme = gr.themes.Soft(
171
+ font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"],
172
+ )
173
+
174
+ css = """
175
+ .gradio-container {max-width: none !important;}
176
+ .tab-content {padding: 20px;}
177
+ """
178
+
179
+ with gr.Blocks(theme=theme, css=css, title="Qwen3-TTS Demo") as demo:
180
+ gr.Markdown(
181
+ """
182
+ # Qwen3-TTS Demo
183
+
184
+ A unified Text-to-Speech demo featuring three powerful modes:
185
+ - **Voice Design**: Create custom voices using natural language descriptions
186
+ - **Voice Clone (Base)**: Clone any voice from a reference audio
187
+ - **TTS (CustomVoice)**: Generate speech with predefined speakers and optional style instructions
188
+
189
+ Built with [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) by Alibaba Qwen Team.
190
+ """
191
+ )
192
+
193
+ with gr.Tabs():
194
+ # Tab 1: Voice Design (Default, 1.7B only)
195
+ with gr.Tab("Voice Design"):
196
+ gr.Markdown("### Create Custom Voice with Natural Language")
197
+ with gr.Row():
198
+ with gr.Column(scale=2):
199
+ design_text = gr.Textbox(
200
+ label="Text to Synthesize",
201
+ lines=4,
202
+ placeholder="Enter the text you want to convert to speech...",
203
+ value="It's in the top drawer... wait, it's empty? No way, that's impossible! I'm sure I put it there!"
204
+ )
205
+ design_language = gr.Dropdown(
206
+ label="Language",
207
+ choices=LANGUAGES,
208
+ value="Auto",
209
+ interactive=True,
210
+ )
211
+ design_instruct = gr.Textbox(
212
+ label="Voice Description",
213
+ lines=3,
214
+ placeholder="Describe the voice characteristics you want...",
215
+ value="Speak in an incredulous tone, but with a hint of panic beginning to creep into your voice."
216
+ )
217
+ design_btn = gr.Button("Generate with Custom Voice", variant="primary")
218
+
219
+ with gr.Column(scale=2):
220
+ design_audio_out = gr.Audio(label="Generated Audio", type="numpy")
221
+ design_status = gr.Textbox(label="Status", lines=2, interactive=False)
222
+
223
+ design_btn.click(
224
+ generate_voice_design,
225
+ inputs=[design_text, design_language, design_instruct],
226
+ outputs=[design_audio_out, design_status],
227
+ )
228
+
229
+ # Tab 2: Voice Clone (Base)
230
+ with gr.Tab("Voice Clone (Base)"):
231
+ gr.Markdown("### Clone Voice from Reference Audio")
232
+ with gr.Row():
233
+ with gr.Column(scale=2):
234
+ clone_ref_audio = gr.Audio(
235
+ label="Reference Audio (Upload a voice sample to clone)",
236
+ type="numpy",
237
+ )
238
+ clone_ref_text = gr.Textbox(
239
+ label="Reference Text (Transcript of the reference audio)",
240
+ lines=2,
241
+ placeholder="Enter the exact text spoken in the reference audio...",
242
+ )
243
+ clone_xvector = gr.Checkbox(
244
+ label="Use x-vector only (No reference text needed, but lower quality)",
245
+ value=False,
246
+ )
247
+
248
+ with gr.Column(scale=2):
249
+ clone_target_text = gr.Textbox(
250
+ label="Target Text (Text to synthesize with cloned voice)",
251
+ lines=4,
252
+ placeholder="Enter the text you want the cloned voice to speak...",
253
+ )
254
+ with gr.Row():
255
+ clone_language = gr.Dropdown(
256
+ label="Language",
257
+ choices=LANGUAGES,
258
+ value="Auto",
259
+ interactive=True,
260
+ )
261
+ clone_model_size = gr.Dropdown(
262
+ label="Model Size",
263
+ choices=MODEL_SIZES,
264
+ value="1.7B",
265
+ interactive=True,
266
+ )
267
+ clone_btn = gr.Button("Clone & Generate", variant="primary")
268
+
269
+ with gr.Row():
270
+ clone_audio_out = gr.Audio(label="Generated Audio", type="numpy")
271
+ clone_status = gr.Textbox(label="Status", lines=2, interactive=False)
272
+
273
+ clone_btn.click(
274
+ generate_voice_clone,
275
+ inputs=[clone_ref_audio, clone_ref_text, clone_target_text, clone_language, clone_xvector, clone_model_size],
276
+ outputs=[clone_audio_out, clone_status],
277
+ )
278
+
279
+ # Tab 3: TTS (CustomVoice)
280
+ with gr.Tab("TTS (CustomVoice)"):
281
+ gr.Markdown("### Text-to-Speech with Predefined Speakers")
282
+ with gr.Row():
283
+ with gr.Column(scale=2):
284
+ tts_text = gr.Textbox(
285
+ label="Text to Synthesize",
286
+ lines=4,
287
+ placeholder="Enter the text you want to convert to speech...",
288
+ value="Hello! Welcome to Text-to-Speech system. This is a demo of our TTS capabilities."
289
+ )
290
+ with gr.Row():
291
+ tts_language = gr.Dropdown(
292
+ label="Language",
293
+ choices=LANGUAGES,
294
+ value="Auto",
295
+ interactive=True,
296
+ )
297
+ tts_speaker = gr.Dropdown(
298
+ label="Speaker",
299
+ choices=SPEAKERS,
300
+ value="Vivian",
301
+ interactive=True,
302
+ )
303
+ with gr.Row():
304
+ tts_instruct = gr.Textbox(
305
+ label="Style Instruction (Optional)",
306
+ lines=2,
307
+ placeholder="e.g., Speak in a cheerful and energetic tone",
308
+ )
309
+ tts_model_size = gr.Dropdown(
310
+ label="Model Size",
311
+ choices=MODEL_SIZES,
312
+ value="1.7B",
313
+ interactive=True,
314
+ )
315
+ tts_btn = gr.Button("Generate Speech", variant="primary")
316
+
317
+ with gr.Column(scale=2):
318
+ tts_audio_out = gr.Audio(label="Generated Audio", type="numpy")
319
+ tts_status = gr.Textbox(label="Status", lines=2, interactive=False)
320
+
321
+ tts_btn.click(
322
+ generate_custom_voice,
323
+ inputs=[tts_text, tts_language, tts_speaker, tts_instruct, tts_model_size],
324
+ outputs=[tts_audio_out, tts_status],
325
+ )
326
+
327
+ gr.Markdown(
328
+ """
329
+ ---
330
+ ### Disclaimer
331
+ The audio is automatically generated by an AI model solely to demonstrate the model's capabilities.
332
+ It may be inaccurate or inappropriate and does not represent the views of the developer/operator.
333
+ Do not use this service to generate unlawful, harmful, or infringing content.
334
+
335
+ **Note**: This demo uses HuggingFace Spaces Zero GPU. Each generation has a time limit.
336
+ For longer texts, please split them into smaller segments.
337
+ """
338
+ )
339
+
340
+ return demo
341
+
342
+
343
+ if __name__ == "__main__":
344
+ demo = build_ui()
345
+ demo.queue(default_concurrency_limit=2).launch()
qwen_tts/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ qwen_tts: Qwen-TTS package.
19
+ """
20
+
21
+ from .inference.qwen3_tts_model import Qwen3TTSModel, VoiceClonePromptItem
22
+ from .inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
23
+
24
+ __all__ = ["__version__"]
25
+ __version__ = "0.0.1"
qwen_tts/__main__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ def main():
17
+ print(
18
+ "qwen_tts package.\n"
19
+ "Use CLI entrypoints:\n"
20
+ " - qwen-tts-demo\n"
21
+ )
22
+
23
+ if __name__ == "__main__":
24
+ main()
qwen_tts/cli/demo.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ A gradio demo for Qwen3 TTS models.
18
+ """
19
+
20
+ import argparse
21
+ import os
22
+ import tempfile
23
+ from dataclasses import asdict
24
+ from typing import Any, Dict, List, Optional, Tuple
25
+
26
+ import gradio as gr
27
+ import numpy as np
28
+ import torch
29
+
30
+ from .. import Qwen3TTSModel, VoiceClonePromptItem
31
+
32
+
33
+ def _title_case_display(s: str) -> str:
34
+ s = (s or "").strip()
35
+ s = s.replace("_", " ")
36
+ return " ".join([w[:1].upper() + w[1:] if w else "" for w in s.split()])
37
+
38
+
39
+ def _build_choices_and_map(items: Optional[List[str]]) -> Tuple[List[str], Dict[str, str]]:
40
+ if not items:
41
+ return [], {}
42
+ display = [_title_case_display(x) for x in items]
43
+ mapping = {d: r for d, r in zip(display, items)}
44
+ return display, mapping
45
+
46
+
47
+ def _dtype_from_str(s: str) -> torch.dtype:
48
+ s = (s or "").strip().lower()
49
+ if s in ("bf16", "bfloat16"):
50
+ return torch.bfloat16
51
+ if s in ("fp16", "float16", "half"):
52
+ return torch.float16
53
+ if s in ("fp32", "float32"):
54
+ return torch.float32
55
+ raise ValueError(f"Unsupported torch dtype: {s}. Use bfloat16/float16/float32.")
56
+
57
+
58
+ def _maybe(v):
59
+ return v if v is not None else gr.update()
60
+
61
+
62
+ def build_parser() -> argparse.ArgumentParser:
63
+ parser = argparse.ArgumentParser(
64
+ prog="qwen-tts-demo",
65
+ description=(
66
+ "Launch a Gradio demo for Qwen3 TTS models (CustomVoice / VoiceDesign / Base).\n\n"
67
+ "Examples:\n"
68
+ " qwen-tts-demo Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice\n"
69
+ " qwen-tts-demo Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign --port 8000 --ip 127.0.0.01\n"
70
+ " qwen-tts-demo Qwen/Qwen3-TTS-12Hz-1.7B-Base --device cuda:0\n"
71
+ " qwen-tts-demo Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice --dtype bfloat16 --no-flash-attn\n"
72
+ ),
73
+ formatter_class=argparse.RawTextHelpFormatter,
74
+ add_help=True,
75
+ )
76
+
77
+ # Positional checkpoint (also supports -c/--checkpoint)
78
+ parser.add_argument(
79
+ "checkpoint_pos",
80
+ nargs="?",
81
+ default=None,
82
+ help="Model checkpoint path or HuggingFace repo id (positional).",
83
+ )
84
+ parser.add_argument(
85
+ "-c",
86
+ "--checkpoint",
87
+ default=None,
88
+ help="Model checkpoint path or HuggingFace repo id (optional if positional is provided).",
89
+ )
90
+
91
+ # Model loading / from_pretrained args
92
+ parser.add_argument(
93
+ "--device",
94
+ default="cuda:0",
95
+ help="Device for device_map, e.g. cpu, cuda, cuda:0 (default: cuda:0).",
96
+ )
97
+ parser.add_argument(
98
+ "--dtype",
99
+ default="bfloat16",
100
+ choices=["bfloat16", "bf16", "float16", "fp16", "float32", "fp32"],
101
+ help="Torch dtype for loading the model (default: bfloat16).",
102
+ )
103
+ parser.add_argument(
104
+ "--flash-attn/--no-flash-attn",
105
+ dest="flash_attn",
106
+ default=True,
107
+ action=argparse.BooleanOptionalAction,
108
+ help="Enable FlashAttention-2 (default: enabled).",
109
+ )
110
+
111
+ # Gradio server args
112
+ parser.add_argument(
113
+ "--ip",
114
+ default="0.0.0.0",
115
+ help="Server bind IP for Gradio (default: 0.0.0.0).",
116
+ )
117
+ parser.add_argument(
118
+ "--port",
119
+ type=int,
120
+ default=8000,
121
+ help="Server port for Gradio (default: 8000).",
122
+ )
123
+ parser.add_argument(
124
+ "--share/--no-share",
125
+ dest="share",
126
+ default=False,
127
+ action=argparse.BooleanOptionalAction,
128
+ help="Whether to create a public Gradio link (default: disabled).",
129
+ )
130
+ parser.add_argument(
131
+ "--concurrency",
132
+ type=int,
133
+ default=16,
134
+ help="Gradio queue concurrency (default: 16).",
135
+ )
136
+
137
+ # HTTPS args
138
+ parser.add_argument(
139
+ "--ssl-certfile",
140
+ default=None,
141
+ help="Path to SSL certificate file for HTTPS (optional).",
142
+ )
143
+ parser.add_argument(
144
+ "--ssl-keyfile",
145
+ default=None,
146
+ help="Path to SSL key file for HTTPS (optional).",
147
+ )
148
+ parser.add_argument(
149
+ "--ssl-verify",
150
+ default=None,
151
+ help="SSL verify setting for Gradio (optional).",
152
+ )
153
+
154
+ # Optional generation args
155
+ parser.add_argument("--max-new-tokens", type=int, default=None, help="Max new tokens for generation (optional).")
156
+ parser.add_argument("--temperature", type=float, default=None, help="Sampling temperature (optional).")
157
+ parser.add_argument("--top-k", type=int, default=None, help="Top-k sampling (optional).")
158
+ parser.add_argument("--top-p", type=float, default=None, help="Top-p sampling (optional).")
159
+ parser.add_argument("--repetition-penalty", type=float, default=None, help="Repetition penalty (optional).")
160
+ parser.add_argument("--subtalker-top-k", type=int, default=None, help="Subtalker top-k (optional, only for tokenizer v2).")
161
+ parser.add_argument("--subtalker-top-p", type=float, default=None, help="Subtalker top-p (optional, only for tokenizer v2).")
162
+ parser.add_argument(
163
+ "--subtalker-temperature", type=float, default=None, help="Subtalker temperature (optional, only for tokenizer v2)."
164
+ )
165
+
166
+ return parser
167
+
168
+
169
+ def _resolve_checkpoint(args: argparse.Namespace) -> str:
170
+ ckpt = args.checkpoint or args.checkpoint_pos
171
+ if not ckpt:
172
+ raise SystemExit(0) # main() prints help
173
+ return ckpt
174
+
175
+
176
+ def _collect_gen_kwargs(args: argparse.Namespace) -> Dict[str, Any]:
177
+ mapping = {
178
+ "max_new_tokens": args.max_new_tokens,
179
+ "temperature": args.temperature,
180
+ "top_k": args.top_k,
181
+ "top_p": args.top_p,
182
+ "repetition_penalty": args.repetition_penalty,
183
+ "subtalker_top_k": args.subtalker_top_k,
184
+ "subtalker_top_p": args.subtalker_top_p,
185
+ "subtalker_temperature": args.subtalker_temperature,
186
+ }
187
+ return {k: v for k, v in mapping.items() if v is not None}
188
+
189
+
190
+ def _normalize_audio(wav, eps=1e-12, clip=True):
191
+ x = np.asarray(wav)
192
+
193
+ if np.issubdtype(x.dtype, np.integer):
194
+ info = np.iinfo(x.dtype)
195
+
196
+ if info.min < 0:
197
+ y = x.astype(np.float32) / max(abs(info.min), info.max)
198
+ else:
199
+ mid = (info.max + 1) / 2.0
200
+ y = (x.astype(np.float32) - mid) / mid
201
+
202
+ elif np.issubdtype(x.dtype, np.floating):
203
+ y = x.astype(np.float32)
204
+ m = np.max(np.abs(y)) if y.size else 0.0
205
+
206
+ if m <= 1.0 + 1e-6:
207
+ pass
208
+ else:
209
+ y = y / (m + eps)
210
+ else:
211
+ raise TypeError(f"Unsupported dtype: {x.dtype}")
212
+
213
+ if clip:
214
+ y = np.clip(y, -1.0, 1.0)
215
+
216
+ if y.ndim > 1:
217
+ y = np.mean(y, axis=-1).astype(np.float32)
218
+
219
+ return y
220
+
221
+
222
+ def _audio_to_tuple(audio: Any) -> Optional[Tuple[np.ndarray, int]]:
223
+ if audio is None:
224
+ return None
225
+
226
+ if isinstance(audio, tuple) and len(audio) == 2 and isinstance(audio[0], int):
227
+ sr, wav = audio
228
+ wav = _normalize_audio(wav)
229
+ return wav, int(sr)
230
+
231
+ if isinstance(audio, dict) and "sampling_rate" in audio and "data" in audio:
232
+ sr = int(audio["sampling_rate"])
233
+ wav = _normalize_audio(audio["data"])
234
+ return wav, sr
235
+
236
+ return None
237
+
238
+
239
+ def _wav_to_gradio_audio(wav: np.ndarray, sr: int) -> Tuple[int, np.ndarray]:
240
+ wav = np.asarray(wav, dtype=np.float32)
241
+ return sr, wav
242
+
243
+
244
+ def _detect_model_kind(ckpt: str, tts: Qwen3TTSModel) -> str:
245
+ mt = getattr(tts.model, "tts_model_type", None)
246
+ if mt in ("custom_voice", "voice_design", "base"):
247
+ return mt
248
+ else:
249
+ raise ValueError(f"Unknown Qwen-TTS model type: {mt}")
250
+
251
+
252
+ def build_demo(tts: Qwen3TTSModel, ckpt: str, gen_kwargs_default: Dict[str, Any]) -> gr.Blocks:
253
+ model_kind = _detect_model_kind(ckpt, tts)
254
+
255
+ supported_langs_raw = None
256
+ if callable(getattr(tts.model, "get_supported_languages", None)):
257
+ supported_langs_raw = tts.model.get_supported_languages()
258
+
259
+ supported_spks_raw = None
260
+ if callable(getattr(tts.model, "get_supported_speakers", None)):
261
+ supported_spks_raw = tts.model.get_supported_speakers()
262
+
263
+ lang_choices_disp, lang_map = _build_choices_and_map([x for x in (supported_langs_raw or [])])
264
+ spk_choices_disp, spk_map = _build_choices_and_map([x for x in (supported_spks_raw or [])])
265
+
266
+ def _gen_common_kwargs() -> Dict[str, Any]:
267
+ return dict(gen_kwargs_default)
268
+
269
+ theme = gr.themes.Soft(
270
+ font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"],
271
+ )
272
+
273
+ css = ".gradio-container {max-width: none !important;}"
274
+
275
+ with gr.Blocks(theme=theme, css=css) as demo:
276
+ gr.Markdown(
277
+ f"""
278
+ # Qwen3 TTS Demo
279
+ **Checkpoint:** `{ckpt}`
280
+ **Model Type:** `{model_kind}`
281
+ """
282
+ )
283
+
284
+ if model_kind == "custom_voice":
285
+ with gr.Row():
286
+ with gr.Column(scale=2):
287
+ text_in = gr.Textbox(
288
+ label="Text (待合成文本)",
289
+ lines=4,
290
+ placeholder="Enter text to synthesize (输入要合成的文本).",
291
+ )
292
+ with gr.Row():
293
+ lang_in = gr.Dropdown(
294
+ label="Language (语种)",
295
+ choices=lang_choices_disp,
296
+ value="Auto",
297
+ interactive=True,
298
+ )
299
+ spk_in = gr.Dropdown(
300
+ label="Speaker (说话人)",
301
+ choices=spk_choices_disp,
302
+ value="Vivian",
303
+ interactive=True,
304
+ )
305
+ instruct_in = gr.Textbox(
306
+ label="Instruction (Optional) (控制指令,可不输入)",
307
+ lines=2,
308
+ placeholder="e.g. Say it in a very angry tone (例如:用特别伤心的语气说).",
309
+ )
310
+ btn = gr.Button("Generate (生成)", variant="primary")
311
+ with gr.Column(scale=3):
312
+ audio_out = gr.Audio(label="Output Audio (合成结果)", type="numpy")
313
+ err = gr.Textbox(label="Status (状态)", lines=2)
314
+
315
+ def run_instruct(text: str, lang_disp: str, spk_disp: str, instruct: str):
316
+ try:
317
+ if not text or not text.strip():
318
+ return None, "Text is required (必须填写文本)."
319
+ if not spk_disp:
320
+ return None, "Speaker is required (必须选择说话人)."
321
+ language = lang_map.get(lang_disp, "Auto")
322
+ speaker = spk_map.get(spk_disp, spk_disp)
323
+ kwargs = _gen_common_kwargs()
324
+ wavs, sr = tts.generate_custom_voice(
325
+ text=text.strip(),
326
+ language=language,
327
+ speaker=speaker,
328
+ instruct=(instruct or "").strip() or None,
329
+ **kwargs,
330
+ )
331
+ return _wav_to_gradio_audio(wavs[0], sr), "Finished. (生成完成)"
332
+ except Exception as e:
333
+ return None, f"{type(e).__name__}: {e}"
334
+
335
+ btn.click(run_instruct, inputs=[text_in, lang_in, spk_in, instruct_in], outputs=[audio_out, err])
336
+
337
+ elif model_kind == "voice_design":
338
+ with gr.Row():
339
+ with gr.Column(scale=2):
340
+ text_in = gr.Textbox(
341
+ label="Text (待合成文本)",
342
+ lines=4,
343
+ value="It's in the top drawer... wait, it's empty? No way, that's impossible! I'm sure I put it there!"
344
+ )
345
+ with gr.Row():
346
+ lang_in = gr.Dropdown(
347
+ label="Language (语种)",
348
+ choices=lang_choices_disp,
349
+ value="Auto",
350
+ interactive=True,
351
+ )
352
+ design_in = gr.Textbox(
353
+ label="Voice Design Instruction (音色描述)",
354
+ lines=3,
355
+ value="Speak in an incredulous tone, but with a hint of panic beginning to creep into your voice."
356
+ )
357
+ btn = gr.Button("Generate (生成)", variant="primary")
358
+ with gr.Column(scale=3):
359
+ audio_out = gr.Audio(label="Output Audio (合成结果)", type="numpy")
360
+ err = gr.Textbox(label="Status (状态)", lines=2)
361
+
362
+ def run_voice_design(text: str, lang_disp: str, design: str):
363
+ try:
364
+ if not text or not text.strip():
365
+ return None, "Text is required (必须填写文本)."
366
+ if not design or not design.strip():
367
+ return None, "Voice design instruction is required (必须填写音色描述)."
368
+ language = lang_map.get(lang_disp, "Auto")
369
+ kwargs = _gen_common_kwargs()
370
+ wavs, sr = tts.generate_voice_design(
371
+ text=text.strip(),
372
+ language=language,
373
+ instruct=design.strip(),
374
+ **kwargs,
375
+ )
376
+ return _wav_to_gradio_audio(wavs[0], sr), "Finished. (生成完成)"
377
+ except Exception as e:
378
+ return None, f"{type(e).__name__}: {e}"
379
+
380
+ btn.click(run_voice_design, inputs=[text_in, lang_in, design_in], outputs=[audio_out, err])
381
+
382
+ else: # voice_clone for base
383
+ with gr.Tabs():
384
+ with gr.Tab("Clone & Generate (克隆并合成)"):
385
+ with gr.Row():
386
+ with gr.Column(scale=2):
387
+ ref_audio = gr.Audio(
388
+ label="Reference Audio (参考音频)",
389
+ )
390
+ ref_text = gr.Textbox(
391
+ label="Reference Text (参考音频文本)",
392
+ lines=2,
393
+ placeholder="Required if not set use x-vector only (不勾选use x-vector only时必填).",
394
+ )
395
+ xvec_only = gr.Checkbox(
396
+ label="Use x-vector only (仅用说话人向量,效果有限,但不用传入参考音频文本)",
397
+ value=False,
398
+ )
399
+
400
+ with gr.Column(scale=2):
401
+ text_in = gr.Textbox(
402
+ label="Target Text (待合成文本)",
403
+ lines=4,
404
+ placeholder="Enter text to synthesize (输入要合成的文本).",
405
+ )
406
+ lang_in = gr.Dropdown(
407
+ label="Language (语种)",
408
+ choices=lang_choices_disp,
409
+ value="Auto",
410
+ interactive=True,
411
+ )
412
+ btn = gr.Button("Generate (生成)", variant="primary")
413
+
414
+ with gr.Column(scale=3):
415
+ audio_out = gr.Audio(label="Output Audio (合成结果)", type="numpy")
416
+ err = gr.Textbox(label="Status (状态)", lines=2)
417
+
418
+ def run_voice_clone(ref_aud, ref_txt: str, use_xvec: bool, text: str, lang_disp: str):
419
+ try:
420
+ if not text or not text.strip():
421
+ return None, "Target text is required (必须填写待合成文本)."
422
+ at = _audio_to_tuple(ref_aud)
423
+ if at is None:
424
+ return None, "Reference audio is required (必须上传参考音频)."
425
+ if (not use_xvec) and (not ref_txt or not ref_txt.strip()):
426
+ return None, (
427
+ "Reference text is required when use x-vector only is NOT enabled.\n"
428
+ "(未勾选 use x-vector only 时,必须提供参考音频文本;否则请勾选 use x-vector only,但效果会变差.)"
429
+ )
430
+ language = lang_map.get(lang_disp, "Auto")
431
+ kwargs = _gen_common_kwargs()
432
+ wavs, sr = tts.generate_voice_clone(
433
+ text=text.strip(),
434
+ language=language,
435
+ ref_audio=at,
436
+ ref_text=(ref_txt.strip() if ref_txt else None),
437
+ x_vector_only_mode=bool(use_xvec),
438
+ **kwargs,
439
+ )
440
+ return _wav_to_gradio_audio(wavs[0], sr), "Finished. (生成完成)"
441
+ except Exception as e:
442
+ return None, f"{type(e).__name__}: {e}"
443
+
444
+ btn.click(
445
+ run_voice_clone,
446
+ inputs=[ref_audio, ref_text, xvec_only, text_in, lang_in],
447
+ outputs=[audio_out, err],
448
+ )
449
+
450
+ with gr.Tab("Save / Load Voice (保存/加载克隆音色)"):
451
+ with gr.Row():
452
+ with gr.Column(scale=2):
453
+ gr.Markdown(
454
+ """
455
+ ### Save Voice (保存音色)
456
+ Upload reference audio and text, choose use x-vector only or not, then save a reusable voice prompt file.
457
+ (上传参考音频和参考文本,选择是否使用 use x-vector only 模式后保存为可复用的音色文件)
458
+ """
459
+ )
460
+ ref_audio_s = gr.Audio(label="Reference Audio (参考音频)", type="numpy")
461
+ ref_text_s = gr.Textbox(
462
+ label="Reference Text (参考音频文本)",
463
+ lines=2,
464
+ placeholder="Required if not set use x-vector only (不勾选use x-vector only时必填).",
465
+ )
466
+ xvec_only_s = gr.Checkbox(
467
+ label="Use x-vector only (仅用说话人向量,效果有限,但不用传入参考音频文本)",
468
+ value=False,
469
+ )
470
+ save_btn = gr.Button("Save Voice File (保存音色文件)", variant="primary")
471
+ prompt_file_out = gr.File(label="Voice File (音色文件)")
472
+
473
+ with gr.Column(scale=2):
474
+ gr.Markdown(
475
+ """
476
+ ### Load Voice & Generate (加载音色并合成)
477
+ Upload a previously saved voice file, then synthesize new text.
478
+ (上传已保存提示文件后,输入新文本进行合成)
479
+ """
480
+ )
481
+ prompt_file_in = gr.File(label="Upload Prompt File (上传提示文件)")
482
+ text_in2 = gr.Textbox(
483
+ label="Target Text (待合成文本)",
484
+ lines=4,
485
+ placeholder="Enter text to synthesize (输入要合成的文本).",
486
+ )
487
+ lang_in2 = gr.Dropdown(
488
+ label="Language (语种)",
489
+ choices=lang_choices_disp,
490
+ value="Auto",
491
+ interactive=True,
492
+ )
493
+ gen_btn2 = gr.Button("Generate (生成)", variant="primary")
494
+
495
+ with gr.Column(scale=3):
496
+ audio_out2 = gr.Audio(label="Output Audio (合成结果)", type="numpy")
497
+ err2 = gr.Textbox(label="Status (状态)", lines=2)
498
+
499
+ def save_prompt(ref_aud, ref_txt: str, use_xvec: bool):
500
+ try:
501
+ at = _audio_to_tuple(ref_aud)
502
+ if at is None:
503
+ return None, "Reference audio is required (必须上传参考音频)."
504
+ if (not use_xvec) and (not ref_txt or not ref_txt.strip()):
505
+ return None, (
506
+ "Reference text is required when use x-vector only is NOT enabled.\n"
507
+ "(未勾选 use x-vector only 时,必须提供参考音频文本;否则请勾选 use x-vector only,但效果会变差.)"
508
+ )
509
+ items = tts.create_voice_clone_prompt(
510
+ ref_audio=at,
511
+ ref_text=(ref_txt.strip() if ref_txt else None),
512
+ x_vector_only_mode=bool(use_xvec),
513
+ )
514
+ payload = {
515
+ "items": [asdict(it) for it in items],
516
+ }
517
+ fd, out_path = tempfile.mkstemp(prefix="voice_clone_prompt_", suffix=".pt")
518
+ os.close(fd)
519
+ torch.save(payload, out_path)
520
+ return out_path, "Finished. (生成完成)"
521
+ except Exception as e:
522
+ return None, f"{type(e).__name__}: {e}"
523
+
524
+ def load_prompt_and_gen(file_obj, text: str, lang_disp: str):
525
+ try:
526
+ if file_obj is None:
527
+ return None, "Voice file is required (必须上传音色文件)."
528
+ if not text or not text.strip():
529
+ return None, "Target text is required (必须填写待合成文本)."
530
+
531
+ path = getattr(file_obj, "name", None) or getattr(file_obj, "path", None) or str(file_obj)
532
+ payload = torch.load(path, map_location="cpu", weights_only=True)
533
+ if not isinstance(payload, dict) or "items" not in payload:
534
+ return None, "Invalid file format (文件格式不正确)."
535
+
536
+ items_raw = payload["items"]
537
+ if not isinstance(items_raw, list) or len(items_raw) == 0:
538
+ return None, "Empty voice items (音色为空)."
539
+
540
+ items: List[VoiceClonePromptItem] = []
541
+ for d in items_raw:
542
+ if not isinstance(d, dict):
543
+ return None, "Invalid item format in file (文件内部格式错误)."
544
+ ref_code = d.get("ref_code", None)
545
+ if ref_code is not None and not torch.is_tensor(ref_code):
546
+ ref_code = torch.tensor(ref_code)
547
+ ref_spk = d.get("ref_spk_embedding", None)
548
+ if ref_spk is None:
549
+ return None, "Missing ref_spk_embedding (缺少说话人向量)."
550
+ if not torch.is_tensor(ref_spk):
551
+ ref_spk = torch.tensor(ref_spk)
552
+
553
+ items.append(
554
+ VoiceClonePromptItem(
555
+ ref_code=ref_code,
556
+ ref_spk_embedding=ref_spk,
557
+ x_vector_only_mode=bool(d.get("x_vector_only_mode", False)),
558
+ icl_mode=bool(d.get("icl_mode", not bool(d.get("x_vector_only_mode", False)))),
559
+ ref_text=d.get("ref_text", None),
560
+ )
561
+ )
562
+
563
+ language = lang_map.get(lang_disp, "Auto")
564
+ kwargs = _gen_common_kwargs()
565
+ wavs, sr = tts.generate_voice_clone(
566
+ text=text.strip(),
567
+ language=language,
568
+ voice_clone_prompt=items,
569
+ **kwargs,
570
+ )
571
+ return _wav_to_gradio_audio(wavs[0], sr), "Finished. (生成完成)"
572
+ except Exception as e:
573
+ return None, (
574
+ f"Failed to read or use voice file. Check file format/content.\n"
575
+ f"(读取或使用音色文件失败,请检查文件格式或内容)\n"
576
+ f"{type(e).__name__}: {e}"
577
+ )
578
+
579
+ save_btn.click(save_prompt, inputs=[ref_audio_s, ref_text_s, xvec_only_s], outputs=[prompt_file_out, err2])
580
+ gen_btn2.click(load_prompt_and_gen, inputs=[prompt_file_in, text_in2, lang_in2], outputs=[audio_out2, err2])
581
+
582
+ gr.Markdown(
583
+ """
584
+ **Disclaimer (免责声明)**
585
+ - The audio is automatically generated/synthesized by an AI model solely to demonstrate the model’s capabilities; it may be inaccurate or inappropriate, does not represent the views of the developer/operator, and does not constitute professional advice. You are solely responsible for evaluating, using, distributing, or relying on this audio; to the maximum extent permitted by applicable law, the developer/operator disclaims liability for any direct, indirect, incidental, or consequential damages arising from the use of or inability to use the audio, except where liability cannot be excluded by law. Do not use this service to intentionally generate or replicate unlawful, harmful, defamatory, fraudulent, deepfake, or privacy/publicity/copyright/trademark‑infringing content; if a user prompts, supplies materials, or otherwise facilitates any illegal or infringing conduct, the user bears all legal consequences and the developer/operator is not responsible.
586
+ - 音频由人工智能模型自动生成/合成,仅用于体验与展示模型效果,可能存在不准确或不当之处;其内容不代表开发者/运营方立场,亦不构成任何专业建议。用户应自行评估并承担使用、传播或依赖该音频所产生的一切风险与责任;在适用法律允许的最大范围内,开发者/运营方不对因使用或无法使用本音频造成的任何直接、间接、附带或后果性损失承担责任(法律另有强制规定的除外)。严禁利用本服务故意引导生成或复制违法、有害、诽谤、欺诈、深度伪造、侵犯隐私/肖像/著作权/商标等内容;如用户通过提示词、素材或其他方式实施或促成任何违法或侵权行为,相关法律后果由用户自行承担,与开发者/运营方无关。
587
+ """
588
+ )
589
+
590
+ return demo
591
+
592
+
593
+ def main(argv=None) -> int:
594
+ parser = build_parser()
595
+ args = parser.parse_args(argv)
596
+
597
+ if not args.checkpoint and not args.checkpoint_pos:
598
+ parser.print_help()
599
+ return 0
600
+
601
+ ckpt = _resolve_checkpoint(args)
602
+
603
+ dtype = _dtype_from_str(args.dtype)
604
+ attn_impl = "flash_attention_2" if args.flash_attn else None
605
+
606
+ tts = Qwen3TTSModel.from_pretrained(
607
+ ckpt,
608
+ device_map=args.device,
609
+ dtype=dtype,
610
+ attn_implementation=attn_impl,
611
+ )
612
+
613
+ gen_kwargs_default = _collect_gen_kwargs(args)
614
+ demo = build_demo(tts, ckpt, gen_kwargs_default)
615
+
616
+ launch_kwargs: Dict[str, Any] = dict(
617
+ server_name=args.ip,
618
+ server_port=args.port,
619
+ share=args.share,
620
+ )
621
+ if args.ssl_certfile is not None:
622
+ launch_kwargs["ssl_certfile"] = args.ssl_certfile
623
+ if args.ssl_keyfile is not None:
624
+ launch_kwargs["ssl_keyfile"] = args.ssl_keyfile
625
+ if args.ssl_verify is not None:
626
+ launch_kwargs["ssl_verify"] = args.ssl_verify
627
+
628
+ demo.queue(default_concurrency_limit=int(args.concurrency)).launch(**launch_kwargs)
629
+ return 0
630
+
631
+
632
+ if __name__ == "__main__":
633
+ raise SystemExit(main())
qwen_tts/core/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from .tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Config
17
+ from .tokenizer_25hz.modeling_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Model
18
+ from .tokenizer_12hz.configuration_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Config
19
+ from .tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Model
qwen_tts/core/models/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from .configuration_qwen3_tts import Qwen3TTSConfig
17
+ from .modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
18
+ from .processing_qwen3_tts import Qwen3TTSProcessor
qwen_tts/core/models/configuration_qwen3_tts.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from transformers.configuration_utils import PretrainedConfig, layer_type_validation
16
+ from transformers.modeling_rope_utils import rope_config_validation
17
+ from transformers.utils import logging
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class Qwen3TTSSpeakerEncoderConfig(PretrainedConfig):
23
+ r"""
24
+ This is the configuration class to store the configuration of a [`Qwen3TTSSpeakerEncoder`].
25
+ It is used to instantiate a Qwen3TTS speaker encoder model according to the specified arguments, defining the model
26
+ architecture. The architecture is based on the ECAPA-TDNN model.
27
+
28
+ Args:
29
+ mel_dim (`int`, *optional*, defaults to 128):
30
+ The dimension of the input mel-spectrogram.
31
+ enc_dim (`int`, *optional*, defaults to 192):
32
+ The dimension of the final speaker embedding.
33
+ enc_channels (`list[int]`, *optional*, defaults to `[512, 512, 512, 512, 1536]`):
34
+ A list of output channels for each TDNN/SERes2Net layer in the encoder. The first channel size is for the initial TDNN layer,
35
+ the intermediate ones for the `SqueezeExcitationRes2NetBlock` layers, and the last one for the multi-layer feature aggregation.
36
+ enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`):
37
+ A list of kernel sizes for each layer in the encoder, corresponding to `enc_channels`.
38
+ enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`):
39
+ A list of dilations for each layer in the encoder, corresponding to `enc_channels`.
40
+ enc_attention_channels (`int`, *optional*, defaults to 128):
41
+ The number of attention channels in the `AttentiveStatisticsPooling` layer.
42
+ enc_res2net_scale (`int`, *optional*,defaults to 8):
43
+ The scale of the `Res2NetBlock` in the encoder.
44
+ enc_se_channels (`int`, *optional*, defaults to 128):
45
+ The number of channels in the squeeze part of the `SqueezeExcitationBlock`.
46
+ """
47
+ def __init__(
48
+ self,
49
+ mel_dim=128,
50
+ enc_dim=1024,
51
+ enc_channels=[512, 512, 512, 512, 1536],
52
+ enc_kernel_sizes=[5, 3, 3, 3, 1],
53
+ enc_dilations=[1, 2, 3, 4, 1],
54
+ enc_attention_channels=128,
55
+ enc_res2net_scale=8,
56
+ enc_se_channels=128,
57
+ sample_rate=24000,
58
+ ):
59
+ self.mel_dim = mel_dim
60
+ self.enc_dim = enc_dim
61
+ self.enc_channels = enc_channels
62
+ self.enc_kernel_sizes = enc_kernel_sizes
63
+ self.enc_dilations = enc_dilations
64
+ self.enc_attention_channels = enc_attention_channels
65
+ self.enc_res2net_scale = enc_res2net_scale
66
+ self.enc_se_channels = enc_se_channels
67
+ self.sample_rate = sample_rate
68
+
69
+
70
+ class Qwen3TTSTalkerCodePredictorConfig(PretrainedConfig):
71
+ r"""
72
+ This is the configuration class to store the configuration of a [`Qwen3TTSTalkerCodePredictorModel`]. It is used to instantiate a
73
+ Qwen3TTSTalkerCodePredictor model according to the specified arguments, defining the model architecture.
74
+
75
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
76
+ documentation from [`PretrainedConfig`] for more information.
77
+
78
+
79
+ Args:
80
+ vocab_size (`int`, *optional*, defaults to 151936):
81
+ Vocabulary size of the Qwen3TTSTalkerCodePredictor model. Defines the number of different tokens that can be represented by the
82
+ `inputs_ids` passed when calling [`Qwen3TTSTalkerCodePredictorModel`]
83
+ hidden_size (`int`, *optional*, defaults to 4096):
84
+ Dimension of the hidden representations.
85
+ intermediate_size (`int`, *optional*, defaults to 22016):
86
+ Dimension of the MLP representations.
87
+ num_hidden_layers (`int`, *optional*, defaults to 32):
88
+ Number of hidden layers in the Transformer encoder.
89
+ num_attention_heads (`int`, *optional*, defaults to 32):
90
+ Number of attention heads for each attention layer in the Transformer encoder.
91
+ num_key_value_heads (`int`, *optional*, defaults to 32):
92
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
93
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
94
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
95
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
96
+ by meanpooling all the original heads within that group. For more details, check out [this
97
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
98
+ head_dim (`int`, *optional*, defaults to 128):
99
+ The attention head dimension.
100
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
101
+ The non-linear activation function (function or string) in the decoder.
102
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
103
+ The maximum sequence length that this model might ever be used with.
104
+ initializer_range (`float`, *optional*, defaults to 0.02):
105
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
106
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
107
+ The epsilon used by the rms normalization layers.
108
+ use_cache (`bool`, *optional*, defaults to `True`):
109
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
110
+ relevant if `config.is_decoder=True`.
111
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
112
+ Whether the model's input and output word embeddings should be tied.
113
+ rope_theta (`float`, *optional*, defaults to 10000.0):
114
+ The base period of the RoPE embeddings.
115
+ rope_scaling (`Dict`, *optional*):
116
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
117
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
118
+ accordingly.
119
+ Expected contents:
120
+ `rope_type` (`str`):
121
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
122
+ 'llama3'], with 'default' being the original RoPE implementation.
123
+ `factor` (`float`, *optional*):
124
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
125
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
126
+ original maximum pre-trained length.
127
+ `original_max_position_embeddings` (`int`, *optional*):
128
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
129
+ pretraining.
130
+ `attention_factor` (`float`, *optional*):
131
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
132
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
133
+ `factor` field to infer the suggested value.
134
+ `beta_fast` (`float`, *optional*):
135
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
136
+ ramp function. If unspecified, it defaults to 32.
137
+ `beta_slow` (`float`, *optional*):
138
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
139
+ ramp function. If unspecified, it defaults to 1.
140
+ `short_factor` (`list[float]`, *optional*):
141
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
142
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
143
+ size divided by the number of attention heads divided by 2
144
+ `long_factor` (`list[float]`, *optional*):
145
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
146
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
147
+ size divided by the number of attention heads divided by 2
148
+ `low_freq_factor` (`float`, *optional*):
149
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
150
+ `high_freq_factor` (`float`, *optional*):
151
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
152
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
153
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
154
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
155
+ Whether to use sliding window attention.
156
+ sliding_window (`int`, *optional*, defaults to 4096):
157
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
158
+ max_window_layers (`int`, *optional*, defaults to 28):
159
+ The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
160
+ additional layer afterwards will use SWA (Sliding Window Attention).
161
+ layer_types (`list`, *optional*):
162
+ Attention pattern for each layer.
163
+ attention_dropout (`float`, *optional*, defaults to 0.0):
164
+ The dropout ratio for the attention probabilities.
165
+
166
+ """
167
+
168
+ model_type = "qwen3_tts_talker_code_predictor"
169
+ keys_to_ignore_at_inference = ["past_key_values"]
170
+
171
+ # Default tensor parallel plan for base model `Qwen3TTSTalkerCodePredictor`
172
+ base_model_tp_plan = {
173
+ "layers.*.self_attn.q_proj": "colwise",
174
+ "layers.*.self_attn.k_proj": "colwise",
175
+ "layers.*.self_attn.v_proj": "colwise",
176
+ "layers.*.self_attn.o_proj": "rowwise",
177
+ "layers.*.mlp.gate_proj": "colwise",
178
+ "layers.*.mlp.up_proj": "colwise",
179
+ "layers.*.mlp.down_proj": "rowwise",
180
+ }
181
+ base_model_pp_plan = {
182
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
183
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
184
+ "norm": (["hidden_states"], ["hidden_states"]),
185
+ }
186
+
187
+ def __init__(
188
+ self,
189
+ vocab_size=2048,
190
+ hidden_size=1024,
191
+ intermediate_size=3072,
192
+ num_hidden_layers=5,
193
+ num_attention_heads=16,
194
+ num_key_value_heads=8,
195
+ head_dim=128,
196
+ hidden_act="silu",
197
+ max_position_embeddings=32768,
198
+ initializer_range=0.02,
199
+ rms_norm_eps=0.000001,
200
+ use_cache=True,
201
+ tie_word_embeddings=False,
202
+ rope_theta=10000,
203
+ rope_scaling=None,
204
+ attention_bias=False,
205
+ use_sliding_window=False,
206
+ sliding_window=4096,
207
+ max_window_layers=28,
208
+ layer_types=None,
209
+ attention_dropout=0,
210
+ num_code_groups=32,
211
+ **kwargs,
212
+ ):
213
+ super().__init__(
214
+ tie_word_embeddings=tie_word_embeddings,
215
+ **kwargs,
216
+ )
217
+ self.vocab_size = vocab_size
218
+ self.max_position_embeddings = max_position_embeddings
219
+ self.hidden_size = hidden_size
220
+ self.intermediate_size = intermediate_size
221
+ self.num_hidden_layers = num_hidden_layers
222
+ self.num_attention_heads = num_attention_heads
223
+ self.use_sliding_window = use_sliding_window
224
+ self.sliding_window = sliding_window if self.use_sliding_window else None
225
+ self.max_window_layers = max_window_layers
226
+
227
+ # for backward compatibility
228
+ if num_key_value_heads is None:
229
+ num_key_value_heads = num_attention_heads
230
+
231
+ self.num_key_value_heads = num_key_value_heads
232
+ self.head_dim = head_dim
233
+ self.hidden_act = hidden_act
234
+ self.initializer_range = initializer_range
235
+ self.rms_norm_eps = rms_norm_eps
236
+ self.use_cache = use_cache
237
+ self.rope_theta = rope_theta
238
+ self.rope_scaling = rope_scaling
239
+ self.attention_bias = attention_bias
240
+ self.attention_dropout = attention_dropout
241
+ # Validate the correctness of rotary position embeddings parameters
242
+ # BC: if there is a 'type' field, move it to 'rope_type'.
243
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
244
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
245
+ rope_config_validation(self)
246
+
247
+ self.layer_types = layer_types
248
+ if self.layer_types is None:
249
+ self.layer_types = [
250
+ "sliding_attention"
251
+ if self.sliding_window is not None and i >= self.max_window_layers
252
+ else "full_attention"
253
+ for i in range(self.num_hidden_layers)
254
+ ]
255
+ layer_type_validation(self.layer_types)
256
+ self.num_code_groups = num_code_groups
257
+
258
+
259
+ class Qwen3TTSTalkerConfig(PretrainedConfig):
260
+ r"""
261
+ This is the configuration class to store the configuration of a [`Qwen3TTSTalkerModel`]. It is used to instantiate a
262
+ Qwen3TTSTalker model according to the specified arguments, defining the model architecture.
263
+
264
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
265
+ documentation from [`PretrainedConfig`] for more information.
266
+
267
+
268
+ Args:
269
+ vocab_size (`int`, *optional*, defaults to 151936):
270
+ Vocabulary size of the Qwen3TTSTalker model. Defines the number of different tokens that can be represented by the
271
+ `inputs_ids` passed when calling [`Qwen3TTSTalkerModel`]
272
+ hidden_size (`int`, *optional*, defaults to 2048):
273
+ Dimension of the hidden representations.
274
+ intermediate_size (`int`, *optional*, defaults to 6144):
275
+ Dimension of the MLP representations.
276
+ num_hidden_layers (`int`, *optional*, defaults to 24):
277
+ Number of hidden layers in the Transformer encoder.
278
+ num_attention_heads (`int`, *optional*, defaults to 32):
279
+ Number of attention heads for each attention layer in the Transformer encoder.
280
+ num_key_value_heads (`int`, *optional*, defaults to 4):
281
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
282
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
283
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
284
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
285
+ by meanpooling all the original heads within that group. For more details, check out [this
286
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
287
+
288
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
289
+ The non-linear activation function (function or string) in the decoder.
290
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
291
+ The maximum sequence length that this model might ever be used with.
292
+ initializer_range (`float`, *optional*, defaults to 0.02):
293
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
294
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
295
+ The epsilon used by the rms normalization layers.
296
+ use_cache (`bool`, *optional*, defaults to `True`):
297
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
298
+ relevant if `config.is_decoder=True`.
299
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
300
+ Whether the model's input and output word embeddings should be tied.
301
+ rope_theta (`float`, *optional*, defaults to 10000.0):
302
+ The base period of the RoPE embeddings.
303
+ rope_scaling (`Dict`, *optional*):
304
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
305
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
306
+ accordingly.
307
+ Expected contents:
308
+ `rope_type` (`str`):
309
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
310
+ 'llama3'], with 'default' being the original RoPE implementation.
311
+ `factor` (`float`, *optional*):
312
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
313
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
314
+ original maximum pre-trained length.
315
+ `original_max_position_embeddings` (`int`, *optional*):
316
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
317
+ pretraining.
318
+ `attention_factor` (`float`, *optional*):
319
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
320
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
321
+ `factor` field to infer the suggested value.
322
+ `beta_fast` (`float`, *optional*):
323
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
324
+ ramp function. If unspecified, it defaults to 32.
325
+ `beta_slow` (`float`, *optional*):
326
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
327
+ ramp function. If unspecified, it defaults to 1.
328
+ `short_factor` (`list[float]`, *optional*):
329
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
330
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
331
+ size divided by the number of attention heads divided by 2
332
+ `long_factor` (`list[float]`, *optional*):
333
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
334
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
335
+ size divided by the number of attention heads divided by 2
336
+ `low_freq_factor` (`float`, *optional*):
337
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
338
+ `high_freq_factor` (`float`, *optional*):
339
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
340
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
341
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
342
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
343
+ Whether to use sliding window attention.
344
+ sliding_window (`int`, *optional*, defaults to 4096):
345
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
346
+ attention_dropout (`float`, *optional*, defaults to 0.0):
347
+ The dropout ratio for the attention probabilities.
348
+ """
349
+
350
+ model_type = "qwen3_tts_talker"
351
+ keys_to_ignore_at_inference = ["past_key_values"]
352
+
353
+ # Default tensor parallel plan for base model `Qwen3TTSTalker`
354
+ base_model_tp_plan = {
355
+ "layers.*.self_attn.q_proj": "colwise",
356
+ "layers.*.self_attn.k_proj": "colwise",
357
+ "layers.*.self_attn.v_proj": "colwise",
358
+ "layers.*.self_attn.o_proj": "rowwise",
359
+ "layers.*.mlp.gate_proj": "colwise",
360
+ "layers.*.mlp.up_proj": "colwise",
361
+ "layers.*.mlp.down_proj": "rowwise",
362
+ }
363
+ base_model_pp_plan = {
364
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
365
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
366
+ "norm": (["hidden_states"], ["hidden_states"]),
367
+ }
368
+ sub_configs = {"code_predictor_config": Qwen3TTSTalkerCodePredictorConfig}
369
+
370
+ def __init__(
371
+ self,
372
+ code_predictor_config=None,
373
+ vocab_size=3072,
374
+ hidden_size=1024,
375
+ intermediate_size=2048,
376
+ num_hidden_layers=20,
377
+ num_attention_heads=16,
378
+ num_key_value_heads=2,
379
+ hidden_act="silu",
380
+ max_position_embeddings=32768,
381
+ initializer_range=0.02,
382
+ rms_norm_eps=0.000001,
383
+ use_cache=True,
384
+ tie_word_embeddings=False,
385
+ rope_theta=10000,
386
+ rope_scaling=None,
387
+ attention_bias=False,
388
+ use_sliding_window=False,
389
+ sliding_window=4096,
390
+ attention_dropout=0,
391
+ num_code_groups=32,
392
+ text_hidden_size=2048,
393
+ codec_eos_token_id=4198,
394
+ codec_think_id=4202,
395
+ codec_nothink_id=4203,
396
+ codec_think_bos_id=4204,
397
+ codec_think_eos_id=4205,
398
+ codec_pad_id=4196,
399
+ codec_bos_id=4197,
400
+ spk_id=None,
401
+ spk_is_dialect=None,
402
+ codec_language_id=None,
403
+ **kwargs,
404
+ ):
405
+ super().__init__(
406
+ tie_word_embeddings=tie_word_embeddings,
407
+ **kwargs,
408
+ )
409
+ self.vocab_size = vocab_size
410
+ self.max_position_embeddings = max_position_embeddings
411
+ self.hidden_size = hidden_size
412
+ self.intermediate_size = intermediate_size
413
+ self.num_hidden_layers = num_hidden_layers
414
+ self.num_attention_heads = num_attention_heads
415
+ self.use_sliding_window = use_sliding_window
416
+ self.sliding_window = sliding_window if use_sliding_window else None
417
+
418
+ self.num_key_value_heads = num_key_value_heads
419
+ self.hidden_act = hidden_act
420
+ self.initializer_range = initializer_range
421
+ self.rms_norm_eps = rms_norm_eps
422
+ self.use_cache = use_cache
423
+ self.rope_theta = rope_theta
424
+ self.rope_scaling = rope_scaling
425
+ self.attention_bias = attention_bias
426
+ self.attention_dropout = attention_dropout
427
+ # Validate the correctness of rotary position embeddings parameters
428
+ # BC: if there is a 'type' field, move it to 'rope_type'.
429
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
430
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
431
+
432
+ if code_predictor_config is None:
433
+ code_predictor_config = {}
434
+ self.code_predictor_config = Qwen3TTSTalkerCodePredictorConfig()
435
+ logger.info("code_predictor_config is None. Initializing code_predictor model with default values")
436
+ elif isinstance(code_predictor_config, Qwen3TTSTalkerCodePredictorConfig):
437
+ self.code_predictor_config = code_predictor_config
438
+ else:
439
+ self.code_predictor_config = Qwen3TTSTalkerCodePredictorConfig(**code_predictor_config)
440
+ self.num_code_groups = num_code_groups
441
+ self.text_hidden_size = text_hidden_size
442
+ self.codec_eos_token_id = codec_eos_token_id
443
+ self.codec_think_id = codec_think_id
444
+ self.codec_language_id = codec_language_id
445
+ self.codec_nothink_id = codec_nothink_id
446
+ self.codec_think_bos_id = codec_think_bos_id
447
+ self.codec_think_eos_id = codec_think_eos_id
448
+ self.codec_pad_id = codec_pad_id
449
+ self.codec_bos_id = codec_bos_id
450
+ self.spk_id = spk_id
451
+ self.spk_is_dialect = spk_is_dialect
452
+
453
+
454
+ class Qwen3TTSConfig(PretrainedConfig):
455
+ """
456
+ This is the configuration class to store the configuration of a [`Qwen3TTSForConditionalGeneration`].
457
+ """
458
+
459
+ model_type = "qwen3_tts"
460
+ sub_configs = {
461
+ "talker_config": Qwen3TTSTalkerConfig,
462
+ "speaker_encoder_config": Qwen3TTSSpeakerEncoderConfig,
463
+ }
464
+
465
+ def __init__(
466
+ self,
467
+ talker_config=None,
468
+ speaker_encoder_config=None,
469
+ tokenizer_type=None,
470
+ tts_model_size=None,
471
+ tts_model_type=None,
472
+ im_start_token_id=151644,
473
+ im_end_token_id=151645,
474
+ tts_pad_token_id=151671,
475
+ tts_bos_token_id=151672,
476
+ tts_eos_token_id=151673,
477
+ **kwargs,
478
+ ):
479
+ super().__init__(**kwargs)
480
+
481
+ if talker_config is None:
482
+ talker_config = {}
483
+ logger.info("talker_config is None. Initializing talker model with default values")
484
+ if speaker_encoder_config is None:
485
+ speaker_encoder_config = {}
486
+ logger.info("speaker_encoder_config is None. Initializing talker model with default values")
487
+
488
+ self.talker_config = Qwen3TTSTalkerConfig(**talker_config)
489
+ self.speaker_encoder_config = Qwen3TTSSpeakerEncoderConfig(**speaker_encoder_config)
490
+
491
+ self.tokenizer_type = tokenizer_type
492
+ self.tts_model_size = tts_model_size
493
+ self.tts_model_type = tts_model_type
494
+
495
+ self.im_start_token_id = im_start_token_id
496
+ self.im_end_token_id = im_end_token_id
497
+ self.tts_pad_token_id = tts_pad_token_id
498
+ self.tts_bos_token_id = tts_bos_token_id
499
+ self.tts_eos_token_id = tts_eos_token_id
500
+
501
+
502
+ __all__ = ["Qwen3TTSConfig", "Qwen3TTSTalkerConfig", "Qwen3TTSSpeakerEncoderConfig"]
qwen_tts/core/models/modeling_qwen3_tts.py ADDED
@@ -0,0 +1,2246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Qwen3TTS model."""
16
+
17
+ import json
18
+ import os
19
+ from dataclasses import dataclass
20
+ from typing import Callable, Optional
21
+
22
+ import torch
23
+ from librosa.filters import mel as librosa_mel_fn
24
+ from torch import nn
25
+ from torch.nn import functional as F
26
+ from transformers.activations import ACT2FN
27
+ from transformers.cache_utils import Cache, DynamicCache
28
+ from transformers.generation import GenerationMixin
29
+ from transformers.integrations import use_kernel_forward_from_hub
30
+ from transformers.masking_utils import (
31
+ create_causal_mask,
32
+ create_sliding_window_causal_mask,
33
+ )
34
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
+ from transformers.modeling_layers import GradientCheckpointingLayer
36
+ from transformers.modeling_outputs import (
37
+ BaseModelOutputWithPast,
38
+ CausalLMOutputWithPast,
39
+ ModelOutput,
40
+ )
41
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
42
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
43
+ from transformers.processing_utils import Unpack
44
+ from transformers.utils import can_return_tuple, logging
45
+ from transformers.utils.hub import cached_file
46
+
47
+ from ...inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
48
+ from .configuration_qwen3_tts import (
49
+ Qwen3TTSConfig,
50
+ Qwen3TTSSpeakerEncoderConfig,
51
+ Qwen3TTSTalkerCodePredictorConfig,
52
+ Qwen3TTSTalkerConfig,
53
+ )
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+
58
+ class Res2NetBlock(torch.nn.Module):
59
+ def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
60
+ super().__init__()
61
+
62
+ in_channel = in_channels // scale
63
+ hidden_channel = out_channels // scale
64
+
65
+ self.blocks = nn.ModuleList(
66
+ [
67
+ TimeDelayNetBlock(
68
+ in_channel,
69
+ hidden_channel,
70
+ kernel_size=kernel_size,
71
+ dilation=dilation,
72
+ )
73
+ for i in range(scale - 1)
74
+ ]
75
+ )
76
+ self.scale = scale
77
+
78
+ def forward(self, hidden_states):
79
+ outputs = []
80
+ for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)):
81
+ if i == 0:
82
+ output_part = hidden_part
83
+ elif i == 1:
84
+ output_part = self.blocks[i - 1](hidden_part)
85
+ else:
86
+ output_part = self.blocks[i - 1](hidden_part + output_part)
87
+ outputs.append(output_part)
88
+ output = torch.cat(outputs, dim=1)
89
+ return output
90
+
91
+
92
+ class SqueezeExcitationBlock(nn.Module):
93
+ def __init__(self, in_channels, se_channels, out_channels):
94
+ super().__init__()
95
+
96
+ self.conv1 = nn.Conv1d(
97
+ in_channels=in_channels,
98
+ out_channels=se_channels,
99
+ kernel_size=1,
100
+ padding="same",
101
+ padding_mode="reflect",
102
+ )
103
+ self.relu = nn.ReLU(inplace=True)
104
+ self.conv2 = nn.Conv1d(
105
+ in_channels=se_channels,
106
+ out_channels=out_channels,
107
+ kernel_size=1,
108
+ padding="same",
109
+ padding_mode="reflect",
110
+ )
111
+ self.sigmoid = nn.Sigmoid()
112
+
113
+ def forward(self, hidden_states):
114
+ hidden_states_mean = hidden_states.mean(dim=2, keepdim=True)
115
+
116
+ hidden_states_mean = self.relu(self.conv1(hidden_states_mean))
117
+ hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean))
118
+
119
+ return hidden_states * hidden_states_mean
120
+
121
+
122
+ class AttentiveStatisticsPooling(nn.Module):
123
+ """This class implements an attentive statistic pooling layer for each channel.
124
+ It returns the concatenated mean and std of the input tensor.
125
+ """
126
+
127
+ def __init__(self, channels, attention_channels=128):
128
+ super().__init__()
129
+
130
+ self.eps = 1e-12
131
+ self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1)
132
+ self.tanh = nn.Tanh()
133
+ self.conv = nn.Conv1d(
134
+ in_channels=attention_channels,
135
+ out_channels=channels,
136
+ kernel_size=1,
137
+ padding="same",
138
+ padding_mode="reflect",
139
+ )
140
+
141
+ def _length_to_mask(self, length, max_len=None, dtype=None, device=None):
142
+ """Creates a binary mask for each sequence.
143
+
144
+ Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
145
+
146
+ Arguments
147
+ ---------
148
+ length : torch.LongTensor
149
+ Containing the length of each sequence in the batch. Must be 1D.
150
+ max_len : int
151
+ Max length for the mask, also the size of the second dimension.
152
+ dtype : torch.dtype, default: None
153
+ The dtype of the generated mask.
154
+ device: torch.device, default: None
155
+ The device to put the mask variable.
156
+
157
+ Returns
158
+ -------
159
+ mask : tensor
160
+ The binary mask.
161
+ """
162
+
163
+ if max_len is None:
164
+ max_len = length.max().long().item() # using arange to generate mask
165
+ mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
166
+ len(length), max_len
167
+ ) < length.unsqueeze(1)
168
+
169
+ mask = torch.as_tensor(mask, dtype=dtype, device=device)
170
+ return mask
171
+
172
+ def _compute_statistics(self, x, m, dim=2):
173
+ mean = (m * x).sum(dim)
174
+ std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps))
175
+ return mean, std
176
+
177
+ def forward(self, hidden_states):
178
+ seq_length = hidden_states.shape[-1]
179
+ lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device)
180
+
181
+ # Make binary mask of shape [N, 1, L]
182
+ mask = self._length_to_mask(
183
+ lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device
184
+ )
185
+ mask = mask.unsqueeze(1)
186
+
187
+ # Expand the temporal context of the pooling layer by allowing the
188
+ # self-attention to look at global properties of the utterance.
189
+ total = mask.sum(dim=2, keepdim=True)
190
+
191
+ mean, std = self._compute_statistics(hidden_states, mask / total)
192
+ mean = mean.unsqueeze(2).repeat(1, 1, seq_length)
193
+ std = std.unsqueeze(2).repeat(1, 1, seq_length)
194
+ attention = torch.cat([hidden_states, mean, std], dim=1)
195
+
196
+ # Apply layers
197
+ attention = self.conv(self.tanh(self.tdnn(attention)))
198
+
199
+ # Filter out zero-paddings
200
+ attention = attention.masked_fill(mask == 0, float("-inf"))
201
+
202
+ attention = F.softmax(attention, dim=2)
203
+ mean, std = self._compute_statistics(hidden_states, attention)
204
+ # Append mean and std of the batch
205
+ pooled_stats = torch.cat((mean, std), dim=1)
206
+ pooled_stats = pooled_stats.unsqueeze(2)
207
+
208
+ return pooled_stats
209
+
210
+ class TimeDelayNetBlock(nn.Module):
211
+ def __init__(
212
+ self,
213
+ in_channels,
214
+ out_channels,
215
+ kernel_size,
216
+ dilation,
217
+ ):
218
+ super().__init__()
219
+ self.conv = nn.Conv1d(
220
+ in_channels=in_channels,
221
+ out_channels=out_channels,
222
+ kernel_size=kernel_size,
223
+ dilation=dilation,
224
+ padding="same",
225
+ padding_mode="reflect",
226
+ )
227
+ self.activation = nn.ReLU()
228
+
229
+ def forward(self, hidden_states: torch.Tensor):
230
+ return self.activation(self.conv(hidden_states))
231
+
232
+ class SqueezeExcitationRes2NetBlock(nn.Module):
233
+ """An implementation of building block in ECAPA-TDNN, i.e.,
234
+ TDNN-Res2Net-TDNN-SqueezeExcitationBlock.
235
+ """
236
+
237
+ def __init__(
238
+ self,
239
+ in_channels,
240
+ out_channels,
241
+ res2net_scale=8,
242
+ se_channels=128,
243
+ kernel_size=1,
244
+ dilation=1,
245
+ ):
246
+ super().__init__()
247
+ self.out_channels = out_channels
248
+ self.tdnn1 = TimeDelayNetBlock(
249
+ in_channels,
250
+ out_channels,
251
+ kernel_size=1,
252
+ dilation=1,
253
+ )
254
+ self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation)
255
+ self.tdnn2 = TimeDelayNetBlock(
256
+ out_channels,
257
+ out_channels,
258
+ kernel_size=1,
259
+ dilation=1,
260
+ )
261
+ self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels)
262
+
263
+ def forward(self, hidden_state):
264
+ residual = hidden_state
265
+
266
+ hidden_state = self.tdnn1(hidden_state)
267
+ hidden_state = self.res2net_block(hidden_state)
268
+ hidden_state = self.tdnn2(hidden_state)
269
+ hidden_state = self.se_block(hidden_state)
270
+
271
+ return hidden_state + residual
272
+
273
+
274
+ class Qwen3TTSSpeakerEncoder(torch.nn.Module):
275
+ """An implementation of the speaker embedding model in a paper.
276
+ "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
277
+ TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143).
278
+ Use for Qwen3TTS extract speaker embedding.
279
+ """
280
+
281
+ def __init__(self, config: Qwen3TTSSpeakerEncoderConfig):
282
+ super().__init__()
283
+ if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len(
284
+ config.enc_dilations
285
+ ):
286
+ raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length")
287
+ self.channels = config.enc_channels
288
+ self.blocks = nn.ModuleList()
289
+
290
+ # The initial TDNN layer
291
+ self.blocks.append(
292
+ TimeDelayNetBlock(
293
+ config.mel_dim,
294
+ config.enc_channels[0],
295
+ config.enc_kernel_sizes[0],
296
+ config.enc_dilations[0],
297
+ )
298
+ )
299
+
300
+ # SE-Res2Net layers
301
+ for i in range(1, len(config.enc_channels) - 1):
302
+ self.blocks.append(
303
+ SqueezeExcitationRes2NetBlock(
304
+ config.enc_channels[i - 1],
305
+ config.enc_channels[i],
306
+ res2net_scale=config.enc_res2net_scale,
307
+ se_channels=config.enc_se_channels,
308
+ kernel_size=config.enc_kernel_sizes[i],
309
+ dilation=config.enc_dilations[i],
310
+ )
311
+ )
312
+
313
+ # Multi-layer feature aggregation
314
+ self.mfa = TimeDelayNetBlock(
315
+ config.enc_channels[-1],
316
+ config.enc_channels[-1],
317
+ config.enc_kernel_sizes[-1],
318
+ config.enc_dilations[-1],
319
+ )
320
+
321
+ # Attentive Statistical Pooling
322
+ self.asp = AttentiveStatisticsPooling(
323
+ config.enc_channels[-1],
324
+ attention_channels=config.enc_attention_channels,
325
+ )
326
+
327
+ # Final linear transformation
328
+ self.fc = nn.Conv1d(
329
+ in_channels=config.enc_channels[-1] * 2,
330
+ out_channels=config.enc_dim,
331
+ kernel_size=1,
332
+ padding="same",
333
+ padding_mode="reflect",
334
+ )
335
+
336
+ def forward(self, hidden_states):
337
+ # Minimize transpose for efficiency
338
+ hidden_states = hidden_states.transpose(1, 2)
339
+
340
+ hidden_states_list = []
341
+ for layer in self.blocks:
342
+ hidden_states = layer(hidden_states)
343
+ hidden_states_list.append(hidden_states)
344
+
345
+ # Multi-layer feature aggregation
346
+ hidden_states = torch.cat(hidden_states_list[1:], dim=1)
347
+ hidden_states = self.mfa(hidden_states)
348
+
349
+ # Attentive Statistical Pooling
350
+ hidden_states = self.asp(hidden_states)
351
+
352
+ # Final linear transformation
353
+ hidden_states = self.fc(hidden_states)
354
+
355
+ hidden_states = hidden_states.squeeze(-1)
356
+ return hidden_states
357
+
358
+
359
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
360
+ return torch.log(torch.clamp(x, min=clip_val) * C)
361
+
362
+ def mel_spectrogram(
363
+ y: torch.Tensor,
364
+ n_fft: int,
365
+ num_mels: int,
366
+ sampling_rate: int,
367
+ hop_size: int,
368
+ win_size: int,
369
+ fmin: int,
370
+ fmax: int = None,
371
+ center: bool = False,
372
+ ) -> torch.Tensor:
373
+ """
374
+ Calculate the mel spectrogram of an input signal.
375
+ This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
376
+
377
+ Args:
378
+ y (torch.Tensor): Input signal.
379
+ n_fft (int): FFT size.
380
+ num_mels (int): Number of mel bins.
381
+ sampling_rate (int): Sampling rate of the input signal.
382
+ hop_size (int): Hop size for STFT.
383
+ win_size (int): Window size for STFT.
384
+ fmin (int): Minimum frequency for mel filterbank.
385
+ fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
386
+ center (bool): Whether to pad the input to center the frames. Default is False.
387
+
388
+ Returns:
389
+ torch.Tensor: Mel spectrogram.
390
+ """
391
+ if torch.min(y) < -1.0:
392
+ print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
393
+ if torch.max(y) > 1.0:
394
+ print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
395
+
396
+ device = y.device
397
+
398
+ mel = librosa_mel_fn(
399
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
400
+ )
401
+
402
+ mel_basis = torch.from_numpy(mel).float().to(device)
403
+ hann_window = torch.hann_window(win_size).to(device)
404
+
405
+ padding = (n_fft - hop_size) // 2
406
+ y = torch.nn.functional.pad(
407
+ y.unsqueeze(1), (padding, padding), mode="reflect"
408
+ ).squeeze(1)
409
+
410
+ spec = torch.stft(
411
+ y,
412
+ n_fft,
413
+ hop_length=hop_size,
414
+ win_length=win_size,
415
+ window=hann_window,
416
+ center=center,
417
+ pad_mode="reflect",
418
+ normalized=False,
419
+ onesided=True,
420
+ return_complex=True,
421
+ )
422
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
423
+
424
+ mel_spec = torch.matmul(mel_basis, spec)
425
+ mel_spec = dynamic_range_compression_torch(mel_spec)
426
+
427
+ return mel_spec
428
+
429
+
430
+ class Qwen3TTSPreTrainedModel(PreTrainedModel):
431
+ config_class = Qwen3TTSConfig
432
+ base_model_prefix = "model"
433
+ supports_gradient_checkpointing = True
434
+ _no_split_modules = ["Qwen3TTSDecoderLayer"]
435
+ _skip_keys_device_placement = "past_key_values"
436
+ _supports_flash_attn_2 = True
437
+ _supports_sdpa = True
438
+ _supports_cache_class = True
439
+ _supports_static_cache = False
440
+ _supports_attention_backend = True
441
+
442
+ def _init_weights(self, module):
443
+ # important: this ported version of Qwen2.5OmniThinker isn't meant for training from scratch - only
444
+ # inference and fine-tuning - so the proper init weights code has been removed
445
+ std = self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02
446
+
447
+ if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv3d, nn.ConvTranspose1d)):
448
+ module.weight.data.normal_(mean=0.0, std=std)
449
+ if module.bias is not None:
450
+ module.bias.data.zero_()
451
+ elif isinstance(module, nn.Embedding):
452
+ module.weight.data.normal_(mean=0.0, std=std)
453
+ if module.padding_idx is not None:
454
+ module.weight.data[module.padding_idx].zero_()
455
+ elif isinstance(module, nn.LayerNorm):
456
+ if module.weight is not None:
457
+ module.weight.data.fill_(1.0)
458
+ if module.bias is not None:
459
+ module.bias.data.zero_()
460
+
461
+
462
+ class Qwen3TTSTalkerTextPreTrainedModel(PreTrainedModel):
463
+ base_model_prefix = "model"
464
+ supports_gradient_checkpointing = True
465
+ _no_split_modules = []
466
+ _skip_keys_device_placement = ["past_key_values"]
467
+ _supports_flash_attn_3 = True
468
+ _supports_flash_attn_2 = True
469
+ _supports_sdpa = True
470
+ _supports_flex_attn = True
471
+ _supports_cache_class = True
472
+ _supports_quantized_cache = True
473
+ _supports_static_cache = False
474
+ _supports_attention_backend = True
475
+
476
+ def _init_weights(self, module):
477
+ std = self.config.initializer_range
478
+ if isinstance(module, nn.Linear):
479
+ module.weight.data.normal_(mean=0.0, std=std)
480
+ if module.bias is not None:
481
+ module.bias.data.zero_()
482
+ elif isinstance(module, nn.Embedding):
483
+ module.weight.data.normal_(mean=0.0, std=std)
484
+ if module.padding_idx is not None:
485
+ module.weight.data[module.padding_idx].zero_()
486
+ elif isinstance(module, Qwen3TTSRMSNorm):
487
+ module.weight.data.fill_(1.0)
488
+
489
+
490
+ class Qwen3TTSTalkerRotaryEmbedding(nn.Module):
491
+ def __init__(self, config: Qwen3TTSTalkerConfig, device=None):
492
+ super().__init__()
493
+ # BC: "rope_type" was originally "type"
494
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
495
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
496
+ else:
497
+ self.rope_type = "default"
498
+ self.max_seq_len_cached = config.max_position_embeddings
499
+ self.original_max_seq_len = config.max_position_embeddings
500
+
501
+ self.config = config
502
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
503
+
504
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
505
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
506
+ self.original_inv_freq = self.inv_freq
507
+
508
+ @torch.no_grad()
509
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
510
+ def forward(self, x, position_ids):
511
+ # In contrast to other models, Qwen3TTSThinkerText has different position ids for the grids
512
+ # So we expand the inv_freq to shape (3, ...)
513
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
514
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
515
+
516
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
517
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
518
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
519
+ emb = torch.cat((freqs, freqs), dim=-1)
520
+ cos = emb.cos() * self.attention_scaling
521
+ sin = emb.sin() * self.attention_scaling
522
+
523
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
524
+
525
+ class Qwen3TTSRotaryEmbedding(nn.Module):
526
+ def __init__(self, config: Qwen3TTSConfig, device=None):
527
+ super().__init__()
528
+ # BC: "rope_type" was originally "type"
529
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
530
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
531
+ else:
532
+ self.rope_type = "default"
533
+ self.max_seq_len_cached = config.max_position_embeddings
534
+ self.original_max_seq_len = config.max_position_embeddings
535
+
536
+ self.config = config
537
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
538
+
539
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
540
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
541
+ self.original_inv_freq = self.inv_freq
542
+
543
+ @torch.no_grad()
544
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
545
+ def forward(self, x, position_ids):
546
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
547
+ position_ids_expanded = position_ids[:, None, :].float()
548
+
549
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
550
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
551
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
552
+ emb = torch.cat((freqs, freqs), dim=-1)
553
+ cos = emb.cos() * self.attention_scaling
554
+ sin = emb.sin() * self.attention_scaling
555
+
556
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
557
+
558
+
559
+ @use_kernel_forward_from_hub("RMSNorm")
560
+ class Qwen3TTSRMSNorm(nn.Module):
561
+ def __init__(self, hidden_size, eps=1e-6):
562
+ """
563
+ Qwen3TTSRMSNorm is equivalent to T5LayerNorm
564
+ """
565
+ super().__init__()
566
+ self.weight = nn.Parameter(torch.ones(hidden_size))
567
+ self.variance_epsilon = eps
568
+
569
+ def forward(self, hidden_states):
570
+ input_dtype = hidden_states.dtype
571
+ hidden_states = hidden_states.to(torch.float32)
572
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
573
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
574
+ return self.weight * hidden_states.to(input_dtype)
575
+
576
+ def extra_repr(self):
577
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
578
+
579
+ def rotate_half(x):
580
+ """Rotates half the hidden dims of the input."""
581
+ x1 = x[..., : x.shape[-1] // 2]
582
+ x2 = x[..., x.shape[-1] // 2 :]
583
+ return torch.cat((-x2, x1), dim=-1)
584
+
585
+
586
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
587
+ """
588
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
589
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
590
+ """
591
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
592
+ if n_rep == 1:
593
+ return hidden_states
594
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
595
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
596
+
597
+
598
+ def eager_attention_forward(
599
+ module: nn.Module,
600
+ query: torch.Tensor,
601
+ key: torch.Tensor,
602
+ value: torch.Tensor,
603
+ attention_mask: Optional[torch.Tensor],
604
+ scaling: float,
605
+ dropout: float = 0.0,
606
+ **kwargs,
607
+ ):
608
+ key_states = repeat_kv(key, module.num_key_value_groups)
609
+ value_states = repeat_kv(value, module.num_key_value_groups)
610
+
611
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
612
+ if attention_mask is not None:
613
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
614
+ attn_weights = attn_weights + causal_mask
615
+
616
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
617
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
618
+ attn_output = torch.matmul(attn_weights, value_states)
619
+ attn_output = attn_output.transpose(1, 2).contiguous()
620
+
621
+ return attn_output, attn_weights
622
+
623
+
624
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, mrope_interleaved=False, unsqueeze_dim=1):
625
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
626
+
627
+ Explanation:
628
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
629
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
630
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
631
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
632
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
633
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
634
+ difference with modern LLMs.
635
+
636
+ Args:
637
+ q (`torch.Tensor`): The query tensor.
638
+ k (`torch.Tensor`): The key tensor.
639
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
640
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
641
+ position_ids (`torch.Tensor`):
642
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
643
+ used to pass offsetted position ids when working with a KV-cache.
644
+ mrope_section(`List(int)`):
645
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
646
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
647
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
648
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
649
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
650
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
651
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
652
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
653
+ Returns:
654
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
655
+ """
656
+ if mrope_interleaved:
657
+
658
+ def apply_interleaved_rope(x, modality_num):
659
+ x_t = x[0].clone()
660
+ index_ranges = []
661
+ for i, n in enumerate(mrope_section[1:], 1):
662
+ beg_idx = i
663
+ end_idx = n * modality_num
664
+ index_ranges.append((beg_idx, end_idx))
665
+ for beg_idx, end_idx in index_ranges:
666
+ x_t[..., beg_idx:end_idx:modality_num] = x[beg_idx, ..., beg_idx:end_idx:modality_num]
667
+ return x_t
668
+
669
+ dim = cos.shape[-1]
670
+ modality_num = len(mrope_section)
671
+ cos = torch.cat([apply_interleaved_rope(cos[..., : dim // 2], modality_num)] * 2, dim=-1).unsqueeze(
672
+ unsqueeze_dim
673
+ )
674
+ sin = torch.cat([apply_interleaved_rope(sin[..., : dim // 2], modality_num)] * 2, dim=-1).unsqueeze(
675
+ unsqueeze_dim
676
+ )
677
+ else:
678
+ mrope_section = mrope_section * 2
679
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
680
+ unsqueeze_dim
681
+ )
682
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
683
+ unsqueeze_dim
684
+ )
685
+
686
+ q_embed = (q * cos) + (rotate_half(q) * sin)
687
+ k_embed = (k * cos) + (rotate_half(k) * sin)
688
+ return q_embed, k_embed
689
+
690
+
691
+ class Qwen3TTSTalkerAttention(nn.Module):
692
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
693
+
694
+ def __init__(self, config, layer_idx):
695
+ super().__init__()
696
+ self.config = config
697
+ self.layer_idx = layer_idx
698
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
699
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
700
+ self.scaling = self.head_dim**-0.5
701
+ self.attention_dropout = config.attention_dropout
702
+ self.is_causal = True
703
+
704
+ self.q_proj = nn.Linear(
705
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
706
+ )
707
+ self.k_proj = nn.Linear(
708
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
709
+ )
710
+ self.v_proj = nn.Linear(
711
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
712
+ )
713
+ self.o_proj = nn.Linear(
714
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
715
+ )
716
+ self.q_norm = Qwen3TTSRMSNorm(
717
+ self.head_dim, eps=config.rms_norm_eps
718
+ ) # unlike olmo, only on the head dim!
719
+ self.k_norm = Qwen3TTSRMSNorm(
720
+ self.head_dim, eps=config.rms_norm_eps
721
+ ) # thus post q_norm does not need reshape
722
+ self.sliding_window = getattr(config, "sliding_window", None)
723
+ self.rope_scaling = config.rope_scaling
724
+
725
+ def forward(
726
+ self,
727
+ hidden_states: torch.Tensor,
728
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
729
+ attention_mask: Optional[torch.Tensor],
730
+ past_key_values: Optional[Cache] = None,
731
+ cache_position: Optional[torch.LongTensor] = None,
732
+ **kwargs: Unpack[FlashAttentionKwargs],
733
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
734
+ input_shape = hidden_states.shape[:-1]
735
+ hidden_shape = (*input_shape, -1, self.head_dim)
736
+
737
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
738
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
739
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
740
+
741
+ cos, sin = position_embeddings
742
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
743
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"], self.rope_scaling["interleaved"]
744
+ )
745
+
746
+ if past_key_values is not None:
747
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
748
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
749
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
750
+
751
+ attention_interface: Callable = eager_attention_forward
752
+ if self.config._attn_implementation != "eager":
753
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
754
+
755
+ attn_output, attn_weights = attention_interface(
756
+ self,
757
+ query_states,
758
+ key_states,
759
+ value_states,
760
+ attention_mask,
761
+ dropout=0.0 if not self.training else self.attention_dropout,
762
+ scaling=self.scaling,
763
+ sliding_window=self.sliding_window, # diff with Llama
764
+ **kwargs,
765
+ )
766
+
767
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
768
+ attn_output = self.o_proj(attn_output)
769
+ return attn_output, attn_weights
770
+
771
+
772
+ class Qwen3TTSTalkerResizeMLP(nn.Module):
773
+ def __init__(self, input_size: int, intermediate_size: int, output_size: int, act: str, bias=False):
774
+ super().__init__()
775
+ self.linear_fc1 = nn.Linear(input_size, intermediate_size, bias=bias)
776
+ self.linear_fc2 = nn.Linear(intermediate_size, output_size, bias=bias)
777
+ self.act_fn = ACT2FN[act]
778
+
779
+ def forward(self, hidden_state):
780
+ return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
781
+
782
+
783
+ @dataclass
784
+ class Qwen3TTSTalkerCodePredictorOutputWithPast(ModelOutput):
785
+ r"""
786
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
787
+ Language modeling loss (for next-token prediction).
788
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
789
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
790
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
791
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
792
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
793
+
794
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
795
+ `past_key_values` input) to speed up sequential decoding.
796
+ """
797
+
798
+ loss: Optional[torch.FloatTensor] = None
799
+ logits: torch.FloatTensor = None
800
+ past_key_values: Optional[list[torch.FloatTensor]] = None
801
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
802
+ attentions: Optional[tuple[torch.FloatTensor]] = None
803
+ generation_steps: Optional[int] = None
804
+
805
+
806
+ class Qwen3TTSTalkerTextMLP(nn.Module):
807
+ def __init__(self, config, intermediate_size=None):
808
+ super().__init__()
809
+ self.config = config
810
+ self.hidden_size = config.hidden_size
811
+ self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
812
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
813
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
814
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
815
+ self.act_fn = ACT2FN[config.hidden_act]
816
+
817
+ def forward(self, x):
818
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
819
+ return down_proj
820
+
821
+
822
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
823
+ """Applies Rotary Position Embedding to the query and key tensors.
824
+
825
+ Args:
826
+ q (`torch.Tensor`): The query tensor.
827
+ k (`torch.Tensor`): The key tensor.
828
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
829
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
830
+ position_ids (`torch.Tensor`, *optional*):
831
+ Deprecated and unused.
832
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
833
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
834
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
835
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
836
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
837
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
838
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
839
+ Returns:
840
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
841
+ """
842
+ cos = cos.unsqueeze(unsqueeze_dim)
843
+ sin = sin.unsqueeze(unsqueeze_dim)
844
+ q_embed = (q * cos) + (rotate_half(q) * sin)
845
+ k_embed = (k * cos) + (rotate_half(k) * sin)
846
+ return q_embed, k_embed
847
+
848
+
849
+ class Qwen3TTSAttention(nn.Module):
850
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
851
+
852
+ def __init__(self, config: Qwen3TTSConfig, layer_idx: int):
853
+ super().__init__()
854
+ self.config = config
855
+ self.layer_idx = layer_idx
856
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
857
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
858
+ self.scaling = self.head_dim**-0.5
859
+ self.attention_dropout = config.attention_dropout
860
+ self.is_causal = True
861
+
862
+ self.q_proj = nn.Linear(
863
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
864
+ )
865
+ self.k_proj = nn.Linear(
866
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
867
+ )
868
+ self.v_proj = nn.Linear(
869
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
870
+ )
871
+ self.o_proj = nn.Linear(
872
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
873
+ )
874
+ self.q_norm = Qwen3TTSRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
875
+ self.k_norm = Qwen3TTSRMSNorm(
876
+ self.head_dim, eps=config.rms_norm_eps
877
+ ) # thus post q_norm does not need reshape
878
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
879
+
880
+ def forward(
881
+ self,
882
+ hidden_states: torch.Tensor,
883
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
884
+ attention_mask: Optional[torch.Tensor],
885
+ past_key_values: Optional[Cache] = None,
886
+ cache_position: Optional[torch.LongTensor] = None,
887
+ **kwargs: Unpack[FlashAttentionKwargs],
888
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
889
+ input_shape = hidden_states.shape[:-1]
890
+ hidden_shape = (*input_shape, -1, self.head_dim)
891
+
892
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
893
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
894
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
895
+
896
+ cos, sin = position_embeddings
897
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
898
+
899
+ if past_key_values is not None:
900
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
901
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
902
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
903
+
904
+ attention_interface: Callable = eager_attention_forward
905
+ if self.config._attn_implementation != "eager":
906
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
907
+
908
+ attn_output, attn_weights = attention_interface(
909
+ self,
910
+ query_states,
911
+ key_states,
912
+ value_states,
913
+ attention_mask,
914
+ dropout=0.0 if not self.training else self.attention_dropout,
915
+ scaling=self.scaling,
916
+ sliding_window=self.sliding_window, # diff with Llama
917
+ **kwargs,
918
+ )
919
+
920
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
921
+ attn_output = self.o_proj(attn_output)
922
+ return attn_output, attn_weights
923
+
924
+
925
+ class Qwen3TTSDecoderLayer(GradientCheckpointingLayer):
926
+ def __init__(self, config: Qwen3TTSConfig, layer_idx: int):
927
+ super().__init__()
928
+ self.hidden_size = config.hidden_size
929
+
930
+ self.self_attn = Qwen3TTSAttention(config=config, layer_idx=layer_idx)
931
+
932
+ self.mlp = Qwen3TTSTalkerTextMLP(config)
933
+ self.input_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
934
+ self.post_attention_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
935
+ self.attention_type = config.layer_types[layer_idx]
936
+
937
+ def forward(
938
+ self,
939
+ hidden_states: torch.Tensor,
940
+ attention_mask: Optional[torch.Tensor] = None,
941
+ position_ids: Optional[torch.LongTensor] = None,
942
+ past_key_values: Optional[Cache] = None,
943
+ output_attentions: Optional[bool] = False,
944
+ use_cache: Optional[bool] = False,
945
+ cache_position: Optional[torch.LongTensor] = None,
946
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
947
+ **kwargs: Unpack[FlashAttentionKwargs],
948
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
949
+ residual = hidden_states
950
+ hidden_states = self.input_layernorm(hidden_states)
951
+
952
+ # Self Attention
953
+ hidden_states, self_attn_weights = self.self_attn(
954
+ hidden_states=hidden_states,
955
+ attention_mask=attention_mask,
956
+ position_ids=position_ids,
957
+ past_key_values=past_key_values,
958
+ output_attentions=output_attentions,
959
+ use_cache=use_cache,
960
+ cache_position=cache_position,
961
+ position_embeddings=position_embeddings,
962
+ **kwargs,
963
+ )
964
+ hidden_states = residual + hidden_states
965
+
966
+ # Fully Connected
967
+ residual = hidden_states
968
+ hidden_states = self.post_attention_layernorm(hidden_states)
969
+ hidden_states = self.mlp(hidden_states)
970
+ hidden_states = residual + hidden_states
971
+
972
+ outputs = (hidden_states,)
973
+ if output_attentions:
974
+ outputs += (self_attn_weights,)
975
+
976
+ return outputs
977
+
978
+
979
+ class Qwen3TTSTalkerCodePredictorModel(Qwen3TTSPreTrainedModel):
980
+ config_class = Qwen3TTSTalkerCodePredictorConfig
981
+ base_model_prefix = "talker.code_predictor.model"
982
+
983
+ def __init__(self, config: Qwen3TTSTalkerCodePredictorConfig, embedding_dim: int):
984
+ super().__init__(config)
985
+ self.padding_idx = config.pad_token_id
986
+ self.vocab_size = config.vocab_size
987
+ self.layers = nn.ModuleList(
988
+ [Qwen3TTSDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
989
+ )
990
+ self.norm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
991
+ self.rotary_emb = Qwen3TTSRotaryEmbedding(config=config)
992
+ self.gradient_checkpointing = False
993
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
994
+ self.codec_embedding = nn.ModuleList(
995
+ [nn.Embedding(config.vocab_size, embedding_dim) for _ in range(config.num_code_groups - 1)]
996
+ )
997
+
998
+ # Initialize weights and apply final processing
999
+ self.post_init()
1000
+
1001
+ def get_input_embeddings(self):
1002
+ return self.codec_embedding
1003
+
1004
+ def set_input_embeddings(self, value):
1005
+ self.embed_tokens = value
1006
+
1007
+ @can_return_tuple
1008
+ def forward(
1009
+ self,
1010
+ input_ids=None,
1011
+ attention_mask=None,
1012
+ position_ids=None,
1013
+ past_key_values=None,
1014
+ inputs_embeds=None,
1015
+ use_cache=None,
1016
+ output_attentions=None,
1017
+ output_hidden_states=None,
1018
+ cache_position=None,
1019
+ generation_steps=None,
1020
+ **flash_attn_kwargs,
1021
+ ) -> BaseModelOutputWithPast:
1022
+ if input_ids is not None:
1023
+ raise ValueError("`input_ids` is expected to be `None`")
1024
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1025
+ output_hidden_states = (
1026
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1027
+ )
1028
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1029
+
1030
+ if (input_ids is None) ^ (inputs_embeds is not None):
1031
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1032
+
1033
+ if self.gradient_checkpointing and self.training and use_cache:
1034
+ logger.warning_once(
1035
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1036
+ )
1037
+ use_cache = False
1038
+
1039
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
1040
+ if not isinstance(past_key_values, (type(None), Cache)):
1041
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
1042
+
1043
+ if inputs_embeds is None:
1044
+ inputs_embeds = self.embed_tokens(input_ids)
1045
+
1046
+ if use_cache and past_key_values is None:
1047
+ past_key_values = DynamicCache()
1048
+
1049
+ if cache_position is None:
1050
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1051
+ cache_position = torch.arange(
1052
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1053
+ )
1054
+
1055
+ if position_ids is None:
1056
+ position_ids = cache_position.unsqueeze(0)
1057
+
1058
+ # It may already have been prepared by e.g. `generate`
1059
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
1060
+ # Prepare mask arguments
1061
+ mask_kwargs = {
1062
+ "config": self.config,
1063
+ "input_embeds": inputs_embeds,
1064
+ "attention_mask": attention_mask,
1065
+ "cache_position": cache_position,
1066
+ "past_key_values": past_key_values,
1067
+ }
1068
+ # Create the masks
1069
+ causal_mask_mapping = {
1070
+ "full_attention": create_causal_mask(**mask_kwargs),
1071
+ }
1072
+ # The sliding window alternating layers are not always activated depending on the config
1073
+ if self.has_sliding_layers:
1074
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
1075
+
1076
+ hidden_states = inputs_embeds
1077
+
1078
+ # create position embeddings to be shared across the decoder layers
1079
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1080
+
1081
+ # decoder layers
1082
+ all_hidden_states = () if output_hidden_states else None
1083
+ all_self_attns = () if output_attentions else None
1084
+
1085
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
1086
+ if output_hidden_states:
1087
+ all_hidden_states += (hidden_states,)
1088
+
1089
+ layer_outputs = decoder_layer(
1090
+ hidden_states,
1091
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
1092
+ position_ids=position_ids,
1093
+ past_key_values=past_key_values,
1094
+ output_attentions=output_attentions,
1095
+ use_cache=use_cache,
1096
+ cache_position=cache_position,
1097
+ position_embeddings=position_embeddings,
1098
+ **flash_attn_kwargs,
1099
+ )
1100
+
1101
+ hidden_states = layer_outputs[0]
1102
+
1103
+ if output_attentions:
1104
+ all_self_attns += (layer_outputs[1],)
1105
+
1106
+ hidden_states = self.norm(hidden_states)
1107
+
1108
+ # add hidden states from the last decoder layer
1109
+ if output_hidden_states:
1110
+ all_hidden_states += (hidden_states,)
1111
+
1112
+ return BaseModelOutputWithPast(
1113
+ last_hidden_state=hidden_states,
1114
+ past_key_values=past_key_values if use_cache else None,
1115
+ hidden_states=all_hidden_states,
1116
+ attentions=all_self_attns,
1117
+ )
1118
+
1119
+
1120
+ class Qwen3TTSTalkerCodePredictorModelForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin):
1121
+ _tied_weights_keys = ["lm_head.weight"]
1122
+ _tp_plan = {"lm_head": "colwise_rep"}
1123
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
1124
+ config_class = Qwen3TTSTalkerCodePredictorConfig
1125
+ base_model_prefix = "talker.code_predictor"
1126
+
1127
+ def __init__(self, config: Qwen3TTSTalkerCodePredictorConfig, talker_config: Qwen3TTSTalkerConfig):
1128
+ super().__init__(config)
1129
+ self.model = Qwen3TTSTalkerCodePredictorModel(config, talker_config.hidden_size)
1130
+ self.vocab_size = config.vocab_size
1131
+ self.lm_head = nn.ModuleList(
1132
+ [nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_code_groups - 1)]
1133
+ )
1134
+
1135
+ if config.hidden_size != talker_config.hidden_size:
1136
+ self.small_to_mtp_projection = torch.nn.Linear(talker_config.hidden_size, config.hidden_size, bias=True)
1137
+ else:
1138
+ self.small_to_mtp_projection = torch.nn.Identity()
1139
+
1140
+ # Initialize weights and apply final processing
1141
+ self.post_init()
1142
+
1143
+ def get_input_embeddings(self):
1144
+ return self.model.get_input_embeddings()
1145
+
1146
+ def set_input_embeddings(self, value):
1147
+ self.model.embed_tokens = value
1148
+
1149
+ def get_output_embeddings(self):
1150
+ return self.lm_head
1151
+
1152
+ def set_output_embeddings(self, new_embeddings):
1153
+ self.lm_head = new_embeddings
1154
+
1155
+ def set_decoder(self, decoder):
1156
+ self.model = decoder
1157
+
1158
+ def get_decoder(self):
1159
+ return self.model
1160
+
1161
+ def forward_finetune(
1162
+ self,
1163
+ input_ids=None,
1164
+ attention_mask=None,
1165
+ position_ids=None,
1166
+ past_key_values=None,
1167
+ inputs_embeds=None,
1168
+ labels=None,
1169
+ use_cache=None,
1170
+ output_attentions=None,
1171
+ output_hidden_states=None,
1172
+ cache_position=None,
1173
+ generation_steps=None,
1174
+ **kwargs,
1175
+ ) -> CausalLMOutputWithPast:
1176
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1177
+ output_hidden_states = (
1178
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1179
+ )
1180
+
1181
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1182
+ outputs: BaseModelOutputWithPast = self.model(
1183
+ input_ids=None,
1184
+ attention_mask=attention_mask,
1185
+ position_ids=position_ids,
1186
+ past_key_values=past_key_values,
1187
+ inputs_embeds=inputs_embeds,
1188
+ use_cache=use_cache,
1189
+ output_attentions=output_attentions,
1190
+ output_hidden_states=output_hidden_states,
1191
+ cache_position=cache_position,
1192
+ **kwargs,
1193
+ )
1194
+
1195
+ hidden_states = outputs.last_hidden_state
1196
+
1197
+ logits = []
1198
+ for i in range(1, self.config.num_code_groups):
1199
+ logits.append(self.lm_head[i-1](hidden_states[:, i]))
1200
+ logits = torch.stack(logits, dim=1)
1201
+
1202
+ loss = None
1203
+ if labels is not None:
1204
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1205
+
1206
+ return Qwen3TTSTalkerCodePredictorOutputWithPast(
1207
+ loss=loss,
1208
+ logits=logits
1209
+ )
1210
+
1211
+ @can_return_tuple
1212
+ def forward(
1213
+ self,
1214
+ input_ids=None,
1215
+ attention_mask=None,
1216
+ position_ids=None,
1217
+ past_key_values=None,
1218
+ inputs_embeds=None,
1219
+ labels=None,
1220
+ use_cache=None,
1221
+ output_attentions=None,
1222
+ output_hidden_states=None,
1223
+ cache_position=None,
1224
+ generation_steps=None,
1225
+ **kwargs,
1226
+ ) -> CausalLMOutputWithPast:
1227
+ r"""
1228
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1229
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1230
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1231
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1232
+ """
1233
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1234
+ output_hidden_states = (
1235
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1236
+ )
1237
+
1238
+ # Prefill stage
1239
+ if inputs_embeds is not None and inputs_embeds.shape[1] > 1:
1240
+ generation_steps = inputs_embeds.shape[1] - 2 # hidden & layer 0
1241
+ # Generation stage
1242
+ else:
1243
+ inputs_embeds = self.model.get_input_embeddings()[generation_steps - 1](input_ids)
1244
+ inputs_embeds = self.small_to_mtp_projection(inputs_embeds)
1245
+
1246
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1247
+ outputs: BaseModelOutputWithPast = self.model(
1248
+ input_ids=None,
1249
+ attention_mask=attention_mask,
1250
+ position_ids=position_ids,
1251
+ past_key_values=past_key_values,
1252
+ inputs_embeds=inputs_embeds,
1253
+ use_cache=use_cache,
1254
+ output_attentions=output_attentions,
1255
+ output_hidden_states=output_hidden_states,
1256
+ cache_position=cache_position,
1257
+ **kwargs,
1258
+ )
1259
+
1260
+ hidden_states = outputs.last_hidden_state
1261
+ logits = self.lm_head[generation_steps](hidden_states)
1262
+
1263
+ loss = None
1264
+ if labels is not None:
1265
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1266
+
1267
+ return Qwen3TTSTalkerCodePredictorOutputWithPast(
1268
+ loss=loss,
1269
+ logits=logits,
1270
+ past_key_values=outputs.past_key_values,
1271
+ hidden_states=outputs.hidden_states,
1272
+ attentions=outputs.attentions,
1273
+ generation_steps=generation_steps + 1,
1274
+ )
1275
+
1276
+ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False, num_new_tokens=1):
1277
+ model_kwargs = super()._update_model_kwargs_for_generation(
1278
+ outputs, model_kwargs, is_encoder_decoder, num_new_tokens
1279
+ )
1280
+ model_kwargs["generation_steps"] = outputs.generation_steps
1281
+ return model_kwargs
1282
+
1283
+
1284
+ @dataclass
1285
+ class Qwen3TTSTalkerOutputWithPast(ModelOutput):
1286
+ r"""
1287
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1288
+ Language modeling loss (for next-token prediction).
1289
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1290
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1291
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1292
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
1293
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
1294
+
1295
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
1296
+ `past_key_values` input) to speed up sequential decoding.
1297
+ """
1298
+
1299
+ loss: Optional[torch.FloatTensor] = None
1300
+ logits: Optional[torch.FloatTensor] = None
1301
+ past_key_values: Optional[list[torch.FloatTensor]] = None
1302
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
1303
+ attentions: Optional[tuple[torch.FloatTensor]] = None
1304
+ past_hidden: Optional[torch.FloatTensor] = None
1305
+ generation_step: Optional[int] = None
1306
+ trailing_text_hidden: Optional[torch.FloatTensor] = None
1307
+ tts_pad_embed: Optional[torch.FloatTensor] = None
1308
+
1309
+
1310
+ class Qwen3TTSTalkerDecoderLayer(GradientCheckpointingLayer):
1311
+ def __init__(self, config, layer_idx):
1312
+ super().__init__()
1313
+ self.hidden_size = config.hidden_size
1314
+ self.self_attn = Qwen3TTSTalkerAttention(config, layer_idx)
1315
+
1316
+ self.mlp = Qwen3TTSTalkerTextMLP(config, intermediate_size=config.intermediate_size)
1317
+
1318
+ self.input_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1319
+ self.post_attention_layernorm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1320
+
1321
+ def forward(
1322
+ self,
1323
+ hidden_states: torch.Tensor,
1324
+ attention_mask: Optional[torch.Tensor] = None,
1325
+ position_ids: Optional[torch.LongTensor] = None,
1326
+ past_key_values: Optional[tuple[torch.Tensor]] = None,
1327
+ output_attentions: Optional[bool] = False,
1328
+ use_cache: Optional[bool] = False,
1329
+ cache_position: Optional[torch.LongTensor] = None,
1330
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
1331
+ **kwargs: Unpack[FlashAttentionKwargs],
1332
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
1333
+ """
1334
+ Args:
1335
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1336
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
1337
+ `(batch, sequence_length)` where padding elements are indicated by 0.
1338
+ output_attentions (`bool`, *optional*):
1339
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1340
+ returned tensors for more detail.
1341
+ use_cache (`bool`, *optional*):
1342
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1343
+ (see `past_key_values`).
1344
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
1345
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1346
+ Indices depicting the position of the input sequence tokens in the sequence.
1347
+ position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
1348
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
1349
+ with `head_dim` being the embedding dimension of each attention head.
1350
+ kwargs (`dict`, *optional*):
1351
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
1352
+ into the model
1353
+ """
1354
+
1355
+ residual = hidden_states
1356
+
1357
+ hidden_states = self.input_layernorm(hidden_states)
1358
+
1359
+ # Self Attention
1360
+ hidden_states, self_attn_weights = self.self_attn(
1361
+ hidden_states=hidden_states,
1362
+ attention_mask=attention_mask,
1363
+ position_ids=position_ids,
1364
+ past_key_values=past_key_values,
1365
+ output_attentions=output_attentions,
1366
+ use_cache=use_cache,
1367
+ cache_position=cache_position,
1368
+ position_embeddings=position_embeddings,
1369
+ **kwargs,
1370
+ )
1371
+ hidden_states = residual + hidden_states
1372
+
1373
+ # Fully Connected
1374
+ residual = hidden_states
1375
+ hidden_states = self.post_attention_layernorm(hidden_states)
1376
+
1377
+ hidden_states = self.mlp(hidden_states)
1378
+
1379
+ hidden_states = residual + hidden_states
1380
+
1381
+ outputs = (hidden_states,)
1382
+
1383
+ if output_attentions:
1384
+ outputs += (self_attn_weights,)
1385
+
1386
+ return outputs
1387
+
1388
+
1389
+ class Qwen3TTSTalkerModel(Qwen3TTSTalkerTextPreTrainedModel):
1390
+ config_class = Qwen3TTSTalkerConfig
1391
+ base_model_prefix = "talker.model"
1392
+
1393
+ def __init__(self, config):
1394
+ super().__init__(config)
1395
+ self.padding_idx = config.pad_token_id
1396
+ self.vocab_size = config.vocab_size
1397
+ self.layers = nn.ModuleList(
1398
+ [Qwen3TTSTalkerDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1399
+ )
1400
+ self.norm = Qwen3TTSRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1401
+ self.rotary_emb = Qwen3TTSTalkerRotaryEmbedding(config)
1402
+ self.gradient_checkpointing = False
1403
+ self.codec_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
1404
+ self.text_embedding = nn.Embedding(config.text_vocab_size, config.text_hidden_size)
1405
+
1406
+ # Initialize weights and apply final processing
1407
+ self.post_init()
1408
+
1409
+ def get_input_embeddings(self):
1410
+ return self.codec_embedding
1411
+
1412
+ def get_text_embeddings(self):
1413
+ return self.text_embedding
1414
+
1415
+ def set_input_embeddings(self, value):
1416
+ self.embed_tokens = value
1417
+
1418
+ @can_return_tuple
1419
+ def forward(
1420
+ self,
1421
+ input_ids: Optional[torch.LongTensor] = None,
1422
+ attention_mask: Optional[torch.Tensor] = None,
1423
+ position_ids: Optional[torch.LongTensor] = None,
1424
+ past_key_values: Optional[list[torch.FloatTensor]] = None,
1425
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1426
+ use_cache: Optional[bool] = None,
1427
+ output_attentions: Optional[bool] = None,
1428
+ output_hidden_states: Optional[bool] = None,
1429
+ cache_position: Optional[torch.LongTensor] = None,
1430
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
1431
+ ) -> BaseModelOutputWithPast:
1432
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1433
+ output_hidden_states = (
1434
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1435
+ )
1436
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1437
+
1438
+ if (input_ids is None) ^ (inputs_embeds is not None):
1439
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1440
+
1441
+ if self.gradient_checkpointing and self.training:
1442
+ if use_cache:
1443
+ logger.warning_once(
1444
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1445
+ )
1446
+ use_cache = False
1447
+
1448
+ if use_cache and past_key_values is None:
1449
+ past_key_values = DynamicCache()
1450
+
1451
+ if inputs_embeds is None:
1452
+ inputs_embeds = self.embed_tokens(input_ids)
1453
+
1454
+ if cache_position is None:
1455
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1456
+ cache_position = torch.arange(
1457
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
1458
+ )
1459
+
1460
+ # the hard coded `3` is for temporal, height and width.
1461
+ if position_ids is None:
1462
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
1463
+ elif position_ids.ndim == 2:
1464
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
1465
+
1466
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
1467
+ text_position_ids = position_ids[0]
1468
+ position_ids = position_ids[1:]
1469
+ else:
1470
+ text_position_ids = position_ids[0]
1471
+
1472
+ mask_function = create_causal_mask if self.config.sliding_window is None else create_sliding_window_causal_mask
1473
+ causal_mask = mask_function(
1474
+ config=self.config,
1475
+ input_embeds=inputs_embeds,
1476
+ attention_mask=attention_mask,
1477
+ cache_position=cache_position,
1478
+ past_key_values=past_key_values,
1479
+ position_ids=text_position_ids,
1480
+ )
1481
+
1482
+ hidden_states = inputs_embeds
1483
+
1484
+ # create position embeddings to be shared across the decoder layers
1485
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1486
+
1487
+ # decoder layers
1488
+ all_hidden_states = () if output_hidden_states else None
1489
+ all_self_attns = () if output_attentions else None
1490
+
1491
+ for decoder_layer in self.layers:
1492
+ if output_hidden_states:
1493
+ all_hidden_states += (hidden_states,)
1494
+
1495
+ layer_outputs = decoder_layer(
1496
+ hidden_states,
1497
+ attention_mask=causal_mask,
1498
+ position_ids=text_position_ids,
1499
+ past_key_values=past_key_values,
1500
+ output_attentions=output_attentions,
1501
+ use_cache=use_cache,
1502
+ cache_position=cache_position,
1503
+ position_embeddings=position_embeddings,
1504
+ **flash_attn_kwargs,
1505
+ )
1506
+
1507
+ hidden_states = layer_outputs[0]
1508
+
1509
+ if output_attentions:
1510
+ all_self_attns += (layer_outputs[1],)
1511
+
1512
+ hidden_states = self.norm(hidden_states)
1513
+
1514
+ # add hidden states from the last decoder layer
1515
+ if output_hidden_states:
1516
+ all_hidden_states += (hidden_states,)
1517
+
1518
+ return BaseModelOutputWithPast(
1519
+ last_hidden_state=hidden_states,
1520
+ past_key_values=past_key_values,
1521
+ hidden_states=all_hidden_states,
1522
+ attentions=all_self_attns,
1523
+ )
1524
+
1525
+
1526
+ class Qwen3TTSTalkerForConditionalGeneration(Qwen3TTSTalkerTextPreTrainedModel, GenerationMixin):
1527
+ _tied_weights_keys = ["lm_head.weight"]
1528
+ _tp_plan = {"lm_head": "colwise_rep"}
1529
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
1530
+ config_class = Qwen3TTSTalkerConfig
1531
+ base_model_prefix = "talker"
1532
+
1533
+ def __init__(self, config: Qwen3TTSTalkerConfig):
1534
+ super().__init__(config)
1535
+ self.model = Qwen3TTSTalkerModel(config)
1536
+ self.vocab_size = config.vocab_size
1537
+ self.text_projection = Qwen3TTSTalkerResizeMLP(
1538
+ config.text_hidden_size, config.text_hidden_size, config.hidden_size, config.hidden_act, bias=True
1539
+ )
1540
+
1541
+ self.codec_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1542
+ self.code_predictor = Qwen3TTSTalkerCodePredictorModelForConditionalGeneration(
1543
+ config=config.code_predictor_config,
1544
+ talker_config=config
1545
+ )
1546
+ self.rope_deltas = None
1547
+
1548
+ # Initialize weights and apply final processing
1549
+ self.post_init()
1550
+
1551
+ # TODO: hack, modular cannot inherit multiple classes
1552
+
1553
+ def get_input_embeddings(self):
1554
+ return self.model.get_input_embeddings()
1555
+
1556
+ def get_text_embeddings(self):
1557
+ return self.model.get_text_embeddings()
1558
+
1559
+ def set_input_embeddings(self, value):
1560
+ self.model.embed_tokens = value
1561
+
1562
+ def get_output_embeddings(self):
1563
+ return self.lm_head
1564
+
1565
+ def set_output_embeddings(self, new_embeddings):
1566
+ self.lm_head = new_embeddings
1567
+
1568
+ def set_decoder(self, decoder):
1569
+ self.model = decoder
1570
+
1571
+ def get_decoder(self):
1572
+ return self.model
1573
+
1574
+ def forward_sub_talker_finetune(self, codec_ids, talker_hidden_states):
1575
+ assert len(codec_ids.shape) == 2
1576
+ assert len(talker_hidden_states.shape) == 2
1577
+ assert codec_ids.shape[0] == talker_hidden_states.shape[0]
1578
+ assert talker_hidden_states.shape[1] == self.config.hidden_size
1579
+ assert codec_ids.shape[1] == self.config.num_code_groups
1580
+
1581
+ sub_talker_inputs_embeds = [talker_hidden_states.unsqueeze(1)]
1582
+
1583
+ for i in range(self.config.num_code_groups - 1):
1584
+ if i == 0:
1585
+ sub_talker_inputs_embeds.append(self.get_input_embeddings()(codec_ids[:, :1]))
1586
+ else:
1587
+ sub_talker_inputs_embeds.append(self.code_predictor.get_input_embeddings()[i-1](codec_ids[:, i:i+1]))
1588
+ sub_talker_inputs_embeds = torch.cat(sub_talker_inputs_embeds, dim=1)
1589
+
1590
+ sub_talker_outputs = self.code_predictor.forward_finetune(inputs_embeds=sub_talker_inputs_embeds,
1591
+ labels=codec_ids[:, 1:])
1592
+
1593
+ sub_talker_logits = sub_talker_outputs.logits
1594
+ sub_talker_loss = sub_talker_outputs.loss
1595
+ return sub_talker_logits, sub_talker_loss
1596
+
1597
+ @can_return_tuple
1598
+ def forward(
1599
+ self,
1600
+ input_ids=None,
1601
+ attention_mask=None,
1602
+ position_ids=None,
1603
+ past_key_values=None,
1604
+ inputs_embeds=None,
1605
+ labels=None,
1606
+ use_cache=None,
1607
+ output_attentions=None,
1608
+ output_hidden_states=None,
1609
+ cache_position=None,
1610
+ past_hidden=None,
1611
+ trailing_text_hidden=None,
1612
+ tts_pad_embed=None,
1613
+ generation_step=None,
1614
+ subtalker_dosample=None,
1615
+ subtalker_top_p=None,
1616
+ subtalker_top_k=None,
1617
+ subtalker_temperature=None,
1618
+ **kwargs,
1619
+ ) -> CausalLMOutputWithPast:
1620
+ r"""
1621
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1622
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1623
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1624
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1625
+ ```"""
1626
+ # Prefill
1627
+ if inputs_embeds is not None and inputs_embeds.shape[1] > 1:
1628
+ generation_step = -1
1629
+ codec_ids = None
1630
+ # Generate
1631
+ else:
1632
+ last_id_hidden = self.get_input_embeddings()(input_ids)
1633
+ predictor_result = self.code_predictor.generate(
1634
+ inputs_embeds=torch.cat((past_hidden, last_id_hidden), dim=1),
1635
+ max_new_tokens=self.config.num_code_groups - 1,
1636
+ do_sample=subtalker_dosample,
1637
+ top_p=subtalker_top_p,
1638
+ top_k=subtalker_top_k,
1639
+ temperature=subtalker_temperature,
1640
+ output_hidden_states=True,
1641
+ return_dict_in_generate=True,
1642
+ )
1643
+ codec_ids = torch.cat((input_ids, predictor_result.sequences), dim=-1)
1644
+ codec_hiddens = torch.cat(
1645
+ [last_id_hidden]
1646
+ + [self.code_predictor.get_input_embeddings()[i](predictor_result.sequences[..., i:i+1]) for i in range(self.config.num_code_groups - 1)],
1647
+ dim=1,
1648
+ )
1649
+ inputs_embeds = codec_hiddens.sum(1, keepdim=True)
1650
+
1651
+ if generation_step < trailing_text_hidden.shape[1]:
1652
+ inputs_embeds = inputs_embeds + trailing_text_hidden[:, generation_step].unsqueeze(1)
1653
+ else:
1654
+ inputs_embeds = inputs_embeds + tts_pad_embed
1655
+ if attention_mask is not None:
1656
+ if (
1657
+ cache_position is None
1658
+ or (cache_position is not None and cache_position[0] == 0)
1659
+ or self.rope_deltas is None
1660
+ ):
1661
+ delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
1662
+ position_ids, rope_deltas = self.get_rope_index(
1663
+ attention_mask,
1664
+ )
1665
+ rope_deltas = rope_deltas - delta0
1666
+ self.rope_deltas = rope_deltas
1667
+ else:
1668
+ batch_size, seq_length = input_ids.shape
1669
+ delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
1670
+ position_ids = torch.arange(seq_length, device=input_ids.device)
1671
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1672
+ position_ids = position_ids.add(delta)
1673
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
1674
+
1675
+ outputs: BaseModelOutputWithPast = self.model(
1676
+ input_ids=None,
1677
+ attention_mask=attention_mask,
1678
+ position_ids=position_ids,
1679
+ past_key_values=past_key_values,
1680
+ inputs_embeds=inputs_embeds,
1681
+ use_cache=use_cache,
1682
+ output_attentions=output_attentions,
1683
+ output_hidden_states=output_hidden_states,
1684
+ cache_position=cache_position,
1685
+ **kwargs,
1686
+ )
1687
+
1688
+ hidden_states = outputs.last_hidden_state
1689
+ logits = self.codec_head(hidden_states)
1690
+
1691
+ loss = None
1692
+ if labels is not None:
1693
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
1694
+
1695
+
1696
+ return Qwen3TTSTalkerOutputWithPast(
1697
+ loss=loss,
1698
+ logits=logits,
1699
+ past_key_values=outputs.past_key_values,
1700
+ hidden_states=(outputs.hidden_states, codec_ids),
1701
+ attentions=outputs.attentions,
1702
+ past_hidden=hidden_states[:, -1:, :],
1703
+ generation_step=generation_step + 1,
1704
+ trailing_text_hidden=trailing_text_hidden,
1705
+ tts_pad_embed=tts_pad_embed,
1706
+ )
1707
+
1708
+ def get_rope_index(
1709
+ self,
1710
+ attention_mask: Optional[torch.Tensor] = None,
1711
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1712
+ """
1713
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
1714
+
1715
+ Explanation:
1716
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
1717
+
1718
+ For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
1719
+ Examples:
1720
+ input_ids: [T T T T T], here T is for text.
1721
+ temporal position_ids: [0, 1, 2, 3, 4]
1722
+ height position_ids: [0, 1, 2, 3, 4]
1723
+ width position_ids: [0, 1, 2, 3, 4]
1724
+
1725
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
1726
+ and 1D rotary position embedding for text part.
1727
+ Examples:
1728
+ Temporal (Time): 3 patches, representing different segments of the video in time.
1729
+ Height: 2 patches, dividing each frame vertically.
1730
+ Width: 2 patches, dividing each frame horizontally.
1731
+ We also have some important parameters:
1732
+ fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
1733
+ interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
1734
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
1735
+ text temporal position_ids: [101, 102, 103, 104, 105]
1736
+ text height position_ids: [101, 102, 103, 104, 105]
1737
+ text width position_ids: [101, 102, 103, 104, 105]
1738
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
1739
+
1740
+ Args:
1741
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1742
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1743
+ it.
1744
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1745
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1746
+
1747
+ - 1 for tokens that are **not masked**,
1748
+ - 0 for tokens that are **masked**.
1749
+
1750
+ Returns:
1751
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
1752
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
1753
+ """
1754
+ mrope_position_deltas = []
1755
+
1756
+ position_ids = attention_mask.float().cumsum(-1) - 1
1757
+ position_ids.masked_fill_(attention_mask == 0, 1)
1758
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
1759
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
1760
+ mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True)
1761
+
1762
+ return position_ids, mrope_position_deltas
1763
+
1764
+ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder=False, num_new_tokens=1):
1765
+ model_kwargs = super()._update_model_kwargs_for_generation(
1766
+ outputs, model_kwargs, is_encoder_decoder, num_new_tokens
1767
+ )
1768
+ model_kwargs["past_hidden"] = outputs.past_hidden
1769
+ model_kwargs["generation_step"] = outputs.generation_step
1770
+ model_kwargs["trailing_text_hidden"] = outputs.trailing_text_hidden
1771
+ model_kwargs["tts_pad_embed"] = outputs.tts_pad_embed
1772
+ return model_kwargs
1773
+
1774
+
1775
+ class Qwen3TTSForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin):
1776
+ config_class = Qwen3TTSConfig
1777
+
1778
+ def __init__(self, config: Qwen3TTSConfig):
1779
+ super().__init__(config)
1780
+ self.config = config
1781
+
1782
+ self.talker = Qwen3TTSTalkerForConditionalGeneration(self.config.talker_config)
1783
+
1784
+ if config.tts_model_type == "base":
1785
+ self.speaker_encoder = Qwen3TTSSpeakerEncoder(self.config.speaker_encoder_config)
1786
+ else:
1787
+ self.speaker_encoder = None
1788
+
1789
+ self.speech_tokenizer = None
1790
+ self.generate_config = None
1791
+
1792
+ self.supported_speakers = self.config.talker_config.spk_id.keys()
1793
+ self.supported_languages = ["auto"]
1794
+ for language_id in self.config.talker_config.codec_language_id.keys():
1795
+ if "dialect" not in language_id:
1796
+ self.supported_languages.append(language_id)
1797
+
1798
+ self.speaker_encoder_sample_rate = self.config.speaker_encoder_config.sample_rate
1799
+ self.tokenizer_type = self.config.tokenizer_type
1800
+ self.tts_model_size = self.config.tts_model_size
1801
+ self.tts_model_type = self.config.tts_model_type
1802
+
1803
+ self.post_init()
1804
+
1805
+ def load_speech_tokenizer(self, speech_tokenizer):
1806
+ self.speech_tokenizer = speech_tokenizer
1807
+
1808
+ def load_generate_config(self, generate_config):
1809
+ self.generate_config = generate_config
1810
+
1811
+ def get_supported_speakers(self):
1812
+ return self.supported_speakers
1813
+
1814
+ def get_supported_languages(self):
1815
+ return self.supported_languages
1816
+
1817
+ @classmethod
1818
+ def from_pretrained(
1819
+ cls,
1820
+ pretrained_model_name_or_path,
1821
+ *model_args,
1822
+ config=None,
1823
+ cache_dir=None,
1824
+ ignore_mismatched_sizes=False,
1825
+ force_download=False,
1826
+ local_files_only=False,
1827
+ token=None,
1828
+ revision="main",
1829
+ use_safetensors=None,
1830
+ weights_only=True,
1831
+ **kwargs,
1832
+ ):
1833
+ model = super().from_pretrained(
1834
+ pretrained_model_name_or_path,
1835
+ *model_args,
1836
+ config=config,
1837
+ cache_dir=cache_dir,
1838
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
1839
+ force_download=force_download,
1840
+ local_files_only=local_files_only,
1841
+ token=token,
1842
+ revision=revision,
1843
+ use_safetensors=use_safetensors,
1844
+ weights_only=weights_only,
1845
+ **kwargs,
1846
+ )
1847
+ speech_tokenizer_path = cached_file(
1848
+ pretrained_model_name_or_path,
1849
+ "speech_tokenizer/config.json",
1850
+ subfolder=kwargs.pop("subfolder", None),
1851
+ cache_dir=kwargs.pop("cache_dir", None),
1852
+ force_download=kwargs.pop("force_download", False),
1853
+ proxies=kwargs.pop("proxies", None),
1854
+ resume_download=kwargs.pop("resume_download", None),
1855
+ local_files_only=kwargs.pop("local_files_only", False),
1856
+ token=kwargs.pop("use_auth_token", None),
1857
+ revision=kwargs.pop("revision", None),
1858
+ )
1859
+ if speech_tokenizer_path is None:
1860
+ raise ValueError(f"""{pretrained_model_name_or_path}/{speech_tokenizer_path} not exists""")
1861
+ speech_tokenizer_dir = os.path.dirname(speech_tokenizer_path)
1862
+ speech_tokenizer = Qwen3TTSTokenizer.from_pretrained(
1863
+ speech_tokenizer_dir,
1864
+ *model_args,
1865
+ **kwargs,
1866
+ )
1867
+ model.load_speech_tokenizer(speech_tokenizer)
1868
+
1869
+ generate_config_path = cached_file(
1870
+ pretrained_model_name_or_path,
1871
+ "generation_config.json",
1872
+ subfolder=kwargs.pop("subfolder", None),
1873
+ cache_dir=kwargs.pop("cache_dir", None),
1874
+ force_download=kwargs.pop("force_download", False),
1875
+ proxies=kwargs.pop("proxies", None),
1876
+ resume_download=kwargs.pop("resume_download", None),
1877
+ local_files_only=kwargs.pop("local_files_only", False),
1878
+ token=kwargs.pop("use_auth_token", None),
1879
+ revision=kwargs.pop("revision", None),
1880
+ )
1881
+ with open(generate_config_path, "r", encoding="utf-8") as f:
1882
+ generate_config = json.load(f)
1883
+ model.load_generate_config(generate_config)
1884
+
1885
+ return model
1886
+
1887
+ @torch.inference_mode()
1888
+ def extract_speaker_embedding(self, audio, sr):
1889
+ assert sr == 24000, "Only support 24kHz audio"
1890
+ mels = mel_spectrogram(
1891
+ torch.from_numpy(audio).unsqueeze(0),
1892
+ n_fft=1024,
1893
+ num_mels=128,
1894
+ sampling_rate=24000,
1895
+ hop_size=256,
1896
+ win_size=1024,
1897
+ fmin=0,
1898
+ fmax=12000
1899
+ ).transpose(1, 2)
1900
+ speaker_embedding = self.speaker_encoder(mels.to(self.device).to(self.dtype))[0]
1901
+ return speaker_embedding
1902
+
1903
+ @torch.inference_mode()
1904
+ def generate_speaker_prompt(
1905
+ self,
1906
+ voice_clone_prompt: list[dict]
1907
+ ):
1908
+ voice_clone_spk_embeds = []
1909
+ for index in range(len(voice_clone_prompt['ref_spk_embedding'])):
1910
+ ref_spk_embedding = voice_clone_prompt["ref_spk_embedding"][index].to(self.talker.device).to(self.talker.dtype)
1911
+ voice_clone_spk_embeds.append(ref_spk_embedding)
1912
+
1913
+ return voice_clone_spk_embeds
1914
+
1915
+ def generate_icl_prompt(
1916
+ self,
1917
+ text_id: torch.Tensor,
1918
+ ref_id: torch.Tensor,
1919
+ ref_code: torch.Tensor,
1920
+ tts_pad_embed: torch.Tensor,
1921
+ tts_eos_embed: torch.Tensor,
1922
+ non_streaming_mode: bool,
1923
+ ):
1924
+ # text embed (ref id + text id + eos) 1 T1 D
1925
+ text_embed = self.talker.text_projection(
1926
+ self.talker.get_text_embeddings()(torch.cat([ref_id, text_id],
1927
+ dim=-1)))
1928
+ text_embed = torch.cat([text_embed, tts_eos_embed], dim=1)
1929
+ # codec embed (codec bos + codec) 1 T2 D
1930
+ codec_embed = []
1931
+ for i in range(self.talker.config.num_code_groups):
1932
+ if i == 0:
1933
+ codec_embed.append(self.talker.get_input_embeddings()(ref_code[:, :1]))
1934
+ else:
1935
+ codec_embed.append(self.talker.code_predictor.get_input_embeddings()[i-1](ref_code[:, i:i+1]))
1936
+ codec_embed = torch.cat(codec_embed, dim=1).sum(1).unsqueeze(0)
1937
+ codec_embed = torch.cat([self.talker.get_input_embeddings()(
1938
+ torch.tensor(
1939
+ [[
1940
+ self.config.talker_config.codec_bos_id,
1941
+ ]],
1942
+ device=self.talker.device,
1943
+ dtype=text_id.dtype,
1944
+ )
1945
+ ), codec_embed], dim=1)
1946
+ # compute lens
1947
+ text_lens = text_embed.shape[1]
1948
+ codec_lens = codec_embed.shape[1]
1949
+ if non_streaming_mode:
1950
+ icl_input_embed = text_embed + self.talker.get_input_embeddings()(
1951
+ torch.tensor(
1952
+ [[
1953
+ self.config.talker_config.codec_pad_id,
1954
+ ] * text_lens],
1955
+ device=self.talker.device,
1956
+ dtype=text_id.dtype,
1957
+ )
1958
+ )
1959
+ icl_input_embed = torch.cat([icl_input_embed, codec_embed + tts_pad_embed], dim=1)
1960
+ return icl_input_embed, tts_pad_embed
1961
+ else:
1962
+ if text_lens > codec_lens:
1963
+ return text_embed[:, :codec_lens] + codec_embed, text_embed[:, codec_lens:]
1964
+ else:
1965
+ text_embed = torch.cat([text_embed] + [tts_pad_embed] * (codec_lens - text_lens), dim=1)
1966
+ return text_embed + codec_embed, tts_pad_embed
1967
+
1968
+ @torch.no_grad()
1969
+ def generate(
1970
+ self,
1971
+ input_ids: Optional[list[torch.Tensor]] = None,
1972
+ instruct_ids: Optional[list[torch.Tensor]] = None,
1973
+ ref_ids: Optional[list[torch.Tensor]] = None,
1974
+ voice_clone_prompt: list[dict] = None,
1975
+ languages: list[str] = None,
1976
+ speakers: list[str] = None,
1977
+ non_streaming_mode = False,
1978
+ max_new_tokens: int = 4096,
1979
+ do_sample: bool = True,
1980
+ top_k: int = 50,
1981
+ top_p: float = 1.0,
1982
+ temperature: float = 0.9,
1983
+ subtalker_dosample: bool = True,
1984
+ subtalker_top_k: int = 50,
1985
+ subtalker_top_p: float = 1.0,
1986
+ subtalker_temperature: float = 0.9,
1987
+ eos_token_id: Optional[int] = None,
1988
+ repetition_penalty: float = 1.05,
1989
+ **kwargs,
1990
+ ):
1991
+ talker_kwargs = {
1992
+ "max_new_tokens": max_new_tokens,
1993
+ "min_new_tokens": 2,
1994
+ "do_sample": do_sample,
1995
+ "top_k": top_k,
1996
+ "top_p": top_p,
1997
+ "temperature": temperature,
1998
+ "subtalker_dosample": subtalker_dosample,
1999
+ "subtalker_top_k": subtalker_top_k,
2000
+ "subtalker_top_p": subtalker_top_p,
2001
+ "subtalker_temperature": subtalker_temperature,
2002
+ "eos_token_id": eos_token_id
2003
+ if eos_token_id is not None
2004
+ else self.config.talker_config.codec_eos_token_id,
2005
+ "repetition_penalty": repetition_penalty,
2006
+ "suppress_tokens": [
2007
+ i
2008
+ for i in range(self.config.talker_config.vocab_size - 1024, self.config.talker_config.vocab_size)
2009
+ if i not in (self.config.talker_config.codec_eos_token_id,)
2010
+ ],
2011
+ "output_hidden_states": getattr(kwargs, "output_hidden_states", True),
2012
+ "return_dict_in_generate": getattr(kwargs, "return_dict_in_generate", True)
2013
+ }
2014
+
2015
+ talker_input_embeds = [[] for _ in range(len(input_ids))]
2016
+
2017
+ voice_clone_spk_embeds = None
2018
+ # voice clone speaker prompt generate
2019
+ if voice_clone_prompt is not None:
2020
+ voice_clone_spk_embeds = self.generate_speaker_prompt(voice_clone_prompt)
2021
+
2022
+ # instruct text prompt generate
2023
+ if instruct_ids is not None:
2024
+ for index, instruct_id in enumerate(instruct_ids):
2025
+ if instruct_id is not None:
2026
+ talker_input_embeds[index].append(self.talker.text_projection(
2027
+ self.talker.get_text_embeddings()(instruct_id)))
2028
+
2029
+ # tts text prompt generate
2030
+ trailing_text_hiddens = []
2031
+ if speakers is None:
2032
+ speakers = [None] * len(input_ids)
2033
+ for index, (input_id, language, speaker) in enumerate(zip(input_ids, languages, speakers)):
2034
+ if voice_clone_spk_embeds is None:
2035
+ if speaker == "" or speaker == None: # Instruct create speaker
2036
+ speaker_embed = None
2037
+ else:
2038
+ if speaker.lower() not in self.config.talker_config.spk_id:
2039
+ raise NotImplementedError(f"Speaker {speaker} not implemented")
2040
+ else:
2041
+ spk_id = self.config.talker_config.spk_id[speaker.lower()]
2042
+ speaker_embed = self.talker.get_input_embeddings()(
2043
+ torch.tensor(
2044
+ spk_id,
2045
+ device=self.talker.device,
2046
+ dtype=input_id.dtype,
2047
+ )
2048
+ )
2049
+ else:
2050
+ if voice_clone_prompt["x_vector_only_mode"][index] or voice_clone_prompt["icl_mode"][index]:
2051
+ speaker_embed = voice_clone_spk_embeds[index]
2052
+ else:
2053
+ speaker_embed = None
2054
+
2055
+ assert language is not None
2056
+
2057
+ if language.lower() == "auto":
2058
+ language_id = None
2059
+ else:
2060
+ if language.lower() not in self.config.talker_config.codec_language_id:
2061
+ raise NotImplementedError(f"Language {language} not implemented")
2062
+ else:
2063
+ language_id = self.config.talker_config.codec_language_id[language.lower()]
2064
+
2065
+ if (language.lower() in ["chinese", "auto"] and \
2066
+ speaker != "" and speaker is not None and \
2067
+ self.config.talker_config.spk_is_dialect[speaker.lower()] != False):
2068
+ dialect = self.config.talker_config.spk_is_dialect[speaker.lower()]
2069
+ language_id = self.config.talker_config.codec_language_id[dialect]
2070
+
2071
+ tts_bos_embed, tts_eos_embed, tts_pad_embed = self.talker.text_projection(
2072
+ self.talker.get_text_embeddings()(
2073
+ torch.tensor(
2074
+ [[self.config.tts_bos_token_id, self.config.tts_eos_token_id, self.config.tts_pad_token_id]],
2075
+ device=self.talker.device,
2076
+ dtype=input_id.dtype,
2077
+ )
2078
+ )
2079
+ ).chunk(3, dim=1) # 3 * [1 1 d]
2080
+
2081
+ # codec: tag and speaker
2082
+ if language_id is None:
2083
+ codec_prefill_list = [[
2084
+ self.config.talker_config.codec_nothink_id,
2085
+ self.config.talker_config.codec_think_bos_id,
2086
+ self.config.talker_config.codec_think_eos_id,
2087
+ ]]
2088
+ else:
2089
+ codec_prefill_list = [[
2090
+ self.config.talker_config.codec_think_id,
2091
+ self.config.talker_config.codec_think_bos_id,
2092
+ language_id,
2093
+ self.config.talker_config.codec_think_eos_id,
2094
+ ]]
2095
+
2096
+ codec_input_emebdding_0 = self.talker.get_input_embeddings()(
2097
+ torch.tensor(
2098
+ codec_prefill_list,
2099
+ device=self.talker.device,
2100
+ dtype=input_id.dtype,
2101
+ )
2102
+ )
2103
+ codec_input_emebdding_1 = self.talker.get_input_embeddings()(
2104
+ torch.tensor(
2105
+ [[
2106
+ self.config.talker_config.codec_pad_id,
2107
+ self.config.talker_config.codec_bos_id,
2108
+ ]],
2109
+ device=self.talker.device,
2110
+ dtype=input_id.dtype,
2111
+ )
2112
+ )
2113
+ if speaker_embed is None:
2114
+ codec_input_emebdding = torch.cat([codec_input_emebdding_0,
2115
+ codec_input_emebdding_1], dim=1)
2116
+ else:
2117
+ codec_input_emebdding = torch.cat([codec_input_emebdding_0,
2118
+ speaker_embed.view(1, 1, -1),
2119
+ codec_input_emebdding_1], dim=1)
2120
+
2121
+ # '<|im_start|>assistant\n我叫通义千问,是阿里云的开源大模型。<|im_end|>\n<|im_start|>assistant\n'
2122
+
2123
+ # <|im_start|>assistant\n
2124
+ _talker_input_embed_role = self.talker.text_projection(
2125
+ self.talker.get_text_embeddings()(input_id[:, :3])
2126
+ )
2127
+
2128
+ # tts_pad * 4 + tts_bos
2129
+ _talker_input_embed = torch.cat((tts_pad_embed.expand(-1, codec_input_emebdding.shape[1] - 2, -1),
2130
+ tts_bos_embed,
2131
+ ), dim=1) + codec_input_emebdding[:, :-1]
2132
+
2133
+ talker_input_embed = torch.cat((_talker_input_embed_role, _talker_input_embed), dim=1)
2134
+
2135
+ if voice_clone_prompt is not None and voice_clone_prompt["ref_code"] is not None and voice_clone_prompt["icl_mode"][index]:
2136
+ icl_input_embed, trailing_text_hidden = self.generate_icl_prompt(
2137
+ text_id=input_id[:, 3:-5],
2138
+ ref_id=ref_ids[index][:, 3:-2],
2139
+ ref_code=voice_clone_prompt["ref_code"][index].to(self.talker.device),
2140
+ tts_pad_embed=tts_pad_embed,
2141
+ tts_eos_embed=tts_eos_embed,
2142
+ non_streaming_mode=non_streaming_mode,
2143
+ )
2144
+ talker_input_embed = torch.cat([talker_input_embed, icl_input_embed], dim=1)
2145
+ else:
2146
+ # tts_text_first_token
2147
+ talker_input_embed = torch.cat([talker_input_embed,
2148
+ self.talker.text_projection(self.talker.get_text_embeddings()(input_id[:, 3:4])) + codec_input_emebdding[:, -1:]],
2149
+ dim=1)
2150
+ if non_streaming_mode:
2151
+ talker_input_embed = talker_input_embed[:, :-1] # 去掉原本放进去的text
2152
+ talker_input_embed = torch.cat([talker_input_embed,
2153
+ torch.cat((self.talker.text_projection(
2154
+ self.talker.get_text_embeddings()(input_id[:, 3:-5])
2155
+ ), tts_eos_embed), dim=1) + self.talker.get_input_embeddings()(
2156
+ torch.tensor(
2157
+ [[
2158
+ self.config.talker_config.codec_pad_id,
2159
+ ] * (input_id[:, 3:-5].shape[1] + 1)],
2160
+ device=self.talker.device,
2161
+ dtype=input_id.dtype,
2162
+ )
2163
+ ),
2164
+ tts_pad_embed + self.talker.get_input_embeddings()(
2165
+ torch.tensor(
2166
+ [[
2167
+ self.config.talker_config.codec_bos_id,
2168
+ ]],
2169
+ device=self.talker.device,
2170
+ dtype=input_id.dtype,
2171
+ )
2172
+ )
2173
+ ], dim=1)
2174
+ trailing_text_hidden = tts_pad_embed
2175
+ else:
2176
+ # 叫通义千问,是阿里云的开源大模型。
2177
+ trailing_text_hidden = torch.cat((self.talker.text_projection(
2178
+ self.talker.get_text_embeddings()(input_id[:, 4:-5])
2179
+ ), tts_eos_embed), dim=1)
2180
+ talker_input_embeds[index].append(talker_input_embed)
2181
+ trailing_text_hiddens.append(trailing_text_hidden)
2182
+
2183
+ for index, talker_input_embed in enumerate(talker_input_embeds):
2184
+ talker_input_embeds[index] = torch.cat([item for item in talker_input_embed if item is not None], dim=1)
2185
+
2186
+ # for batch inferquence
2187
+ original_lengths = torch.tensor([t.shape[1] for t in talker_input_embeds])
2188
+ # left padding for talker input embeds
2189
+ sequences = [t.squeeze(0) for t in talker_input_embeds]
2190
+ sequences_reversed = [t.flip(dims=[0]) for t in sequences]
2191
+ padded_reversed = torch.nn.utils.rnn.pad_sequence(
2192
+ sequences_reversed,
2193
+ batch_first=True,
2194
+ padding_value=0.0
2195
+ )
2196
+ talker_input_embeds = padded_reversed.flip(dims=[1])
2197
+ # generate mask
2198
+ batch_size, max_len = talker_input_embeds.shape[0], talker_input_embeds.shape[1]
2199
+ indices = torch.arange(max_len).expand(batch_size, -1)
2200
+ num_pads = max_len - original_lengths
2201
+ talker_attention_mask = (indices >= num_pads.unsqueeze(1)).long().to(talker_input_embeds.device)
2202
+ # padding trailing text hiddens
2203
+ pad_embedding_vector = tts_pad_embed.squeeze()
2204
+ sequences_to_pad = [t.squeeze(0) for t in trailing_text_hiddens]
2205
+ trailing_text_original_lengths = [s.shape[0] for s in sequences_to_pad]
2206
+ padded_hiddens = torch.nn.utils.rnn.pad_sequence(
2207
+ sequences_to_pad,
2208
+ batch_first=True,
2209
+ padding_value=0.0
2210
+ )
2211
+ arange_tensor = torch.arange(max(trailing_text_original_lengths),
2212
+ device=padded_hiddens.device).expand(len(trailing_text_original_lengths), -1)
2213
+ lengths_tensor = torch.tensor(trailing_text_original_lengths, device=padded_hiddens.device).unsqueeze(1)
2214
+ padding_mask = arange_tensor >= lengths_tensor
2215
+ padded_hiddens[padding_mask] = pad_embedding_vector
2216
+ trailing_text_hiddens = padded_hiddens
2217
+
2218
+ # forward
2219
+ talker_result = self.talker.generate(
2220
+ inputs_embeds=talker_input_embeds,
2221
+ attention_mask=talker_attention_mask,
2222
+ trailing_text_hidden=trailing_text_hiddens,
2223
+ tts_pad_embed=tts_pad_embed,
2224
+ **talker_kwargs,
2225
+ )
2226
+
2227
+ talker_codes = torch.stack([hid[-1] for hid in talker_result.hidden_states if hid[-1] is not None], dim=1)
2228
+ talker_hidden_states = torch.cat([hid[0][-1][:, -1:] for hid in talker_result.hidden_states], dim=1)[:, :-1]
2229
+
2230
+ first_codebook = talker_codes[:, :, 0]
2231
+ is_stop_token = (first_codebook == self.config.talker_config.codec_eos_token_id)
2232
+ stop_indices = torch.argmax(is_stop_token.int(), dim=1)
2233
+ has_stop_token = is_stop_token.any(dim=1)
2234
+ effective_lengths = torch.where(has_stop_token, stop_indices, talker_codes.shape[1])
2235
+
2236
+ talker_codes_list = [talker_codes[i, :length, ] for i, length in enumerate(effective_lengths)]
2237
+ talker_hidden_states_list = [talker_hidden_states[i, :length, :] for i, length in enumerate(effective_lengths)]
2238
+
2239
+ return talker_codes_list, talker_hidden_states_list
2240
+
2241
+ __all__ = [
2242
+ "Qwen3TTSForConditionalGeneration",
2243
+ "Qwen3TTSTalkerForConditionalGeneration",
2244
+ "Qwen3TTSPreTrainedModel",
2245
+ "Qwen3TTSTalkerModel",
2246
+ ]
qwen_tts/core/models/processing_qwen3_tts.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from transformers.feature_extraction_utils import BatchFeature
16
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
17
+
18
+
19
+ class Qwen3TTSProcessorKwargs(ProcessingKwargs, total=False):
20
+ _defaults = {
21
+ "text_kwargs": {
22
+ "padding": False,
23
+ "padding_side": "left",
24
+ }
25
+ }
26
+
27
+ class Qwen3TTSProcessor(ProcessorMixin):
28
+ r"""
29
+ Constructs a Qwen3TTS processor.
30
+
31
+ Args:
32
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
33
+ The text tokenizer.
34
+ chat_template (`Optional[str]`, *optional*):
35
+ The Jinja template to use for formatting the conversation. If not provided, the default chat template is used.
36
+ """
37
+
38
+ attributes = ["tokenizer"]
39
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
40
+
41
+ def __init__(
42
+ self, tokenizer=None, chat_template=None
43
+ ):
44
+ super().__init__(tokenizer, chat_template=chat_template)
45
+
46
+ def __call__(self, text=None, **kwargs) -> BatchFeature:
47
+ """
48
+ Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
49
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
50
+ the text.
51
+
52
+ Args:
53
+ text (`str`, `List[str]`, `List[List[str]]`):
54
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
55
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
56
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
57
+ """
58
+
59
+ if text is None:
60
+ raise ValueError("You need to specify either a `text` input to process.")
61
+
62
+ output_kwargs = self._merge_kwargs(
63
+ Qwen3TTSProcessorKwargs,
64
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
65
+ **kwargs,
66
+ )
67
+ if not isinstance(text, list):
68
+ text = [text]
69
+
70
+ texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
71
+
72
+ return BatchFeature(
73
+ data={**texts_inputs},
74
+ tensor_type=kwargs.get("return_tensors"),
75
+ )
76
+
77
+ def batch_decode(self, *args, **kwargs):
78
+ """
79
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
80
+ refer to the docstring of this method for more information.
81
+ """
82
+ return self.tokenizer.batch_decode(*args, **kwargs)
83
+
84
+ def decode(self, *args, **kwargs):
85
+ """
86
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
87
+ the docstring of this method for more information.
88
+ """
89
+ return self.tokenizer.decode(*args, **kwargs)
90
+
91
+ def apply_chat_template(self, conversations, chat_template=None, **kwargs):
92
+ if isinstance(conversations[0], dict):
93
+ conversations = [conversations]
94
+ return super().apply_chat_template(conversations, chat_template, **kwargs)
95
+
96
+ @property
97
+ def model_input_names(self):
98
+ tokenizer_input_names = self.tokenizer.model_input_names
99
+ return list(
100
+ dict.fromkeys(
101
+ tokenizer_input_names
102
+ )
103
+ )
104
+
105
+
106
+ __all__ = ["Qwen3TTSProcessor"]
qwen_tts/core/tokenizer_12hz/configuration_qwen3_tts_tokenizer_v2.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen3TTSTokenizerV2 model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+ from transformers import MimiConfig
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class Qwen3TTSTokenizerV2DecoderConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV2DecoderConfig`].
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+ Args:
34
+ codebook_size (`int`, *optional*, defaults to 2048):
35
+ Number of entries in each residual codebook used for acoustic token quantization.
36
+ hidden_size (`int`, *optional*, defaults to 1024):
37
+ Dimensionality of the hidden states and embeddings in the autoregressive transformer decoder.
38
+ max_position_embeddings (`int`, *optional*, defaults to 8000):
39
+ Maximum sequence length that the autoregressive decoder can handle. Determines positional embedding size.
40
+ rope_theta (`float`, *optional*, defaults to 10000.0):
41
+ The base period for rotary position embeddings (RoPE) applied to attention layers.
42
+ num_attention_heads (`int`, *optional*, defaults to 16):
43
+ Number of attention heads for each attention layer in the decoder.
44
+ num_key_value_heads (`int`, *optional*, defaults to 16):
45
+ Number of key and value attention heads used in grouped-query attention (if applicable).
46
+ attention_bias (`bool`, *optional*, defaults to `False`):
47
+ Whether to use bias in the attention projection layers.
48
+ sliding_window (`int`, *optional*, defaults to 72):
49
+ Window size for local attention mechanism, limiting attention context to improve efficiency.
50
+ intermediate_size (`int`, *optional*, defaults to 3072):
51
+ Dimensionality of the feed-forward (intermediate) layer in each transformer block.
52
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
53
+ The non-linear activation function used in the feed-forward layers. Supports `"silu"`, `"relu"`, `"gelu"`, etc.
54
+ layer_scale_initial_scale (`float`, *optional*, defaults to 0.01):
55
+ Initial value for LayerScale applied in transformer blocks, helping stabilize training.
56
+ rms_norm_eps (`float`, *optional*, defaults to 1e-5):
57
+ Epsilon value for RMS normalization layers to prevent division by zero.
58
+ num_hidden_layers (`int`, *optional*, defaults to 8):
59
+ Number of transformer blocks in the autoregressive decoder.
60
+ num_quantizers (`int`, *optional*, defaults to 16):
61
+ Number of residual vector quantizers used in the vocoder for fine-grained audio reconstruction.
62
+ upsample_rates (`Tuple[int]`, *optional*, defaults to `(8, 5, 4, 3)`):
63
+ Rate at which features are upsampled in the final waveform synthesis stage.
64
+ upsampling_ratios (`Tuple[int]`, *optional*, defaults to `(2, 2)`):
65
+ Ratios used in transposed convolutional layers to progressively upsample feature maps to waveform.
66
+ decoder_dim (`int`, *optional*, defaults to 1536):
67
+ Final dimensionality of the decoder's output before waveform generation.
68
+ attention_dropout (`float`, *optional*, defaults to 0.0):
69
+ Dropout probability applied to attention weights in the decoder.
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ codebook_size=2048,
75
+ hidden_size=1024,
76
+ latent_dim=1024,
77
+ max_position_embeddings=8000,
78
+ rope_theta=10000,
79
+ num_attention_heads=16,
80
+ num_key_value_heads=16,
81
+ attention_bias=False,
82
+ sliding_window=72,
83
+ intermediate_size=3072,
84
+ hidden_act="silu",
85
+ layer_scale_initial_scale=0.01,
86
+ rms_norm_eps=1e-5,
87
+ num_hidden_layers=8,
88
+ num_quantizers=16,
89
+ upsample_rates=(8, 5, 4, 3),
90
+ upsampling_ratios=(2, 2),
91
+ decoder_dim=1536,
92
+ attention_dropout=0.0,
93
+ **kwargs,
94
+ ):
95
+ super().__init__(**kwargs)
96
+ self.codebook_size = codebook_size
97
+ self.hidden_size = hidden_size
98
+ self.latent_dim = latent_dim
99
+ self.max_position_embeddings = max_position_embeddings
100
+ self.rope_theta = rope_theta
101
+ self.num_attention_heads = num_attention_heads
102
+ self.num_key_value_heads = num_key_value_heads
103
+ self.attention_bias = attention_bias
104
+ self.sliding_window = sliding_window
105
+ self.intermediate_size = intermediate_size
106
+ self.hidden_act = hidden_act
107
+ self.layer_scale_initial_scale = layer_scale_initial_scale
108
+ self.rms_norm_eps = rms_norm_eps
109
+ self.num_hidden_layers = num_hidden_layers
110
+ self.num_quantizers = num_quantizers
111
+ self.upsample_rates = upsample_rates
112
+ self.upsampling_ratios = upsampling_ratios
113
+ self.decoder_dim = decoder_dim
114
+ self.attention_dropout = attention_dropout
115
+
116
+ @property
117
+ def layer_types(self):
118
+ """
119
+ All layer in code2wav should be sliding attention
120
+ """
121
+ return ["sliding_attention"] * self.num_hidden_layers
122
+
123
+
124
+ class Qwen3TTSTokenizerV2Config(PretrainedConfig):
125
+ """
126
+ This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV2Config`]. It is used to instantiate a Qwen3TTSTokenizerV2Model
127
+ model according to the specified sub-models configurations, defining the model architecture.
128
+
129
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
130
+ documentation from [`PretrainedConfig`] for more information.
131
+
132
+ Args:
133
+ encoder_config (`dict`, *optional*): Configuration of the underlying encoder sub-model.
134
+ decoder_config (`dict`, *optional*): Configuration of the underlying decoder sub-model.
135
+ """
136
+
137
+ model_type = "qwen3_tts_tokenizer_12hz"
138
+ sub_configs = {
139
+ "encoder_config": MimiConfig,
140
+ "decoder_config": Qwen3TTSTokenizerV2DecoderConfig,
141
+ }
142
+
143
+ def __init__(
144
+ self,
145
+ encoder_config=None,
146
+ decoder_config=None,
147
+ encoder_valid_num_quantizers=16,
148
+ input_sample_rate=24000,
149
+ output_sample_rate=24000,
150
+ decode_upsample_rate=1920,
151
+ encode_downsample_rate=1920,
152
+ **kwargs,
153
+ ):
154
+ super().__init__(**kwargs)
155
+ if encoder_config is None:
156
+ encoder_config = {}
157
+ logger.info("encoder_config is None. Initializing encoder with default values")
158
+ if decoder_config is None:
159
+ decoder_config = {}
160
+ logger.info("decoder_config is None. Initializing decoder with default values")
161
+
162
+ self.encoder_config = MimiConfig(**encoder_config)
163
+ self.decoder_config = Qwen3TTSTokenizerV2DecoderConfig(**decoder_config)
164
+
165
+ self.encoder_valid_num_quantizers = encoder_valid_num_quantizers
166
+ self.input_sample_rate = input_sample_rate
167
+ self.output_sample_rate = output_sample_rate
168
+ self.decode_upsample_rate = decode_upsample_rate
169
+ self.encode_downsample_rate = encode_downsample_rate
170
+
171
+
172
+ __all__ = ["Qwen3TTSTokenizerV2Config", "Qwen3TTSTokenizerV2DecoderConfig"]
qwen_tts/core/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py ADDED
@@ -0,0 +1,1025 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Qwen3TTSTokenizerV2 model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Callable, Optional, Union, List
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torch import nn
24
+ from torch.nn import Parameter
25
+ from torch.nn import functional as F
26
+ from transformers import MimiConfig, MimiModel
27
+ from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache, DynamicCache
29
+ from transformers.integrations import use_kernel_forward_from_hub
30
+ from transformers.masking_utils import (
31
+ create_causal_mask,
32
+ create_sliding_window_causal_mask,
33
+ )
34
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
+ from transformers.modeling_layers import GradientCheckpointingLayer
36
+ from transformers.modeling_outputs import BaseModelOutputWithPast
37
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
+ from transformers.processing_utils import Unpack
40
+ from transformers.utils import ModelOutput, auto_docstring, logging
41
+ from transformers.utils.deprecation import deprecate_kwarg
42
+ from transformers.utils.generic import check_model_inputs
43
+
44
+ from .configuration_qwen3_tts_tokenizer_v2 import (
45
+ Qwen3TTSTokenizerV2Config,
46
+ Qwen3TTSTokenizerV2DecoderConfig,
47
+ )
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ @dataclass
53
+ @auto_docstring
54
+ class Qwen3TTSTokenizerV2EncoderOutput(ModelOutput):
55
+ r"""
56
+ audio_codes (`List[torch.LongTensor]`):
57
+ Discret code embeddings computed using `model.encode`, each tensor has shape (codes_length_i, num_quantizers).
58
+ """
59
+
60
+ audio_codes: List[torch.LongTensor] = None
61
+
62
+
63
+ @dataclass
64
+ @auto_docstring
65
+ class Qwen3TTSTokenizerV2DecoderOutput(ModelOutput):
66
+ r"""
67
+ audio_values (`List[torch.FloatTensor]`):
68
+ Decoded audio values, obtained using the decoder part of Qwen3TTSTokenizerV1.
69
+ Each tensor has shape (segment_length_i).
70
+ """
71
+
72
+ audio_values: List[torch.FloatTensor] = None
73
+
74
+
75
+ def rotate_half(x):
76
+ """Rotates half the hidden dims of the input."""
77
+ x1 = x[..., : x.shape[-1] // 2]
78
+ x2 = x[..., x.shape[-1] // 2 :]
79
+ return torch.cat((-x2, x1), dim=-1)
80
+
81
+
82
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
83
+ """Applies Rotary Position Embedding to the query and key tensors.
84
+
85
+ Args:
86
+ q (`torch.Tensor`): The query tensor.
87
+ k (`torch.Tensor`): The key tensor.
88
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
89
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
90
+ position_ids (`torch.Tensor`, *optional*):
91
+ Deprecated and unused.
92
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
93
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
94
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
95
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
96
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
97
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
98
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
99
+ Returns:
100
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
101
+ """
102
+ cos = cos.unsqueeze(unsqueeze_dim)
103
+ sin = sin.unsqueeze(unsqueeze_dim)
104
+ q_embed = (q * cos) + (rotate_half(q) * sin)
105
+ k_embed = (k * cos) + (rotate_half(k) * sin)
106
+ return q_embed, k_embed
107
+
108
+
109
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
110
+ """
111
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
112
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
113
+ """
114
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
115
+ if n_rep == 1:
116
+ return hidden_states
117
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
118
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
119
+
120
+
121
+ def eager_attention_forward(
122
+ module: nn.Module,
123
+ query: torch.Tensor,
124
+ key: torch.Tensor,
125
+ value: torch.Tensor,
126
+ attention_mask: Optional[torch.Tensor],
127
+ scaling: float,
128
+ dropout: float = 0.0,
129
+ **kwargs,
130
+ ):
131
+ key_states = repeat_kv(key, module.num_key_value_groups)
132
+ value_states = repeat_kv(value, module.num_key_value_groups)
133
+
134
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
135
+ if attention_mask is not None:
136
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
137
+ attn_weights = attn_weights + causal_mask
138
+
139
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
140
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
141
+ attn_output = torch.matmul(attn_weights, value_states)
142
+ attn_output = attn_output.transpose(1, 2).contiguous()
143
+
144
+ return attn_output, attn_weights
145
+
146
+
147
+ @auto_docstring
148
+ class Qwen3TTSTokenizerV2DecoderPreTrainedModel(PreTrainedModel):
149
+ config: Qwen3TTSTokenizerV2DecoderConfig
150
+ base_model_prefix = "model"
151
+ supports_gradient_checkpointing = True
152
+ _skip_keys_device_placement = "past_key_values"
153
+ _supports_flash_attn = True
154
+ _supports_sdpa = True
155
+ _can_compile_fullgraph = False
156
+ _supports_attention_backend = True
157
+
158
+
159
+ class Qwen3TTSTokenizerV2CausalConvNet(nn.Module):
160
+ def __init__(
161
+ self,
162
+ in_channels,
163
+ out_channels,
164
+ kernel_size,
165
+ dilation=1,
166
+ stride=1,
167
+ groups=1,
168
+ ):
169
+ super().__init__()
170
+ self.conv = nn.Conv1d(
171
+ in_channels,
172
+ out_channels,
173
+ kernel_size,
174
+ stride=stride,
175
+ dilation=dilation,
176
+ groups=groups,
177
+ )
178
+ self.stride = stride
179
+ self.kernel_size = (kernel_size - 1) * dilation + 1
180
+ self.dilation = dilation
181
+ self.padding = self.kernel_size - self.stride
182
+
183
+ def _get_extra_padding_for_conv1d(self, hidden_state: torch.Tensor) -> int:
184
+ length = hidden_state.shape[-1]
185
+ n_frames = (length - self.kernel_size + self.padding) / self.stride + 1
186
+ ideal_length = (math.ceil(n_frames) - 1) * self.stride + (self.kernel_size - self.padding)
187
+ return ideal_length - length
188
+
189
+ def forward(self, hidden_state):
190
+ extra_padding = self._get_extra_padding_for_conv1d(hidden_state)
191
+ hidden_state = F.pad(hidden_state, (self.padding, extra_padding), mode="constant", value=0)
192
+ return self.conv(hidden_state).contiguous()
193
+
194
+
195
+ class Qwen3TTSTokenizerV2CausalTransConvNet(nn.Module):
196
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1):
197
+ super().__init__()
198
+ self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=stride)
199
+
200
+ pad = kernel_size - stride
201
+ self.left_pad = math.ceil(pad)
202
+ self.right_pad = pad = self.left_pad
203
+
204
+ def forward(self, hidden_state):
205
+ hidden_state = self.conv(hidden_state)
206
+ hidden_state = hidden_state[..., self.left_pad : hidden_state.shape[-1] - self.right_pad]
207
+ return hidden_state.contiguous()
208
+
209
+
210
+ class Qwen3TTSTokenizerV2ConvNeXtBlock(nn.Module):
211
+ def __init__(self, dim: int):
212
+ super().__init__()
213
+ self.dwconv = Qwen3TTSTokenizerV2CausalConvNet(
214
+ dim,
215
+ dim,
216
+ kernel_size=7,
217
+ groups=dim,
218
+ dilation=1,
219
+ )
220
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
221
+ self.pwconv1 = nn.Linear(dim, 4 * dim)
222
+ self.act = nn.GELU()
223
+ self.pwconv2 = nn.Linear(4 * dim, dim)
224
+ self.gamma = nn.Parameter(1e-6 * torch.ones(dim))
225
+
226
+ def forward(self, hidden_states):
227
+ input = hidden_states
228
+
229
+ hidden_states = self.dwconv(hidden_states)
230
+ hidden_states = hidden_states.permute(0, 2, 1)
231
+ hidden_states = self.norm(hidden_states)
232
+ hidden_states = self.pwconv1(hidden_states)
233
+ hidden_states = self.act(hidden_states)
234
+ hidden_states = self.pwconv2(hidden_states)
235
+
236
+ hidden_states = self.gamma * hidden_states
237
+
238
+ hidden_states = hidden_states.permute(0, 2, 1)
239
+
240
+ hidden_states = input + hidden_states
241
+
242
+ return hidden_states
243
+
244
+
245
+ class Qwen3TTSTokenizerV2DecoderRotatoryEmbedding(nn.Module):
246
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
247
+
248
+ def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, device=None):
249
+ super().__init__()
250
+ # BC: "rope_type" was originally "type"
251
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
252
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
253
+ else:
254
+ self.rope_type = "default"
255
+ self.max_seq_len_cached = config.max_position_embeddings
256
+ self.original_max_seq_len = config.max_position_embeddings
257
+
258
+ self.config = config
259
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
260
+
261
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
262
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
263
+ self.original_inv_freq = self.inv_freq
264
+
265
+ @torch.no_grad()
266
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
267
+ def forward(self, x, position_ids):
268
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
269
+ position_ids_expanded = position_ids[:, None, :].float()
270
+
271
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
272
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
273
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
274
+ emb = torch.cat((freqs, freqs), dim=-1)
275
+ cos = emb.cos() * self.attention_scaling
276
+ sin = emb.sin() * self.attention_scaling
277
+
278
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
279
+
280
+
281
+ class Qwen3TTSTokenizerV2DecoderAttention(nn.Module):
282
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
283
+
284
+ def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, layer_idx):
285
+ super().__init__()
286
+ self.config = config
287
+ self.layer_idx = layer_idx
288
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
289
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
290
+ self.scaling = self.head_dim**-0.5
291
+ self.attention_dropout = config.attention_dropout
292
+ self.is_causal = True
293
+
294
+ self.q_proj = nn.Linear(
295
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
296
+ )
297
+ self.k_proj = nn.Linear(
298
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
299
+ )
300
+ self.v_proj = nn.Linear(
301
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
302
+ )
303
+ self.o_proj = nn.Linear(
304
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
305
+ )
306
+ self.q_norm = nn.Identity()
307
+ self.k_norm = nn.Identity()
308
+ self.sliding_window = config.sliding_window
309
+
310
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
311
+ def forward(
312
+ self,
313
+ hidden_states: torch.Tensor,
314
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
315
+ attention_mask: Optional[torch.Tensor],
316
+ past_key_values: Optional[Cache] = None,
317
+ cache_position: Optional[torch.LongTensor] = None,
318
+ **kwargs: Unpack[FlashAttentionKwargs],
319
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
320
+ input_shape = hidden_states.shape[:-1]
321
+ hidden_shape = (*input_shape, -1, self.head_dim)
322
+
323
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
324
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
325
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
326
+
327
+ cos, sin = position_embeddings
328
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
329
+
330
+ if past_key_values is not None:
331
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
332
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
333
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
334
+
335
+ attention_interface: Callable = eager_attention_forward
336
+ if self.config._attn_implementation != "eager":
337
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
338
+
339
+ attn_output, attn_weights = attention_interface(
340
+ self,
341
+ query_states,
342
+ key_states,
343
+ value_states,
344
+ attention_mask,
345
+ dropout=0.0 if not self.training else self.attention_dropout,
346
+ scaling=self.scaling,
347
+ sliding_window=self.sliding_window, # diff with Llama
348
+ **kwargs,
349
+ )
350
+
351
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
352
+ attn_output = self.o_proj(attn_output)
353
+ return attn_output, attn_weights
354
+
355
+
356
+ class Qwen3TTSTokenizerV2DecoderMlp(nn.Module):
357
+ def __init__(self, config):
358
+ super().__init__()
359
+ self.config = config
360
+ self.hidden_size = config.hidden_size
361
+ self.intermediate_size = config.intermediate_size
362
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
363
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
364
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
365
+ self.act_fn = ACT2FN[config.hidden_act]
366
+
367
+ def forward(self, x):
368
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
369
+ return down_proj
370
+
371
+
372
+ @use_kernel_forward_from_hub("RMSNorm")
373
+ class Qwen3TTSTokenizerV2DecoderRMSNorm(nn.Module):
374
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
375
+ """
376
+ Qwen3TTSTokenizerV2DecoderRMSNorm is equivalent to T5LayerNorm
377
+ """
378
+ super().__init__()
379
+ self.weight = nn.Parameter(torch.ones(hidden_size))
380
+ self.variance_epsilon = eps
381
+
382
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
383
+ input_dtype = hidden_states.dtype
384
+ hidden_states = hidden_states.to(torch.float32)
385
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
386
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
387
+ return self.weight * hidden_states.to(input_dtype)
388
+
389
+ def extra_repr(self):
390
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
391
+
392
+
393
+ class Qwen3TTSTokenizerV2DecoderLayerScale(nn.Module):
394
+ """Layer scale from [Touvron et al 2021] (https://huggingface.co/papers/2103.17239).
395
+ This rescales diagonally the residual outputs close to 0, with a learnt scale.
396
+ """
397
+
398
+ def __init__(self, config):
399
+ super().__init__()
400
+ channels = config.hidden_size
401
+ initial_scale = config.layer_scale_initial_scale
402
+ self.scale = nn.Parameter(torch.full((channels,), initial_scale, requires_grad=True))
403
+
404
+ def forward(self, x: torch.Tensor):
405
+ return self.scale * x
406
+
407
+
408
+ class Qwen3TTSTokenizerV2DecoderTransformerLayer(GradientCheckpointingLayer):
409
+ def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, layer_idx):
410
+ super().__init__()
411
+ self.hidden_size = config.hidden_size
412
+ self.self_attn = Qwen3TTSTokenizerV2DecoderAttention(config, layer_idx)
413
+ self.mlp = Qwen3TTSTokenizerV2DecoderMlp(config)
414
+ self.input_layernorm = Qwen3TTSTokenizerV2DecoderRMSNorm(config.hidden_size, config.rms_norm_eps)
415
+ self.post_attention_layernorm = Qwen3TTSTokenizerV2DecoderRMSNorm(config.hidden_size, config.rms_norm_eps)
416
+ self.self_attn_layer_scale = Qwen3TTSTokenizerV2DecoderLayerScale(config)
417
+ self.mlp_layer_scale = Qwen3TTSTokenizerV2DecoderLayerScale(config)
418
+ self.attention_type = "sliding_attention"
419
+
420
+ def forward(
421
+ self,
422
+ hidden_states: torch.Tensor,
423
+ attention_mask: Optional[torch.Tensor] = None,
424
+ position_ids: Optional[torch.LongTensor] = None,
425
+ past_key_values: Optional[Cache] = None,
426
+ use_cache: Optional[bool] = False,
427
+ cache_position: Optional[torch.LongTensor] = None,
428
+ **kwargs,
429
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
430
+ """
431
+ Args:
432
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
433
+ attention_mask (`torch.FloatTensor`, *optional*):
434
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
435
+ query_sequence_length, key_sequence_length)` if default attention is used.
436
+ output_attentions (`bool`, *optional*):
437
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
438
+ returned tensors for more detail.
439
+ use_cache (`bool`, *optional*):
440
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
441
+ (see `past_key_values`).
442
+ past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
443
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
444
+ Indices depicting the position of the input sequence tokens in the sequence
445
+ kwargs (`dict`, *optional*):
446
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
447
+ into the model
448
+ """
449
+ residual = hidden_states
450
+
451
+ hidden_states = self.input_layernorm(hidden_states)
452
+
453
+ # Self Attention
454
+ hidden_states, _ = self.self_attn(
455
+ hidden_states=hidden_states,
456
+ attention_mask=attention_mask,
457
+ position_ids=position_ids,
458
+ past_key_values=past_key_values,
459
+ use_cache=use_cache,
460
+ cache_position=cache_position,
461
+ **kwargs,
462
+ )
463
+ hidden_states = residual + self.self_attn_layer_scale(hidden_states)
464
+
465
+ # Fully Connected
466
+ residual = hidden_states
467
+ hidden_states = self.post_attention_layernorm(hidden_states)
468
+ hidden_states = self.mlp(hidden_states)
469
+ hidden_states = residual + self.mlp_layer_scale(hidden_states)
470
+
471
+ return hidden_states
472
+
473
+
474
+ @auto_docstring
475
+ class Qwen3TTSTokenizerV2DecoderTransformerModel(Qwen3TTSTokenizerV2DecoderPreTrainedModel):
476
+ _can_record_outputs = {
477
+ "hidden_states": Qwen3TTSTokenizerV2DecoderTransformerLayer,
478
+ "attentions": Qwen3TTSTokenizerV2DecoderAttention,
479
+ }
480
+
481
+ def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig):
482
+ super().__init__(config)
483
+ self.layers = nn.ModuleList(
484
+ [Qwen3TTSTokenizerV2DecoderTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
485
+ )
486
+ self.norm = Qwen3TTSTokenizerV2DecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
487
+ self.rotary_emb = Qwen3TTSTokenizerV2DecoderRotatoryEmbedding(config=config)
488
+ self.gradient_checkpointing = False
489
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
490
+ self.window_size = config.sliding_window
491
+
492
+ self.input_proj = nn.Linear(config.latent_dim, config.hidden_size)
493
+ self.output_proj = nn.Linear(config.hidden_size, config.latent_dim)
494
+
495
+ # Initialize weights and apply final processing
496
+ self.post_init()
497
+
498
+ @check_model_inputs()
499
+ @auto_docstring
500
+ def forward(
501
+ self,
502
+ input_ids=None,
503
+ attention_mask=None,
504
+ position_ids=None,
505
+ past_key_values=None,
506
+ inputs_embeds=None,
507
+ use_cache=None,
508
+ cache_position=None,
509
+ **kwargs,
510
+ ) -> BaseModelOutputWithPast:
511
+ if input_ids is not None:
512
+ raise ValueError("input_ids is not expected")
513
+ if (input_ids is None) ^ (inputs_embeds is not None):
514
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
515
+
516
+ if inputs_embeds is None:
517
+ inputs_embeds = self.embed_tokens(input_ids)
518
+
519
+ inputs_embeds = self.input_proj(inputs_embeds)
520
+
521
+ if use_cache and past_key_values is None:
522
+ past_key_values = DynamicCache(config=self.config)
523
+
524
+ if cache_position is None:
525
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
526
+ cache_position = torch.arange(
527
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
528
+ )
529
+
530
+ if position_ids is None:
531
+ position_ids = cache_position.unsqueeze(0)
532
+
533
+ # It may already have been prepared by e.g. `generate`
534
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
535
+ # Prepare mask arguments
536
+ mask_kwargs = {
537
+ "config": self.config,
538
+ "input_embeds": inputs_embeds,
539
+ "attention_mask": attention_mask,
540
+ "cache_position": cache_position,
541
+ "past_key_values": past_key_values,
542
+ "position_ids": position_ids,
543
+ }
544
+ # Create the masks
545
+ causal_mask_mapping = {
546
+ "full_attention": create_causal_mask(**mask_kwargs),
547
+ }
548
+ # The sliding window alternating layers are not always activated depending on the config
549
+ if self.has_sliding_layers:
550
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
551
+
552
+ hidden_states = inputs_embeds
553
+
554
+ # create position embeddings to be shared across the decoder layers
555
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
556
+
557
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
558
+ hidden_states = decoder_layer(
559
+ hidden_states,
560
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
561
+ position_ids=position_ids,
562
+ past_key_values=past_key_values,
563
+ use_cache=use_cache,
564
+ cache_position=cache_position,
565
+ position_embeddings=position_embeddings,
566
+ **kwargs,
567
+ )
568
+
569
+ hidden_states = self.norm(hidden_states)
570
+ hidden_states = self.output_proj(hidden_states)
571
+ return BaseModelOutputWithPast(
572
+ last_hidden_state=hidden_states,
573
+ past_key_values=past_key_values if use_cache else None,
574
+ )
575
+
576
+
577
+ class SnakeBeta(nn.Module):
578
+ """
579
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
580
+ Shape:
581
+ - Input: (B, C, T)
582
+ - Output: (B, C, T), same shape as the input
583
+ Parameters:
584
+ - alpha - trainable parameter that controls frequency
585
+ - beta - trainable parameter that controls magnitude
586
+ References:
587
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
588
+ https://huggingface.co/papers/2006.08195
589
+ """
590
+
591
+ def __init__(self, in_features, alpha=1.0):
592
+ super().__init__()
593
+ self.in_features = in_features
594
+
595
+ # initialize alpha
596
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
597
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
598
+
599
+ self.no_div_by_zero = 0.000000001
600
+
601
+ def forward(self, hidden_states):
602
+ """
603
+ Forward pass of the function.
604
+ Applies the function to the input elementwise.
605
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
606
+ """
607
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
608
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
609
+ alpha = torch.exp(alpha)
610
+ beta = torch.exp(beta)
611
+ hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
612
+ torch.sin(hidden_states * alpha), 2
613
+ )
614
+
615
+ return hidden_states
616
+
617
+
618
+ class Qwen3TTSTokenizerV2DecoderDecoderResidualUnit(nn.Module):
619
+ def __init__(self, dim: int = 16, dilation: int = 1):
620
+ super().__init__()
621
+
622
+ self.act1 = SnakeBeta(dim)
623
+ self.conv1 = Qwen3TTSTokenizerV2CausalConvNet(dim, dim, kernel_size=7, dilation=dilation)
624
+ self.act2 = SnakeBeta(dim)
625
+ self.conv2 = Qwen3TTSTokenizerV2CausalConvNet(dim, dim, kernel_size=1)
626
+
627
+ def forward(self, hidden_state):
628
+ residual = hidden_state
629
+
630
+ hidden_state = self.act1(hidden_state)
631
+ hidden_state = self.conv1(hidden_state)
632
+ hidden_state = self.act2(hidden_state)
633
+ hidden_state = self.conv2(hidden_state)
634
+ return hidden_state + residual
635
+
636
+
637
+ class Qwen3TTSTokenizerV2DecoderDecoderBlock(Qwen3TTSTokenizerV2DecoderPreTrainedModel):
638
+ def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, layer_idx):
639
+ super().__init__(config)
640
+ in_dim = config.decoder_dim // 2**layer_idx
641
+ out_dim = config.decoder_dim // 2 ** (layer_idx + 1)
642
+ upsample_rate = config.upsample_rates[layer_idx]
643
+
644
+ block = [
645
+ SnakeBeta(in_dim),
646
+ Qwen3TTSTokenizerV2CausalTransConvNet(in_dim, out_dim, 2 * upsample_rate, upsample_rate),
647
+ ]
648
+
649
+ for dilation in (1, 3, 9):
650
+ block.append(Qwen3TTSTokenizerV2DecoderDecoderResidualUnit(out_dim, dilation))
651
+
652
+ self.block = nn.ModuleList(block)
653
+
654
+ def forward(self, hidden):
655
+ for block in self.block:
656
+ hidden = block(hidden)
657
+ return hidden
658
+
659
+
660
+ class EuclideanCodebook(nn.Module):
661
+ def __init__(
662
+ self,
663
+ dim: int,
664
+ codebook_size: int,
665
+ epsilon: float = 1e-5,
666
+ ):
667
+ super().__init__()
668
+ self.dim = dim
669
+ self.codebook_size = codebook_size
670
+ self.epsilon = epsilon
671
+
672
+ self.cluster_usage = nn.Parameter(torch.ones(codebook_size))
673
+ self.embedding_sum = nn.Parameter(torch.zeros(codebook_size, dim))
674
+
675
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
676
+ embedding = self.embedding_sum / self.cluster_usage.clamp(min=self.epsilon)[:, None]
677
+ quantized = F.embedding(codes, embedding)
678
+ return quantized
679
+
680
+
681
+ class VectorQuantization(nn.Module):
682
+ def __init__(
683
+ self,
684
+ dim: int,
685
+ codebook_size: int,
686
+ codebook_dim: Optional[int] = None,
687
+ epsilon: float = 1e-5,
688
+ ):
689
+ super().__init__()
690
+ if codebook_dim is None:
691
+ codebook_dim = dim
692
+
693
+ requires_projection = codebook_dim != dim
694
+
695
+ self.project_out = (
696
+ nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
697
+ )
698
+ self.epsilon = epsilon
699
+ self._codebook = EuclideanCodebook(
700
+ dim=codebook_dim,
701
+ codebook_size=codebook_size,
702
+ epsilon=epsilon
703
+ )
704
+ self.codebook_size = codebook_size
705
+
706
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
707
+ quantized = self._codebook.decode(codes)
708
+ quantized = self.project_out(quantized)
709
+ quantized = quantized.transpose(1, 2)
710
+ return quantized
711
+
712
+
713
+ class ResidualVectorQuantization(nn.Module):
714
+ def __init__(self, *, num_quantizers: int, **kwargs):
715
+ super().__init__()
716
+ self.layers = nn.ModuleList(
717
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
718
+ )
719
+
720
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
721
+ quantized = torch.zeros([1], device=codes.device)[0]
722
+ for idx, layer_codes in enumerate(codes):
723
+ layer = self.layers[idx]
724
+ assert isinstance(layer, VectorQuantization)
725
+ quantized = quantized + layer.decode(layer_codes)
726
+ return quantized
727
+
728
+
729
+ class ResidualVectorQuantizer(nn.Module):
730
+ def __init__(
731
+ self,
732
+ dimension: int = 128,
733
+ input_dimension: Optional[int] = None,
734
+ output_dimension: Optional[int] = None,
735
+ n_q: int = 8,
736
+ q_dropout: bool = False,
737
+ no_quantization_rate: float = 0.0,
738
+ bins: int = 1024,
739
+ decay: float = 0.99,
740
+ force_projection: bool = False,
741
+ ):
742
+ super().__init__()
743
+ self.max_n_q = n_q
744
+ self.n_q = n_q
745
+ self.q_dropout = q_dropout
746
+ self.no_quantization_rate = no_quantization_rate
747
+ self.dimension = dimension
748
+ self.input_dimension = input_dimension or dimension
749
+ self.output_dimension = output_dimension or dimension
750
+ self.bins = bins
751
+ self.decay = decay
752
+ self.input_proj: torch.nn.Module
753
+ self.output_proj: torch.nn.Module
754
+ if self.input_dimension == self.dimension and not force_projection:
755
+ self.input_proj = torch.nn.Identity()
756
+ else:
757
+ self.input_proj = torch.nn.Conv1d(
758
+ self.input_dimension, self.dimension, 1, bias=False
759
+ )
760
+ if self.output_dimension == self.dimension and not force_projection:
761
+ self.output_proj = torch.nn.Identity()
762
+ else:
763
+ self.output_proj = torch.nn.Conv1d(
764
+ self.dimension, self.output_dimension, 1, bias=False
765
+ )
766
+ self.vq = ResidualVectorQuantization(
767
+ dim=self.dimension,
768
+ codebook_size=self.bins,
769
+ num_quantizers=self.n_q
770
+ )
771
+
772
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
773
+ codes = codes.transpose(0, 1)
774
+ quantized = self.vq.decode(codes)
775
+ quantized = self.output_proj(quantized)
776
+ return quantized
777
+
778
+
779
+ class SplitResidualVectorQuantizer(nn.Module):
780
+ """Residual Vector Quantizer with separate projections for the first quantizer and the rest.
781
+
782
+ Args:
783
+ n_q (int): Number of residual vector quantizers used.
784
+ n_semantic_q (int): Number of residual vector quantizers used for the semantic quantizer.
785
+ **kwargs: Arguments to the constructor of `ResidualVectorQuantizer` that are shared between both.
786
+ """
787
+
788
+ def __init__(
789
+ self,
790
+ *,
791
+ n_q: int = 8,
792
+ n_q_semantic: int = 1,
793
+ **kwargs,
794
+ ):
795
+ super().__init__()
796
+ assert n_q > n_q_semantic, (
797
+ f"Number of quantizers {n_q} must be larger "
798
+ f"than the number of semantic quantizers {n_q_semantic}."
799
+ )
800
+ self.max_n_q = n_q
801
+ self.n_q_semantic = n_q_semantic
802
+ self.n_q_acoustic = n_q - n_q_semantic
803
+ q_dropout = kwargs.pop("q_dropout", False)
804
+ self.rvq_first = ResidualVectorQuantizer(
805
+ n_q=n_q_semantic, force_projection=True, q_dropout=False, **kwargs
806
+ )
807
+ self.rvq_rest = ResidualVectorQuantizer(
808
+ n_q=n_q - n_q_semantic,
809
+ force_projection=True,
810
+ q_dropout=q_dropout,
811
+ **kwargs,
812
+ )
813
+
814
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
815
+ """Decode the given codes to the quantized representation."""
816
+ # codes is [B, K, T], with T frames, K nb of codebooks.
817
+ quantized = self.rvq_first.decode(codes[:, : self.n_q_semantic])
818
+ if codes.shape[1] > self.n_q_semantic:
819
+ quantized += self.rvq_rest.decode(codes[:, self.n_q_semantic :])
820
+ return quantized
821
+
822
+
823
+ class Qwen3TTSTokenizerV2Decoder(Qwen3TTSTokenizerV2DecoderPreTrainedModel):
824
+ def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig):
825
+ super().__init__(config)
826
+ self.total_upsample = np.prod(config.upsample_rates + config.upsampling_ratios)
827
+ self.pre_transformer = Qwen3TTSTokenizerV2DecoderTransformerModel._from_config(config)
828
+
829
+ self.quantizer = SplitResidualVectorQuantizer(
830
+ dimension=config.codebook_dim // 2,
831
+ n_q=config.num_quantizers,
832
+ n_q_semantic=1,
833
+ bins=config.codebook_size,
834
+ input_dimension=config.codebook_dim,
835
+ output_dimension=config.codebook_dim,
836
+ )
837
+
838
+ self.pre_conv = Qwen3TTSTokenizerV2CausalConvNet(
839
+ config.codebook_dim,
840
+ config.latent_dim,
841
+ kernel_size=3,
842
+ )
843
+
844
+ upsample = []
845
+ for factor in config.upsampling_ratios:
846
+ upsample.append(
847
+ nn.ModuleList(
848
+ [
849
+ Qwen3TTSTokenizerV2CausalTransConvNet(config.latent_dim, config.latent_dim, factor, factor),
850
+ Qwen3TTSTokenizerV2ConvNeXtBlock(config.latent_dim),
851
+ ]
852
+ )
853
+ )
854
+ self.upsample = nn.ModuleList(upsample)
855
+
856
+ decoder = [Qwen3TTSTokenizerV2CausalConvNet(config.latent_dim, config.decoder_dim, 7)]
857
+ for i in range(len(config.upsample_rates)):
858
+ decoder.append(Qwen3TTSTokenizerV2DecoderDecoderBlock(config, i))
859
+ output_dim = config.decoder_dim // 2 ** len(config.upsample_rates)
860
+ decoder += [
861
+ SnakeBeta(output_dim),
862
+ Qwen3TTSTokenizerV2CausalConvNet(output_dim, 1, 7),
863
+ ]
864
+ self.decoder = nn.ModuleList(decoder)
865
+
866
+ self.post_init()
867
+
868
+ def forward(self, codes):
869
+ if codes.shape[1] != self.config.num_quantizers:
870
+ raise ValueError(f"Expected {self.config.num_quantizers} layer of codes, got {codes.shape[1]}")
871
+
872
+ hidden = self.quantizer.decode(codes)
873
+ hidden = self.pre_conv(hidden).transpose(1, 2)
874
+
875
+ hidden = self.pre_transformer(inputs_embeds=hidden).last_hidden_state
876
+ hidden = hidden.permute(0, 2, 1)
877
+ for blocks in self.upsample:
878
+ for block in blocks:
879
+ hidden = block(hidden)
880
+ wav = hidden
881
+ for block in self.decoder:
882
+ wav = block(wav)
883
+ return wav.clamp(min=-1, max=1)
884
+
885
+ def chunked_decode(self, codes, chunk_size=300, left_context_size=25):
886
+ wavs = []
887
+ start_index = 0
888
+ while start_index < codes.shape[-1]:
889
+ end_index = min(start_index + chunk_size, codes.shape[-1])
890
+ context_size = left_context_size if start_index - left_context_size > 0 else start_index
891
+ codes_chunk = codes[..., start_index - context_size : end_index]
892
+ wav_chunk = self(codes_chunk)
893
+ wavs.append(wav_chunk[..., context_size * self.total_upsample :])
894
+ start_index = end_index
895
+ return torch.cat(wavs, dim=-1)
896
+
897
+
898
+ class Qwen3TTSTokenizerV2Encoder(MimiModel):
899
+ def __init__(self, config: MimiConfig):
900
+ super().__init__(config)
901
+ self.config = config
902
+
903
+ self.upsample = None
904
+ self.decoder_transformer = None
905
+ self.decoder = None
906
+
907
+ self.post_init()
908
+
909
+
910
+ @auto_docstring
911
+ class Qwen3TTSTokenizerV2PreTrainedModel(PreTrainedModel):
912
+ config: Qwen3TTSTokenizerV2Config
913
+ base_model_prefix = "model"
914
+ supports_gradient_checkpointing = True
915
+ _skip_keys_device_placement = "past_key_values"
916
+ _supports_flash_attn = True
917
+ _supports_sdpa = True
918
+ _can_compile_fullgraph = False
919
+ _supports_attention_backend = True
920
+
921
+
922
+ @auto_docstring(
923
+ custom_intro="""
924
+ The Qwen3TTSTokenizerV2 model.
925
+ """
926
+ )
927
+ class Qwen3TTSTokenizerV2Model(Qwen3TTSTokenizerV2PreTrainedModel):
928
+ def __init__(self, config: Qwen3TTSTokenizerV2Config):
929
+ super().__init__(config)
930
+ self.config = config
931
+
932
+ self.encoder_valid_num_quantizers = config.encoder_valid_num_quantizers
933
+
934
+ self.input_sample_rate = config.input_sample_rate
935
+ self.output_sample_rate = config.output_sample_rate
936
+
937
+ self.decode_upsample_rate = config.decode_upsample_rate
938
+ self.encode_downsample_rate = config.encode_downsample_rate
939
+
940
+ self.encoder = Qwen3TTSTokenizerV2Encoder._from_config(self.config.encoder_config)
941
+ self.decoder = Qwen3TTSTokenizerV2Decoder._from_config(self.config.decoder_config)
942
+
943
+ self.post_init()
944
+
945
+ def get_model_type(self):
946
+ return self.config.model_type
947
+
948
+ def get_input_sample_rate(self):
949
+ return self.input_sample_rate
950
+
951
+ def get_output_sample_rate(self):
952
+ return self.output_sample_rate
953
+
954
+ def get_encode_downsample_rate(self):
955
+ return self.encode_downsample_rate
956
+
957
+ def get_decode_upsample_rate(self):
958
+ return self.decode_upsample_rate
959
+
960
+ def encode(
961
+ self,
962
+ input_values: torch.Tensor,
963
+ padding_mask: Optional[torch.Tensor] = None,
964
+ return_dict: Optional[bool] = None,
965
+ ) -> Union[tuple[torch.Tensor, Optional[torch.Tensor]], Qwen3TTSTokenizerV2EncoderOutput]:
966
+ """
967
+ Encodes the input audio waveform into discrete codes.
968
+
969
+ Args:
970
+ input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`):
971
+ Float values of the input audio waveform.
972
+ padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`):
973
+ Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
974
+ for *masked*.
975
+ return_dict (`bool`, *optional*):
976
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
977
+ """
978
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
979
+
980
+ encoded_frames = self.encoder.encode(input_values=input_values.unsqueeze(1),
981
+ return_dict=True)
982
+ audio_codes = encoded_frames.audio_codes[:, :self.encoder_valid_num_quantizers]
983
+ audio_codes = [code[..., :-(-mask.sum() // self.encode_downsample_rate)].transpose(0, 1) for code, mask in zip(audio_codes, padding_mask)]
984
+
985
+ if not return_dict:
986
+ return (
987
+ audio_codes,
988
+ )
989
+
990
+ return Qwen3TTSTokenizerV2EncoderOutput(audio_codes)
991
+
992
+ def decode(
993
+ self,
994
+ audio_codes: torch.Tensor,
995
+ return_dict: Optional[bool] = None,
996
+ ) -> Union[tuple[torch.Tensor, torch.Tensor], Qwen3TTSTokenizerV2DecoderOutput]:
997
+ """
998
+ Decodes the given frames into an output audio waveform.
999
+
1000
+ Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
1001
+ trimmed.
1002
+
1003
+ Args:
1004
+ audio_codes (`torch.LongTensor` of shape `(batch_size, codes_length, num_quantizers)`, *optional*):
1005
+ Discret code embeddings computed using `model.encode`.
1006
+ return_dict (`bool`, *optional*):
1007
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1008
+
1009
+ """
1010
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1011
+
1012
+ audio_values = self.decoder.chunked_decode(audio_codes.transpose(1, 2)).squeeze(1)
1013
+
1014
+ audio_lengths = (audio_codes[..., 0] > 0).sum(1) * self.decode_upsample_rate
1015
+ audio_values = [a[:l] for a, l in zip(audio_values, audio_lengths)]
1016
+
1017
+ if not return_dict:
1018
+ return (
1019
+ audio_values,
1020
+ )
1021
+
1022
+ return Qwen3TTSTokenizerV2DecoderOutput(audio_values)
1023
+
1024
+
1025
+ __all__ = ["Qwen3TTSTokenizerV2Model", "Qwen3TTSTokenizerV2PreTrainedModel"]
qwen_tts/core/tokenizer_25hz/configuration_qwen3_tts_tokenizer_v1.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen3TTSTokenizerV1 model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class Qwen3TTSTokenizerV1DecoderDiTConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1DecoderToken2WavDiT.
27
+ It defines the architecture of the DiT model, which is used for generating mel-spectrograms from tokens.
28
+
29
+ Args:
30
+ hidden_size (`int`, *optional*, defaults to 1024):
31
+ The dimension of the model.
32
+ num_hidden_layers (`int`, *optional*, defaults to 22):
33
+ The number of transformer blocks in the DiT model.
34
+ num_attention_heads (`int`, *optional*, defaults to 16):
35
+ The number of attention heads in each transformer block.
36
+ ff_mult (`int`, *optional*, defaults to 2):
37
+ The multiplier for the feedforward layer in each transformer block.
38
+ emb_dim (`int`, *optional*, defaults to 512):
39
+ The dimension of the embedding layer.
40
+ head_dim (`int`, *optional*, defaults to 64):
41
+ The dimension of each attention head.
42
+ repeats (`int`, *optional*, defaults to 2):
43
+ The number of times the codec embeddings are repeated.
44
+ num_embeds (`int`, *optional*, defaults to 8193):
45
+ The number of unique embeddings in the codec.
46
+ mel_dim (`int`, *optional*, defaults to 80):
47
+ The dimension of the mel-spectrogram.
48
+ dropout (`float`, *optional*, defaults to 0.1):
49
+ The dropout rate for the transformer blocks.
50
+
51
+ enc_emb_dim (`int`, *optional*, defaults to 192):
52
+ The dimension of the pre-trained speaker embedding.
53
+ enc_dim (`int`, *optional*, defaults to 128):
54
+ The dimension of the encoder output.
55
+ enc_channels (`list[int]`, *optional*, defaults to `[256, 256, 256, 256, 768]`):
56
+ A list of output channels for each TDNN/SERes2Net layer in the encoder.
57
+ enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`):
58
+ A list of kernel sizes for each layer in the encoder.
59
+ enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`):
60
+ A list of dilations for each layer in the encoder.
61
+ enc_attention_channels (`int`, *optional*, defaults to 64):
62
+ The number of attention channels in the SqueezeExcitationBlock.
63
+ enc_res2net_scale (`int`, *optional*, defaults to 2):
64
+ The scale of the Res2Net block in the encoder.
65
+ enc_se_channels (`int`, *optional*, defaults to 64):
66
+ The number of output channels after squeeze in the SqueezeExcitationBlock.
67
+ """
68
+
69
+ model_type = "qwen3_tts_tokenizer_v1_decoder_dit"
70
+
71
+ def __init__(
72
+ self,
73
+ hidden_size=1024,
74
+ num_hidden_layers=22,
75
+ num_attention_heads=16,
76
+ ff_mult=2,
77
+ emb_dim=512,
78
+ head_dim=64,
79
+ rope_theta=10000.0,
80
+ max_position_embeddings=32768,
81
+ block_size=24,
82
+ look_ahead_layers=[10],
83
+ look_backward_layers=[0, 20],
84
+ repeats=2,
85
+ num_embeds=8193,
86
+ mel_dim=80,
87
+ dropout=0.1,
88
+ enc_emb_dim=192,
89
+ enc_dim=128,
90
+ enc_channels=[256, 256, 256, 256, 768],
91
+ enc_kernel_sizes=[5, 3, 3, 3, 1],
92
+ enc_dilations=[1, 2, 3, 4, 1],
93
+ enc_attention_channels=64,
94
+ enc_res2net_scale=2,
95
+ enc_se_channels=64,
96
+ **kwargs,
97
+ ):
98
+ self.hidden_size = hidden_size
99
+ self.num_hidden_layers = num_hidden_layers
100
+ self.num_attention_heads = num_attention_heads
101
+ self.ff_mult = ff_mult
102
+ self.emb_dim = emb_dim
103
+ self.head_dim = head_dim
104
+ self.rope_theta = rope_theta
105
+ self.max_position_embeddings = max_position_embeddings
106
+ self.block_size = block_size
107
+ self.look_ahead_layers = look_ahead_layers
108
+ self.look_backward_layers = look_backward_layers
109
+ self.repeats = repeats
110
+ self.num_embeds = num_embeds
111
+ self.mel_dim = mel_dim
112
+ self.dropout = dropout
113
+ self.enc_emb_dim = enc_emb_dim
114
+ self.enc_dim = enc_dim
115
+ self.enc_channels = enc_channels
116
+ self.enc_kernel_sizes = enc_kernel_sizes
117
+ self.enc_dilations = enc_dilations
118
+ self.enc_attention_channels = enc_attention_channels
119
+ self.enc_res2net_scale = enc_res2net_scale
120
+ self.enc_se_channels = enc_se_channels
121
+ super().__init__(**kwargs)
122
+
123
+
124
+ class Qwen3TTSTokenizerV1DecoderBigVGANConfig(PretrainedConfig):
125
+ r"""
126
+ This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1DecoderToken2WavBigVGAN module.
127
+ It defines the architecture of the BigVGAN model, which is used for converting mel-spectrograms to waveforms.
128
+
129
+ Args:
130
+ mel_dim (`int`, *optional*, defaults to 80):
131
+ The dimension of the mel-spectrogram.
132
+ upsample_initial_channel (`int`, *optional*, defaults to 1536):
133
+ The number of channels in the initial upsampling layer.
134
+ resblock_kernel_sizes (`list[int]`, *optional*, defaults to `[3, 7, 11]`):
135
+ A list of kernel sizes for each residual block.
136
+ resblock_dilation_sizes (`list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
137
+ A list of dilation sizes for each residual block.
138
+ upsample_rates (`list[int]`, *optional*, defaults to `[5, 3, 2, 2, 2, 2]`):
139
+ A list of upsampling rates for each upsampling layer.
140
+ upsample_kernel_sizes (`list[int]`, *optional*, defaults to `[11, 7, 4, 4, 4, 4]`):
141
+ A list of kernel sizes for each upsampling layer.
142
+ """
143
+
144
+ model_type = "qwen3_tts_tokenizer_v1_decoder_bigvgan"
145
+
146
+ def __init__(
147
+ self,
148
+ mel_dim=80,
149
+ upsample_initial_channel=1536,
150
+ resblock_kernel_sizes=[3, 7, 11],
151
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
152
+ upsample_rates=[5, 3, 2, 2, 2, 2],
153
+ upsample_kernel_sizes=[11, 7, 4, 4, 4, 4],
154
+ **kwargs,
155
+ ):
156
+ self.mel_dim = mel_dim
157
+ self.upsample_initial_channel = upsample_initial_channel
158
+ self.resblock_kernel_sizes = resblock_kernel_sizes
159
+ self.resblock_dilation_sizes = resblock_dilation_sizes
160
+ self.upsample_rates = upsample_rates
161
+ self.upsample_kernel_sizes = upsample_kernel_sizes
162
+ super().__init__(**kwargs)
163
+
164
+
165
+ class Qwen3TTSTokenizerV1DecoderConfig(PretrainedConfig):
166
+ r"""
167
+ This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV1DecoderConfig`].
168
+
169
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
170
+ documentation from [`PretrainedConfig`] for more information.
171
+
172
+ Args:
173
+ dit_config ([`DiT_Args`], *optional*):
174
+ Configuration class for the Diffusion Transformer (DiT) module responsible for generating mel-spectrograms.
175
+ bigvgan_config ([`BigVGAN_Args`], *optional*):
176
+ Configuration class for the BigVGAN module responsible for converting mel-spectrograms to waveforms.
177
+ """
178
+
179
+ model_type = "qwen3_tts_tokenizer_v1_decoder"
180
+ sub_configs = {
181
+ "dit_config": Qwen3TTSTokenizerV1DecoderDiTConfig,
182
+ "bigvgan_config": Qwen3TTSTokenizerV1DecoderBigVGANConfig,
183
+ }
184
+
185
+ def __init__(self, dit_config=None, bigvgan_config=None, **kwargs):
186
+ if dit_config is None:
187
+ dit_config = {}
188
+ if bigvgan_config is None:
189
+ bigvgan_config = {}
190
+ self.dit_config = Qwen3TTSTokenizerV1DecoderDiTConfig(**dit_config)
191
+ self.bigvgan_config = Qwen3TTSTokenizerV1DecoderBigVGANConfig(**bigvgan_config)
192
+ super().__init__(**kwargs)
193
+
194
+
195
+ class Qwen3TTSTokenizerV1EncoderConfig(PretrainedConfig):
196
+ r"""
197
+ This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1 Encoder.
198
+
199
+ The encoder typically takes mel-spectrogram features and produces high-level audio representations, then (optionally)
200
+ applies an Audio-VQ module (e.g., GRVQ) to discretize continuous representations into codes.
201
+
202
+ Args:
203
+ n_mels (`int`, *optional*, defaults to 128):
204
+ Number of mel bins in the input mel-spectrogram.
205
+ n_ctx (`int`, *optional*, defaults to 1500):
206
+ Maximum input sequence length (in frames/tokens) for the encoder.
207
+ n_state (`int`, *optional*, defaults to 1280):
208
+ Hidden size (model dimension) of the encoder transformer.
209
+ n_head (`int`, *optional*, defaults to 20):
210
+ Number of attention heads in each transformer layer.
211
+ n_layer (`int`, *optional*, defaults to 32):
212
+ Number of transformer layers.
213
+ n_window (`int`, *optional*, defaults to 100):
214
+ Window size used by the model for local attention / chunking (implementation-dependent).
215
+ output_dim (`int`, *optional*, defaults to 3584):
216
+ Output feature dimension produced by the encoder head (before/after projection, implementation-dependent).
217
+
218
+ grad_checkpointing (`bool`, *optional*, defaults to `False`):
219
+ Whether to enable gradient checkpointing to reduce memory usage during training.
220
+ enable_mp (`bool`, *optional*, defaults to `False`):
221
+ Whether to enable model parallel features (implementation-dependent).
222
+ audio_sequence_parallel (`bool`, *optional*, defaults to `False`):
223
+ Whether to enable sequence parallelism for audio branch (implementation-dependent).
224
+
225
+ audio_vq_type (`str`, *optional*, defaults to `"GRVQ"`):
226
+ Type of audio vector-quantization module. Common choices: `"GRVQ"`, `"RVQ"`, etc.
227
+ audio_vq_layers (`int`, *optional*, defaults to 6):
228
+ Number of VQ layers / quantizers (e.g., number of residual quantizers for RVQ/GRVQ-like designs).
229
+ audio_vq_codebook_size (`int`, *optional*, defaults to 32768):
230
+ Size of each codebook (number of entries).
231
+ audio_vq_codebook_dim (`int`, *optional*, defaults to 1280):
232
+ Dimension of codebook vectors (often equals encoder hidden size).
233
+ audio_vq_pe (`bool`, *optional*, defaults to `True`):
234
+ Whether to use positional encoding (or position embeddings) inside the VQ module.
235
+ audio_vq_ds_rate (`int`, *optional*, defaults to 2):
236
+ Downsampling rate applied before VQ (e.g., temporal downsample factor).
237
+ """
238
+
239
+ model_type = "qwen3_tts_tokenizer_v1_encoder"
240
+
241
+ def __init__(
242
+ self,
243
+ n_mels=128,
244
+ n_ctx=1500,
245
+ n_state=1280,
246
+ n_head=20,
247
+ n_layer=32,
248
+ n_window=100,
249
+ output_dim=3584,
250
+ grad_checkpointing=False,
251
+ enable_mp=False,
252
+ audio_sequence_parallel=False,
253
+ audio_vq_type="GRVQ",
254
+ audio_vq_layers=6,
255
+ audio_vq_codebook_size=32768,
256
+ audio_vq_codebook_dim=1280,
257
+ audio_vq_pe=True,
258
+ audio_vq_ds_rate=2,
259
+ **kwargs,
260
+ ):
261
+ super().__init__(**kwargs)
262
+ self.n_mels = n_mels
263
+ self.n_ctx = n_ctx
264
+ self.n_state = n_state
265
+ self.n_head = n_head
266
+ self.n_layer = n_layer
267
+ self.n_window = n_window
268
+ self.output_dim = output_dim
269
+ self.grad_checkpointing = grad_checkpointing
270
+ self.enable_mp = enable_mp
271
+ self.audio_sequence_parallel = audio_sequence_parallel
272
+ self.audio_vq_type = audio_vq_type
273
+ self.audio_vq_layers = audio_vq_layers
274
+ self.audio_vq_codebook_size = audio_vq_codebook_size
275
+ self.audio_vq_codebook_dim = audio_vq_codebook_dim
276
+ self.audio_vq_pe = audio_vq_pe
277
+ self.audio_vq_ds_rate = audio_vq_ds_rate
278
+
279
+
280
+ class Qwen3TTSTokenizerV1Config(PretrainedConfig):
281
+ """
282
+ This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV1Config`]. It is used to instantiate a Qwen3TTSTokenizerV1Model
283
+ model according to the specified sub-models configurations, defining the model architecture.
284
+
285
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
286
+ documentation from [`PretrainedConfig`] for more information.
287
+
288
+ Args:
289
+ encoder_config (`dict`, *optional*): Configuration of the underlying encoder sub-model.
290
+ decoder_config (`dict`, *optional*): Configuration of the underlying decoder sub-model.
291
+ """
292
+
293
+ model_type = "qwen3_tts_tokenizer_25hz"
294
+ sub_configs = {
295
+ "encoder_config": Qwen3TTSTokenizerV1EncoderConfig,
296
+ "decoder_config": Qwen3TTSTokenizerV1DecoderConfig,
297
+ }
298
+
299
+ def __init__(
300
+ self,
301
+ encoder_config=None,
302
+ decoder_config=None,
303
+ input_sample_rate=24000,
304
+ output_sample_rate=24000,
305
+ decode_upsample_rate=1920,
306
+ encode_downsample_rate=1920,
307
+ **kwargs,
308
+ ):
309
+ super().__init__(**kwargs)
310
+ if encoder_config is None:
311
+ encoder_config = {}
312
+ logger.info("encoder_config is None. Initializing encoder with default values")
313
+ if decoder_config is None:
314
+ decoder_config = {}
315
+ logger.info("decoder_config is None. Initializing decoder with default values")
316
+
317
+ self.encoder_config = Qwen3TTSTokenizerV1EncoderConfig(**encoder_config)
318
+ self.decoder_config = Qwen3TTSTokenizerV1DecoderConfig(**decoder_config)
319
+
320
+ self.input_sample_rate = input_sample_rate
321
+ self.output_sample_rate = output_sample_rate
322
+ self.decode_upsample_rate = decode_upsample_rate
323
+ self.encode_downsample_rate = encode_downsample_rate
324
+
325
+
326
+ __all__ = [
327
+ "Qwen3TTSTokenizerV1Config",
328
+ "Qwen3TTSTokenizerV1EncoderConfig",
329
+ "Qwen3TTSTokenizerV1DecoderConfig",
330
+ "Qwen3TTSTokenizerV1DecoderBigVGANConfig",
331
+ "Qwen3TTSTokenizerV1DecoderDiTConfig"
332
+ ]
qwen_tts/core/tokenizer_25hz/modeling_qwen3_tts_tokenizer_v1.py ADDED
@@ -0,0 +1,1528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Qwen3TTSTokenizerV1 model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Union, List
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torch import nn
24
+ from torch.nn import Parameter
25
+ from torch.nn import functional as F
26
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
27
+ from transformers.utils import ModelOutput, auto_docstring, logging
28
+ from transformers.utils.hub import cached_file
29
+
30
+ from torch.nn.utils.rnn import pad_sequence
31
+
32
+ from .vq.whisper_encoder import get_mel_audio, get_T_after_cnn
33
+ from .vq.speech_vq import WhisperEncoderVQ, XVectorExtractor
34
+
35
+ from .configuration_qwen3_tts_tokenizer_v1 import (
36
+ Qwen3TTSTokenizerV1Config,
37
+ Qwen3TTSTokenizerV1EncoderConfig,
38
+ Qwen3TTSTokenizerV1DecoderConfig,
39
+ Qwen3TTSTokenizerV1DecoderBigVGANConfig,
40
+ Qwen3TTSTokenizerV1DecoderDiTConfig
41
+ )
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ @dataclass
47
+ @auto_docstring
48
+ class Qwen3TTSTokenizerV1EncoderOutput(ModelOutput):
49
+ r"""
50
+ audio_codes (`List[torch.LongTensor]`):
51
+ Discret code embeddings computed using `model.encode`, each tensor has shape (codes_length_i,).
52
+ xvectors (`List[torch.FloatTensor]`):
53
+ X-vector embeddings computed using `model.encode`, each tensor has shape (xvector_dim,).
54
+ ref_mels (`List[torch.FloatTensor]`):
55
+ Reference mel spectrogram computed using `model.encode`, each tensor has shape (mel_length_i, mel_dim,).
56
+ """
57
+
58
+ audio_codes: List[torch.LongTensor] = None
59
+ xvectors: List[torch.FloatTensor] = None
60
+ ref_mels: List[torch.FloatTensor] = None
61
+
62
+
63
+ @dataclass
64
+ @auto_docstring
65
+ class Qwen3TTSTokenizerV1DecoderOutput(ModelOutput):
66
+ r"""
67
+ audio_values (`List[torch.FloatTensor]`):
68
+ Decoded audio values, obtained using the decoder part of Qwen3TTSTokenizerV1.
69
+ Each tensor has shape (segment_length_i).
70
+ """
71
+
72
+ audio_values: List[torch.FloatTensor] = None
73
+
74
+
75
+ @auto_docstring
76
+ class Qwen3TTSTokenizerV1DecoderPreTrainedModel(PreTrainedModel):
77
+ config: Qwen3TTSTokenizerV1DecoderConfig
78
+ base_model_prefix = "model"
79
+ supports_gradient_checkpointing = True
80
+ _skip_keys_device_placement = "past_key_values"
81
+ _supports_flash_attn = True
82
+ _supports_sdpa = True
83
+ _can_compile_fullgraph = False
84
+ _supports_attention_backend = True
85
+
86
+
87
+ @auto_docstring
88
+ class Qwen3TTSTokenizerV1EncoderPreTrainedModel(PreTrainedModel):
89
+ config: Qwen3TTSTokenizerV1EncoderConfig
90
+ base_model_prefix = "model"
91
+ supports_gradient_checkpointing = True
92
+ _skip_keys_device_placement = "past_key_values"
93
+ _supports_flash_attn = True
94
+ _supports_sdpa = True
95
+ _can_compile_fullgraph = False
96
+ _supports_attention_backend = True
97
+
98
+
99
+ class Qwen3TTSTokenizerV1DecoderDiTRotaryEmbedding(nn.Module):
100
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
101
+
102
+ def __init__(self, dim, base=10000):
103
+ super().__init__()
104
+
105
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
106
+ self.register_buffer("inv_freq", inv_freq)
107
+
108
+ def forward(self, x):
109
+ batch_size, seq_len = x.shape[0], x.shape[1]
110
+ t = torch.arange(seq_len, device=x.device)
111
+ device_type = x.device.type
112
+ device_type = device_type if device_type != "mps" else "cpu"
113
+ with torch.autocast(device_type=device_type, enabled=False):
114
+ freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float()
115
+ freqs = torch.stack((freqs, freqs), dim=-1)
116
+ freqs = freqs.reshape(*freqs.shape[:-2], -1)
117
+ freqs = freqs.repeat(batch_size, *([1] * freqs.dim()))
118
+ cos = freqs.cos()
119
+ sin = freqs.sin()
120
+
121
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
122
+
123
+
124
+ class TimeDelayNetBlock(nn.Module):
125
+ def __init__(
126
+ self,
127
+ in_channels,
128
+ out_channels,
129
+ kernel_size,
130
+ dilation,
131
+ ):
132
+ super().__init__()
133
+ self.conv = nn.Conv1d(
134
+ in_channels=in_channels,
135
+ out_channels=out_channels,
136
+ kernel_size=kernel_size,
137
+ dilation=dilation,
138
+ padding="same",
139
+ padding_mode="reflect",
140
+ )
141
+ self.activation = nn.ReLU()
142
+
143
+ def forward(self, hidden_states: torch.Tensor):
144
+ return self.activation(self.conv(hidden_states))
145
+
146
+
147
+ class Res2NetBlock(torch.nn.Module):
148
+ def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
149
+ super().__init__()
150
+
151
+ in_channel = in_channels // scale
152
+ hidden_channel = out_channels // scale
153
+
154
+ self.blocks = nn.ModuleList(
155
+ [
156
+ TimeDelayNetBlock(
157
+ in_channel,
158
+ hidden_channel,
159
+ kernel_size=kernel_size,
160
+ dilation=dilation,
161
+ )
162
+ for i in range(scale - 1)
163
+ ]
164
+ )
165
+ self.scale = scale
166
+
167
+ def forward(self, hidden_states):
168
+ outputs = []
169
+ for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)):
170
+ if i == 0:
171
+ output_part = hidden_part
172
+ elif i == 1:
173
+ output_part = self.blocks[i - 1](hidden_part)
174
+ else:
175
+ output_part = self.blocks[i - 1](hidden_part + output_part)
176
+ outputs.append(output_part)
177
+ output = torch.cat(outputs, dim=1)
178
+ return output
179
+
180
+
181
+ class SqueezeExcitationBlock(nn.Module):
182
+ def __init__(self, in_channels, se_channels, out_channels):
183
+ super().__init__()
184
+
185
+ self.conv1 = nn.Conv1d(
186
+ in_channels=in_channels,
187
+ out_channels=se_channels,
188
+ kernel_size=1,
189
+ padding="same",
190
+ padding_mode="reflect",
191
+ )
192
+ self.relu = nn.ReLU(inplace=True)
193
+ self.conv2 = nn.Conv1d(
194
+ in_channels=se_channels,
195
+ out_channels=out_channels,
196
+ kernel_size=1,
197
+ padding="same",
198
+ padding_mode="reflect",
199
+ )
200
+ self.sigmoid = nn.Sigmoid()
201
+
202
+ def forward(self, hidden_states):
203
+ hidden_states_mean = hidden_states.mean(dim=2, keepdim=True)
204
+
205
+ hidden_states_mean = self.relu(self.conv1(hidden_states_mean))
206
+ hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean))
207
+
208
+ return hidden_states * hidden_states_mean
209
+
210
+
211
+ class AttentiveStatisticsPooling(nn.Module):
212
+ """This class implements an attentive statistic pooling layer for each channel.
213
+ It returns the concatenated mean and std of the input tensor.
214
+ """
215
+
216
+ def __init__(self, channels, attention_channels=128):
217
+ super().__init__()
218
+
219
+ self.eps = 1e-12
220
+ self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1)
221
+ self.tanh = nn.Tanh()
222
+ self.conv = nn.Conv1d(
223
+ in_channels=attention_channels,
224
+ out_channels=channels,
225
+ kernel_size=1,
226
+ padding="same",
227
+ padding_mode="reflect",
228
+ )
229
+
230
+ def _length_to_mask(self, length, max_len=None, dtype=None, device=None):
231
+ """Creates a binary mask for each sequence.
232
+
233
+ Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
234
+
235
+ Arguments
236
+ ---------
237
+ length : torch.LongTensor
238
+ Containing the length of each sequence in the batch. Must be 1D.
239
+ max_len : int
240
+ Max length for the mask, also the size of the second dimension.
241
+ dtype : torch.dtype, default: None
242
+ The dtype of the generated mask.
243
+ device: torch.device, default: None
244
+ The device to put the mask variable.
245
+
246
+ Returns
247
+ -------
248
+ mask : tensor
249
+ The binary mask.
250
+ """
251
+
252
+ if max_len is None:
253
+ max_len = length.max().long().item() # using arange to generate mask
254
+ mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
255
+ len(length), max_len
256
+ ) < length.unsqueeze(1)
257
+
258
+ mask = torch.as_tensor(mask, dtype=dtype, device=device)
259
+ return mask
260
+
261
+ def _compute_statistics(self, x, m, dim=2):
262
+ mean = (m * x).sum(dim)
263
+ std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps))
264
+ return mean, std
265
+
266
+ def forward(self, hidden_states):
267
+ seq_length = hidden_states.shape[-1]
268
+ lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device)
269
+
270
+ # Make binary mask of shape [N, 1, L]
271
+ mask = self._length_to_mask(
272
+ lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device
273
+ )
274
+ mask = mask.unsqueeze(1)
275
+
276
+ # Expand the temporal context of the pooling layer by allowing the
277
+ # self-attention to look at global properties of the utterance.
278
+ total = mask.sum(dim=2, keepdim=True)
279
+
280
+ mean, std = self._compute_statistics(hidden_states, mask / total)
281
+ mean = mean.unsqueeze(2).repeat(1, 1, seq_length)
282
+ std = std.unsqueeze(2).repeat(1, 1, seq_length)
283
+ attention = torch.cat([hidden_states, mean, std], dim=1)
284
+
285
+ # Apply layers
286
+ attention = self.conv(self.tanh(self.tdnn(attention)))
287
+
288
+ # Filter out zero-paddings
289
+ attention = attention.masked_fill(mask == 0, float("-inf"))
290
+
291
+ attention = F.softmax(attention, dim=2)
292
+ mean, std = self._compute_statistics(hidden_states, attention)
293
+ # Append mean and std of the batch
294
+ pooled_stats = torch.cat((mean, std), dim=1)
295
+ pooled_stats = pooled_stats.unsqueeze(2)
296
+
297
+ return pooled_stats
298
+
299
+
300
+ class SqueezeExcitationRes2NetBlock(nn.Module):
301
+ """An implementation of building block in ECAPA-TDNN, i.e.,
302
+ TDNN-Res2Net-TDNN-SqueezeExcitationBlock.
303
+ """
304
+
305
+ def __init__(
306
+ self,
307
+ in_channels,
308
+ out_channels,
309
+ res2net_scale=8,
310
+ se_channels=128,
311
+ kernel_size=1,
312
+ dilation=1,
313
+ ):
314
+ super().__init__()
315
+ self.out_channels = out_channels
316
+ self.tdnn1 = TimeDelayNetBlock(
317
+ in_channels,
318
+ out_channels,
319
+ kernel_size=1,
320
+ dilation=1,
321
+ )
322
+ self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation)
323
+ self.tdnn2 = TimeDelayNetBlock(
324
+ out_channels,
325
+ out_channels,
326
+ kernel_size=1,
327
+ dilation=1,
328
+ )
329
+ self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels)
330
+
331
+ def forward(self, hidden_state):
332
+ residual = hidden_state
333
+
334
+ hidden_state = self.tdnn1(hidden_state)
335
+ hidden_state = self.res2net_block(hidden_state)
336
+ hidden_state = self.tdnn2(hidden_state)
337
+ hidden_state = self.se_block(hidden_state)
338
+
339
+ return hidden_state + residual
340
+
341
+
342
+ class ECAPA_TimeDelayNet(torch.nn.Module):
343
+ """An implementation of the speaker embedding model in a paper.
344
+ "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
345
+ TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143).
346
+ """
347
+
348
+ def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
349
+ super().__init__()
350
+ if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len(
351
+ config.enc_dilations
352
+ ):
353
+ raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length")
354
+ self.channels = config.enc_channels
355
+ self.blocks = nn.ModuleList()
356
+
357
+ # The initial TDNN layer
358
+ self.blocks.append(
359
+ TimeDelayNetBlock(
360
+ config.mel_dim,
361
+ config.enc_channels[0],
362
+ config.enc_kernel_sizes[0],
363
+ config.enc_dilations[0],
364
+ )
365
+ )
366
+
367
+ # SE-Res2Net layers
368
+ for i in range(1, len(config.enc_channels) - 1):
369
+ self.blocks.append(
370
+ SqueezeExcitationRes2NetBlock(
371
+ config.enc_channels[i - 1],
372
+ config.enc_channels[i],
373
+ res2net_scale=config.enc_res2net_scale,
374
+ se_channels=config.enc_se_channels,
375
+ kernel_size=config.enc_kernel_sizes[i],
376
+ dilation=config.enc_dilations[i],
377
+ )
378
+ )
379
+
380
+ # Multi-layer feature aggregation
381
+ self.mfa = TimeDelayNetBlock(
382
+ config.enc_channels[-1],
383
+ config.enc_channels[-1],
384
+ config.enc_kernel_sizes[-1],
385
+ config.enc_dilations[-1],
386
+ )
387
+
388
+ # Attentive Statistical Pooling
389
+ self.asp = AttentiveStatisticsPooling(
390
+ config.enc_channels[-1],
391
+ attention_channels=config.enc_attention_channels,
392
+ )
393
+
394
+ # Final linear transformation
395
+ self.fc = nn.Conv1d(
396
+ in_channels=config.enc_channels[-1] * 2,
397
+ out_channels=config.enc_dim,
398
+ kernel_size=1,
399
+ padding="same",
400
+ padding_mode="reflect",
401
+ )
402
+
403
+ def forward(self, hidden_states):
404
+ # Minimize transpose for efficiency
405
+ hidden_states = hidden_states.transpose(1, 2)
406
+
407
+ hidden_states_list = []
408
+ for layer in self.blocks:
409
+ hidden_states = layer(hidden_states)
410
+ hidden_states_list.append(hidden_states)
411
+
412
+ # Multi-layer feature aggregation
413
+ hidden_states = torch.cat(hidden_states_list[1:], dim=1)
414
+ hidden_states = self.mfa(hidden_states)
415
+
416
+ # Attentive Statistical Pooling
417
+ hidden_states = self.asp(hidden_states)
418
+
419
+ # Final linear transformation
420
+ hidden_states = self.fc(hidden_states)
421
+
422
+ hidden_states = hidden_states.squeeze(-1)
423
+ return hidden_states
424
+
425
+
426
+ class DiTInputEmbedding(nn.Module):
427
+ def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
428
+ super().__init__()
429
+ self.proj = nn.Linear(
430
+ config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim,
431
+ config.hidden_size,
432
+ )
433
+ self.spk_encoder = ECAPA_TimeDelayNet(config)
434
+
435
+ def forward(
436
+ self,
437
+ hidden_states: torch.Tensor,
438
+ speaker_embedding: torch.Tensor,
439
+ condition_vector: torch.Tensor,
440
+ code_embed: torch.Tensor,
441
+ drop_audio_cond: Optional[bool] = False,
442
+ code_embed_uncond: Optional[bool] = None,
443
+ apply_cfg: Optional[bool] = True,
444
+ ):
445
+ if apply_cfg:
446
+ hidden_states = torch.cat([hidden_states, hidden_states], dim=0)
447
+ speaker_embedding = torch.cat([speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0)
448
+ condition_vector = torch.cat([condition_vector, torch.zeros_like(condition_vector)], dim=0)
449
+ code_embed = torch.cat([code_embed, code_embed_uncond], dim=0)
450
+ elif drop_audio_cond: # cfg for cond audio
451
+ condition_vector = torch.zeros_like(condition_vector)
452
+ speaker_embedding = torch.zeros_like(speaker_embedding)
453
+ condition_vector = self.spk_encoder(condition_vector).unsqueeze(1).repeat(1, hidden_states.size(1), 1)
454
+ hidden_states = self.proj(torch.cat((hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1))
455
+
456
+ return hidden_states
457
+
458
+
459
+ # Transformer backbone using DiT blocks
460
+ class DiTCodecEmbedding(nn.Module):
461
+ def __init__(self, codec_num_embeds, codec_dim, repeats):
462
+ super().__init__()
463
+ self.repeats = repeats
464
+ self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim)
465
+
466
+ def forward(self, code, drop_code=False):
467
+ if drop_code:
468
+ code = torch.zeros_like(code)
469
+ code_embed = self.codec_embed(code)
470
+
471
+ code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1)
472
+ return code_embed
473
+
474
+
475
+ # AdaLayerNormZero
476
+ # return with modulated x for attn input, and params for later mlp modulation
477
+ class AdaLayerNormZero(nn.Module):
478
+ def __init__(self, dim):
479
+ super().__init__()
480
+
481
+ self.silu = nn.SiLU()
482
+ self.linear = nn.Linear(dim, dim * 6)
483
+
484
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
485
+
486
+ def forward(self, hidden_states, emb=None):
487
+ emb = self.linear(self.silu(emb))
488
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
489
+
490
+ hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None]
491
+ return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp
492
+
493
+
494
+ # AdaLayerNormZero for final layer
495
+ # return only with modulated x for attn input, cuz no more mlp modulation
496
+ class AdaLayerNormZero_Final(nn.Module):
497
+ def __init__(self, dim):
498
+ super().__init__()
499
+
500
+ self.silu = nn.SiLU()
501
+ self.linear = nn.Linear(dim, dim * 2)
502
+
503
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
504
+
505
+ def forward(self, hidden_states, emb):
506
+ emb = self.linear(self.silu(emb))
507
+ scale, shift = torch.chunk(emb, 2, dim=1)
508
+
509
+ hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
510
+ return hidden_states
511
+
512
+
513
+ # FeedForward
514
+ class DiTMLP(nn.Module):
515
+ def __init__(self, dim, mult=4, dropout=0.0):
516
+ super().__init__()
517
+ inner_dim = int(dim * mult)
518
+
519
+ self.ff = nn.ModuleList(
520
+ [
521
+ nn.Linear(dim, inner_dim),
522
+ nn.GELU(approximate="tanh"),
523
+ nn.Dropout(dropout),
524
+ nn.Linear(inner_dim, dim),
525
+ ]
526
+ )
527
+
528
+ def forward(self, hidden_states):
529
+ for layer in self.ff:
530
+ hidden_states = layer(hidden_states)
531
+ return hidden_states
532
+
533
+
534
+ # Modified from Llama with a different rotate function, will fixed in next release
535
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
536
+ """Applies Rotary Position Embedding to the query and key tensors.
537
+
538
+ Args:
539
+ q (`torch.Tensor`): The query tensor.
540
+ k (`torch.Tensor`): The key tensor.
541
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
542
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
543
+ position_ids (`torch.Tensor`, *optional*):
544
+ Deprecated and unused.
545
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
546
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
547
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
548
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
549
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
550
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
551
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
552
+ Returns:
553
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
554
+ """
555
+
556
+ def rotate_half_codec(x):
557
+ # x = rearrange(x, "... (d r) -> ... d r", r=2)
558
+ x = x.reshape(*x.shape[:-1], -1, 2)
559
+ x1, x2 = x.unbind(dim=-1)
560
+ x = torch.stack((-x2, x1), dim=-1)
561
+ return x.reshape(*x.shape[:-2], -1)
562
+
563
+ cos = cos.unsqueeze(unsqueeze_dim)
564
+ sin = sin.unsqueeze(unsqueeze_dim)
565
+ q_embed = (q * cos) + (rotate_half_codec(q) * sin)
566
+ k_embed = (k * cos) + (rotate_half_codec(k) * sin)
567
+ return q_embed, k_embed
568
+
569
+
570
+ class DiTAttention(nn.Module):
571
+ def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
572
+ super().__init__()
573
+
574
+ self.config = config
575
+ self.dim = config.hidden_size
576
+ self.heads = config.num_attention_heads
577
+ self.inner_dim = config.head_dim * config.num_attention_heads
578
+ self.dropout = config.dropout
579
+ self.is_causal = False
580
+
581
+ self.to_q = nn.Linear(config.hidden_size, self.inner_dim)
582
+ self.to_k = nn.Linear(config.hidden_size, self.inner_dim)
583
+ self.to_v = nn.Linear(config.hidden_size, self.inner_dim)
584
+
585
+ self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)])
586
+
587
+ def forward(
588
+ self,
589
+ hidden_states, # noised input x
590
+ position_embeddings=None, # rotary position embedding for x
591
+ attention_mask=None,
592
+ ) -> torch.Tensor:
593
+ batch_size = hidden_states.shape[0]
594
+
595
+ # `sample` projections.
596
+ query = self.to_q(hidden_states)
597
+ key = self.to_k(hidden_states)
598
+ value = self.to_v(hidden_states)
599
+
600
+ # attention
601
+ inner_dim = key.shape[-1]
602
+ head_dim = inner_dim // self.heads
603
+ query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
604
+ key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
605
+ value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
606
+
607
+ # apply rotary position embedding
608
+ # Due to training process, only first head is applied with RoPE, will be fixed at next release
609
+ cos, sin = position_embeddings
610
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
611
+
612
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
613
+ attention_weights, _ = attention_interface(
614
+ self,
615
+ query,
616
+ key,
617
+ value,
618
+ attention_mask=attention_mask,
619
+ is_causal=False,
620
+ )
621
+
622
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
623
+ attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim)
624
+ attention_weights = attention_weights.to(query.dtype)
625
+
626
+ # linear proj
627
+ attention_output = self.to_out[0](attention_weights)
628
+ attention_output = self.to_out[1](attention_output)
629
+
630
+ return attention_output
631
+
632
+
633
+ # time step conditioning embedding
634
+ class SinusPositionEmbedding(nn.Module):
635
+ def __init__(self, dim):
636
+ super().__init__()
637
+ self.dim = dim
638
+
639
+ def forward(self, hidden_states, scale=1000):
640
+ device = hidden_states.device
641
+ half_dim = self.dim // 2
642
+ emb = math.log(10000) / (half_dim - 1)
643
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
644
+ emb = scale * hidden_states.unsqueeze(1) * emb.unsqueeze(0)
645
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
646
+ return emb.type_as(hidden_states)
647
+
648
+
649
+ class DiTTimestepEmbedding(nn.Module):
650
+ def __init__(self, dim, freq_embed_dim=256):
651
+ super().__init__()
652
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
653
+ self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)])
654
+
655
+ def forward(self, timestep):
656
+ time_hidden = self.time_embed(timestep)
657
+ time_hidden = time_hidden.to(timestep.dtype)
658
+ for layer in self.time_mlp:
659
+ time_hidden = layer(time_hidden) # b d
660
+ return time_hidden
661
+
662
+
663
+ class DiTDecoderLayer(nn.Module):
664
+ def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig, look_ahead_block=0, look_backward_block=0):
665
+ super().__init__()
666
+ self.attn_norm = AdaLayerNormZero(config.hidden_size)
667
+
668
+ self.attn = DiTAttention(config)
669
+ self.look_ahead_block = look_ahead_block
670
+ self.look_backward_block = look_backward_block
671
+ self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6)
672
+ self.ff = DiTMLP(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout)
673
+
674
+ def forward(
675
+ self, hidden_states, timestep, position_embeddings=None, block_diff=None
676
+ ): # x: noised input, t: time embedding
677
+ # pre-norm & modulation for attention input
678
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(hidden_states, emb=timestep)
679
+
680
+ # attention
681
+ attn_output = self.attn(
682
+ hidden_states=norm,
683
+ position_embeddings=position_embeddings,
684
+ attention_mask=(block_diff >= -float(self.look_backward_block))
685
+ & (block_diff <= float(self.look_ahead_block)),
686
+ )
687
+
688
+ # process attention output for input x
689
+ hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output
690
+
691
+ norm = self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
692
+ ff_output = self.ff(norm)
693
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
694
+
695
+ return hidden_states
696
+
697
+
698
+ class SnakeBeta(nn.Module):
699
+ """
700
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
701
+ Shape:
702
+ - Input: (B, C, T)
703
+ - Output: (B, C, T), same shape as the input
704
+ Parameters:
705
+ - alpha - trainable parameter that controls frequency
706
+ - beta - trainable parameter that controls magnitude
707
+ References:
708
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
709
+ https://huggingface.co/papers/2006.08195
710
+ """
711
+
712
+ def __init__(self, in_features, alpha=1.0):
713
+ super().__init__()
714
+ self.in_features = in_features
715
+
716
+ # initialize alpha
717
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
718
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
719
+
720
+ self.no_div_by_zero = 0.000000001
721
+
722
+ def forward(self, hidden_states):
723
+ """
724
+ Forward pass of the function.
725
+ Applies the function to the input elementwise.
726
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
727
+ """
728
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
729
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
730
+ alpha = torch.exp(alpha)
731
+ beta = torch.exp(beta)
732
+ hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
733
+ torch.sin(hidden_states * alpha), 2
734
+ )
735
+
736
+ return hidden_states
737
+
738
+
739
+ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
740
+ """Generates a 1D Kaiser-windowed sinc filter.
741
+
742
+ Args:
743
+ cutoff (float): Normalized cutoff frequency (0 to 0.5).
744
+ half_width (float): Transition bandwidth.
745
+ kernel_size (int): Number of filter taps.
746
+
747
+ Returns:
748
+ torch.Tensor: A tensor of shape (1, 1, kernel_size) representing the filter.
749
+ """
750
+ is_even = kernel_size % 2 == 0
751
+ half_size = kernel_size // 2
752
+
753
+ # Compute Kaiser window parameters
754
+ delta_f = 4 * half_width
755
+ attenuation = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
756
+
757
+ if attenuation > 50.0:
758
+ beta = 0.1102 * (attenuation - 8.7)
759
+ elif attenuation >= 21.0:
760
+ beta = 0.5842 * (attenuation - 21) ** 0.4 + 0.07886 * (attenuation - 21.0)
761
+ else:
762
+ beta = 0.0
763
+
764
+ kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32)
765
+
766
+ # Compute time indices
767
+ if is_even:
768
+ time_indices = torch.arange(-half_size, half_size) + 0.5
769
+ else:
770
+ time_indices = torch.arange(kernel_size) - half_size
771
+
772
+ # Compute sinc filter
773
+ if cutoff == 0:
774
+ return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape
775
+
776
+ sinc_filter = torch.sinc(2 * cutoff * time_indices)
777
+ normalized_filter = 2 * cutoff * kaiser_window * sinc_filter
778
+
779
+ # Normalize to ensure sum = 1 (avoid leakage of constant component)
780
+ normalized_filter /= normalized_filter.sum()
781
+
782
+ return normalized_filter.view(1, 1, kernel_size)
783
+
784
+
785
+ class UpSample1d(nn.Module):
786
+ def __init__(self, ratio=2, kernel_size=None):
787
+ super().__init__()
788
+ self.ratio = ratio
789
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
790
+ self.stride = ratio
791
+ self.pad = self.kernel_size // ratio - 1
792
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
793
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
794
+
795
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
796
+ self.register_buffer("filter", filter, persistent=False)
797
+
798
+ def forward(self, hidden_states):
799
+ channels = hidden_states.shape[1]
800
+
801
+ hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate")
802
+ hidden_states = self.ratio * F.conv_transpose1d(
803
+ hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels
804
+ )
805
+ hidden_states = hidden_states[..., self.pad_left : -self.pad_right]
806
+
807
+ return hidden_states
808
+
809
+
810
+ class DownSample1d(nn.Module):
811
+ def __init__(self, ratio=2, kernel_size=None):
812
+ super().__init__()
813
+ cutoff = 0.5 / ratio
814
+ half_width = 0.6 / ratio
815
+
816
+ if cutoff < 0.0:
817
+ raise ValueError("Minimum cutoff must be larger than zero.")
818
+ if cutoff > 0.5:
819
+ raise ValueError("A cutoff above 0.5 does not make sense.")
820
+
821
+ self.even = kernel_size % 2 == 0
822
+ self.pad_left = kernel_size // 2 - int(self.even)
823
+ self.pad_right = kernel_size // 2
824
+ self.stride = ratio
825
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
826
+ self.register_buffer("filter", filter, persistent=False)
827
+
828
+ def forward(self, hidden_states):
829
+ channels = hidden_states.shape[1]
830
+ hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate")
831
+ out = F.conv1d(hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels)
832
+ return out
833
+
834
+
835
+ class TorchActivation1d(nn.Module):
836
+ def __init__(
837
+ self,
838
+ activation,
839
+ up_ratio: int = 2,
840
+ down_ratio: int = 2,
841
+ up_kernel_size: int = 12,
842
+ down_kernel_size: int = 12,
843
+ ):
844
+ super().__init__()
845
+ if not callable(activation):
846
+ raise TypeError("Activation function must be callable")
847
+ self.act = activation
848
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
849
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
850
+
851
+ def forward(self, hidden_states):
852
+ hidden_states = self.upsample(hidden_states)
853
+ hidden_states = self.act(hidden_states)
854
+ hidden_states = self.downsample(hidden_states)
855
+
856
+ return hidden_states
857
+
858
+
859
+ class CausalConv1d(nn.Conv1d):
860
+ def __init__(self, *args, **kwargs):
861
+ super().__init__(*args, **kwargs)
862
+ self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1)
863
+
864
+ def forward(self, x):
865
+ return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias)
866
+
867
+
868
+ class AMPBlock(torch.nn.Module):
869
+ def __init__(
870
+ self,
871
+ channels,
872
+ kernel_size=3,
873
+ dilation=(1, 3, 5),
874
+ causal_type='1',
875
+ ):
876
+ super().__init__()
877
+
878
+ self.convs1 = nn.ModuleList(
879
+ [
880
+ CausalConv1d(
881
+ channels,
882
+ channels,
883
+ kernel_size,
884
+ 1,
885
+ dilation=dilation[0],
886
+ ),
887
+ CausalConv1d(
888
+ channels,
889
+ channels,
890
+ kernel_size,
891
+ 1,
892
+ dilation=dilation[1],
893
+ ),
894
+ CausalConv1d(
895
+ channels,
896
+ channels,
897
+ kernel_size,
898
+ 1,
899
+ dilation=dilation[2],
900
+ ),
901
+ ]
902
+ )
903
+
904
+ if causal_type == '1':
905
+ self.convs2 = nn.ModuleList(
906
+ [
907
+ nn.Conv1d(
908
+ channels,
909
+ channels,
910
+ kernel_size,
911
+ 1,
912
+ dilation=1,
913
+ padding=self._get_padding(kernel_size, 1),
914
+ ),
915
+ nn.Conv1d(
916
+ channels,
917
+ channels,
918
+ kernel_size,
919
+ 1,
920
+ dilation=1,
921
+ padding=self._get_padding(kernel_size, 1),
922
+ ),
923
+ nn.Conv1d(
924
+ channels,
925
+ channels,
926
+ kernel_size,
927
+ 1,
928
+ dilation=1,
929
+ padding=self._get_padding(kernel_size, 1),
930
+ ),
931
+ ]
932
+ )
933
+ else:
934
+ self.convs2 = nn.ModuleList(
935
+ [
936
+ CausalConv1d(
937
+ channels,
938
+ channels,
939
+ kernel_size,
940
+ 1,
941
+ dilation=1,
942
+ ),
943
+ CausalConv1d(
944
+ channels,
945
+ channels,
946
+ kernel_size,
947
+ 1,
948
+ dilation=1,
949
+ ),
950
+ CausalConv1d(
951
+ channels,
952
+ channels,
953
+ kernel_size,
954
+ 1,
955
+ dilation=1,
956
+ ),
957
+ ]
958
+ )
959
+
960
+ self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
961
+
962
+ self.activations = nn.ModuleList(
963
+ [TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)]
964
+ )
965
+
966
+ if causal_type == '2':
967
+ self.pre_conv = nn.Conv1d(
968
+ channels,
969
+ channels,
970
+ kernel_size,
971
+ stride=1,
972
+ padding=self._get_padding(kernel_size, 1),
973
+ )
974
+ self.pre_act = TorchActivation1d(activation=SnakeBeta(channels))
975
+ else:
976
+ self.pre_conv = nn.Identity()
977
+ self.pre_act = nn.Identity()
978
+
979
+ def _get_padding(self, kernel_size, dilation=1):
980
+ return int((kernel_size * dilation - dilation) / 2)
981
+
982
+ def forward(self, x):
983
+ hidden_states = self.pre_conv(x)
984
+ hidden_states = self.pre_act(hidden_states)
985
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
986
+ for conv1, conv2, act1, act2 in zip(self.convs1, self.convs2, acts1, acts2):
987
+ hidden_states = act1(hidden_states)
988
+ hidden_states = conv1(hidden_states)
989
+ hidden_states = act2(hidden_states)
990
+ hidden_states = conv2(hidden_states)
991
+ x = x + hidden_states
992
+ return x
993
+
994
+
995
+ @auto_docstring
996
+ class Qwen3TTSTokenizerV1DecoderBigVGANModel(Qwen3TTSTokenizerV1DecoderPreTrainedModel):
997
+ config: Qwen3TTSTokenizerV1DecoderBigVGANConfig
998
+
999
+ def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
1000
+ super().__init__(config)
1001
+ self.num_residual_blocks = len(config.resblock_kernel_sizes)
1002
+ self.num_upsample_layers = len(config.upsample_rates)
1003
+
1004
+ self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 5, 1, padding=2)
1005
+
1006
+ # Removing extra ModuleList breaks official state dict
1007
+ ups = [
1008
+ nn.ModuleList(
1009
+ [
1010
+ nn.ConvTranspose1d(
1011
+ config.upsample_initial_channel // (2**layer_idx),
1012
+ config.upsample_initial_channel // (2 ** (layer_idx + 1)),
1013
+ kernel_size,
1014
+ stride,
1015
+ padding=(kernel_size - stride) // 2,
1016
+ )
1017
+ ]
1018
+ )
1019
+ for layer_idx, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes))
1020
+ ]
1021
+ self.ups = nn.ModuleList(ups)
1022
+
1023
+ self.resblocks = nn.ModuleList(
1024
+ [
1025
+ AMPBlock(config.upsample_initial_channel // (2 ** (layer_idx + 1)), kernel_size, dilation, '1' if layer_idx > 1 else '2')
1026
+ for layer_idx in range(self.num_upsample_layers)
1027
+ for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes)
1028
+ ]
1029
+ )
1030
+
1031
+ self.activation_post = TorchActivation1d(
1032
+ activation=SnakeBeta(config.upsample_initial_channel // (2**self.num_upsample_layers))
1033
+ )
1034
+ self.conv_post = nn.Conv1d(
1035
+ config.upsample_initial_channel // (2**self.num_upsample_layers), 1, 7, 1, padding=3, bias=False
1036
+ )
1037
+
1038
+ def normalize_spectrogram(self, spectrogram, max_value, min_db):
1039
+ return torch.clamp((2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, -max_value, max_value)
1040
+
1041
+ def amplitude_to_db(self, amplitude, min_db_level):
1042
+ min_level = torch.exp(
1043
+ torch.tensor(min_db_level / 20.0 * np.log(10), device=amplitude.device, dtype=amplitude.dtype)
1044
+ )
1045
+ return 20 * torch.log10(torch.clamp(amplitude, min=min_level))
1046
+
1047
+ def process_mel_spectrogram(self, mel_spectrogram):
1048
+ amplitude_spectrum = torch.exp(mel_spectrogram)
1049
+ decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20
1050
+ return self.normalize_spectrogram(decibel_spectrum, 1, -115)
1051
+
1052
+ def forward(self, mel_spectrogram):
1053
+ processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram)
1054
+ hidden_representation = self.conv_pre(processed_spectrogram)
1055
+
1056
+ for layer_index in range(self.num_upsample_layers):
1057
+ hidden_representation = self.ups[layer_index][0](hidden_representation)
1058
+ residual_output = sum(
1059
+ self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation)
1060
+ for block_index in range(self.num_residual_blocks)
1061
+ )
1062
+ residual_output = residual_output / self.num_residual_blocks
1063
+ hidden_representation = residual_output
1064
+
1065
+ hidden_representation = self.activation_post(hidden_representation)
1066
+ output_waveform = self.conv_post(hidden_representation)
1067
+ return torch.clamp(output_waveform, min=-1.0, max=1.0).squeeze(1)
1068
+
1069
+
1070
+ @auto_docstring
1071
+ class Qwen3TTSTokenizerV1DecoderDiTModel(Qwen3TTSTokenizerV1DecoderPreTrainedModel):
1072
+ config: Qwen3TTSTokenizerV1DecoderDiTConfig
1073
+ _no_split_modules = ["DiTDecoderLayer"]
1074
+
1075
+ def __init__(self, config: Qwen3TTSTokenizerV1DecoderDiTConfig):
1076
+ super().__init__(config)
1077
+ self.mel_dim = config.mel_dim
1078
+ self.repeats = config.repeats
1079
+ self.time_embed = DiTTimestepEmbedding(config.hidden_size)
1080
+
1081
+ self.text_embed = DiTCodecEmbedding(config.num_embeds, config.emb_dim, config.repeats)
1082
+ self.input_embed = DiTInputEmbedding(config)
1083
+
1084
+ self.rotary_embed = Qwen3TTSTokenizerV1DecoderDiTRotaryEmbedding(config.head_dim)
1085
+
1086
+ self.hidden_size = config.hidden_size
1087
+ self.layers = config.num_hidden_layers
1088
+ self.block_size = config.block_size
1089
+ self.num_attention_heads = config.num_attention_heads
1090
+
1091
+ self.transformer_blocks = nn.ModuleList()
1092
+ for i in range(config.num_hidden_layers):
1093
+ self.transformer_blocks.append(
1094
+ DiTDecoderLayer(
1095
+ config,
1096
+ look_ahead_block=1 if i in config.look_ahead_layers else 0,
1097
+ look_backward_block=1 if i in config.look_backward_layers else 0,
1098
+ )
1099
+ )
1100
+
1101
+ self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation
1102
+ self.proj_out = nn.Linear(config.hidden_size, config.mel_dim)
1103
+
1104
+ def _create_block_diff(self, hidden_states):
1105
+ batch, seq_len = hidden_states.shape[0], hidden_states.shape[1]
1106
+ block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length]
1107
+
1108
+ block_i = block_indices.unsqueeze(1) # [seq_length, 1]
1109
+ block_j = block_indices.unsqueeze(0) # [1, seq_length]
1110
+ block_diff = block_j - block_i # (n, n)
1111
+
1112
+ return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len)
1113
+
1114
+ def forward(
1115
+ self,
1116
+ hidden_states,
1117
+ condition_vector,
1118
+ speaker_embedding,
1119
+ quantized_code,
1120
+ time_step,
1121
+ drop_audio_conditioning=False,
1122
+ drop_code=False,
1123
+ apply_cfg=True,
1124
+ ):
1125
+ batch_size = hidden_states.shape[0] * 2
1126
+ if time_step.ndim == 0:
1127
+ time_step = time_step.repeat(batch_size)
1128
+
1129
+ # Compute embeddings
1130
+ time_embedding = self.time_embed(time_step)
1131
+ text_embedding = self.text_embed(quantized_code, drop_code=False if apply_cfg else drop_code)
1132
+ text_embedding_unconditioned = self.text_embed(quantized_code, drop_code=True) if apply_cfg else None
1133
+
1134
+ hidden_states = self.input_embed(
1135
+ hidden_states,
1136
+ speaker_embedding,
1137
+ condition_vector,
1138
+ text_embedding,
1139
+ drop_audio_cond=drop_audio_conditioning,
1140
+ code_embed_uncond=text_embedding_unconditioned,
1141
+ apply_cfg=apply_cfg,
1142
+ )
1143
+
1144
+ # Compute positional encodings
1145
+ position_embeddings = self.rotary_embed(hidden_states)
1146
+ blockwise_difference = self._create_block_diff(hidden_states)
1147
+
1148
+ # Transformer blocks
1149
+ for transformer_block in self.transformer_blocks:
1150
+ hidden_states = transformer_block(
1151
+ hidden_states,
1152
+ time_embedding,
1153
+ position_embeddings=position_embeddings,
1154
+ block_diff=blockwise_difference,
1155
+ )
1156
+
1157
+ hidden_states = self.norm_out(hidden_states, time_embedding)
1158
+ output = self.proj_out(hidden_states)
1159
+
1160
+ return output
1161
+
1162
+ def optimized_scale(self, positive_flat, negative_flat):
1163
+ # Calculate dot production
1164
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
1165
+ # Squared norm of uncondition
1166
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
1167
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
1168
+ st_star = dot_product / squared_norm
1169
+ return st_star
1170
+
1171
+ @torch.no_grad()
1172
+ def sample(
1173
+ self,
1174
+ conditioning_vector,
1175
+ reference_mel_spectrogram,
1176
+ quantized_code,
1177
+ num_steps=10,
1178
+ guidance_scale=0.5,
1179
+ sway_coefficient=-1.0,
1180
+ ):
1181
+ noise_initialization = torch.randn([quantized_code.shape[0], 30000, self.mel_dim], dtype=reference_mel_spectrogram.dtype)
1182
+ maximum_duration = quantized_code.shape[1] * self.repeats
1183
+ initial_state = noise_initialization[:, :maximum_duration].to(quantized_code.device)
1184
+ conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1)
1185
+
1186
+ def ode_function(time_step, hidden_states):
1187
+ if guidance_scale < 1e-5:
1188
+ prediction = self(
1189
+ hidden_states=hidden_states,
1190
+ speaker_embedding=conditioning_vector,
1191
+ condition_vector=reference_mel_spectrogram,
1192
+ quantized_code=quantized_code,
1193
+ time_step=time_step,
1194
+ drop_audio_conditioning=False,
1195
+ drop_code=False,
1196
+ )
1197
+ return prediction
1198
+
1199
+ model_output = self(
1200
+ hidden_states=hidden_states,
1201
+ quantized_code=quantized_code,
1202
+ speaker_embedding=conditioning_vector,
1203
+ condition_vector=reference_mel_spectrogram,
1204
+ time_step=time_step,
1205
+ apply_cfg=True,
1206
+ )
1207
+ guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0)
1208
+
1209
+ return guided_prediction + (guided_prediction - null_prediction) * guidance_scale
1210
+
1211
+ initial_time = 0
1212
+ time_embedding = torch.linspace(
1213
+ initial_time, 1, num_steps, device=quantized_code.device, dtype=conditioning_vector.dtype
1214
+ )
1215
+
1216
+ if sway_coefficient is not None:
1217
+ time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding)
1218
+
1219
+ values = initial_state.clone()
1220
+ for t0, t1 in zip(time_embedding[:-1], time_embedding[1:]):
1221
+ dt = t1 - t0
1222
+ vt = ode_function(t0, values)
1223
+ values = values + vt * dt
1224
+
1225
+ generated_mel_spectrogram = values.permute(0, 2, 1)
1226
+ return generated_mel_spectrogram
1227
+
1228
+
1229
+ @auto_docstring
1230
+ class Qwen3TTSTokenizerV1Decoder(Qwen3TTSTokenizerV1DecoderPreTrainedModel):
1231
+ config: Qwen3TTSTokenizerV1DecoderConfig
1232
+ base_model_prefix = "model"
1233
+ _no_split_modules = ["Qwen3TTSTokenizerV1DecoderDiTModel", "Qwen3TTSTokenizerV1DecoderBigVGANModel"]
1234
+
1235
+ def __init__(self, config: Qwen3TTSTokenizerV1DecoderConfig):
1236
+ super().__init__(config)
1237
+ attn_impl = config._attn_implementation
1238
+ if config._attn_implementation == "flash_attention_2":
1239
+ logger.warning_once(
1240
+ "Qwen3TTSTokenizerV1Decoder must inference with fp32, but flash_attention_2 only supports fp16 and bf16, "
1241
+ "attention implementation of Qwen3TTSTokenizerV1Decoder will fallback to sdpa."
1242
+ )
1243
+ attn_impl = "sdpa"
1244
+ elif config._attn_implementation == "eager":
1245
+ logger.warning_once(
1246
+ "Qwen3TTSTokenizerV1Decoder does not support eager attention implementation, fall back to sdpa"
1247
+ )
1248
+ attn_impl = "sdpa"
1249
+ self.dit = Qwen3TTSTokenizerV1DecoderDiTModel._from_config(
1250
+ config.dit_config, attn_implementation=attn_impl
1251
+ )
1252
+ self.bigvgan = Qwen3TTSTokenizerV1DecoderBigVGANModel._from_config(
1253
+ config.bigvgan_config, attn_implementation=attn_impl
1254
+ )
1255
+
1256
+ def forward(
1257
+ self,
1258
+ code,
1259
+ conditioning,
1260
+ reference_mel,
1261
+ num_steps=10,
1262
+ guidance_scale=0.5,
1263
+ sway_coefficient=-1.0,
1264
+ **kwargs,
1265
+ ):
1266
+ """Generates a waveform from input code and conditioning parameters."""
1267
+
1268
+ mel_spectrogram = self.dit.sample(
1269
+ conditioning,
1270
+ reference_mel,
1271
+ code,
1272
+ num_steps=num_steps,
1273
+ guidance_scale=guidance_scale,
1274
+ sway_coefficient=sway_coefficient,
1275
+ )
1276
+
1277
+ waveform = self.bigvgan(mel_spectrogram)
1278
+
1279
+ return waveform
1280
+
1281
+
1282
+ class Qwen3TTSTokenizerV1Encoder(Qwen3TTSTokenizerV1EncoderPreTrainedModel):
1283
+ config: Qwen3TTSTokenizerV1EncoderConfig
1284
+ def __init__(self, config: Qwen3TTSTokenizerV1EncoderConfig):
1285
+ super().__init__(config)
1286
+
1287
+ self.tokenizer = WhisperEncoderVQ(
1288
+ n_mels=config.n_mels,
1289
+ n_ctx=config.n_ctx,
1290
+ n_state=config.n_state,
1291
+ n_head=config.n_head,
1292
+ n_layer=config.n_layer,
1293
+ n_window=config.n_window,
1294
+ output_dim=config.output_dim,
1295
+ grad_checkpointing=config.grad_checkpointing,
1296
+ enable_mp=config.enable_mp,
1297
+ audio_sequence_parallel=config.audio_sequence_parallel,
1298
+ audio_vq_type=config.audio_vq_type,
1299
+ audio_vq_layers=config.audio_vq_layers,
1300
+ audio_vq_codebook_size=config.audio_vq_codebook_size,
1301
+ audio_vq_codebook_dim=config.audio_vq_codebook_dim,
1302
+ audio_vq_pe=config.audio_vq_pe,
1303
+ audio_vq_ds_rate=config.audio_vq_ds_rate,
1304
+ )
1305
+
1306
+ self.padding = True
1307
+ self.audio_vq_ds_rate = self.tokenizer.audio_vq_ds_rate
1308
+
1309
+ def speech2mel(self, speechs):
1310
+ mels = [
1311
+ get_mel_audio(
1312
+ speech, padding = self.padding, audio_vq_ds_rate = self.audio_vq_ds_rate
1313
+ ).to(speech.dtype).to(self.tokenizer.conv1.weight.device)
1314
+ for speech in speechs
1315
+ ]
1316
+ return mels
1317
+
1318
+ def mel2code(self, mels):
1319
+ audio_mellens = [mel.size(-1) for mel in mels]
1320
+ audio_aftercnnlens = [get_T_after_cnn(T) for T in audio_mellens]
1321
+ audio_seqlens = [T + 2 for T in audio_aftercnnlens]
1322
+
1323
+ with torch.no_grad():
1324
+ _, indices = self.tokenizer(
1325
+ x_list = mels,
1326
+ audio_mellens = audio_mellens,
1327
+ audio_aftercnnlens = audio_aftercnnlens,
1328
+ audio_seqlens = audio_seqlens,
1329
+ return_indices=True,
1330
+ )
1331
+
1332
+ indice_lens = [T // self.tokenizer.audio_vq_ds_rate for T in audio_aftercnnlens]
1333
+ indices = pad_sequence(torch.split(indices, indice_lens), batch_first=True, padding_value=0)
1334
+
1335
+ return indices, indice_lens
1336
+
1337
+ def quantize_speech(self, speechs):
1338
+ mels = self.speech2mel(speechs)
1339
+ indices, indice_lens = self.mel2code(mels)
1340
+ return indices, indice_lens
1341
+
1342
+
1343
+ @auto_docstring
1344
+ class Qwen3TTSTokenizerV1PreTrainedModel(PreTrainedModel):
1345
+ config: Qwen3TTSTokenizerV1Config
1346
+ base_model_prefix = "model"
1347
+ supports_gradient_checkpointing = True
1348
+ _skip_keys_device_placement = "past_key_values"
1349
+ _supports_flash_attn = True
1350
+ _supports_sdpa = True
1351
+ _can_compile_fullgraph = False
1352
+ _supports_attention_backend = True
1353
+
1354
+
1355
+ @auto_docstring(
1356
+ custom_intro="""
1357
+ The Qwen3TTSTokenizerV1 model.
1358
+ """
1359
+ )
1360
+ class Qwen3TTSTokenizerV1Model(Qwen3TTSTokenizerV1PreTrainedModel):
1361
+ def __init__(self, config: Qwen3TTSTokenizerV1Config):
1362
+ super().__init__(config)
1363
+ self.config = config
1364
+
1365
+ self.input_sample_rate = config.input_sample_rate
1366
+ self.output_sample_rate = config.output_sample_rate
1367
+
1368
+ self.decode_upsample_rate = config.decode_upsample_rate
1369
+ self.encode_downsample_rate = config.encode_downsample_rate
1370
+
1371
+ self.encoder = Qwen3TTSTokenizerV1Encoder._from_config(self.config.encoder_config)
1372
+ self.decoder = Qwen3TTSTokenizerV1Decoder._from_config(self.config.decoder_config)
1373
+
1374
+ self.encoder_xvector_extractor = None
1375
+
1376
+ self.post_init()
1377
+
1378
+ def load_encoder_xvector_extractor(self, model_path):
1379
+ self.encoder_xvector_extractor = XVectorExtractor(model_path)
1380
+
1381
+ def get_model_type(self):
1382
+ return self.config.model_type
1383
+
1384
+ def get_input_sample_rate(self):
1385
+ return self.input_sample_rate
1386
+
1387
+ def get_output_sample_rate(self):
1388
+ return self.output_sample_rate
1389
+
1390
+ def get_encode_downsample_rate(self):
1391
+ return self.encode_downsample_rate
1392
+
1393
+ def get_decode_upsample_rate(self):
1394
+ return self.decode_upsample_rate
1395
+
1396
+ @classmethod
1397
+ def from_pretrained(
1398
+ cls,
1399
+ pretrained_model_name_or_path,
1400
+ *model_args,
1401
+ config=None,
1402
+ cache_dir=None,
1403
+ ignore_mismatched_sizes=False,
1404
+ force_download=False,
1405
+ local_files_only=False,
1406
+ token=None,
1407
+ revision="main",
1408
+ use_safetensors=None,
1409
+ weights_only=True,
1410
+ **kwargs,
1411
+ ):
1412
+ model = super().from_pretrained(
1413
+ pretrained_model_name_or_path,
1414
+ *model_args,
1415
+ config=config,
1416
+ cache_dir=cache_dir,
1417
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
1418
+ force_download=force_download,
1419
+ local_files_only=local_files_only,
1420
+ token=token,
1421
+ revision=revision,
1422
+ use_safetensors=use_safetensors,
1423
+ weights_only=weights_only,
1424
+ **kwargs,
1425
+ )
1426
+ encoder_xvector_extractor_path = cached_file(
1427
+ pretrained_model_name_or_path,
1428
+ "campplus.onnx",
1429
+ subfolder=kwargs.pop("subfolder", None),
1430
+ cache_dir=kwargs.pop("cache_dir", None),
1431
+ force_download=kwargs.pop("force_download", False),
1432
+ proxies=kwargs.pop("proxies", None),
1433
+ resume_download=kwargs.pop("resume_download", None),
1434
+ local_files_only=kwargs.pop("local_files_only", False),
1435
+ token=kwargs.pop("use_auth_token", None),
1436
+ revision=kwargs.pop("revision", None),
1437
+ )
1438
+ if encoder_xvector_extractor_path is None:
1439
+ raise ValueError(f"""{pretrained_model_name_or_path}/{encoder_xvector_extractor_path} not exists""")
1440
+ model.load_encoder_xvector_extractor(encoder_xvector_extractor_path)
1441
+
1442
+ return model
1443
+
1444
+ def encode(
1445
+ self,
1446
+ input_values: torch.Tensor,
1447
+ padding_mask: Optional[torch.Tensor] = None,
1448
+ return_dict: Optional[bool] = None,
1449
+ ) -> Union[tuple[torch.Tensor, Optional[torch.Tensor]], Qwen3TTSTokenizerV1EncoderOutput]:
1450
+ """
1451
+ Encodes the input audio waveform into discrete codes.
1452
+
1453
+ Args:
1454
+ input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`):
1455
+ Float values of the input audio waveform.
1456
+ padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`):
1457
+ Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
1458
+ for *masked*.
1459
+ return_dict (`bool`, *optional*):
1460
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1461
+ """
1462
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1463
+
1464
+ wavs = [value[:mask.sum()] for value, mask in zip(input_values, padding_mask)]
1465
+
1466
+ codes, codes_lens = self.encoder.quantize_speech(wavs)
1467
+ codes = [c[:l] for c, l in zip(codes, codes_lens)]
1468
+
1469
+ xvectors = []
1470
+ ref_mels = []
1471
+ for wav in wavs:
1472
+ xvector, ref_mel = self.encoder_xvector_extractor.extract_code(wav.cpu().numpy())
1473
+ xvector = torch.tensor(xvector).to(wav.dtype).to(wav.device)
1474
+ ref_mel = torch.tensor(ref_mel).to(wav.dtype).to(wav.device)
1475
+ xvectors.append(xvector)
1476
+ ref_mels.append(ref_mel)
1477
+
1478
+ if not return_dict:
1479
+ return (
1480
+ codes,
1481
+ xvectors,
1482
+ ref_mels
1483
+ )
1484
+
1485
+ return Qwen3TTSTokenizerV1EncoderOutput(codes, xvectors, ref_mels)
1486
+
1487
+ def decode(
1488
+ self,
1489
+ audio_codes: torch.Tensor,
1490
+ xvectors: torch.Tensor,
1491
+ ref_mels: torch.Tensor,
1492
+ return_dict: Optional[bool] = None,
1493
+ ) -> Union[tuple[torch.Tensor, torch.Tensor], Qwen3TTSTokenizerV1DecoderOutput]:
1494
+ """
1495
+ Decodes the given frames into an output audio waveform.
1496
+
1497
+ Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
1498
+ trimmed.
1499
+
1500
+ Args:
1501
+ audio_codes (`torch.LongTensor` of shape `(batch_size, codes_length)`, *optional*):
1502
+ Discret code embeddings computed using `model.encode`.
1503
+ xvectors (`torch.FloatTensor` of shape `(batch_size, xvector_dim)`, *optional*):
1504
+ X-vector embeddings computed using `model.encode`.
1505
+ ref_mels (`torch.FloatTensor` of shape `(batch_size, mel_length, mel_dim)`, *optional*):
1506
+ Reference mel spectrogram computed using `model.encode`.
1507
+ return_dict (`bool`, *optional*):
1508
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1509
+
1510
+ """
1511
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1512
+
1513
+ audio_values = self.decoder(code=audio_codes,
1514
+ reference_mel=ref_mels,
1515
+ conditioning=xvectors)
1516
+
1517
+ audio_lengths = (audio_codes > 0).sum(1) * self.decode_upsample_rate
1518
+ audio_values = [a[:l] for a, l in zip(audio_values, audio_lengths)]
1519
+
1520
+ if not return_dict:
1521
+ return (
1522
+ audio_values,
1523
+ )
1524
+
1525
+ return Qwen3TTSTokenizerV1DecoderOutput(audio_values)
1526
+
1527
+
1528
+ __all__ = ["Qwen3TTSTokenizerV1Model", "Qwen3TTSTokenizerV1PreTrainedModel"]
qwen_tts/core/tokenizer_25hz/vq/assets/mel_filters.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7450ae70723a5ef9d341e3cee628c7cb0177f36ce42c44b7ed2bf3325f0f6d4c
3
+ size 4271
qwen_tts/core/tokenizer_25hz/vq/core_vq.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # This implementation is inspired from
8
+ # https://github.com/lucidrains/vector-quantize-pytorch
9
+ # which is released under MIT License. Hereafter, the original license:
10
+ # MIT License
11
+ #
12
+ # Copyright (c) 2020 Phil Wang
13
+ #
14
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ # of this software and associated documentation files (the "Software"), to deal
16
+ # in the Software without restriction, including without limitation the rights
17
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ # copies of the Software, and to permit persons to whom the Software is
19
+ # furnished to do so, subject to the following conditions:
20
+ #
21
+ # The above copyright notice and this permission notice shall be included in all
22
+ # copies or substantial portions of the Software.
23
+ #
24
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ # SOFTWARE.
31
+
32
+ """Core vector quantization implementation."""
33
+ import random
34
+ import typing as tp
35
+ from random import randrange
36
+
37
+ import numpy as np
38
+ from einops import rearrange, repeat
39
+ from math import ceil
40
+ import torch
41
+ from torch import nn
42
+ import torch.nn.functional as F
43
+
44
+
45
+ def round_up_multiple(num, mult):
46
+ return ceil(num / mult) * mult
47
+
48
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
49
+ return val if val is not None else d
50
+
51
+
52
+ def ema_inplace(moving_avg, new, decay: float):
53
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
54
+
55
+
56
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
57
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
58
+
59
+
60
+ def uniform_init(*shape: int):
61
+ t = torch.empty(shape)
62
+ nn.init.kaiming_uniform_(t)
63
+ return t
64
+
65
+
66
+ def sample_vectors(samples, num: int):
67
+ num_samples, device = samples.shape[0], samples.device
68
+
69
+ if num_samples >= num:
70
+ indices = torch.randperm(num_samples, device=device)[:num]
71
+ else:
72
+ indices = torch.randint(0, num_samples, (num,), device=device)
73
+
74
+ return samples[indices]
75
+
76
+
77
+ @torch.no_grad()
78
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
79
+ dim, dtype = samples.shape[-1], samples.dtype
80
+
81
+ means = sample_vectors(samples, num_clusters)
82
+
83
+ for _ in range(num_iters):
84
+ dists = -(
85
+ samples.pow(2).sum(1, keepdim=True)
86
+ - 2 * torch.matmul(samples, means.t())
87
+ + means.t().pow(2).sum(0, keepdim=True)
88
+ )
89
+
90
+ buckets = dists.max(dim=-1).indices
91
+ del dists
92
+ bins = torch.bincount(buckets, minlength=num_clusters)
93
+ zero_mask = bins == 0
94
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
95
+
96
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
97
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
98
+ new_means = new_means / bins_min_clamped[..., None]
99
+
100
+ means = torch.where(zero_mask[..., None], means, new_means)
101
+ return means, bins
102
+
103
+
104
+ def preprocess(x):
105
+ x = rearrange(x, "... d -> (...) d")
106
+ return x
107
+
108
+
109
+ def postprocess_emb(embed_ind, shape):
110
+ return embed_ind.view(*shape[:-1])
111
+
112
+
113
+ class EuclideanCodebook(nn.Module):
114
+ """Codebook with Euclidean distance.
115
+ Args:
116
+ dim (int): Dimension.
117
+ codebook_size (int): Codebook size.
118
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
119
+ If set to true, run the k-means algorithm on the first training batch and use
120
+ the learned centroids as initialization.
121
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
122
+ decay (float): Decay for exponential moving average over the codebooks.
123
+ epsilon (float): Epsilon value for numerical stability.
124
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
125
+ that have an exponential moving average cluster size less than the specified threshold with
126
+ randomly selected vector from the current batch.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ dim: int,
132
+ codebook_size: int,
133
+ kmeans_init: int = False,
134
+ kmeans_iters: int = 10,
135
+ decay: float = 0.99,
136
+ epsilon: float = 1e-5,
137
+ threshold_ema_dead_code: float = 2.0,
138
+ ):
139
+ super().__init__()
140
+ self.decay = decay
141
+ self.codebook_size = codebook_size
142
+ self.kmeans_iters = kmeans_iters
143
+ self.epsilon = epsilon
144
+ self.threshold_ema_dead_code = threshold_ema_dead_code
145
+
146
+ self.inited = None
147
+ self.cluster_size = None
148
+ self.embed = None
149
+ self.embed_avg = None
150
+ self.training = True
151
+
152
+ def init_embed_(self, data):
153
+ if self.inited:
154
+ return
155
+
156
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
157
+ self.embed.data.copy_(embed)
158
+ self.embed_avg.data.copy_(embed.clone())
159
+ self.cluster_size.data.copy_(cluster_size)
160
+ self.inited.data.copy_(torch.Tensor([True]))
161
+ # Make sure all buffers across workers are in sync after initialization
162
+ # distrib.broadcast_tensors([self.embed, self.embed_avg, self.cluster_size, self.inited])
163
+
164
+ def replace_(self, samples, mask):
165
+ modified_codebook = torch.where(
166
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
167
+ )
168
+ self.embed.data.copy_(modified_codebook)
169
+
170
+ def expire_codes_(self, batch_samples):
171
+ if self.threshold_ema_dead_code == 0:
172
+ return
173
+
174
+ cluster_size = self.cluster_size / sum(self.cluster_size) * self.codebook_size
175
+ expired_codes = cluster_size < self.threshold_ema_dead_code
176
+ if not torch.any(expired_codes):
177
+ return
178
+ else:
179
+ print(f"VQ expire infos: num_expire={sum(expired_codes)}, cluster_size[:5]={cluster_size[:5]}")
180
+
181
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
182
+ self.replace_(batch_samples, mask=expired_codes)
183
+ # sync buffers outside for efficiency
184
+ # distrib.broadcast_tensors(self.buffers())
185
+
186
+ def quantize(self, x):
187
+ embed = self.embed.t()
188
+ dist = -(
189
+ x.pow(2).sum(1, keepdim=True)
190
+ - 2 * x @ embed
191
+ + embed.pow(2).sum(0, keepdim=True)
192
+ )
193
+ embed_ind = dist.max(dim=-1).indices
194
+ return embed_ind
195
+
196
+ def dequantize(self, embed_ind):
197
+ quantize = F.embedding(embed_ind, self.embed)
198
+ return quantize
199
+
200
+ def encode(self, x, buffers):
201
+ self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
202
+
203
+ shape = x.shape
204
+ # pre-process
205
+ x = preprocess(x)
206
+ # quantize
207
+ embed_ind = self.quantize(x)
208
+ # post-process
209
+ embed_ind = postprocess_emb(embed_ind, shape)
210
+ return embed_ind
211
+
212
+ def decode(self, embed_ind, buffers):
213
+ self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
214
+
215
+ quantize = self.dequantize(embed_ind)
216
+ return quantize
217
+
218
+ def forward(self, x, buffers):
219
+ self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
220
+
221
+ shape, dtype = x.shape, x.dtype
222
+ x = preprocess(x)
223
+
224
+ self.init_embed_(x)
225
+ if self.training:
226
+ # We do the expiry of code at that point as buffers are in sync
227
+ # and all the workers will take the same decision.
228
+ self.expire_codes_(x)
229
+
230
+ embed_ind = self.quantize(x)
231
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
232
+ embed_ind = postprocess_emb(embed_ind, shape)
233
+ quantize = self.dequantize(embed_ind)
234
+
235
+ if self.training:
236
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
237
+ embed_sum = x.t() @ embed_onehot
238
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
239
+ cluster_size = (
240
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
241
+ * self.cluster_size.sum()
242
+ )
243
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
244
+ self.embed.data.copy_(embed_normalized)
245
+ # Note: after ema update, there is a very small difference between codebooks on GPUs.
246
+ # The impact can be very small, ignore it.
247
+
248
+ return quantize, embed_ind
249
+
250
+
251
+ class VectorQuantization(nn.Module):
252
+ """Vector quantization implementation.
253
+ Currently, supports only euclidean distance.
254
+ Args:
255
+ dim (int): Dimension
256
+ codebook_size (int): Codebook size
257
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
258
+ decay (float): Decay for exponential moving average over the codebooks.
259
+ epsilon (float): Epsilon value for numerical stability.
260
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
261
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
262
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
263
+ that have an exponential moving average cluster size less than the specified threshold with
264
+ randomly selected vector from the current batch.
265
+ commitment_weight (float): Weight for commitment loss.
266
+ """
267
+ def __init__(
268
+ self,
269
+ dim: int,
270
+ codebook_size: int,
271
+ codebook_dim: tp.Optional[int] = None,
272
+ decay: float = 0.99,
273
+ epsilon: float = 1e-5,
274
+ kmeans_init: bool = True,
275
+ kmeans_iters: int = 50,
276
+ threshold_ema_dead_code: float = 2.0,
277
+ commitment_weight: float = 1.,
278
+ ):
279
+ super().__init__()
280
+ _codebook_dim: int = default(codebook_dim, dim)
281
+
282
+ requires_projection = _codebook_dim != dim
283
+ self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
284
+ self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
285
+
286
+ self.epsilon = epsilon
287
+ self.commitment_weight = commitment_weight
288
+
289
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
290
+ kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
291
+ decay=decay, epsilon=epsilon,
292
+ threshold_ema_dead_code=threshold_ema_dead_code)
293
+ self.codebook_size = codebook_size
294
+ self.training = True
295
+
296
+ @property
297
+ def codebook(self):
298
+ return self._codebook.embed
299
+
300
+ def encode(self, x, buffers):
301
+ # x = rearrange(x, "b d n -> b n d")
302
+ x = self.project_in(x)
303
+ embed_in = self._codebook.encode(x, buffers)
304
+ return embed_in
305
+
306
+ def decode(self, embed_ind, buffers):
307
+ quantize = self._codebook.decode(embed_ind, buffers)
308
+ quantize = self.project_out(quantize)
309
+ # quantize = rearrange(quantize, "b n d -> b d n")
310
+ return quantize
311
+
312
+ def forward(self, x, buffers):
313
+ device = x.device
314
+ # x = rearrange(x, "b d n -> b n d")
315
+ x = self.project_in(x)
316
+
317
+ quantize, embed_ind = self._codebook(x, buffers)
318
+
319
+ if self.training:
320
+ quantize = x + (quantize - x).detach()
321
+
322
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
323
+
324
+ if self.training:
325
+ if self.commitment_weight > 0:
326
+ commit_loss = F.mse_loss(quantize.detach(), x)
327
+ loss = loss + commit_loss * self.commitment_weight
328
+
329
+ quantize = self.project_out(quantize)
330
+ # quantize = rearrange(quantize, "b n d -> b d n")
331
+ return quantize, embed_ind, loss
332
+
333
+
334
+ class DistributedResidualVectorQuantization(nn.Module):
335
+ """Efficient distributed residual vector quantization implementation.
336
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
337
+ """
338
+ def __init__(self, *,
339
+ num_quantizers,
340
+ quantize_dropout: bool = False,
341
+ rand_num_quant: tp.Optional[tp.List] = None,
342
+ **kwargs):
343
+ super().__init__()
344
+ """
345
+ dim: int,
346
+ codebook_size: int,
347
+ codebook_dim: tp.Optional[int] = None,
348
+ """
349
+ codebook_size, codebook_dim = kwargs["codebook_size"], kwargs["codebook_dim"] if kwargs["codebook_dim"] else kwargs["dim"]
350
+ kmeans_init = kwargs["kmeans_init"]
351
+ if isinstance(kmeans_init, bool):
352
+ if not kwargs["kmeans_init"]:
353
+ # use uniform init
354
+ embed = uniform_init(num_quantizers, codebook_size, codebook_dim)
355
+ inited = True
356
+ else:
357
+ # to perform kmeans init on first batch
358
+ embed = torch.zeros(num_quantizers, codebook_size, codebook_dim)
359
+ inited = False
360
+ elif isinstance(kmeans_init, str):
361
+ # use prepared kmeans init
362
+ embed = np.load(kmeans_init)
363
+ embed = torch.from_numpy(embed)
364
+ if embed.dim() == 2:
365
+ embed = embed.unsqueeze(0)
366
+ inited = True
367
+ else:
368
+ raise TypeError("kmeans_init should be either a bool or string path to init weights.")
369
+
370
+ self.register_buffer("inited", torch.Tensor([[inited] for _ in range(num_quantizers)]))
371
+ self.register_buffer("cluster_size", torch.zeros(num_quantizers, codebook_size))
372
+ self.register_buffer("embed", embed)
373
+ self.register_buffer("embed_avg", embed.clone())
374
+
375
+ self.q0_ds_ratio = 1
376
+ if "q0_ds_ratio" in kwargs:
377
+ self.q0_ds_ratio = kwargs.pop("q0_ds_ratio")
378
+
379
+ self.layers = nn.ModuleList()
380
+ for i in range(num_quantizers):
381
+ vq_args = dict(**kwargs)
382
+ vq = VectorQuantization(**vq_args)
383
+ self.layers.append(vq)
384
+
385
+ self.quantize_dropout = quantize_dropout
386
+ self.rand_num_quant = rand_num_quant
387
+
388
+ def forward(self, x, n_q: tp.Optional[int] = None):
389
+ quantized_out = torch.zeros_like(x)
390
+ residual = x
391
+ bb, cc, tt = x.shape
392
+ device = x.device
393
+
394
+ all_losses = []
395
+ all_indices = []
396
+ all_sub_quants = []
397
+ n_q = n_q or len(self.layers)
398
+
399
+ should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
400
+ if should_quantize_dropout:
401
+ rand_quantize_dropout_index = random.choice(self.rand_num_quant)
402
+
403
+ null_indices_shape = (x.shape[0], x.shape[2])
404
+ null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
405
+ null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
406
+ null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
407
+
408
+ for quantizer_index, layer in enumerate(self.layers[:n_q]):
409
+ # dropout except the first quantizer
410
+ if should_quantize_dropout and quantizer_index >= rand_quantize_dropout_index:
411
+ all_indices.append(null_indices)
412
+ all_losses.append(null_loss)
413
+ all_sub_quants.append(null_sub_quant)
414
+ continue
415
+
416
+ quant_in = residual
417
+ if self.q0_ds_ratio > 1 and quantizer_index == 0:
418
+ quant_in = F.interpolate(quant_in, size=[tt//2])
419
+ quantized, indices, loss = layer(quant_in, [
420
+ self.inited[quantizer_index],
421
+ self.cluster_size[quantizer_index],
422
+ self.embed[quantizer_index],
423
+ self.embed_avg[quantizer_index]
424
+ ])
425
+ if self.q0_ds_ratio > 1 and quantizer_index == 0:
426
+ quantized = F.interpolate(quantized, size=[tt])
427
+ indices = F.interpolate(indices.unsqueeze(1).float(), size=[tt]).squeeze(1).long()
428
+ residual = residual - quantized
429
+ quantized_out = quantized_out + quantized
430
+
431
+ all_indices.append(indices)
432
+ all_losses.append(loss)
433
+ all_sub_quants.append(quantized)
434
+
435
+ # sync buffers after one forward step
436
+ # distrib.broadcast_tensors(self.buffers())
437
+ out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
438
+
439
+ return quantized_out, out_indices, out_losses
440
+
441
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
442
+ residual = x
443
+ all_indices = []
444
+ n_q = n_q or len(self.layers)
445
+ for i, layer in enumerate(self.layers[:n_q]):
446
+ indices = layer.encode(residual, [
447
+ self.inited[i],
448
+ self.cluster_size[i],
449
+ self.embed[i],
450
+ self.embed_avg[i]
451
+ ])
452
+ quantized = layer.decode(indices, [
453
+ self.inited[i],
454
+ self.cluster_size[i],
455
+ self.embed[i],
456
+ self.embed_avg[i]
457
+ ])
458
+ residual = residual - quantized
459
+ all_indices.append(indices)
460
+ out_indices = torch.stack(all_indices)
461
+ return out_indices
462
+
463
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
464
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
465
+ for i, indices in enumerate(q_indices):
466
+ layer = self.layers[i]
467
+ quantized = layer.decode(indices, [
468
+ self.inited[i],
469
+ self.cluster_size[i],
470
+ self.embed[i],
471
+ self.embed_avg[i]
472
+ ])
473
+ quantized_out = quantized_out + quantized
474
+ return quantized_out
475
+
476
+
477
+ class DistributedGroupResidualVectorQuantization(nn.Module):
478
+ """Efficient distributed group residual vector quantization implementation.
479
+ Follows Algorithm 1. in https://arxiv.org/abs/2305.02765
480
+ Group Then rvq
481
+ """
482
+ def __init__(self, *,
483
+ num_groups,
484
+ num_quantizers,
485
+ quantize_dropout: bool = False,
486
+ rand_num_quant: tp.Optional[tp.List] = None,
487
+ **kwargs):
488
+ super().__init__()
489
+ self.rvqs = nn.ModuleList(
490
+ [
491
+ DistributedResidualVectorQuantization(
492
+ num_quantizers=num_quantizers,
493
+ quantize_dropout=quantize_dropout,
494
+ rand_num_quant=rand_num_quant,
495
+ **kwargs
496
+ )
497
+ for _ in range(num_groups)
498
+ ]
499
+ )
500
+ self.num_groups = num_groups
501
+
502
+ def forward(self, x, n_q: tp.Optional[int] = None):
503
+ x_lst = torch.chunk(x, chunks=self.num_groups, dim=1)
504
+ all_quantized_out = []
505
+ all_indices = []
506
+ all_losses = []
507
+ for mod, item in zip(self.rvqs, x_lst):
508
+ quantized_out, out_indices, out_losses = mod(item, n_q)
509
+ all_quantized_out.append(quantized_out)
510
+ all_indices.append(out_indices)
511
+ all_losses.append(out_losses)
512
+
513
+ out_losses = torch.stack(all_losses, dim=1).mean(dim=1)
514
+
515
+ return torch.cat(all_quantized_out, dim=1), torch.stack(all_indices, dim=1), out_losses
516
+
517
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
518
+ x_lst = torch.chunk(x, chunks=self.num_groups, dim=1)
519
+ return torch.stack([mod.encode(item, n_q) for mod, item in zip(self.rvqs, x_lst)], dim=1)
520
+
521
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
522
+ q_indices_lst = torch.chunk(q_indices, chunks=self.num_groups, dim=1)
523
+ return torch.cat([mod.decode(item.squeeze(1)) for mod, item in zip(self.rvqs, q_indices_lst)], dim=1)
qwen_tts/core/tokenizer_25hz/vq/speech_vq.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import sox
17
+ import copy
18
+ import torch
19
+ import operator
20
+ import onnxruntime
21
+
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import torchaudio.compliance.kaldi as kaldi
25
+
26
+ from librosa.filters import mel as librosa_mel_fn
27
+ from itertools import accumulate
28
+ from typing import List
29
+ from torch import Tensor
30
+
31
+ from .core_vq import DistributedGroupResidualVectorQuantization
32
+ from .whisper_encoder import WhisperEncoder, Conv1d, ConvTranspose1d
33
+
34
+
35
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
36
+ return torch.log(torch.clamp(x, min=clip_val) * C)
37
+
38
+ def spectral_normalize_torch(magnitudes):
39
+ output = dynamic_range_compression_torch(magnitudes)
40
+ return output
41
+
42
+ class MelSpectrogramFeatures(nn.Module):
43
+ """
44
+ Calculate the BigVGAN style mel spectrogram of an input signal.
45
+ Args:
46
+ filter_length (int): The number of samples in the filter window, used for the Fourier Transform. Default is 1024.
47
+ hop_length (int): The number of samples between successive frames (stride of the STFT). Default is 160.
48
+ win_length (int): The length of the window function applied to each frame, usually less than or equal to the filter length. Default is 640.
49
+ n_mel_channels (int): The number of Mel-frequency channels to output from the Mel-scale spectrogram. Default is 80.
50
+ mel_fmin (int): The minimum frequency (in Hz) of the Mel-scale spectrogram. Default is 0.
51
+ mel_fmax (int): The maximum frequency (in Hz) of the Mel-scale spectrogram. Default is 8000.
52
+ sampling_rate (int): The sampling rate of the audio data (in Hz). Default is 16000.
53
+ sampling_rate_org (int, optional): The original sampling rate of the audio data before any resampling (in Hz), if applicable. Default is None.
54
+ padding (str): The padding mode for the input signal. 'center' pads the signal symmetrically around its center. Default is 'center'.
55
+
56
+ Returns:
57
+ torch.Tensor: Mel spectrogram.
58
+ """
59
+ def __init__(self,
60
+ filter_length=1024,
61
+ hop_length=160,
62
+ win_length=640,
63
+ n_mel_channels=80,
64
+ mel_fmin=0,
65
+ mel_fmax=8000,
66
+ sampling_rate=16000,
67
+ sampling_rate_org=None,
68
+ padding='center',
69
+ use_db = False,
70
+ ):
71
+ super().__init__()
72
+ if padding not in ["center", "same"]:
73
+ raise ValueError("Padding must be 'center' or 'same'.")
74
+ self.padding = padding
75
+
76
+ self.filter_length = filter_length
77
+ self.hop_length = hop_length
78
+ self.win_length = win_length
79
+ self.n_mel_channels = n_mel_channels
80
+ self.mel_fmin = mel_fmin
81
+ self.mel_fmax = mel_fmax
82
+ self.sampling_rate = sampling_rate
83
+ self.sampling_rate_org = sampling_rate_org if sampling_rate_org is not None else sampling_rate
84
+ self.mel_basis = {}
85
+ self.hann_window = {}
86
+
87
+ def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
88
+ with torch.no_grad():
89
+ feats = self.extract(audio, **kwargs)
90
+ return feats
91
+
92
+ def extract(self, audio, **kwargs):
93
+
94
+ if len(audio.shape) == 3:
95
+ audio = audio.squeeze(1) if audio.shape[1] == 1 else audio.squeeze(2)
96
+ assert len(audio.shape) == 2
97
+
98
+ y = audio
99
+ if len(list(self.mel_basis.keys())) == 0:
100
+ mel = librosa_mel_fn(sr=self.sampling_rate, n_fft=self.filter_length, n_mels=self.n_mel_channels, fmin=self.mel_fmin, fmax=self.mel_fmax)
101
+ self.mel_basis[str(self.mel_fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
102
+ self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to(y.device)
103
+
104
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((self.filter_length-self.hop_length)/2), int((self.filter_length-self.hop_length)/2)), mode='reflect')
105
+ y = y.squeeze(1)
106
+
107
+ spec = torch.stft(y, self.filter_length, hop_length=self.hop_length, win_length=self.win_length, window=self.hann_window[str(y.device)],
108
+ center=False, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
109
+ spec = torch.view_as_real(spec)
110
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
111
+
112
+ spec = torch.matmul(self.mel_basis[str(self.mel_fmax)+'_'+str(y.device)], spec)
113
+ spec = spectral_normalize_torch(spec)
114
+
115
+ return spec
116
+
117
+
118
+ class XVectorExtractor(nn.Module):
119
+ def __init__(self, audio_codec_with_xvector):
120
+ super().__init__()
121
+ option = onnxruntime.SessionOptions()
122
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
123
+ option.intra_op_num_threads = 1
124
+ providers = ["CPUExecutionProvider"]
125
+ self.ort_session = onnxruntime.InferenceSession(audio_codec_with_xvector, sess_options=option, providers=providers)
126
+
127
+ self.tfm = sox.Transformer()
128
+ self.tfm.norm(db_level=-6)
129
+
130
+ self.mel_ext = MelSpectrogramFeatures(
131
+ filter_length=1024,
132
+ hop_length=160,
133
+ win_length=640,
134
+ n_mel_channels=80,
135
+ mel_fmin=0,
136
+ mel_fmax=8000,
137
+ sampling_rate=16000
138
+ )
139
+
140
+ def extract_code(self, audio):
141
+ with torch.no_grad():
142
+ norm_audio = self.sox_norm(audio)
143
+
144
+ norm_audio = torch.from_numpy(copy.deepcopy(norm_audio)).unsqueeze(0)
145
+ feat = kaldi.fbank(norm_audio,
146
+ num_mel_bins=80,
147
+ dither=0,
148
+ sample_frequency=16000)
149
+ feat = feat - feat.mean(dim=0, keepdim=True)
150
+ norm_embedding = self.ort_session.run(None, {self.ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten()
151
+ norm_embedding = F.normalize(torch.from_numpy(norm_embedding), dim=0)
152
+
153
+ ref_mel = self.mel_ext.extract(audio=norm_audio)
154
+
155
+ return norm_embedding.numpy(), ref_mel.permute(0,2,1).squeeze(0).numpy()
156
+
157
+ def sox_norm(self, audio):
158
+ wav_norm = self.tfm.build_array(input_array=audio, sample_rate_in=16000)
159
+ return wav_norm
160
+
161
+
162
+ class WhisperEncoderVQ(WhisperEncoder):
163
+ def __init__(
164
+ self,
165
+ n_mels: int,
166
+ n_ctx: int,
167
+ n_state: int,
168
+ n_head: int,
169
+ n_layer: int,
170
+ n_window: int = 1500,
171
+ output_dim: int = 512,
172
+ grad_checkpointing: bool = False,
173
+ enable_mp: bool = False,
174
+ audio_sequence_parallel: bool = False,
175
+ audio_vq_layers: int = -1,
176
+ audio_vq_type: str = "NULL",
177
+ audio_vq_codebook_size: int = 4096,
178
+ audio_vq_pe: bool = False,
179
+ audio_vq_commit_loss: float = 0.0,
180
+ audio_vq_out_commit_loss: float = 0.0,
181
+ audio_vq_no_quantize: bool = False,
182
+ audio_vq_ff_layer: int = 0,
183
+ audio_vq_threshold_ema_dead_code: float = 0.1,
184
+ audio_vq_codebook_dim: int = None,
185
+ audio_vq_ds_rate: int = None,
186
+ ):
187
+ super().__init__(n_mels, n_ctx, n_state, n_head, n_layer, n_window, output_dim, grad_checkpointing, enable_mp, audio_sequence_parallel)
188
+
189
+ self.audio_vq_layers = audio_vq_layers
190
+ self.audio_vq_type = audio_vq_type
191
+ self.audio_vq_codebook_size = audio_vq_codebook_size
192
+ self.audio_vq_pe = audio_vq_pe
193
+ self.audio_vq_commit_loss = audio_vq_commit_loss
194
+ self.audio_vq_out_commit_loss = audio_vq_out_commit_loss
195
+ self.audio_vq_no_quantize = audio_vq_no_quantize
196
+ self.audio_vq_ff_layer = audio_vq_ff_layer
197
+
198
+ if audio_vq_layers > 0:
199
+ self.vq_feature_dim = self.n_state
200
+ self.audio_vq_ds_rate = 1
201
+ else:
202
+ raise NotImplementedError(f"Unsupported audio_vq_layers: {audio_vq_layers}")
203
+
204
+ if self.audio_vq_ds_rate == audio_vq_ds_rate:
205
+ self.audio_vq_downsample = nn.Identity()
206
+ self.audio_vq_upsample = nn.Identity()
207
+ else:
208
+ assert audio_vq_ds_rate % self.audio_vq_ds_rate == 0
209
+ stride = audio_vq_ds_rate // self.audio_vq_ds_rate
210
+ self.audio_vq_downsample = Conv1d(self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride)
211
+ self.audio_vq_upsample = ConvTranspose1d(self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride)
212
+ self.audio_vq_ds_rate = audio_vq_ds_rate
213
+
214
+ if audio_vq_type == "GRVQ":
215
+ self.audio_quantizer = DistributedGroupResidualVectorQuantization(
216
+ codebook_size = audio_vq_codebook_size,
217
+ dim = self.vq_feature_dim,
218
+ codebook_dim = self.vq_codebook_dim if audio_vq_codebook_dim is None else audio_vq_codebook_dim,
219
+ num_groups=1,
220
+ num_quantizers=1,
221
+ kmeans_init=False,
222
+ threshold_ema_dead_code = audio_vq_threshold_ema_dead_code
223
+ )
224
+ else:
225
+ raise NotImplementedError(f"Unsupported audio_vq_type: {audio_vq_type}")
226
+
227
+ if self.audio_vq_pe:
228
+ self.project_after_vq_pe = nn.Linear(self.n_state, self.n_state)
229
+
230
+ def _calc_quantize_activities(self, indices):
231
+ indices_onehot = F.one_hot(indices.long().flatten(), self.audio_vq_codebook_size).sum(dim=0)
232
+ vq_num_activities = sum(indices_onehot>0)
233
+ vq_num_tokens = sum(indices_onehot)
234
+ return {
235
+ "vq_num_activities": vq_num_activities,
236
+ "vq_num_tokens": vq_num_tokens,
237
+ }
238
+
239
+ def _do_quantize(self, x, pe=None, y=None):
240
+ """
241
+ x: torch.Tensor, shape = (T, D)
242
+ q: torch.Tensor, shape = (T, D)
243
+ i: torch.Tensor, shape = (T)
244
+ """
245
+ if self.audio_vq_out_commit_loss > 0:
246
+ x_teacher = x.clone()
247
+ x = x.unsqueeze(0)
248
+
249
+ x = self.audio_vq_downsample(x.transpose(1, 2))
250
+ x = x.transpose(1, 2)
251
+
252
+ vq_stats = {}
253
+
254
+ if self.audio_vq_type == "GRVQ":
255
+ if self.training:
256
+ raise NotImplementedError
257
+ else:
258
+ indices = self.audio_quantizer.encode(x)
259
+ x = self.audio_quantizer.decode(indices)
260
+ indices = indices.squeeze(2).squeeze(1)
261
+
262
+ vq_stats.update(self._calc_quantize_activities(indices))
263
+
264
+ x, indices = x.squeeze(0), indices.squeeze(0)
265
+ if self.audio_vq_pe:
266
+ x = x + pe
267
+ x = self.project_after_vq_pe(x)
268
+
269
+ x = self.audio_vq_upsample(x.unsqueeze(0).transpose(1, 2))
270
+ x = x.transpose(1, 2).squeeze(0)
271
+
272
+ if self.audio_vq_out_commit_loss > 0:
273
+ vq_out_commit_loss = F.mse_loss(x_teacher.detach(), x)
274
+ vq_stats["vq_out_commit_loss"] = vq_out_commit_loss * self.audio_vq_out_commit_loss
275
+
276
+ return x, indices, vq_stats
277
+
278
+ def forward(self, x_list: List[Tensor], audio_mellens:List[int], audio_aftercnnlens:List[int], audio_seqlens:List[int], return_indices=False, audio_pitchs=None):
279
+ """
280
+ x : torch.Tensor, shape = (n_mels, n_ctx)
281
+ the mel spectrogram of the audio
282
+ """
283
+
284
+ aftercnn_x_list = []
285
+ pe_for_vq_list = []
286
+ for each_x in x_list:
287
+ each_x_split_list = each_x.split(self.n_window * 2, dim=1)
288
+ for each_x_split in each_x_split_list:
289
+ each_x_split = F.gelu(self.conv1(each_x_split))
290
+ each_x_split = F.gelu(self.conv2(each_x_split))
291
+ each_x_split = each_x_split.permute(1, 0) # L,D
292
+
293
+ each_positional_embedding_split = self.positional_embedding[:each_x_split.shape[0]]
294
+ aftercnn_x_list.append(each_x_split+each_positional_embedding_split.to(each_x_split.dtype))
295
+
296
+ pe_for_vq_split = self.positional_embedding[:each_x_split.shape[0] // self.audio_vq_ds_rate]
297
+ pe_for_vq_list.append(pe_for_vq_split.to(each_x_split.dtype))
298
+
299
+ pe_for_vq = torch.cat(pe_for_vq_list, dim=0)
300
+ x = torch.cat(aftercnn_x_list, dim=0)
301
+ src_len = x.size(0)
302
+
303
+ output_list = []
304
+ for item in audio_aftercnnlens:
305
+ while item > self.n_window:
306
+ output_list.append(self.n_window)
307
+ item -= self.n_window
308
+ output_list.append(item)
309
+
310
+ cu_seqlens = list(accumulate(output_list, func=operator.add,initial=0))
311
+ cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32)
312
+
313
+ layer_id = 0
314
+
315
+ for block in self.blocks:
316
+ layer_id+=1
317
+
318
+ x = block(x, cu_seqlens=cu_seqlens)
319
+
320
+ if self.audio_vq_layers == layer_id: # vq inside encoder
321
+ x, indices, vq_stats = self._do_quantize(x, pe_for_vq)
322
+ if return_indices:
323
+ return x, indices
324
+
325
+ if self.avg_pooler:
326
+ x_list = x.split(audio_aftercnnlens, dim=0)
327
+ token_x_list = []
328
+ for x in x_list:
329
+ x = x.permute(1, 0)
330
+ x = self.avg_pooler(x)
331
+ x = x.permute(1, 0)
332
+ token_x_list.append(x)
333
+ x = torch.cat(token_x_list, dim=0)
334
+
335
+ x = self.ln_post(x)
336
+
337
+ x = self.proj(x)
338
+
339
+ output = torch.zeros(
340
+ (x.size(0) + len(audio_seqlens) * 2, x.size(1)),
341
+ device=x.device, dtype=x.dtype
342
+ )
343
+
344
+ audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0))
345
+ start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32)
346
+ end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1
347
+
348
+ audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool)
349
+ audio_tokens_mask[start_ids] = False
350
+ audio_tokens_mask[end_ids] = False
351
+ output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype)
352
+ output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype)
353
+ output[audio_tokens_mask] = x
354
+
355
+ if self.audio_vq_type != "NULL":
356
+ return output, vq_stats
357
+ return output
qwen_tts/core/tokenizer_25hz/vq/whisper_encoder.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import os
17
+ import math
18
+ import torch
19
+ import operator
20
+
21
+ import numpy as np
22
+ import torch.nn.functional as F
23
+
24
+ from functools import lru_cache
25
+ from typing import Optional, Union, List
26
+ from torch import nn, Tensor
27
+ from itertools import accumulate
28
+
29
+ try:
30
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
31
+ except ImportError:
32
+ try:
33
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_varlen_func
34
+ except ImportError:
35
+ print("\n********\nWarning: flash-attn is not installed. Will only run the manual PyTorch version. Please install flash-attn for faster inference.\n********\n ")
36
+ flash_attn_varlen_func = None
37
+
38
+
39
+ N_FFT = 400
40
+ HOP_LENGTH = 160
41
+
42
+
43
+ @lru_cache(maxsize=None)
44
+ def mel_filters(device, n_mels: int) -> torch.Tensor:
45
+ """
46
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
47
+ Allows decoupling librosa dependency; saved using:
48
+
49
+ np.savez_compressed(
50
+ "mel_filters.npz",
51
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
52
+ mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
53
+ )
54
+ """
55
+ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
56
+
57
+ filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
58
+ with np.load(filters_path, allow_pickle=False) as f:
59
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
60
+
61
+
62
+ def log_mel_spectrogram(
63
+ audio: Union[str, np.ndarray, torch.Tensor],
64
+ n_mels: int = 80,
65
+ padding: int = 0,
66
+ device: Optional[Union[str, torch.device]] = None,
67
+ ):
68
+ """
69
+ Compute the log-Mel spectrogram of
70
+
71
+ Parameters
72
+ ----------
73
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
74
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
75
+
76
+ n_mels: int
77
+ The number of Mel-frequency filters, only 80 is supported
78
+
79
+ padding: int
80
+ Number of zero samples to pad to the right
81
+
82
+ device: Optional[Union[str, torch.device]]
83
+ If given, the audio tensor is moved to this device before STFT
84
+
85
+ Returns
86
+ -------
87
+ torch.Tensor, shape = (80, n_frames)
88
+ A Tensor that contains the Mel spectrogram
89
+ """
90
+ if not torch.is_tensor(audio):
91
+ audio = torch.from_numpy(audio)
92
+
93
+ if device is not None:
94
+ audio = audio.to(device)
95
+ if padding > 0:
96
+ audio = F.pad(audio, (0, padding))
97
+ window = torch.hann_window(N_FFT).to(audio.device)
98
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
99
+ magnitudes = stft[..., :-1].abs() ** 2
100
+
101
+ filters = mel_filters(audio.device, n_mels)
102
+ mel_spec = filters @ magnitudes
103
+
104
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
105
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
106
+ log_spec = (log_spec + 4.0) / 4.0
107
+ return log_spec
108
+
109
+
110
+ def get_T_after_cnn(L_in, dilation=1):
111
+ for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
112
+ L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
113
+ L_out = 1 + L_out // stride
114
+ L_in = L_out
115
+ return L_out
116
+
117
+
118
+ def get_mel_audio(audio, padding=False, audio_vq_ds_rate = 1, n_mels = 128):
119
+ audio_len = len(audio)
120
+ if padding:
121
+ reduction = 160 * 2 * audio_vq_ds_rate
122
+ audio_pad = math.ceil(audio_len / reduction) * reduction - audio_len
123
+ mel = log_mel_spectrogram(audio, n_mels=n_mels, padding=audio_pad)
124
+ else:
125
+ mel = log_mel_spectrogram(audio, n_mels=n_mels) # [F,T]
126
+ return mel
127
+
128
+
129
+ def sinusoids(length, channels, max_timescale=10000):
130
+ """Returns sinusoids for positional embedding"""
131
+ assert channels % 2 == 0
132
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
133
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
134
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
135
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
136
+
137
+
138
+ class Conv1d(nn.Conv1d):
139
+ def _conv_forward(
140
+ self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
141
+ ) -> Tensor:
142
+ return super()._conv_forward(
143
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
144
+ )
145
+
146
+
147
+ class ConvTranspose1d(nn.ConvTranspose1d):
148
+ def _conv_forward(
149
+ self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
150
+ ) -> Tensor:
151
+ return super()._conv_forward(
152
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
153
+ )
154
+
155
+
156
+ class Linear(nn.Linear):
157
+ def forward(self, x: Tensor) -> Tensor:
158
+ return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) )
159
+
160
+
161
+ class MultiHeadAttention(nn.Module):
162
+ def __init__(self, n_state: int, n_head: int):
163
+ super().__init__()
164
+ self.n_head = n_head
165
+ self.query = Linear(n_state, n_state)
166
+ self.key = Linear(n_state, n_state, bias=False)
167
+ self.value = Linear(n_state, n_state)
168
+ self.out = Linear(n_state, n_state)
169
+
170
+ self.use_flash_attention = True
171
+
172
+ def forward(
173
+ self,
174
+ x: Tensor,
175
+ cu_seqlens = None,
176
+ ):
177
+ q = self.query(x)
178
+ k = self.key(x)
179
+ v = self.value(x)
180
+
181
+ if self.use_flash_attention:
182
+ if flash_attn_varlen_func is None:
183
+ x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
184
+ else:
185
+ if q.dtype not in [torch.float16, torch.bfloat16]:
186
+ x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
187
+ self.use_flash_attention = False
188
+ else:
189
+ x = self.qkv_flash_attention(q, k, v, cu_seqlens=cu_seqlens)
190
+ else:
191
+ x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
192
+
193
+ output = self.out(x)
194
+ return output
195
+
196
+ def qkv_flash_attention(
197
+ self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens=None
198
+ ):
199
+ n_ctx, n_state = q.shape
200
+ # scale = (n_state // self.n_head) ** -0.25
201
+ q = q.view(n_ctx, self.n_head, -1)# (batch_size, seqlen, nheads, headdim)
202
+ k = k.view(n_ctx, self.n_head, -1)
203
+ v = v.view(n_ctx, self.n_head, -1)
204
+
205
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
206
+
207
+
208
+ x = flash_attn_varlen_func(
209
+ q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0
210
+ )
211
+ x = x.reshape(n_ctx, n_state)
212
+ return x
213
+
214
+ def qkv_attention_manual(
215
+ self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens: Tensor
216
+ ):
217
+ n_ctx, n_state = q.shape
218
+ head_dim = n_state // self.n_head
219
+ scale = head_dim ** -0.5
220
+
221
+ q = q.view(n_ctx, self.n_head, head_dim)
222
+ k = k.view(n_ctx, self.n_head, head_dim)
223
+ v = v.view(n_ctx, self.n_head, head_dim)
224
+
225
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
226
+ batch_size = len(seqlens)
227
+ max_seqlen = max(seqlens)
228
+
229
+ q_padded = torch.zeros(batch_size, max_seqlen, self.n_head, head_dim, dtype=q.dtype, device=q.device)
230
+ k_padded = torch.zeros_like(q_padded)
231
+ v_padded = torch.zeros_like(q_padded)
232
+
233
+ for i in range(batch_size):
234
+ start_idx = cu_seqlens[i]
235
+ end_idx = cu_seqlens[i+1]
236
+ seq_len = seqlens[i]
237
+ q_padded[i, :seq_len] = q[start_idx:end_idx]
238
+ k_padded[i, :seq_len] = k[start_idx:end_idx]
239
+ v_padded[i, :seq_len] = v[start_idx:end_idx]
240
+
241
+ q_padded = q_padded.transpose(1, 2)
242
+ k_padded = k_padded.transpose(1, 2)
243
+ v_padded = v_padded.transpose(1, 2)
244
+
245
+ attn_mask = torch.arange(max_seqlen, device=q.device)[None, :] < torch.tensor(seqlens, device=q.device)[:, None]
246
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
247
+
248
+ attn_mask = attn_mask.masked_fill(attn_mask == 0, -torch.finfo(q.dtype).max)
249
+
250
+ attn_scores = torch.matmul(q_padded, k_padded.transpose(-2, -1)) * scale
251
+ attn_scores = attn_scores + attn_mask
252
+ attn_weights = F.softmax(attn_scores, dim=-1)
253
+
254
+ context = torch.matmul(attn_weights, v_padded)
255
+
256
+ context = context.transpose(1, 2).contiguous().view(batch_size, max_seqlen, n_state)
257
+
258
+ output_packed = torch.cat([context[i, :seqlens[i]] for i in range(batch_size)], dim=0)
259
+
260
+ assert output_packed.shape == (n_ctx, n_state)
261
+
262
+ return output_packed
263
+
264
+
265
+ class ResidualAttentionBlock(nn.Module):
266
+ def __init__(self, n_state: int, n_head: int,
267
+ enable_mp: bool = False, sequence_parallel: bool = False):
268
+ super().__init__()
269
+ n_mlp = n_state * 4
270
+ self.attn_ln = nn.LayerNorm(n_state)
271
+ self.mlp_ln = nn.LayerNorm(n_state)
272
+
273
+ self.attn = MultiHeadAttention(n_state, n_head)
274
+ self.mlp = nn.Sequential(
275
+ Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
276
+ )
277
+
278
+ def forward(
279
+ self,
280
+ x: Tensor,
281
+ cu_seqlens = None
282
+ ):
283
+ x = x + self.attn(self.attn_ln(x), cu_seqlens=cu_seqlens)
284
+ x = x + self.mlp(self.mlp_ln(x))
285
+ return x
286
+
287
+
288
+ class WhisperEncoder(nn.Module):
289
+ def __init__(
290
+ self,
291
+ n_mels: int,
292
+ n_ctx: int,
293
+ n_state: int,
294
+ n_head: int,
295
+ n_layer: int,
296
+ n_window: int = 1500,
297
+ output_dim: int = 512,
298
+ grad_checkpointing: bool = False,
299
+ enable_mp: bool = False,
300
+ audio_sequence_parallel: bool = False,
301
+ ):
302
+ super().__init__()
303
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
304
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
305
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
306
+ self.n_layer = n_layer
307
+ self.n_mels = n_mels
308
+
309
+ self.blocks = nn.ModuleList(
310
+ [ResidualAttentionBlock(n_state, n_head, enable_mp=enable_mp, sequence_parallel=audio_sequence_parallel)
311
+ for _ in range(n_layer)]
312
+ )
313
+ self.ln_post = nn.LayerNorm(n_state)
314
+ self.avg_pooler = nn.AvgPool1d(2, stride=2)
315
+
316
+ self.proj = torch.nn.Linear(n_state, output_dim)
317
+
318
+ self.audio_bos_eos_token = nn.Embedding(2, output_dim)
319
+
320
+ self.output_dim = output_dim
321
+ self.grad_checkpointing = grad_checkpointing
322
+ self.enable_mp = enable_mp
323
+ self.n_head = n_head
324
+ self.n_state = n_state
325
+ self.n_window = n_window
326
+
327
+ self.audio_sequence_parallel = audio_sequence_parallel
328
+
329
+ self.tp_world_size = 1
330
+
331
+ self.set_audio_sync()
332
+
333
+ def set_audio_sync(self):
334
+ for name, param in self.named_parameters():
335
+ if not name.startswith("blocks"):
336
+ setattr(param, "audio_sync", True)
337
+
338
+ def forward(self, x_list: List[Tensor], audio_mellens:List[int], audio_aftercnnlens:List[int], audio_seqlens:List[int]):
339
+ """
340
+ x : torch.Tensor, shape = (n_mels, n_ctx)
341
+ the mel spectrogram of the audio
342
+ """
343
+
344
+ aftercnn_x_list = []
345
+ for each_x in x_list:
346
+ each_x_split_list = each_x.split(self.n_window * 2, dim=1)
347
+ for each_x_split in each_x_split_list:
348
+ each_x_split = F.gelu(self.conv1(each_x_split))
349
+ each_x_split = F.gelu(self.conv2(each_x_split))
350
+ each_x_split = each_x_split.permute(1, 0) # L,D
351
+ each_positional_embedding_split = self.positional_embedding[:each_x_split.shape[0]]
352
+ aftercnn_x_list.append(each_x_split+each_positional_embedding_split.to(each_x_split.dtype))
353
+
354
+ x = torch.cat(aftercnn_x_list, dim=0)
355
+ src_len = x.size(0)
356
+
357
+ output_list = []
358
+ for item in audio_aftercnnlens:
359
+ while item > self.n_window:
360
+ output_list.append(self.n_window)
361
+ item -= self.n_window
362
+ output_list.append(item)
363
+
364
+ cu_seqlens = list(accumulate(output_list, func=operator.add,initial=0))
365
+ cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32)
366
+
367
+ layer_id = 0
368
+ for block in self.blocks:
369
+ layer_id+=1
370
+ x = block(x, cu_seqlens=cu_seqlens)
371
+
372
+ if self.avg_pooler:
373
+ x_list = x.split(audio_aftercnnlens, dim=0)
374
+ token_x_list = []
375
+ for x in x_list:
376
+ x = x.permute(1, 0)
377
+ x = self.avg_pooler(x)
378
+ x = x.permute(1, 0)
379
+ token_x_list.append(x)
380
+ x = torch.cat(token_x_list, dim=0)
381
+
382
+ x = self.ln_post(x)
383
+ x = self.proj(x)
384
+
385
+ output = torch.zeros(
386
+ (x.size(0) + len(audio_seqlens) * 2, x.size(1)),
387
+ device=x.device, dtype=x.dtype
388
+ )
389
+
390
+ audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0))
391
+ start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32)
392
+ end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1
393
+
394
+ audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool)
395
+ audio_tokens_mask[start_ids] = False
396
+ audio_tokens_mask[end_ids] = False
397
+ output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype)
398
+ output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype)
399
+ output[audio_tokens_mask] = x
400
+ return output
401
+
402
+ def lock(self, layers: int):
403
+ self.conv1.requires_grad_(False)
404
+ self.conv2.requires_grad_(False)
405
+ for i in range(min(layers, len(self.blocks))):
406
+ self.blocks[i].requires_grad_(False)
qwen_tts/inference/qwen3_tts_model.py ADDED
@@ -0,0 +1,874 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import base64
17
+ import io
18
+ import urllib.request
19
+ from dataclasses import dataclass
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+ from urllib.parse import urlparse
22
+
23
+ import librosa
24
+ import numpy as np
25
+ import soundfile as sf
26
+ import torch
27
+ from transformers import AutoConfig, AutoModel, AutoProcessor
28
+
29
+ from ..core.models import Qwen3TTSConfig, Qwen3TTSForConditionalGeneration, Qwen3TTSProcessor
30
+
31
+ AudioLike = Union[
32
+ str, # wav path, URL, base64
33
+ np.ndarray, # waveform (requires sr)
34
+ Tuple[np.ndarray, int], # (waveform, sr)
35
+ ]
36
+
37
+ MaybeList = Union[Any, List[Any]]
38
+
39
+
40
+ @dataclass
41
+ class VoiceClonePromptItem:
42
+ """
43
+ Container for one sample's voice-clone prompt information that can be fed to the model.
44
+
45
+ Fields are aligned with `Qwen3TTSForConditionalGeneration.generate(..., voice_clone_prompt=...)`.
46
+ """
47
+ ref_code: Optional[torch.Tensor] # (T, Q) or (T,) depending on tokenizer 25Hz/12Hz
48
+ ref_spk_embedding: torch.Tensor # (D,)
49
+ x_vector_only_mode: bool
50
+ icl_mode: bool
51
+ ref_text: Optional[str] = None
52
+
53
+
54
+ class Qwen3TTSModel:
55
+ """
56
+ A HuggingFace-style wrapper for Qwen3 TTS models (CustomVoice/VoiceDesign/Base) that provides:
57
+ - from_pretrained() initialization via AutoModel/AutoProcessor
58
+ - generation APIs for:
59
+ * CustomVoice: generate_custom_voice()
60
+ * VoiceDesign: generate_voice_design()
61
+ * Base: generate_voice_clone() + create_voice_clone_prompt()
62
+ - consistent output: (wavs: List[np.ndarray], sample_rate: int)
63
+
64
+ Notes:
65
+ - This wrapper expects the underlying model class to be `Qwen3TTSForConditionalGeneration`
66
+ - Language / speaker validation is done via model methods:
67
+ model.get_supported_languages(), model.get_supported_speakers()
68
+ """
69
+
70
+ def __init__(self, model: Qwen3TTSForConditionalGeneration, processor, generate_defaults: Optional[Dict[str, Any]] = None):
71
+ self.model = model
72
+ self.processor = processor
73
+ self.generate_defaults = generate_defaults or {}
74
+
75
+ self.device = getattr(model, "device", None)
76
+ if self.device is None:
77
+ try:
78
+ self.device = next(model.parameters()).device
79
+ except StopIteration:
80
+ self.device = torch.device("cpu")
81
+
82
+ @classmethod
83
+ def from_pretrained(
84
+ cls,
85
+ pretrained_model_name_or_path: str,
86
+ **kwargs,
87
+ ) -> "Qwen3TTSModel":
88
+ """
89
+ Load a Qwen3 TTS model and its processor in HuggingFace `from_pretrained` style.
90
+
91
+ This method:
92
+ 1) Loads config via AutoConfig (so your side can register model_type -> config/model).
93
+ 2) Loads the model via AutoModel.from_pretrained(...), forwarding `kwargs` unchanged.
94
+ 3) Loads the processor via AutoProcessor.from_pretrained(model_path).
95
+ 4) Loads optional `generate_config.json` from the model directory/repo snapshot if present.
96
+
97
+ Args:
98
+ pretrained_model_name_or_path (str):
99
+ HuggingFace repo id or local directory of the model.
100
+ **kwargs:
101
+ Forwarded as-is into `AutoModel.from_pretrained(...)`.
102
+ Typical examples: device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="flash_attention_2".
103
+
104
+ Returns:
105
+ Qwen3TTSModel:
106
+ Wrapper instance containing `model`, `processor`, and generation defaults.
107
+ """
108
+ AutoConfig.register("qwen3_tts", Qwen3TTSConfig)
109
+ AutoModel.register(Qwen3TTSConfig, Qwen3TTSForConditionalGeneration)
110
+ AutoProcessor.register(Qwen3TTSConfig, Qwen3TTSProcessor)
111
+
112
+ model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
113
+ if not isinstance(model, Qwen3TTSForConditionalGeneration):
114
+ raise TypeError(
115
+ f"AutoModel returned {type(model)}, expected Qwen3TTSForConditionalGeneration. "
116
+ )
117
+
118
+ processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True,)
119
+
120
+ generate_defaults = model.generate_config
121
+ return cls(model=model, processor=processor, generate_defaults=generate_defaults)
122
+
123
+ def _supported_languages_set(self) -> Optional[set]:
124
+ langs = getattr(self.model, "get_supported_languages", None)
125
+ if callable(langs):
126
+ v = langs()
127
+ if v is None:
128
+ return None
129
+ return set([str(x).lower() for x in v])
130
+ return None
131
+
132
+ def _supported_speakers_set(self) -> Optional[set]:
133
+ spks = getattr(self.model, "get_supported_speakers", None)
134
+ if callable(spks):
135
+ v = spks()
136
+ if v is None:
137
+ return None
138
+ return set([str(x).lower() for x in v])
139
+ return None
140
+
141
+ def _validate_languages(self, languages: List[str]) -> None:
142
+ """
143
+ Validate that requested languages are supported by the model.
144
+
145
+ Args:
146
+ languages (List[str]): Language names for each sample.
147
+
148
+ Raises:
149
+ ValueError: If any language is not supported.
150
+ """
151
+ supported = self._supported_languages_set()
152
+ if supported is None:
153
+ return
154
+
155
+ bad = []
156
+ for lang in languages:
157
+ if lang is None:
158
+ bad.append(lang)
159
+ continue
160
+ if str(lang).lower() not in supported:
161
+ bad.append(lang)
162
+ if bad:
163
+ raise ValueError(f"Unsupported languages: {bad}. Supported: {sorted(supported)}")
164
+
165
+ def _validate_speakers(self, speakers: List[Optional[str]]) -> None:
166
+ """
167
+ Validate that requested speakers are supported by the Instruct model.
168
+
169
+ Args:
170
+ speakers (List[Optional[str]]): Speaker names for each sample.
171
+
172
+ Raises:
173
+ ValueError: If any speaker is not supported.
174
+ """
175
+ supported = self._supported_speakers_set()
176
+ if supported is None:
177
+ return
178
+
179
+ bad = []
180
+ for spk in speakers:
181
+ if spk is None or spk == "":
182
+ continue
183
+ if str(spk).lower() not in supported:
184
+ bad.append(spk)
185
+ if bad:
186
+ raise ValueError(f"Unsupported speakers: {bad}. Supported: {sorted(supported)}")
187
+
188
+ def _is_probably_base64(self, s: str) -> bool:
189
+ if s.startswith("data:audio"):
190
+ return True
191
+ if ("/" not in s and "\\" not in s) and len(s) > 256:
192
+ return True
193
+ return False
194
+
195
+ def _is_url(self, s: str) -> bool:
196
+ try:
197
+ u = urlparse(s)
198
+ return u.scheme in ("http", "https") and bool(u.netloc)
199
+ except Exception:
200
+ return False
201
+
202
+ def _decode_base64_to_wav_bytes(self, b64: str) -> bytes:
203
+ if "," in b64 and b64.strip().startswith("data:"):
204
+ b64 = b64.split(",", 1)[1]
205
+ return base64.b64decode(b64)
206
+
207
+ def _load_audio_to_np(self, x: str) -> Tuple[np.ndarray, int]:
208
+ if self._is_url(x):
209
+ with urllib.request.urlopen(x) as resp:
210
+ audio_bytes = resp.read()
211
+ with io.BytesIO(audio_bytes) as f:
212
+ audio, sr = sf.read(f, dtype="float32", always_2d=False)
213
+ elif self._is_probably_base64(x):
214
+ wav_bytes = self._decode_base64_to_wav_bytes(x)
215
+ with io.BytesIO(wav_bytes) as f:
216
+ audio, sr = sf.read(f, dtype="float32", always_2d=False)
217
+ else:
218
+ audio, sr = librosa.load(x, sr=None, mono=True)
219
+
220
+ if audio.ndim > 1:
221
+ audio = np.mean(audio, axis=-1)
222
+
223
+ return audio.astype(np.float32), int(sr)
224
+
225
+ def _normalize_audio_inputs(self, audios: Union[AudioLike, List[AudioLike]]) -> List[Tuple[np.ndarray, int]]:
226
+ """
227
+ Normalize audio inputs into a list of (waveform, sr).
228
+
229
+ Supported forms:
230
+ - str: wav path / URL / base64 audio string
231
+ - (np.ndarray, sr): waveform + sampling rate
232
+ - list of the above
233
+
234
+ Args:
235
+ audios:
236
+ Audio input(s).
237
+
238
+ Returns:
239
+ List[Tuple[np.ndarray, int]]:
240
+ List of (float32 waveform, original sr).
241
+
242
+ Raises:
243
+ ValueError: If a numpy waveform is provided without sr.
244
+ """
245
+ if isinstance(audios, list):
246
+ items = audios
247
+ else:
248
+ items = [audios]
249
+
250
+ out: List[Tuple[np.ndarray, int]] = []
251
+ for a in items:
252
+ if isinstance(a, str):
253
+ out.append(self._load_audio_to_np(a))
254
+ elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray):
255
+ out.append((a[0].astype(np.float32), int(a[1])))
256
+ elif isinstance(a, np.ndarray):
257
+ raise ValueError("For numpy waveform input, pass a tuple (audio, sr).")
258
+ else:
259
+ raise TypeError(f"Unsupported audio input type: {type(a)}")
260
+ for i, a in enumerate(out):
261
+ if a[0].ndim > 1:
262
+ a[0] = np.mean(a[0], axis=-1).astype(np.float32)
263
+ out[i] = (a[0], a[1])
264
+ return out
265
+
266
+ def _ensure_list(self, x: MaybeList) -> List[Any]:
267
+ return x if isinstance(x, list) else [x]
268
+
269
+ def _build_assistant_text(self, text: str) -> str:
270
+ return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
271
+
272
+ def _build_ref_text(self, text: str) -> str:
273
+ return f"<|im_start|>assistant\n{text}<|im_end|>\n"
274
+
275
+ def _build_instruct_text(self, instruct: str) -> str:
276
+ return f"<|im_start|>user\n{instruct}<|im_end|>\n"
277
+
278
+ def _tokenize_texts(self, texts: List[str]) -> List[torch.Tensor]:
279
+ input_ids = []
280
+ for text in texts:
281
+ input = self.processor(text=text, return_tensors="pt", padding=True)
282
+ input_id = input["input_ids"].to(self.device)
283
+ input_id = input_id.unsqueeze(0) if input_id.dim() == 1 else input_id
284
+ input_ids.append(input_id)
285
+ return input_ids
286
+
287
+ def _merge_generate_kwargs(
288
+ self,
289
+ non_streaming_mode: Optional[bool] = None,
290
+ do_sample: Optional[bool] = None,
291
+ top_k: Optional[int] = None,
292
+ top_p: Optional[float] = None,
293
+ temperature: Optional[float] = None,
294
+ repetition_penalty: Optional[float] = None,
295
+ subtalker_dosample: Optional[bool] = None,
296
+ subtalker_top_k: Optional[int] = None,
297
+ subtalker_top_p: Optional[float] = None,
298
+ subtalker_temperature: Optional[float] = None,
299
+ max_new_tokens: Optional[int] = None,
300
+ **kwargs,
301
+ ) -> Dict[str, Any]:
302
+ """
303
+ Merge user-provided generation arguments with defaults from `generate_config.json`.
304
+
305
+ Rule:
306
+ - If the user explicitly passes a value (not None), use it.
307
+ - Otherwise, use the value from generate_config.json if present.
308
+ - Otherwise, fall back to the hard defaults.
309
+
310
+ Args:
311
+ non_streaming_mode, do_sample, top_k, top_p, temperature, repetition_penalty,
312
+ subtalker_dosample, subtalker_top_k, subtalker_top_p, subtalker_temperature, max_new_tokens:
313
+ Common generation parameters.
314
+ **kwargs:
315
+ Other arguments forwarded to model.generate().
316
+
317
+ Returns:
318
+ Dict[str, Any]: Final kwargs to pass into model.generate().
319
+ """
320
+ hard_defaults = dict(
321
+ non_streaming_mode=False,
322
+ do_sample=True,
323
+ top_k=50,
324
+ top_p=1.0,
325
+ temperature=0.9,
326
+ repetition_penalty=1.05,
327
+ subtalker_dosample=True,
328
+ subtalker_top_k=50,
329
+ subtalker_top_p=1.0,
330
+ subtalker_temperature=0.9,
331
+ max_new_tokens=2048,
332
+ )
333
+
334
+ def pick(name: str, user_val: Any) -> Any:
335
+ if user_val is not None:
336
+ return user_val
337
+ if name in self.generate_defaults:
338
+ return self.generate_defaults[name]
339
+ return hard_defaults[name]
340
+
341
+ merged = dict(kwargs)
342
+ merged.update(
343
+ non_streaming_mode=pick("non_streaming_mode", non_streaming_mode),
344
+ do_sample=pick("do_sample", do_sample),
345
+ top_k=pick("top_k", top_k),
346
+ top_p=pick("top_p", top_p),
347
+ temperature=pick("temperature", temperature),
348
+ repetition_penalty=pick("repetition_penalty", repetition_penalty),
349
+ subtalker_dosample=pick("subtalker_dosample", subtalker_dosample),
350
+ subtalker_top_k=pick("subtalker_top_k", subtalker_top_k),
351
+ subtalker_top_p=pick("subtalker_top_p", subtalker_top_p),
352
+ subtalker_temperature=pick("subtalker_temperature", subtalker_temperature),
353
+ max_new_tokens=pick("max_new_tokens", max_new_tokens),
354
+ )
355
+ return merged
356
+
357
+ # voice clone model
358
+ @torch.inference_mode()
359
+ def create_voice_clone_prompt(
360
+ self,
361
+ ref_audio: Union[AudioLike, List[AudioLike]],
362
+ ref_text: Optional[Union[str, List[Optional[str]]]] = None,
363
+ x_vector_only_mode: Union[bool, List[bool]] = False,
364
+ ) -> List[VoiceClonePromptItem]:
365
+ """
366
+ Build voice-clone prompt items from reference audio (and optionally reference text) using Base model.
367
+
368
+ Modes:
369
+ - x_vector_only_mode=True:
370
+ Only speaker embedding is used to clone voice; ref_text/ref_code are ignored.
371
+ This is mutually exclusive with ICL.
372
+ - x_vector_only_mode=False:
373
+ ICL mode is enabled automatically (icl_mode=True). In this case ref_text is required,
374
+ because the model continues/conditions on the reference text + reference speech codes.
375
+
376
+ Batch behavior:
377
+ - ref_audio can be a single item or a list.
378
+ - ref_text and x_vector_only_mode can be scalars or lists.
379
+ - If any of them are lists with length > 1, lengths must match.
380
+
381
+ Audio input:
382
+ - str: local wav path / URL / base64
383
+ - (np.ndarray, sr): waveform + sampling rate
384
+
385
+ Args:
386
+ ref_audio:
387
+ Reference audio(s) used to extract:
388
+ - ref_code via `model.speech_tokenizer.encode(...)`
389
+ - ref_spk_embedding via `model.extract_speaker_embedding(...)` (resampled to 24k)
390
+ ref_text:
391
+ Reference transcript(s). Required when x_vector_only_mode=False (ICL mode).
392
+ x_vector_only_mode:
393
+ Whether to use speaker embedding only. If False, ICL mode will be used.
394
+
395
+ Returns:
396
+ List[VoiceClonePromptItem]:
397
+ List of prompt items that can be converted into `voice_clone_prompt` dict.
398
+
399
+ Raises:
400
+ ValueError:
401
+ - If x_vector_only_mode=False but ref_text is missing.
402
+ - If batch lengths mismatch.
403
+ """
404
+ if self.model.tts_model_type != "base":
405
+ raise ValueError(
406
+ f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
407
+ f"tts_model_size: {self.model.tts_model_size}\n"
408
+ f"tts_model_type: {self.model.tts_model_type}\n"
409
+ "does not support create_voice_clone_prompt, Please check Model Card or Readme for more details."
410
+ )
411
+
412
+ ref_audio_list = self._ensure_list(ref_audio)
413
+ ref_text_list = self._ensure_list(ref_text) if isinstance(ref_text, list) else ([ref_text] * len(ref_audio_list))
414
+ xvec_list = self._ensure_list(x_vector_only_mode) if isinstance(x_vector_only_mode, list) else ([x_vector_only_mode] * len(ref_audio_list))
415
+
416
+ if len(ref_text_list) != len(ref_audio_list) or len(xvec_list) != len(ref_audio_list):
417
+ raise ValueError(
418
+ f"Batch size mismatch: ref_audio={len(ref_audio_list)}, ref_text={len(ref_text_list)}, x_vector_only_mode={len(xvec_list)}"
419
+ )
420
+
421
+ normalized = self._normalize_audio_inputs(ref_audio_list)
422
+
423
+ ref_wavs_for_code: List[np.ndarray] = []
424
+ ref_sr_for_code: List[int] = []
425
+ for wav, sr in normalized:
426
+ ref_wavs_for_code.append(wav)
427
+ ref_sr_for_code.append(sr)
428
+
429
+ if len(set(ref_sr_for_code)) == 1:
430
+ enc = self.model.speech_tokenizer.encode(ref_wavs_for_code, sr=ref_sr_for_code[0])
431
+ ref_codes = enc.audio_codes
432
+ else:
433
+ ref_codes = []
434
+ for wav, sr in normalized:
435
+ ref_codes.append(self.model.speech_tokenizer.encode(wav, sr=sr).audio_codes[0])
436
+
437
+ items: List[VoiceClonePromptItem] = []
438
+ for i, ((wav, sr), code, rtext, xvec_only) in enumerate(zip(normalized, ref_codes, ref_text_list, xvec_list)):
439
+ if not xvec_only:
440
+ if rtext is None or rtext == "":
441
+ raise ValueError(f"ref_text is required when x_vector_only_mode=False (ICL mode). Bad index={i}")
442
+
443
+ wav_resample = wav
444
+ if sr != self.model.speaker_encoder_sample_rate:
445
+ wav_resample = librosa.resample(y=wav_resample.astype(np.float32),
446
+ orig_sr=int(sr),
447
+ target_sr=self.model.speaker_encoder_sample_rate)
448
+
449
+ spk_emb = self.model.extract_speaker_embedding(audio=wav_resample,
450
+ sr=self.model.speaker_encoder_sample_rate)
451
+
452
+ items.append(
453
+ VoiceClonePromptItem(
454
+ ref_code=None if xvec_only else code,
455
+ ref_spk_embedding=spk_emb,
456
+ x_vector_only_mode=bool(xvec_only),
457
+ icl_mode=bool(not xvec_only),
458
+ ref_text=rtext,
459
+ )
460
+ )
461
+ return items
462
+
463
+ def _prompt_items_to_voice_clone_prompt(self, items: List[VoiceClonePromptItem]) -> Dict[str, Any]:
464
+ return dict(
465
+ ref_code=[it.ref_code for it in items],
466
+ ref_spk_embedding=[it.ref_spk_embedding for it in items],
467
+ x_vector_only_mode=[it.x_vector_only_mode for it in items],
468
+ icl_mode=[it.icl_mode for it in items],
469
+ )
470
+
471
+ # voice clone model
472
+ @torch.no_grad()
473
+ def generate_voice_clone(
474
+ self,
475
+ text: Union[str, List[str]],
476
+ language: Union[str, List[str]] = None,
477
+ ref_audio: Optional[Union[AudioLike, List[AudioLike]]] = None,
478
+ ref_text: Optional[Union[str, List[Optional[str]]]] = None,
479
+ x_vector_only_mode: Union[bool, List[bool]] = False,
480
+ voice_clone_prompt: Optional[Union[Dict[str, Any], List[VoiceClonePromptItem]]] = None,
481
+ **kwargs,
482
+ ) -> Tuple[List[np.ndarray], int]:
483
+ """
484
+ Voice clone speech using the Base model.
485
+
486
+ You can provide either:
487
+ - (ref_audio, ref_text, x_vector_only_mode) and let this method build the prompt, OR
488
+ - `VoiceClonePromptItem` returned by `create_voice_clone_prompt`, OR
489
+ - a list of `VoiceClonePromptItem` returned by `create_voice_clone_prompt`.
490
+
491
+ `ref_audio` Supported forms:
492
+ - str: wav path / URL / base64 audio string
493
+ - (np.ndarray, sr): waveform + sampling rate
494
+ - list of the above
495
+
496
+ Input flexibility:
497
+ - text/language can be scalar or list.
498
+ - prompt can be single or batch.
499
+ - If batch mode (len(text)>1), lengths must match.
500
+
501
+ Args:
502
+ text:
503
+ Text(s) to synthesize.
504
+ language:
505
+ Language(s) for each sample.
506
+ ref_audio:
507
+ Reference audio(s) for prompt building. Required if voice_clone_prompt is not provided.
508
+ ref_text:
509
+ Reference text(s) used for ICL mode (required when x_vector_only_mode=False).
510
+ x_vector_only_mode:
511
+ If True, only speaker embedding is used (ignores ref_text/ref_code).
512
+ If False, ICL mode is used automatically.
513
+ voice_clone_prompt:
514
+ list[VoiceClonePromptItem] from `create_voice_clone_prompt`.
515
+ non_streaming_mode:
516
+ Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
517
+ rather than enabling true streaming input or streaming generation.
518
+ do_sample:
519
+ Whether to use sampling, recommended to be set to `true` for most use cases.
520
+ top_k:
521
+ Top-k sampling parameter.
522
+ top_p:
523
+ Top-p sampling parameter.
524
+ temperature:
525
+ Sampling temperature; higher => more random.
526
+ repetition_penalty:
527
+ Penalty to reduce repeated tokens/codes.
528
+ subtalker_dosample:
529
+ Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
530
+ subtalker_top_k:
531
+ Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
532
+ subtalker_top_p:
533
+ Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
534
+ subtalker_temperature:
535
+ Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
536
+ max_new_tokens:
537
+ Maximum number of new codec tokens to generate.
538
+ **kwargs:
539
+ Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
540
+ They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
541
+
542
+ Returns:
543
+ Tuple[List[np.ndarray], int]:
544
+ (wavs, sample_rate)
545
+
546
+ Raises:
547
+ ValueError:
548
+ If batch sizes mismatch or required prompt inputs are missing.
549
+ """
550
+ if self.model.tts_model_type != "base":
551
+ raise ValueError(
552
+ f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
553
+ f"tts_model_size: {self.model.tts_model_size}\n"
554
+ f"tts_model_type: {self.model.tts_model_type}\n"
555
+ "does not support generate_voice_clone, Please check Model Card or Readme for more details."
556
+ )
557
+
558
+ texts = self._ensure_list(text)
559
+ languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
560
+ if len(languages) == 1 and len(texts) > 1:
561
+ languages = languages * len(texts)
562
+ if len(texts) != len(languages):
563
+ raise ValueError(f"Batch size mismatch: text={len(texts)}, language={len(languages)}")
564
+
565
+ self._validate_languages(languages)
566
+
567
+ if voice_clone_prompt is None:
568
+ if ref_audio is None:
569
+ raise ValueError("Either `voice_clone_prompt` or `ref_audio` must be provided.")
570
+ prompt_items = self.create_voice_clone_prompt(ref_audio=ref_audio, ref_text=ref_text, x_vector_only_mode=x_vector_only_mode)
571
+ if len(prompt_items) == 1 and len(texts) > 1:
572
+ prompt_items = prompt_items * len(texts)
573
+ if len(prompt_items) != len(texts):
574
+ raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}")
575
+ voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items)
576
+ ref_texts_for_ids = [it.ref_text for it in prompt_items]
577
+ else:
578
+ if isinstance(voice_clone_prompt, list):
579
+ prompt_items = voice_clone_prompt
580
+ if len(prompt_items) == 1 and len(texts) > 1:
581
+ prompt_items = prompt_items * len(texts)
582
+ if len(prompt_items) != len(texts):
583
+ raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}")
584
+ voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items)
585
+ ref_texts_for_ids = [it.ref_text for it in prompt_items]
586
+ else:
587
+ voice_clone_prompt_dict = voice_clone_prompt
588
+ ref_texts_for_ids = None
589
+
590
+ input_texts = [self._build_assistant_text(t) for t in texts]
591
+ input_ids = self._tokenize_texts(input_texts)
592
+
593
+ ref_ids = None
594
+ if ref_texts_for_ids is not None:
595
+ ref_ids = []
596
+ for i, rt in enumerate(ref_texts_for_ids):
597
+ if rt is None or rt == "":
598
+ ref_ids.append(None)
599
+ else:
600
+ ref_tok = self._tokenize_texts([self._build_ref_text(rt)])[0]
601
+ ref_ids.append(ref_tok)
602
+
603
+ gen_kwargs = self._merge_generate_kwargs(**kwargs)
604
+
605
+ talker_codes_list, _ = self.model.generate(
606
+ input_ids=input_ids,
607
+ ref_ids=ref_ids,
608
+ voice_clone_prompt=voice_clone_prompt_dict,
609
+ languages=languages,
610
+ **gen_kwargs,
611
+ )
612
+
613
+ codes_for_decode = []
614
+ for i, codes in enumerate(talker_codes_list):
615
+ ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
616
+ if ref_code_list is not None and ref_code_list[i] is not None:
617
+ codes_for_decode.append(torch.cat([ref_code_list[i].to(codes.device), codes], dim=0))
618
+ else:
619
+ codes_for_decode.append(codes)
620
+
621
+ wavs_all, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in codes_for_decode])
622
+
623
+ wavs_out: List[np.ndarray] = []
624
+ for i, wav in enumerate(wavs_all):
625
+ ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
626
+ if ref_code_list is not None and ref_code_list[i] is not None:
627
+ ref_len = int(ref_code_list[i].shape[0])
628
+ total_len = int(codes_for_decode[i].shape[0])
629
+ cut = int(ref_len / max(total_len, 1) * wav.shape[0])
630
+ wavs_out.append(wav[cut:])
631
+ else:
632
+ wavs_out.append(wav)
633
+
634
+ return wavs_out, fs
635
+
636
+ # voice design model
637
+ @torch.no_grad()
638
+ def generate_voice_design(
639
+ self,
640
+ text: Union[str, List[str]],
641
+ instruct: Union[str, List[str]],
642
+ language: Union[str, List[str]] = None,
643
+ **kwargs,
644
+ ) -> Tuple[List[np.ndarray], int]:
645
+ """
646
+ Generate speech with the VoiceDesign model using natural-language style instructions.
647
+
648
+ Args:
649
+ text:
650
+ Text(s) to synthesize.
651
+ language:
652
+ Language(s) for each sample.
653
+ instruct:
654
+ Instruction(s) describing desired voice/style. Empty string is allowed (treated as no instruction).
655
+ non_streaming_mode:
656
+ Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
657
+ rather than enabling true streaming input or streaming generation.
658
+ do_sample:
659
+ Whether to use sampling, recommended to be set to `true` for most use cases.
660
+ top_k:
661
+ Top-k sampling parameter.
662
+ top_p:
663
+ Top-p sampling parameter.
664
+ temperature:
665
+ Sampling temperature; higher => more random.
666
+ repetition_penalty:
667
+ Penalty to reduce repeated tokens/codes.
668
+ subtalker_dosample:
669
+ Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
670
+ subtalker_top_k:
671
+ Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
672
+ subtalker_top_p:
673
+ Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
674
+ subtalker_temperature:
675
+ Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
676
+ max_new_tokens:
677
+ Maximum number of new codec tokens to generate.
678
+ **kwargs:
679
+ Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
680
+ They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
681
+
682
+ Returns:
683
+ Tuple[List[np.ndarray], int]:
684
+ (wavs, sample_rate)
685
+ """
686
+ if self.model.tts_model_type != "voice_design":
687
+ raise ValueError(
688
+ f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
689
+ f"tts_model_size: {self.model.tts_model_size}\n"
690
+ f"tts_model_type: {self.model.tts_model_type}\n"
691
+ "does not support generate_voice_design, Please check Model Card or Readme for more details."
692
+ )
693
+
694
+ texts = self._ensure_list(text)
695
+ languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
696
+ instructs = self._ensure_list(instruct)
697
+
698
+ if len(languages) == 1 and len(texts) > 1:
699
+ languages = languages * len(texts)
700
+ if len(instructs) == 1 and len(texts) > 1:
701
+ instructs = instructs * len(texts)
702
+
703
+ if not (len(texts) == len(languages) == len(instructs)):
704
+ raise ValueError(f"Batch size mismatch: text={len(texts)}, language={len(languages)}, instruct={len(instructs)}")
705
+
706
+ self._validate_languages(languages)
707
+
708
+ input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts])
709
+
710
+ instruct_ids: List[Optional[torch.Tensor]] = []
711
+ for ins in instructs:
712
+ if ins is None or ins == "":
713
+ instruct_ids.append(None)
714
+ else:
715
+ instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0])
716
+
717
+ gen_kwargs = self._merge_generate_kwargs(**kwargs)
718
+
719
+ talker_codes_list, _ = self.model.generate(
720
+ input_ids=input_ids,
721
+ instruct_ids=instruct_ids,
722
+ languages=languages,
723
+ **gen_kwargs,
724
+ )
725
+
726
+ wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list])
727
+ return wavs, fs
728
+
729
+ # custom voice model
730
+ @torch.no_grad()
731
+ def generate_custom_voice(
732
+ self,
733
+ text: Union[str, List[str]],
734
+ speaker: Union[str, List[str]],
735
+ language: Union[str, List[str]] = None,
736
+ instruct: Optional[Union[str, List[str]]] = None,
737
+ **kwargs,
738
+ ) -> Tuple[List[np.ndarray], int]:
739
+ """
740
+ Generate speech with the CustomVoice model using a predefined speaker id, optionally controlled by instruction text.
741
+
742
+ Args:
743
+ text:
744
+ Text(s) to synthesize.
745
+ language:
746
+ Language(s) for each sample.
747
+ speaker:
748
+ Speaker name(s). Will be validated against `model.get_supported_speakers()` (case-insensitive).
749
+ instruct:
750
+ Optional instruction(s). If None, treated as empty (no instruction).
751
+ non_streaming_mode:
752
+ Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
753
+ rather than enabling true streaming input or streaming generation.
754
+ do_sample:
755
+ Whether to use sampling, recommended to be set to `true` for most use cases.
756
+ top_k:
757
+ Top-k sampling parameter.
758
+ top_p:
759
+ Top-p sampling parameter.
760
+ temperature:
761
+ Sampling temperature; higher => more random.
762
+ repetition_penalty:
763
+ Penalty to reduce repeated tokens/codes.
764
+ subtalker_dosample:
765
+ Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
766
+ subtalker_top_k:
767
+ Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
768
+ subtalker_top_p:
769
+ Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
770
+ subtalker_temperature:
771
+ Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
772
+ max_new_tokens:
773
+ Maximum number of new codec tokens to generate.
774
+ **kwargs:
775
+ Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
776
+ They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
777
+
778
+ Returns:
779
+ Tuple[List[np.ndarray], int]:
780
+ (wavs, sample_rate)
781
+
782
+ Raises:
783
+ ValueError:
784
+ If any speaker/language is unsupported or batch sizes mismatch.
785
+ """
786
+ if self.model.tts_model_type != "custom_voice":
787
+ raise ValueError(
788
+ f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
789
+ f"tts_model_size: {self.model.tts_model_size}\n"
790
+ f"tts_model_type: {self.model.tts_model_type}\n"
791
+ "does not support generate_custom_voice, Please check Model Card or Readme for more details."
792
+ )
793
+
794
+ texts = self._ensure_list(text)
795
+ languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
796
+ speakers = self._ensure_list(speaker)
797
+ if self.model.tts_model_size in "0b6": # for 0b6 model, instruct is not supported
798
+ instruct = None
799
+ instructs = self._ensure_list(instruct) if isinstance(instruct, list) else ([instruct] * len(texts) if instruct is not None else [""] * len(texts))
800
+
801
+ if len(languages) == 1 and len(texts) > 1:
802
+ languages = languages * len(texts)
803
+ if len(speakers) == 1 and len(texts) > 1:
804
+ speakers = speakers * len(texts)
805
+ if len(instructs) == 1 and len(texts) > 1:
806
+ instructs = instructs * len(texts)
807
+
808
+ if not (len(texts) == len(languages) == len(speakers) == len(instructs)):
809
+ raise ValueError(
810
+ f"Batch size mismatch: text={len(texts)}, language={len(languages)}, speaker={len(speakers)}, instruct={len(instructs)}"
811
+ )
812
+
813
+ self._validate_languages(languages)
814
+ self._validate_speakers(speakers)
815
+
816
+ input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts])
817
+
818
+ instruct_ids: List[Optional[torch.Tensor]] = []
819
+ for ins in instructs:
820
+ if ins is None or ins == "":
821
+ instruct_ids.append(None)
822
+ else:
823
+ instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0])
824
+
825
+ gen_kwargs = self._merge_generate_kwargs(**kwargs)
826
+
827
+ talker_codes_list, _ = self.model.generate(
828
+ input_ids=input_ids,
829
+ instruct_ids=instruct_ids,
830
+ languages=languages,
831
+ speakers=speakers,
832
+ **gen_kwargs,
833
+ )
834
+
835
+ wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list])
836
+ return wavs, fs
837
+
838
+
839
+ def get_supported_speakers(self) -> Optional[List[str]]:
840
+ """
841
+ List supported speaker names for the current model.
842
+
843
+ This is a convenience wrapper around `model.get_supported_speakers()`.
844
+ If the underlying model does not expose speaker constraints (returns None),
845
+ this method also returns None.
846
+
847
+ Returns:
848
+ Optional[List[str]]:
849
+ - A sorted list of supported speaker names (lowercased), if available.
850
+ - None if the model does not provide supported speakers.
851
+ """
852
+ supported = self._supported_speakers_set()
853
+ if supported is None:
854
+ return None
855
+ return sorted(supported)
856
+
857
+
858
+ def get_supported_languages(self) -> Optional[List[str]]:
859
+ """
860
+ List supported language names for the current model.
861
+
862
+ This is a convenience wrapper around `model.get_supported_languages()`.
863
+ If the underlying model does not expose language constraints (returns None),
864
+ this method also returns None.
865
+
866
+ Returns:
867
+ Optional[List[str]]:
868
+ - A sorted list of supported language names (lowercased), if available.
869
+ - None if the model does not provide supported languages.
870
+ """
871
+ supported = self._supported_languages_set()
872
+ if supported is None:
873
+ return None
874
+ return sorted(supported)
qwen_tts/inference/qwen3_tts_tokenizer.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import base64
17
+ import io
18
+ import urllib.request
19
+ from typing import List, Optional, Tuple, Union
20
+ from urllib.parse import urlparse
21
+
22
+ import librosa
23
+ import numpy as np
24
+ import soundfile as sf
25
+ import torch
26
+ from torch.nn.utils.rnn import pad_sequence
27
+ from transformers import AutoConfig, AutoFeatureExtractor, AutoModel
28
+
29
+ from ..core import (
30
+ Qwen3TTSTokenizerV1Config,
31
+ Qwen3TTSTokenizerV1Model,
32
+ Qwen3TTSTokenizerV2Config,
33
+ Qwen3TTSTokenizerV2Model,
34
+ )
35
+
36
+ AudioInput = Union[
37
+ str, # wav path, or base64 string
38
+ np.ndarray, # 1-D float array
39
+ List[str],
40
+ List[np.ndarray],
41
+ ]
42
+
43
+
44
+ class Qwen3TTSTokenizer:
45
+ """
46
+ A wrapper for Qwen3 TTS Tokenizer 25Hz/12Hz with HuggingFace-style loading.
47
+
48
+ - from_pretrained(): loads speech tokenizer model via AutoModel and feature_extractor via AutoFeatureExtractor.
49
+ - encode(): supports wav path(s), base64 audio string(s), numpy array(s).
50
+ - decode(): accepts either the raw model encode output, or a minimal dict/list-of-dicts.
51
+
52
+ Notes:
53
+ - For numpy array input, you must pass `sr` so the audio can be resampled to model sample rate.
54
+ - Returned audio is float32 numpy arrays and the output sample rate.
55
+ """
56
+
57
+ def __init__(self):
58
+ self.model = None
59
+ self.feature_extractor = None
60
+ self.config = None
61
+ self.device = None
62
+
63
+ @classmethod
64
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "Qwen3TTSTokenizer":
65
+ """
66
+ Initialize tokenizer with HuggingFace `from_pretrained` style.
67
+
68
+ Args:
69
+ pretrained_model_name_or_path (str):
70
+ HuggingFace repo id or local directory.
71
+ **kwargs (Any):
72
+ Forwarded to `AutoModel.from_pretrained(...)` directly.
73
+ Typical examples: device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="eager".
74
+
75
+ Returns:
76
+ Qwen3TTSTokenizer:
77
+ Initialized instance with `model`, `feature_extractor`, `config`.
78
+ """
79
+ inst = cls()
80
+
81
+ AutoConfig.register("qwen3_tts_tokenizer_25hz", Qwen3TTSTokenizerV1Config)
82
+ AutoModel.register(Qwen3TTSTokenizerV1Config, Qwen3TTSTokenizerV1Model)
83
+
84
+ AutoConfig.register("qwen3_tts_tokenizer_12hz", Qwen3TTSTokenizerV2Config)
85
+ AutoModel.register(Qwen3TTSTokenizerV2Config, Qwen3TTSTokenizerV2Model)
86
+
87
+ inst.feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
88
+ inst.model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
89
+ inst.config = inst.model.config
90
+
91
+ inst.device = getattr(inst.model, "device", None)
92
+ if inst.device is None:
93
+ # fallback: infer from first parameter device
94
+ try:
95
+ inst.device = next(inst.model.parameters()).device
96
+ except StopIteration:
97
+ inst.device = torch.device("cpu")
98
+
99
+ return inst
100
+
101
+ def _is_probably_base64(self, s: str) -> bool:
102
+ if s.startswith("data:audio"):
103
+ return True
104
+ # Heuristic: no filesystem path separators and long enough.
105
+ if ("/" not in s and "\\" not in s) and len(s) > 256:
106
+ return True
107
+ return False
108
+
109
+ def _is_url(self, s: str) -> bool:
110
+ try:
111
+ u = urlparse(s)
112
+ return u.scheme in ("http", "https") and bool(u.netloc)
113
+ except Exception:
114
+ return False
115
+
116
+ def _decode_base64_to_wav_bytes(self, b64: str) -> bytes:
117
+ # Accept both "data:audio/wav;base64,...." and raw base64
118
+ if "," in b64 and b64.strip().startswith("data:"):
119
+ b64 = b64.split(",", 1)[1]
120
+ return base64.b64decode(b64)
121
+
122
+ def load_audio(
123
+ self,
124
+ x: str,
125
+ target_sr: int,
126
+ ) -> np.ndarray:
127
+ """
128
+ Load audio from wav path or base64 string, then resample to target_sr.
129
+
130
+ Args:
131
+ x (str):
132
+ A wav file path, or a base64 audio string (raw or data URL).
133
+ target_sr (int):
134
+ Target sampling rate.
135
+
136
+ Returns:
137
+ np.ndarray:
138
+ 1-D float32 waveform at target_sr.
139
+ """
140
+ if self._is_url(x):
141
+ with urllib.request.urlopen(x) as resp:
142
+ audio_bytes = resp.read()
143
+ with io.BytesIO(audio_bytes) as f:
144
+ audio, sr = sf.read(f, dtype="float32", always_2d=False)
145
+ elif self._is_probably_base64(x):
146
+ wav_bytes = self._decode_base64_to_wav_bytes(x)
147
+ with io.BytesIO(wav_bytes) as f:
148
+ audio, sr = sf.read(f, dtype="float32", always_2d=False)
149
+ else:
150
+ audio, sr = librosa.load(x, sr=None, mono=True)
151
+
152
+ if audio.ndim > 1:
153
+ audio = np.mean(audio, axis=-1)
154
+
155
+ if sr != target_sr:
156
+ audio = librosa.resample(y=audio, orig_sr=sr, target_sr=target_sr)
157
+
158
+ return audio.astype(np.float32)
159
+
160
+ def _normalize_audio_inputs(
161
+ self,
162
+ audios: AudioInput,
163
+ sr: Optional[int],
164
+ ) -> List[np.ndarray]:
165
+ """
166
+ Normalize all supported input types into a list of 1-D numpy float32 waveforms
167
+ at `self.feature_extractor.sampling_rate`.
168
+
169
+ Args:
170
+ audios (AudioInput):
171
+ - str: wav path OR base64 audio string
172
+ - np.ndarray: raw waveform (sr must be provided)
173
+ - list[str] / list[np.ndarray]
174
+ sr (Optional[int]):
175
+ Sampling rate for raw numpy input. Required if input is np.ndarray or list[np.ndarray].
176
+
177
+ Returns:
178
+ List[np.ndarray]:
179
+ List of float32 waveforms resampled to model input SR.
180
+ """
181
+ target_sr = int(self.feature_extractor.sampling_rate)
182
+
183
+ if isinstance(audios, (str, np.ndarray)):
184
+ audios = [audios]
185
+
186
+ if len(audios) == 0:
187
+ return []
188
+
189
+ if isinstance(audios[0], str):
190
+ # wav path list or base64 list
191
+ return [self.load_audio(x, target_sr=target_sr) for x in audios] # type: ignore[arg-type]
192
+
193
+ # numpy list
194
+ if sr is None:
195
+ raise ValueError("For numpy waveform input, you must provide `sr` (original sampling rate).")
196
+
197
+ out: List[np.ndarray] = []
198
+ for a in audios: # type: ignore[assignment]
199
+ if not isinstance(a, np.ndarray):
200
+ raise TypeError("Mixed input types are not supported. Use all paths/base64 or all numpy arrays.")
201
+ if a.ndim > 1:
202
+ a = np.mean(a, axis=-1)
203
+ if int(sr) != target_sr:
204
+ a = librosa.resample(y=a.astype(np.float32), orig_sr=int(sr), target_sr=target_sr)
205
+ out.append(a.astype(np.float32))
206
+ return out
207
+
208
+ def encode(
209
+ self,
210
+ audios: AudioInput,
211
+ sr: Optional[int] = None,
212
+ return_dict: bool = True,
213
+ ):
214
+ """
215
+ Batch-encode audio into discrete codes (and optional conditioning, depending on 25Hz/12Hz).
216
+
217
+ Args:
218
+ audios (AudioInput):
219
+ Supported forms:
220
+ - np.ndarray: waveform (requires sr)
221
+ - list[np.ndarray]: waveforms (requires sr)
222
+ - str: wav path OR base64 audio string
223
+ - list[str]: wav paths and/or base64 strings
224
+ sr (Optional[int], default=None):
225
+ Original sampling rate for numpy waveform input.
226
+ return_dict (bool, default=True):
227
+ Forwarded to model.encode(...). If True, returns ModelOutput.
228
+
229
+ Returns:
230
+ 25Hz:
231
+ Qwen3TTSTokenizerV1EncoderOutput (if return_dict=True) with fields:
232
+ - audio_codes: List[torch.LongTensor] each (codes_len,)
233
+ - xvectors: List[torch.FloatTensor] each (xvector_dim,)
234
+ - ref_mels: List[torch.FloatTensor] each (mel_len, mel_dim)
235
+ 12Hz:
236
+ Qwen3TTSTokenizerV2EncoderOutput (if return_dict=True) with fields:
237
+ - audio_codes: List[torch.LongTensor] each (codes_len, num_quantizers)
238
+
239
+ If return_dict=False, returns the raw tuple from model.encode.
240
+ """
241
+ wavs = self._normalize_audio_inputs(audios, sr=sr)
242
+
243
+ inputs = self.feature_extractor(
244
+ raw_audio=wavs,
245
+ sampling_rate=int(self.feature_extractor.sampling_rate),
246
+ return_tensors="pt",
247
+ )
248
+ inputs = inputs.to(self.device).to(self.model.dtype)
249
+
250
+ with torch.inference_mode():
251
+ # model.encode expects (B, T) and (B, T)
252
+ enc = self.model.encode(
253
+ inputs["input_values"].squeeze(1),
254
+ inputs["padding_mask"].squeeze(1),
255
+ return_dict=return_dict,
256
+ )
257
+ return enc
258
+
259
+ def decode(
260
+ self,
261
+ encoded,
262
+ ) -> Tuple[List[np.ndarray], int]:
263
+ """
264
+ Decode back to waveform.
265
+
266
+ Usage:
267
+ 1) Pass the raw output of `encode(...)` directly (recommended).
268
+ - 25Hz: expects fields audio_codes, xvectors, ref_mels
269
+ - 12Hz: expects field audio_codes
270
+ 2) Pass a dict or list[dict] (minimal form) for custom pipelines:
271
+ - 25Hz dict keys: {"audio_codes", "xvectors", "ref_mels"}
272
+ - 12Hz dict keys: {"audio_codes"}
273
+ Values can be torch tensors or numpy arrays.
274
+
275
+ Args:
276
+ encoded (Any):
277
+ - ModelOutput returned by `encode()`, OR
278
+ - dict, OR
279
+ - list[dict]
280
+
281
+ Returns:
282
+ Tuple[List[np.ndarray], int]:
283
+ - wavs: list of 1-D float32 numpy arrays
284
+ - sample_rate: int, model output sampling rate
285
+ """
286
+ model_type = self.model.get_model_type()
287
+
288
+ def _to_tensor(x, dtype=None):
289
+ if isinstance(x, torch.Tensor):
290
+ return x
291
+ x = np.asarray(x)
292
+ t = torch.from_numpy(x)
293
+ if dtype is not None:
294
+ t = t.to(dtype)
295
+ return t
296
+
297
+ # Normalize `encoded` into the same shapes as the official demo uses.
298
+ if hasattr(encoded, "audio_codes"):
299
+ # ModelOutput from encode()
300
+ audio_codes_list = encoded.audio_codes
301
+ xvectors_list = getattr(encoded, "xvectors", None)
302
+ ref_mels_list = getattr(encoded, "ref_mels", None)
303
+ elif isinstance(encoded, dict):
304
+ audio_codes_list = encoded["audio_codes"]
305
+ xvectors_list = encoded.get("xvectors", None)
306
+ ref_mels_list = encoded.get("ref_mels", None)
307
+ elif isinstance(encoded, list):
308
+ # list of dicts
309
+ audio_codes_list = [e["audio_codes"] for e in encoded]
310
+ xvectors_list = [e["xvectors"] for e in encoded] if ("xvectors" in encoded[0]) else None
311
+ ref_mels_list = [e["ref_mels"] for e in encoded] if ("ref_mels" in encoded[0]) else None
312
+ else:
313
+ raise TypeError("`encoded` must be an encode output, a dict, or a list of dicts.")
314
+
315
+ # Ensure list form for per-sample tensors
316
+ if isinstance(audio_codes_list, torch.Tensor):
317
+ # Could be a single sample tensor or an already padded batch tensor.
318
+ t = audio_codes_list
319
+ if t.dim() == 1:
320
+ # 25Hz single sample: (C,) -> (1, C)
321
+ t = t.unsqueeze(0)
322
+ elif t.dim() == 2:
323
+ # 12Hz single sample: (C, Q) -> (1, C, Q)
324
+ t = t.unsqueeze(0)
325
+ audio_codes_padded = t.to(self.device)
326
+ else:
327
+ # List[Tensor/np]
328
+ audio_codes_list = [_to_tensor(c, dtype=torch.long) for c in audio_codes_list]
329
+ audio_codes_padded = pad_sequence(audio_codes_list, batch_first=True, padding_value=0).to(self.device)
330
+
331
+ with torch.inference_mode():
332
+ if model_type == "qwen3_tts_tokenizer_25hz":
333
+ if xvectors_list is None or ref_mels_list is None:
334
+ raise ValueError("25Hz decode requires `xvectors` and `ref_mels`.")
335
+
336
+ if isinstance(xvectors_list, torch.Tensor):
337
+ xvectors_batch = xvectors_list
338
+ if xvectors_batch.dim() == 1: # (D,) -> (1, D)
339
+ xvectors_batch = xvectors_batch.unsqueeze(0)
340
+ xvectors_batch = xvectors_batch.to(self.device).to(self.model.dtype)
341
+ else:
342
+ xvectors_list = [_to_tensor(x, dtype=torch.float32) for x in xvectors_list]
343
+ xvectors_batch = torch.stack(xvectors_list, dim=0).to(self.device).to(self.model.dtype)
344
+
345
+ if isinstance(ref_mels_list, torch.Tensor):
346
+ ref_mels_padded = ref_mels_list
347
+ if ref_mels_padded.dim() == 2: # (T, M) -> (1, T, M)
348
+ ref_mels_padded = ref_mels_padded.unsqueeze(0)
349
+ ref_mels_padded = ref_mels_padded.to(self.device).to(self.model.dtype)
350
+ else:
351
+ ref_mels_list = [_to_tensor(m, dtype=torch.float32) for m in ref_mels_list]
352
+ ref_mels_padded = pad_sequence(ref_mels_list, batch_first=True, padding_value=0).to(self.device).to(self.model.dtype)
353
+
354
+ dec = self.model.decode(audio_codes_padded, xvectors_batch, ref_mels_padded, return_dict=True)
355
+ wav_tensors = dec.audio_values
356
+
357
+ elif model_type == "qwen3_tts_tokenizer_12hz":
358
+ dec = self.model.decode(audio_codes_padded, return_dict=True)
359
+ wav_tensors = dec.audio_values
360
+
361
+ else:
362
+ raise ValueError(f"Unknown model type: {model_type}")
363
+
364
+ wavs = [w.to(torch.float32).detach().cpu().numpy() for w in wav_tensors]
365
+ return wavs, int(self.model.get_output_sample_rate())
366
+
367
+ def get_model_type(self) -> str:
368
+ """
369
+ Get the underlying tokenizer model type.
370
+
371
+ Returns:
372
+ str: Model type string from `self.model.config.model_type`
373
+ (e.g. "qwen3_tts_tokenizer_25hz" / "qwen3_tts_tokenizer_12hz").
374
+ """
375
+ return self.model.get_model_type()
376
+
377
+ def get_input_sample_rate(self) -> int:
378
+ """
379
+ Get the expected input sample rate for encoding.
380
+
381
+ Returns:
382
+ int: Input sample rate (Hz).
383
+ """
384
+ return int(self.model.get_input_sample_rate())
385
+
386
+ def get_output_sample_rate(self) -> int:
387
+ """
388
+ Get the output sample rate for decoded waveforms.
389
+
390
+ Returns:
391
+ int: Output sample rate (Hz).
392
+ """
393
+ return int(self.model.get_output_sample_rate())
394
+
395
+ def get_encode_downsample_rate(self) -> int:
396
+ """
397
+ Get the encoder downsample rate (waveform samples per code step).
398
+
399
+ Returns:
400
+ int: Encode downsample rate.
401
+ """
402
+ return int(self.model.get_encode_downsample_rate())
403
+
404
+ def get_decode_upsample_rate(self) -> int:
405
+ """
406
+ Get the decoder upsample rate (waveform samples per code step).
407
+
408
+ Returns:
409
+ int: Decode upsample rate.
410
+ """
411
+ return int(self.model.get_decode_upsample_rate())
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Qwen3-TTS Dependencies for HuggingFace Spaces
2
+ transformers==4.57.3
3
+ accelerate==1.12.0
4
+ einops
5
+ gradio
6
+ librosa
7
+ torchaudio
8
+ soundfile
9
+ sox
10
+ onnxruntime
11
+ spaces
12
+ torch
13
+ numpy