littlebird13 multimodalart HF Staff commited on
Commit
8a13284
·
verified ·
1 Parent(s): bb80b9a

flash-attention-3 (#16)

Browse files

- Update app.py (04a9ea9b96182c6852b961b3d43de4386cbf3c39)
- Upload 18 files (850bacaf6cc310d95a68f832b6c0e446d936a049)
- Update requirements.txt (6dc80d8c9808285dcc9ba2425e73901988ea0cc7)


Co-authored-by: Apolinário from multimodal AI art <multimodalart@users.noreply.huggingface.co>

app.py CHANGED
@@ -8,39 +8,94 @@ import spaces
8
  import gradio as gr
9
  import numpy as np
10
  import torch
11
- from huggingface_hub import snapshot_download
 
12
 
13
- from huggingface_hub import login
14
  HF_TOKEN = os.environ.get('HF_TOKEN')
15
  login(token=HF_TOKEN)
16
 
17
- # Global model holders - keyed by (model_type, model_size)
18
- loaded_models = {}
19
-
20
  # Model size options
21
  MODEL_SIZES = ["0.6B", "1.7B"]
22
 
 
 
 
 
 
 
23
 
24
  def get_model_path(model_type: str, model_size: str) -> str:
25
  """Get model path based on type and size."""
26
  return snapshot_download(f"Qwen/Qwen3-TTS-12Hz-{model_size}-{model_type}")
27
 
28
 
29
- def get_model(model_type: str, model_size: str):
30
- """Get or load a model by type and size."""
31
- global loaded_models
32
- key = (model_type, model_size)
33
- if key not in loaded_models:
34
- from qwen_tts import Qwen3TTSModel
35
- model_path = get_model_path(model_type, model_size)
36
- loaded_models[key] = Qwen3TTSModel.from_pretrained(
37
- model_path,
38
- device_map="cuda",
39
- dtype=torch.bfloat16,
40
- token=HF_TOKEN,
41
- # attn_implementation="flash_attention_2",
42
- )
43
- return loaded_models[key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
  def _normalize_audio(wav, eps=1e-12, clip=True):
@@ -89,15 +144,8 @@ def _audio_to_tuple(audio):
89
  return None
90
 
91
 
92
- # Speaker and language choices for CustomVoice model
93
- SPEAKERS = [
94
- "Aiden", "Dylan", "Eric", "Ono_anna", "Ryan", "Serena", "Sohee", "Uncle_fu", "Vivian"
95
- ]
96
- LANGUAGES = ["Auto", "Chinese", "English", "Japanese", "Korean", "French", "German", "Spanish", "Portuguese", "Russian"]
97
-
98
-
99
  @spaces.GPU(duration=60)
100
- def generate_voice_design(text, language, voice_description):
101
  """Generate speech using Voice Design model (1.7B only)."""
102
  if not text or not text.strip():
103
  return None, "Error: Text is required."
@@ -105,8 +153,7 @@ def generate_voice_design(text, language, voice_description):
105
  return None, "Error: Voice description is required."
106
 
107
  try:
108
- tts = get_model("VoiceDesign", "1.7B")
109
- wavs, sr = tts.generate_voice_design(
110
  text=text.strip(),
111
  language=language,
112
  instruct=voice_description.strip(),
@@ -119,7 +166,7 @@ def generate_voice_design(text, language, voice_description):
119
 
120
 
121
  @spaces.GPU(duration=60)
122
- def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector_only, model_size):
123
  """Generate speech using Base (Voice Clone) model."""
124
  if not target_text or not target_text.strip():
125
  return None, "Error: Target text is required."
@@ -132,7 +179,7 @@ def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector
132
  return None, "Error: Reference text is required when 'Use x-vector only' is not enabled."
133
 
134
  try:
135
- tts = get_model("Base", model_size)
136
  wavs, sr = tts.generate_voice_clone(
137
  text=target_text.strip(),
138
  language=language,
@@ -147,7 +194,7 @@ def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector
147
 
148
 
149
  @spaces.GPU(duration=60)
150
- def generate_custom_voice(text, language, speaker, instruct, model_size):
151
  """Generate speech using CustomVoice model."""
152
  if not text or not text.strip():
153
  return None, "Error: Text is required."
@@ -155,7 +202,7 @@ def generate_custom_voice(text, language, speaker, instruct, model_size):
155
  return None, "Error: Speaker is required."
156
 
157
  try:
158
- tts = get_model("CustomVoice", model_size)
159
  wavs, sr = tts.generate_custom_voice(
160
  text=text.strip(),
161
  language=language,
@@ -184,12 +231,10 @@ def build_ui():
184
  gr.Markdown(
185
  """
186
  # Qwen3-TTS Demo
187
-
188
  A unified Text-to-Speech demo featuring three powerful modes:
189
  - **Voice Design**: Create custom voices using natural language descriptions
190
  - **Voice Clone (Base)**: Clone any voice from a reference audio
191
  - **TTS (CustomVoice)**: Generate speech with predefined speakers and optional style instructions
192
-
193
  Built with [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) by Alibaba Qwen Team.
194
  """
195
  )
@@ -331,7 +376,6 @@ Built with [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) by Alibaba Qwen Team
331
  gr.Markdown(
332
  """
333
  ---
334
-
335
  **Note**: This demo uses HuggingFace Spaces Zero GPU. Each generation has a time limit.
336
  For longer texts, please split them into smaller segments.
337
  """
@@ -342,4 +386,4 @@ For longer texts, please split them into smaller segments.
342
 
343
  if __name__ == "__main__":
344
  demo = build_ui()
345
- demo.launch()
 
8
  import gradio as gr
9
  import numpy as np
10
  import torch
11
+ from huggingface_hub import snapshot_download, login
12
+ from qwen_tts import Qwen3TTSModel
13
 
 
14
  HF_TOKEN = os.environ.get('HF_TOKEN')
15
  login(token=HF_TOKEN)
16
 
 
 
 
17
  # Model size options
18
  MODEL_SIZES = ["0.6B", "1.7B"]
19
 
20
+ # Speaker and language choices for CustomVoice model
21
+ SPEAKERS = [
22
+ "Aiden", "Dylan", "Eric", "Ono_anna", "Ryan", "Serena", "Sohee", "Uncle_fu", "Vivian"
23
+ ]
24
+ LANGUAGES = ["Auto", "Chinese", "English", "Japanese", "Korean", "French", "German", "Spanish", "Portuguese", "Russian"]
25
+
26
 
27
  def get_model_path(model_type: str, model_size: str) -> str:
28
  """Get model path based on type and size."""
29
  return snapshot_download(f"Qwen/Qwen3-TTS-12Hz-{model_size}-{model_type}")
30
 
31
 
32
+ # ============================================================================
33
+ # GLOBAL MODEL LOADING - Load all models at startup
34
+ # ============================================================================
35
+ print("Loading all models to CUDA...")
36
+
37
+ # Voice Design model (1.7B only)
38
+ print("Loading VoiceDesign 1.7B model...")
39
+ voice_design_model = Qwen3TTSModel.from_pretrained(
40
+ get_model_path("VoiceDesign", "1.7B"),
41
+ device_map="cuda",
42
+ dtype=torch.bfloat16,
43
+ token=HF_TOKEN,
44
+ attn_implementation="kernels-community/flash-attn3",
45
+ )
46
+
47
+ # Base (Voice Clone) models - both sizes
48
+ print("Loading Base 0.6B model...")
49
+ base_model_0_6b = Qwen3TTSModel.from_pretrained(
50
+ get_model_path("Base", "0.6B"),
51
+ device_map="cuda",
52
+ dtype=torch.bfloat16,
53
+ token=HF_TOKEN,
54
+ attn_implementation="kernels-community/flash-attn3",
55
+ )
56
+
57
+ print("Loading Base 1.7B model...")
58
+ base_model_1_7b = Qwen3TTSModel.from_pretrained(
59
+ get_model_path("Base", "1.7B"),
60
+ device_map="cuda",
61
+ dtype=torch.bfloat16,
62
+ token=HF_TOKEN,
63
+ attn_implementation="kernels-community/flash-attn3",
64
+ )
65
+
66
+ # CustomVoice models - both sizes
67
+ print("Loading CustomVoice 0.6B model...")
68
+ custom_voice_model_0_6b = Qwen3TTSModel.from_pretrained(
69
+ get_model_path("CustomVoice", "0.6B"),
70
+ device_map="cuda",
71
+ dtype=torch.bfloat16,
72
+ token=HF_TOKEN,
73
+ attn_implementation="kernels-community/flash-attn3",
74
+ )
75
+
76
+ print("Loading CustomVoice 1.7B model...")
77
+ custom_voice_model_1_7b = Qwen3TTSModel.from_pretrained(
78
+ get_model_path("CustomVoice", "1.7B"),
79
+ device_map="cuda",
80
+ dtype=torch.bfloat16,
81
+ token=HF_TOKEN,
82
+ attn_implementation="kernels-community/flash-attn3",
83
+ )
84
+
85
+ print("All models loaded successfully!")
86
+
87
+ # Model lookup dictionaries for easy access
88
+ BASE_MODELS = {
89
+ "0.6B": base_model_0_6b,
90
+ "1.7B": base_model_1_7b,
91
+ }
92
+
93
+ CUSTOM_VOICE_MODELS = {
94
+ "0.6B": custom_voice_model_0_6b,
95
+ "1.7B": custom_voice_model_1_7b,
96
+ }
97
+
98
+ # ============================================================================
99
 
100
 
101
  def _normalize_audio(wav, eps=1e-12, clip=True):
 
144
  return None
145
 
146
 
 
 
 
 
 
 
 
147
  @spaces.GPU(duration=60)
148
+ def generate_voice_design(text, language, voice_description, progress=gr.Progress(track_tqdm=True)):
149
  """Generate speech using Voice Design model (1.7B only)."""
150
  if not text or not text.strip():
151
  return None, "Error: Text is required."
 
153
  return None, "Error: Voice description is required."
154
 
155
  try:
156
+ wavs, sr = voice_design_model.generate_voice_design(
 
157
  text=text.strip(),
158
  language=language,
159
  instruct=voice_description.strip(),
 
166
 
167
 
168
  @spaces.GPU(duration=60)
169
+ def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector_only, model_size, progress=gr.Progress(track_tqdm=True)):
170
  """Generate speech using Base (Voice Clone) model."""
171
  if not target_text or not target_text.strip():
172
  return None, "Error: Target text is required."
 
179
  return None, "Error: Reference text is required when 'Use x-vector only' is not enabled."
180
 
181
  try:
182
+ tts = BASE_MODELS[model_size]
183
  wavs, sr = tts.generate_voice_clone(
184
  text=target_text.strip(),
185
  language=language,
 
194
 
195
 
196
  @spaces.GPU(duration=60)
197
+ def generate_custom_voice(text, language, speaker, instruct, model_size, progress=gr.Progress(track_tqdm=True)):
198
  """Generate speech using CustomVoice model."""
199
  if not text or not text.strip():
200
  return None, "Error: Text is required."
 
202
  return None, "Error: Speaker is required."
203
 
204
  try:
205
+ tts = CUSTOM_VOICE_MODELS[model_size]
206
  wavs, sr = tts.generate_custom_voice(
207
  text=text.strip(),
208
  language=language,
 
231
  gr.Markdown(
232
  """
233
  # Qwen3-TTS Demo
 
234
  A unified Text-to-Speech demo featuring three powerful modes:
235
  - **Voice Design**: Create custom voices using natural language descriptions
236
  - **Voice Clone (Base)**: Clone any voice from a reference audio
237
  - **TTS (CustomVoice)**: Generate speech with predefined speakers and optional style instructions
 
238
  Built with [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) by Alibaba Qwen Team.
239
  """
240
  )
 
376
  gr.Markdown(
377
  """
378
  ---
 
379
  **Note**: This demo uses HuggingFace Spaces Zero GPU. Each generation has a time limit.
380
  For longer texts, please split them into smaller segments.
381
  """
 
386
 
387
  if __name__ == "__main__":
388
  demo = build_ui()
389
+ demo.launch()
qwen_tts/__init__.py CHANGED
@@ -21,5 +21,4 @@ qwen_tts: Qwen-TTS package.
21
  from .inference.qwen3_tts_model import Qwen3TTSModel, VoiceClonePromptItem
22
  from .inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
23
 
24
- __all__ = ["__version__"]
25
- __version__ = "0.0.1"
 
21
  from .inference.qwen3_tts_model import Qwen3TTSModel, VoiceClonePromptItem
22
  from .inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
23
 
24
+ __all__ = ["__version__"]
 
qwen_tts/cli/demo.py CHANGED
@@ -146,9 +146,11 @@ def build_parser() -> argparse.ArgumentParser:
146
  help="Path to SSL key file for HTTPS (optional).",
147
  )
148
  parser.add_argument(
149
- "--ssl-verify",
150
- default=None,
151
- help="SSL verify setting for Gradio (optional).",
 
 
152
  )
153
 
154
  # Optional generation args
@@ -617,13 +619,12 @@ def main(argv=None) -> int:
617
  server_name=args.ip,
618
  server_port=args.port,
619
  share=args.share,
 
620
  )
621
  if args.ssl_certfile is not None:
622
  launch_kwargs["ssl_certfile"] = args.ssl_certfile
623
  if args.ssl_keyfile is not None:
624
  launch_kwargs["ssl_keyfile"] = args.ssl_keyfile
625
- if args.ssl_verify is not None:
626
- launch_kwargs["ssl_verify"] = args.ssl_verify
627
 
628
  demo.queue(default_concurrency_limit=int(args.concurrency)).launch(**launch_kwargs)
629
  return 0
 
146
  help="Path to SSL key file for HTTPS (optional).",
147
  )
148
  parser.add_argument(
149
+ "--ssl-verify/--no-ssl-verify",
150
+ dest="ssl_verify",
151
+ default=True,
152
+ action=argparse.BooleanOptionalAction,
153
+ help="Whether to verify SSL certificate (default: enabled).",
154
  )
155
 
156
  # Optional generation args
 
619
  server_name=args.ip,
620
  server_port=args.port,
621
  share=args.share,
622
+ ssl_verify=True if args.ssl_verify else False,
623
  )
624
  if args.ssl_certfile is not None:
625
  launch_kwargs["ssl_certfile"] = args.ssl_certfile
626
  if args.ssl_keyfile is not None:
627
  launch_kwargs["ssl_keyfile"] = args.ssl_keyfile
 
 
628
 
629
  demo.queue(default_concurrency_limit=int(args.concurrency)).launch(**launch_kwargs)
630
  return 0
qwen_tts/core/models/modeling_qwen3_tts.py CHANGED
@@ -19,7 +19,9 @@ import os
19
  from dataclasses import dataclass
20
  from typing import Callable, Optional
21
 
 
22
  import torch
 
23
  from librosa.filters import mel as librosa_mel_fn
24
  from torch import nn
25
  from torch.nn import functional as F
@@ -27,34 +29,69 @@ from transformers.activations import ACT2FN
27
  from transformers.cache_utils import Cache, DynamicCache
28
  from transformers.generation import GenerationMixin
29
  from transformers.integrations import use_kernel_forward_from_hub
30
- from transformers.masking_utils import (
31
- create_causal_mask,
32
- create_sliding_window_causal_mask,
33
- )
34
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
  from transformers.modeling_layers import GradientCheckpointingLayer
36
- from transformers.modeling_outputs import (
37
- BaseModelOutputWithPast,
38
- CausalLMOutputWithPast,
39
- ModelOutput,
40
- )
41
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
42
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
43
  from transformers.processing_utils import Unpack
44
  from transformers.utils import can_return_tuple, logging
45
  from transformers.utils.hub import cached_file
46
 
47
  from ...inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
48
- from .configuration_qwen3_tts import (
49
- Qwen3TTSConfig,
50
- Qwen3TTSSpeakerEncoderConfig,
51
- Qwen3TTSTalkerCodePredictorConfig,
52
- Qwen3TTSTalkerConfig,
53
- )
54
 
55
  logger = logging.get_logger(__name__)
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  class Res2NetBlock(torch.nn.Module):
59
  def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
60
  super().__init__()
@@ -433,7 +470,7 @@ class Qwen3TTSPreTrainedModel(PreTrainedModel):
433
  supports_gradient_checkpointing = True
434
  _no_split_modules = ["Qwen3TTSDecoderLayer"]
435
  _skip_keys_device_placement = "past_key_values"
436
- _supports_flash_attn_2 = True
437
  _supports_sdpa = True
438
  _supports_cache_class = True
439
  _supports_static_cache = False
@@ -464,8 +501,7 @@ class Qwen3TTSTalkerTextPreTrainedModel(PreTrainedModel):
464
  supports_gradient_checkpointing = True
465
  _no_split_modules = []
466
  _skip_keys_device_placement = ["past_key_values"]
467
- _supports_flash_attn_3 = True
468
- _supports_flash_attn_2 = True
469
  _supports_sdpa = True
470
  _supports_flex_attn = True
471
  _supports_cache_class = True
@@ -1178,6 +1214,8 @@ class Qwen3TTSTalkerCodePredictorModelForConditionalGeneration(Qwen3TTSPreTraine
1178
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1179
  )
1180
 
 
 
1181
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1182
  outputs: BaseModelOutputWithPast = self.model(
1183
  input_ids=None,
@@ -1830,6 +1868,11 @@ class Qwen3TTSForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin)
1830
  weights_only=True,
1831
  **kwargs,
1832
  ):
 
 
 
 
 
1833
  model = super().from_pretrained(
1834
  pretrained_model_name_or_path,
1835
  *model_args,
@@ -1842,8 +1885,18 @@ class Qwen3TTSForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin)
1842
  revision=revision,
1843
  use_safetensors=use_safetensors,
1844
  weights_only=weights_only,
 
1845
  **kwargs,
1846
  )
 
 
 
 
 
 
 
 
 
1847
  speech_tokenizer_path = cached_file(
1848
  pretrained_model_name_or_path,
1849
  "speech_tokenizer/config.json",
 
19
  from dataclasses import dataclass
20
  from typing import Callable, Optional
21
 
22
+ import huggingface_hub
23
  import torch
24
+ from huggingface_hub import snapshot_download
25
  from librosa.filters import mel as librosa_mel_fn
26
  from torch import nn
27
  from torch.nn import functional as F
 
29
  from transformers.cache_utils import Cache, DynamicCache
30
  from transformers.generation import GenerationMixin
31
  from transformers.integrations import use_kernel_forward_from_hub
32
+ from transformers.masking_utils import (create_causal_mask,
33
+ create_sliding_window_causal_mask)
 
 
34
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
  from transformers.modeling_layers import GradientCheckpointingLayer
36
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
37
+ CausalLMOutputWithPast, ModelOutput)
38
+ from transformers.modeling_rope_utils import (ROPE_INIT_FUNCTIONS,
39
+ dynamic_rope_update)
40
+ from transformers.modeling_utils import (ALL_ATTENTION_FUNCTIONS,
41
+ PreTrainedModel)
 
42
  from transformers.processing_utils import Unpack
43
  from transformers.utils import can_return_tuple, logging
44
  from transformers.utils.hub import cached_file
45
 
46
  from ...inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
47
+ from .configuration_qwen3_tts import (Qwen3TTSConfig,
48
+ Qwen3TTSSpeakerEncoderConfig,
49
+ Qwen3TTSTalkerCodePredictorConfig,
50
+ Qwen3TTSTalkerConfig)
 
 
51
 
52
  logger = logging.get_logger(__name__)
53
 
54
 
55
+ def download_weights_from_hf_specific(
56
+ model_name_or_path: str,
57
+ cache_dir: str | None,
58
+ allow_patterns: list[str],
59
+ revision: str | None = None,
60
+ ignore_patterns: str | list[str] | None = None,
61
+ ) -> str:
62
+ """Download model weights from Hugging Face Hub. Users can specify the
63
+ allow_patterns to download only the necessary weights.
64
+
65
+ Args:
66
+ model_name_or_path (str): The model name or path.
67
+ cache_dir (Optional[str]): The cache directory to store the model
68
+ weights. If None, will use HF defaults.
69
+ allow_patterns (list[str]): The allowed patterns for the
70
+ weight files. Files matched by any of the patterns will be
71
+ downloaded.
72
+ revision (Optional[str]): The revision of the model.
73
+ ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
74
+ filter out the weight files. Files matched by any of the patterns
75
+ will be ignored.
76
+
77
+ Returns:
78
+ str: The path to the downloaded model weights.
79
+ """
80
+ assert len(allow_patterns) > 0
81
+ local_only = huggingface_hub.constants.HF_HUB_OFFLINE
82
+
83
+ for allow_pattern in allow_patterns:
84
+ hf_folder = snapshot_download(
85
+ model_name_or_path,
86
+ allow_patterns=allow_pattern,
87
+ ignore_patterns=ignore_patterns,
88
+ cache_dir=cache_dir,
89
+ revision=revision,
90
+ local_files_only=local_only,
91
+ )
92
+ return hf_folder
93
+
94
+
95
  class Res2NetBlock(torch.nn.Module):
96
  def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
97
  super().__init__()
 
470
  supports_gradient_checkpointing = True
471
  _no_split_modules = ["Qwen3TTSDecoderLayer"]
472
  _skip_keys_device_placement = "past_key_values"
473
+ _supports_flash_attn = True
474
  _supports_sdpa = True
475
  _supports_cache_class = True
476
  _supports_static_cache = False
 
501
  supports_gradient_checkpointing = True
502
  _no_split_modules = []
503
  _skip_keys_device_placement = ["past_key_values"]
504
+ _supports_flash_attn = True
 
505
  _supports_sdpa = True
506
  _supports_flex_attn = True
507
  _supports_cache_class = True
 
1214
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1215
  )
1216
 
1217
+ inputs_embeds = self.small_to_mtp_projection(inputs_embeds)
1218
+
1219
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1220
  outputs: BaseModelOutputWithPast = self.model(
1221
  input_ids=None,
 
1868
  weights_only=True,
1869
  **kwargs,
1870
  ):
1871
+ # Hotfix to enable passing the correct attn implementation which is stored in the config but not in kwargs
1872
+ requested_attn_implementation = kwargs.pop("attn_implementation", None)
1873
+ if requested_attn_implementation is None and config and config._attn_implementation:
1874
+ requested_attn_implementation = config._attn_implementation
1875
+
1876
  model = super().from_pretrained(
1877
  pretrained_model_name_or_path,
1878
  *model_args,
 
1885
  revision=revision,
1886
  use_safetensors=use_safetensors,
1887
  weights_only=weights_only,
1888
+ attn_implementation=requested_attn_implementation,
1889
  **kwargs,
1890
  )
1891
+ if not local_files_only and not os.path.isdir(pretrained_model_name_or_path):
1892
+ download_cache_dir = kwargs.get("cache_dir", cache_dir)
1893
+ download_revision = kwargs.get("revision", revision)
1894
+ download_weights_from_hf_specific(
1895
+ pretrained_model_name_or_path,
1896
+ cache_dir=download_cache_dir,
1897
+ allow_patterns=["speech_tokenizer/*"],
1898
+ revision=download_revision,
1899
+ )
1900
  speech_tokenizer_path = cached_file(
1901
  pretrained_model_name_or_path,
1902
  "speech_tokenizer/config.json",
qwen_tts/inference/qwen3_tts_model.py CHANGED
@@ -286,7 +286,6 @@ class Qwen3TTSModel:
286
 
287
  def _merge_generate_kwargs(
288
  self,
289
- non_streaming_mode: Optional[bool] = None,
290
  do_sample: Optional[bool] = None,
291
  top_k: Optional[int] = None,
292
  top_p: Optional[float] = None,
@@ -308,7 +307,7 @@ class Qwen3TTSModel:
308
  - Otherwise, fall back to the hard defaults.
309
 
310
  Args:
311
- non_streaming_mode, do_sample, top_k, top_p, temperature, repetition_penalty,
312
  subtalker_dosample, subtalker_top_k, subtalker_top_p, subtalker_temperature, max_new_tokens:
313
  Common generation parameters.
314
  **kwargs:
@@ -318,7 +317,6 @@ class Qwen3TTSModel:
318
  Dict[str, Any]: Final kwargs to pass into model.generate().
319
  """
320
  hard_defaults = dict(
321
- non_streaming_mode=False,
322
  do_sample=True,
323
  top_k=50,
324
  top_p=1.0,
@@ -340,7 +338,6 @@ class Qwen3TTSModel:
340
 
341
  merged = dict(kwargs)
342
  merged.update(
343
- non_streaming_mode=pick("non_streaming_mode", non_streaming_mode),
344
  do_sample=pick("do_sample", do_sample),
345
  top_k=pick("top_k", top_k),
346
  top_p=pick("top_p", top_p),
@@ -478,6 +475,7 @@ class Qwen3TTSModel:
478
  ref_text: Optional[Union[str, List[Optional[str]]]] = None,
479
  x_vector_only_mode: Union[bool, List[bool]] = False,
480
  voice_clone_prompt: Optional[Union[Dict[str, Any], List[VoiceClonePromptItem]]] = None,
 
481
  **kwargs,
482
  ) -> Tuple[List[np.ndarray], int]:
483
  """
@@ -607,6 +605,7 @@ class Qwen3TTSModel:
607
  ref_ids=ref_ids,
608
  voice_clone_prompt=voice_clone_prompt_dict,
609
  languages=languages,
 
610
  **gen_kwargs,
611
  )
612
 
@@ -640,6 +639,7 @@ class Qwen3TTSModel:
640
  text: Union[str, List[str]],
641
  instruct: Union[str, List[str]],
642
  language: Union[str, List[str]] = None,
 
643
  **kwargs,
644
  ) -> Tuple[List[np.ndarray], int]:
645
  """
@@ -720,6 +720,7 @@ class Qwen3TTSModel:
720
  input_ids=input_ids,
721
  instruct_ids=instruct_ids,
722
  languages=languages,
 
723
  **gen_kwargs,
724
  )
725
 
@@ -734,6 +735,7 @@ class Qwen3TTSModel:
734
  speaker: Union[str, List[str]],
735
  language: Union[str, List[str]] = None,
736
  instruct: Optional[Union[str, List[str]]] = None,
 
737
  **kwargs,
738
  ) -> Tuple[List[np.ndarray], int]:
739
  """
@@ -829,6 +831,7 @@ class Qwen3TTSModel:
829
  instruct_ids=instruct_ids,
830
  languages=languages,
831
  speakers=speakers,
 
832
  **gen_kwargs,
833
  )
834
 
 
286
 
287
  def _merge_generate_kwargs(
288
  self,
 
289
  do_sample: Optional[bool] = None,
290
  top_k: Optional[int] = None,
291
  top_p: Optional[float] = None,
 
307
  - Otherwise, fall back to the hard defaults.
308
 
309
  Args:
310
+ do_sample, top_k, top_p, temperature, repetition_penalty,
311
  subtalker_dosample, subtalker_top_k, subtalker_top_p, subtalker_temperature, max_new_tokens:
312
  Common generation parameters.
313
  **kwargs:
 
317
  Dict[str, Any]: Final kwargs to pass into model.generate().
318
  """
319
  hard_defaults = dict(
 
320
  do_sample=True,
321
  top_k=50,
322
  top_p=1.0,
 
338
 
339
  merged = dict(kwargs)
340
  merged.update(
 
341
  do_sample=pick("do_sample", do_sample),
342
  top_k=pick("top_k", top_k),
343
  top_p=pick("top_p", top_p),
 
475
  ref_text: Optional[Union[str, List[Optional[str]]]] = None,
476
  x_vector_only_mode: Union[bool, List[bool]] = False,
477
  voice_clone_prompt: Optional[Union[Dict[str, Any], List[VoiceClonePromptItem]]] = None,
478
+ non_streaming_mode: bool = False,
479
  **kwargs,
480
  ) -> Tuple[List[np.ndarray], int]:
481
  """
 
605
  ref_ids=ref_ids,
606
  voice_clone_prompt=voice_clone_prompt_dict,
607
  languages=languages,
608
+ non_streaming_mode=non_streaming_mode,
609
  **gen_kwargs,
610
  )
611
 
 
639
  text: Union[str, List[str]],
640
  instruct: Union[str, List[str]],
641
  language: Union[str, List[str]] = None,
642
+ non_streaming_mode: bool = True,
643
  **kwargs,
644
  ) -> Tuple[List[np.ndarray], int]:
645
  """
 
720
  input_ids=input_ids,
721
  instruct_ids=instruct_ids,
722
  languages=languages,
723
+ non_streaming_mode=non_streaming_mode,
724
  **gen_kwargs,
725
  )
726
 
 
735
  speaker: Union[str, List[str]],
736
  language: Union[str, List[str]] = None,
737
  instruct: Optional[Union[str, List[str]]] = None,
738
+ non_streaming_mode: bool = True,
739
  **kwargs,
740
  ) -> Tuple[List[np.ndarray], int]:
741
  """
 
831
  instruct_ids=instruct_ids,
832
  languages=languages,
833
  speakers=speakers,
834
+ non_streaming_mode=non_streaming_mode,
835
  **gen_kwargs,
836
  )
837
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  # Qwen3-TTS Dependencies for HuggingFace Spaces
 
2
  transformers==4.57.3
3
  accelerate==1.12.0
4
  einops
@@ -10,4 +11,5 @@ sox
10
  onnxruntime
11
  spaces
12
  torch
13
- numpy
 
 
1
  # Qwen3-TTS Dependencies for HuggingFace Spaces
2
+ torch==2.8.0
3
  transformers==4.57.3
4
  accelerate==1.12.0
5
  einops
 
11
  onnxruntime
12
  spaces
13
  torch
14
+ numpy
15
+ kernels