flash-attention-3 (#16)
Browse files- Update app.py (04a9ea9b96182c6852b961b3d43de4386cbf3c39)
- Upload 18 files (850bacaf6cc310d95a68f832b6c0e446d936a049)
- Update requirements.txt (6dc80d8c9808285dcc9ba2425e73901988ea0cc7)
Co-authored-by: Apolinário from multimodal AI art <multimodalart@users.noreply.huggingface.co>
- app.py +82 -38
- qwen_tts/__init__.py +1 -2
- qwen_tts/cli/demo.py +6 -5
- qwen_tts/core/models/modeling_qwen3_tts.py +73 -20
- qwen_tts/inference/qwen3_tts_model.py +7 -4
- requirements.txt +3 -1
app.py
CHANGED
|
@@ -8,39 +8,94 @@ import spaces
|
|
| 8 |
import gradio as gr
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
-
from huggingface_hub import snapshot_download
|
|
|
|
| 12 |
|
| 13 |
-
from huggingface_hub import login
|
| 14 |
HF_TOKEN = os.environ.get('HF_TOKEN')
|
| 15 |
login(token=HF_TOKEN)
|
| 16 |
|
| 17 |
-
# Global model holders - keyed by (model_type, model_size)
|
| 18 |
-
loaded_models = {}
|
| 19 |
-
|
| 20 |
# Model size options
|
| 21 |
MODEL_SIZES = ["0.6B", "1.7B"]
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def get_model_path(model_type: str, model_size: str) -> str:
|
| 25 |
"""Get model path based on type and size."""
|
| 26 |
return snapshot_download(f"Qwen/Qwen3-TTS-12Hz-{model_size}-{model_type}")
|
| 27 |
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
def _normalize_audio(wav, eps=1e-12, clip=True):
|
|
@@ -89,15 +144,8 @@ def _audio_to_tuple(audio):
|
|
| 89 |
return None
|
| 90 |
|
| 91 |
|
| 92 |
-
# Speaker and language choices for CustomVoice model
|
| 93 |
-
SPEAKERS = [
|
| 94 |
-
"Aiden", "Dylan", "Eric", "Ono_anna", "Ryan", "Serena", "Sohee", "Uncle_fu", "Vivian"
|
| 95 |
-
]
|
| 96 |
-
LANGUAGES = ["Auto", "Chinese", "English", "Japanese", "Korean", "French", "German", "Spanish", "Portuguese", "Russian"]
|
| 97 |
-
|
| 98 |
-
|
| 99 |
@spaces.GPU(duration=60)
|
| 100 |
-
def generate_voice_design(text, language, voice_description):
|
| 101 |
"""Generate speech using Voice Design model (1.7B only)."""
|
| 102 |
if not text or not text.strip():
|
| 103 |
return None, "Error: Text is required."
|
|
@@ -105,8 +153,7 @@ def generate_voice_design(text, language, voice_description):
|
|
| 105 |
return None, "Error: Voice description is required."
|
| 106 |
|
| 107 |
try:
|
| 108 |
-
|
| 109 |
-
wavs, sr = tts.generate_voice_design(
|
| 110 |
text=text.strip(),
|
| 111 |
language=language,
|
| 112 |
instruct=voice_description.strip(),
|
|
@@ -119,7 +166,7 @@ def generate_voice_design(text, language, voice_description):
|
|
| 119 |
|
| 120 |
|
| 121 |
@spaces.GPU(duration=60)
|
| 122 |
-
def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector_only, model_size):
|
| 123 |
"""Generate speech using Base (Voice Clone) model."""
|
| 124 |
if not target_text or not target_text.strip():
|
| 125 |
return None, "Error: Target text is required."
|
|
@@ -132,7 +179,7 @@ def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector
|
|
| 132 |
return None, "Error: Reference text is required when 'Use x-vector only' is not enabled."
|
| 133 |
|
| 134 |
try:
|
| 135 |
-
tts =
|
| 136 |
wavs, sr = tts.generate_voice_clone(
|
| 137 |
text=target_text.strip(),
|
| 138 |
language=language,
|
|
@@ -147,7 +194,7 @@ def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector
|
|
| 147 |
|
| 148 |
|
| 149 |
@spaces.GPU(duration=60)
|
| 150 |
-
def generate_custom_voice(text, language, speaker, instruct, model_size):
|
| 151 |
"""Generate speech using CustomVoice model."""
|
| 152 |
if not text or not text.strip():
|
| 153 |
return None, "Error: Text is required."
|
|
@@ -155,7 +202,7 @@ def generate_custom_voice(text, language, speaker, instruct, model_size):
|
|
| 155 |
return None, "Error: Speaker is required."
|
| 156 |
|
| 157 |
try:
|
| 158 |
-
tts =
|
| 159 |
wavs, sr = tts.generate_custom_voice(
|
| 160 |
text=text.strip(),
|
| 161 |
language=language,
|
|
@@ -184,12 +231,10 @@ def build_ui():
|
|
| 184 |
gr.Markdown(
|
| 185 |
"""
|
| 186 |
# Qwen3-TTS Demo
|
| 187 |
-
|
| 188 |
A unified Text-to-Speech demo featuring three powerful modes:
|
| 189 |
- **Voice Design**: Create custom voices using natural language descriptions
|
| 190 |
- **Voice Clone (Base)**: Clone any voice from a reference audio
|
| 191 |
- **TTS (CustomVoice)**: Generate speech with predefined speakers and optional style instructions
|
| 192 |
-
|
| 193 |
Built with [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) by Alibaba Qwen Team.
|
| 194 |
"""
|
| 195 |
)
|
|
@@ -331,7 +376,6 @@ Built with [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) by Alibaba Qwen Team
|
|
| 331 |
gr.Markdown(
|
| 332 |
"""
|
| 333 |
---
|
| 334 |
-
|
| 335 |
**Note**: This demo uses HuggingFace Spaces Zero GPU. Each generation has a time limit.
|
| 336 |
For longer texts, please split them into smaller segments.
|
| 337 |
"""
|
|
@@ -342,4 +386,4 @@ For longer texts, please split them into smaller segments.
|
|
| 342 |
|
| 343 |
if __name__ == "__main__":
|
| 344 |
demo = build_ui()
|
| 345 |
-
demo.launch()
|
|
|
|
| 8 |
import gradio as gr
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
+
from huggingface_hub import snapshot_download, login
|
| 12 |
+
from qwen_tts import Qwen3TTSModel
|
| 13 |
|
|
|
|
| 14 |
HF_TOKEN = os.environ.get('HF_TOKEN')
|
| 15 |
login(token=HF_TOKEN)
|
| 16 |
|
|
|
|
|
|
|
|
|
|
| 17 |
# Model size options
|
| 18 |
MODEL_SIZES = ["0.6B", "1.7B"]
|
| 19 |
|
| 20 |
+
# Speaker and language choices for CustomVoice model
|
| 21 |
+
SPEAKERS = [
|
| 22 |
+
"Aiden", "Dylan", "Eric", "Ono_anna", "Ryan", "Serena", "Sohee", "Uncle_fu", "Vivian"
|
| 23 |
+
]
|
| 24 |
+
LANGUAGES = ["Auto", "Chinese", "English", "Japanese", "Korean", "French", "German", "Spanish", "Portuguese", "Russian"]
|
| 25 |
+
|
| 26 |
|
| 27 |
def get_model_path(model_type: str, model_size: str) -> str:
|
| 28 |
"""Get model path based on type and size."""
|
| 29 |
return snapshot_download(f"Qwen/Qwen3-TTS-12Hz-{model_size}-{model_type}")
|
| 30 |
|
| 31 |
|
| 32 |
+
# ============================================================================
|
| 33 |
+
# GLOBAL MODEL LOADING - Load all models at startup
|
| 34 |
+
# ============================================================================
|
| 35 |
+
print("Loading all models to CUDA...")
|
| 36 |
+
|
| 37 |
+
# Voice Design model (1.7B only)
|
| 38 |
+
print("Loading VoiceDesign 1.7B model...")
|
| 39 |
+
voice_design_model = Qwen3TTSModel.from_pretrained(
|
| 40 |
+
get_model_path("VoiceDesign", "1.7B"),
|
| 41 |
+
device_map="cuda",
|
| 42 |
+
dtype=torch.bfloat16,
|
| 43 |
+
token=HF_TOKEN,
|
| 44 |
+
attn_implementation="kernels-community/flash-attn3",
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Base (Voice Clone) models - both sizes
|
| 48 |
+
print("Loading Base 0.6B model...")
|
| 49 |
+
base_model_0_6b = Qwen3TTSModel.from_pretrained(
|
| 50 |
+
get_model_path("Base", "0.6B"),
|
| 51 |
+
device_map="cuda",
|
| 52 |
+
dtype=torch.bfloat16,
|
| 53 |
+
token=HF_TOKEN,
|
| 54 |
+
attn_implementation="kernels-community/flash-attn3",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
print("Loading Base 1.7B model...")
|
| 58 |
+
base_model_1_7b = Qwen3TTSModel.from_pretrained(
|
| 59 |
+
get_model_path("Base", "1.7B"),
|
| 60 |
+
device_map="cuda",
|
| 61 |
+
dtype=torch.bfloat16,
|
| 62 |
+
token=HF_TOKEN,
|
| 63 |
+
attn_implementation="kernels-community/flash-attn3",
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# CustomVoice models - both sizes
|
| 67 |
+
print("Loading CustomVoice 0.6B model...")
|
| 68 |
+
custom_voice_model_0_6b = Qwen3TTSModel.from_pretrained(
|
| 69 |
+
get_model_path("CustomVoice", "0.6B"),
|
| 70 |
+
device_map="cuda",
|
| 71 |
+
dtype=torch.bfloat16,
|
| 72 |
+
token=HF_TOKEN,
|
| 73 |
+
attn_implementation="kernels-community/flash-attn3",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
print("Loading CustomVoice 1.7B model...")
|
| 77 |
+
custom_voice_model_1_7b = Qwen3TTSModel.from_pretrained(
|
| 78 |
+
get_model_path("CustomVoice", "1.7B"),
|
| 79 |
+
device_map="cuda",
|
| 80 |
+
dtype=torch.bfloat16,
|
| 81 |
+
token=HF_TOKEN,
|
| 82 |
+
attn_implementation="kernels-community/flash-attn3",
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
print("All models loaded successfully!")
|
| 86 |
+
|
| 87 |
+
# Model lookup dictionaries for easy access
|
| 88 |
+
BASE_MODELS = {
|
| 89 |
+
"0.6B": base_model_0_6b,
|
| 90 |
+
"1.7B": base_model_1_7b,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
CUSTOM_VOICE_MODELS = {
|
| 94 |
+
"0.6B": custom_voice_model_0_6b,
|
| 95 |
+
"1.7B": custom_voice_model_1_7b,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
# ============================================================================
|
| 99 |
|
| 100 |
|
| 101 |
def _normalize_audio(wav, eps=1e-12, clip=True):
|
|
|
|
| 144 |
return None
|
| 145 |
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
@spaces.GPU(duration=60)
|
| 148 |
+
def generate_voice_design(text, language, voice_description, progress=gr.Progress(track_tqdm=True)):
|
| 149 |
"""Generate speech using Voice Design model (1.7B only)."""
|
| 150 |
if not text or not text.strip():
|
| 151 |
return None, "Error: Text is required."
|
|
|
|
| 153 |
return None, "Error: Voice description is required."
|
| 154 |
|
| 155 |
try:
|
| 156 |
+
wavs, sr = voice_design_model.generate_voice_design(
|
|
|
|
| 157 |
text=text.strip(),
|
| 158 |
language=language,
|
| 159 |
instruct=voice_description.strip(),
|
|
|
|
| 166 |
|
| 167 |
|
| 168 |
@spaces.GPU(duration=60)
|
| 169 |
+
def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector_only, model_size, progress=gr.Progress(track_tqdm=True)):
|
| 170 |
"""Generate speech using Base (Voice Clone) model."""
|
| 171 |
if not target_text or not target_text.strip():
|
| 172 |
return None, "Error: Target text is required."
|
|
|
|
| 179 |
return None, "Error: Reference text is required when 'Use x-vector only' is not enabled."
|
| 180 |
|
| 181 |
try:
|
| 182 |
+
tts = BASE_MODELS[model_size]
|
| 183 |
wavs, sr = tts.generate_voice_clone(
|
| 184 |
text=target_text.strip(),
|
| 185 |
language=language,
|
|
|
|
| 194 |
|
| 195 |
|
| 196 |
@spaces.GPU(duration=60)
|
| 197 |
+
def generate_custom_voice(text, language, speaker, instruct, model_size, progress=gr.Progress(track_tqdm=True)):
|
| 198 |
"""Generate speech using CustomVoice model."""
|
| 199 |
if not text or not text.strip():
|
| 200 |
return None, "Error: Text is required."
|
|
|
|
| 202 |
return None, "Error: Speaker is required."
|
| 203 |
|
| 204 |
try:
|
| 205 |
+
tts = CUSTOM_VOICE_MODELS[model_size]
|
| 206 |
wavs, sr = tts.generate_custom_voice(
|
| 207 |
text=text.strip(),
|
| 208 |
language=language,
|
|
|
|
| 231 |
gr.Markdown(
|
| 232 |
"""
|
| 233 |
# Qwen3-TTS Demo
|
|
|
|
| 234 |
A unified Text-to-Speech demo featuring three powerful modes:
|
| 235 |
- **Voice Design**: Create custom voices using natural language descriptions
|
| 236 |
- **Voice Clone (Base)**: Clone any voice from a reference audio
|
| 237 |
- **TTS (CustomVoice)**: Generate speech with predefined speakers and optional style instructions
|
|
|
|
| 238 |
Built with [Qwen3-TTS](https://github.com/QwenLM/Qwen3-TTS) by Alibaba Qwen Team.
|
| 239 |
"""
|
| 240 |
)
|
|
|
|
| 376 |
gr.Markdown(
|
| 377 |
"""
|
| 378 |
---
|
|
|
|
| 379 |
**Note**: This demo uses HuggingFace Spaces Zero GPU. Each generation has a time limit.
|
| 380 |
For longer texts, please split them into smaller segments.
|
| 381 |
"""
|
|
|
|
| 386 |
|
| 387 |
if __name__ == "__main__":
|
| 388 |
demo = build_ui()
|
| 389 |
+
demo.launch()
|
qwen_tts/__init__.py
CHANGED
|
@@ -21,5 +21,4 @@ qwen_tts: Qwen-TTS package.
|
|
| 21 |
from .inference.qwen3_tts_model import Qwen3TTSModel, VoiceClonePromptItem
|
| 22 |
from .inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
|
| 23 |
|
| 24 |
-
__all__ = ["__version__"]
|
| 25 |
-
__version__ = "0.0.1"
|
|
|
|
| 21 |
from .inference.qwen3_tts_model import Qwen3TTSModel, VoiceClonePromptItem
|
| 22 |
from .inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
|
| 23 |
|
| 24 |
+
__all__ = ["__version__"]
|
|
|
qwen_tts/cli/demo.py
CHANGED
|
@@ -146,9 +146,11 @@ def build_parser() -> argparse.ArgumentParser:
|
|
| 146 |
help="Path to SSL key file for HTTPS (optional).",
|
| 147 |
)
|
| 148 |
parser.add_argument(
|
| 149 |
-
"--ssl-verify",
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
| 152 |
)
|
| 153 |
|
| 154 |
# Optional generation args
|
|
@@ -617,13 +619,12 @@ def main(argv=None) -> int:
|
|
| 617 |
server_name=args.ip,
|
| 618 |
server_port=args.port,
|
| 619 |
share=args.share,
|
|
|
|
| 620 |
)
|
| 621 |
if args.ssl_certfile is not None:
|
| 622 |
launch_kwargs["ssl_certfile"] = args.ssl_certfile
|
| 623 |
if args.ssl_keyfile is not None:
|
| 624 |
launch_kwargs["ssl_keyfile"] = args.ssl_keyfile
|
| 625 |
-
if args.ssl_verify is not None:
|
| 626 |
-
launch_kwargs["ssl_verify"] = args.ssl_verify
|
| 627 |
|
| 628 |
demo.queue(default_concurrency_limit=int(args.concurrency)).launch(**launch_kwargs)
|
| 629 |
return 0
|
|
|
|
| 146 |
help="Path to SSL key file for HTTPS (optional).",
|
| 147 |
)
|
| 148 |
parser.add_argument(
|
| 149 |
+
"--ssl-verify/--no-ssl-verify",
|
| 150 |
+
dest="ssl_verify",
|
| 151 |
+
default=True,
|
| 152 |
+
action=argparse.BooleanOptionalAction,
|
| 153 |
+
help="Whether to verify SSL certificate (default: enabled).",
|
| 154 |
)
|
| 155 |
|
| 156 |
# Optional generation args
|
|
|
|
| 619 |
server_name=args.ip,
|
| 620 |
server_port=args.port,
|
| 621 |
share=args.share,
|
| 622 |
+
ssl_verify=True if args.ssl_verify else False,
|
| 623 |
)
|
| 624 |
if args.ssl_certfile is not None:
|
| 625 |
launch_kwargs["ssl_certfile"] = args.ssl_certfile
|
| 626 |
if args.ssl_keyfile is not None:
|
| 627 |
launch_kwargs["ssl_keyfile"] = args.ssl_keyfile
|
|
|
|
|
|
|
| 628 |
|
| 629 |
demo.queue(default_concurrency_limit=int(args.concurrency)).launch(**launch_kwargs)
|
| 630 |
return 0
|
qwen_tts/core/models/modeling_qwen3_tts.py
CHANGED
|
@@ -19,7 +19,9 @@ import os
|
|
| 19 |
from dataclasses import dataclass
|
| 20 |
from typing import Callable, Optional
|
| 21 |
|
|
|
|
| 22 |
import torch
|
|
|
|
| 23 |
from librosa.filters import mel as librosa_mel_fn
|
| 24 |
from torch import nn
|
| 25 |
from torch.nn import functional as F
|
|
@@ -27,34 +29,69 @@ from transformers.activations import ACT2FN
|
|
| 27 |
from transformers.cache_utils import Cache, DynamicCache
|
| 28 |
from transformers.generation import GenerationMixin
|
| 29 |
from transformers.integrations import use_kernel_forward_from_hub
|
| 30 |
-
from transformers.masking_utils import (
|
| 31 |
-
|
| 32 |
-
create_sliding_window_causal_mask,
|
| 33 |
-
)
|
| 34 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 35 |
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 36 |
-
from transformers.modeling_outputs import (
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 43 |
from transformers.processing_utils import Unpack
|
| 44 |
from transformers.utils import can_return_tuple, logging
|
| 45 |
from transformers.utils.hub import cached_file
|
| 46 |
|
| 47 |
from ...inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
|
| 48 |
-
from .configuration_qwen3_tts import (
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
Qwen3TTSTalkerConfig,
|
| 53 |
-
)
|
| 54 |
|
| 55 |
logger = logging.get_logger(__name__)
|
| 56 |
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
class Res2NetBlock(torch.nn.Module):
|
| 59 |
def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
|
| 60 |
super().__init__()
|
|
@@ -433,7 +470,7 @@ class Qwen3TTSPreTrainedModel(PreTrainedModel):
|
|
| 433 |
supports_gradient_checkpointing = True
|
| 434 |
_no_split_modules = ["Qwen3TTSDecoderLayer"]
|
| 435 |
_skip_keys_device_placement = "past_key_values"
|
| 436 |
-
|
| 437 |
_supports_sdpa = True
|
| 438 |
_supports_cache_class = True
|
| 439 |
_supports_static_cache = False
|
|
@@ -464,8 +501,7 @@ class Qwen3TTSTalkerTextPreTrainedModel(PreTrainedModel):
|
|
| 464 |
supports_gradient_checkpointing = True
|
| 465 |
_no_split_modules = []
|
| 466 |
_skip_keys_device_placement = ["past_key_values"]
|
| 467 |
-
|
| 468 |
-
_supports_flash_attn_2 = True
|
| 469 |
_supports_sdpa = True
|
| 470 |
_supports_flex_attn = True
|
| 471 |
_supports_cache_class = True
|
|
@@ -1178,6 +1214,8 @@ class Qwen3TTSTalkerCodePredictorModelForConditionalGeneration(Qwen3TTSPreTraine
|
|
| 1178 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1179 |
)
|
| 1180 |
|
|
|
|
|
|
|
| 1181 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1182 |
outputs: BaseModelOutputWithPast = self.model(
|
| 1183 |
input_ids=None,
|
|
@@ -1830,6 +1868,11 @@ class Qwen3TTSForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin)
|
|
| 1830 |
weights_only=True,
|
| 1831 |
**kwargs,
|
| 1832 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1833 |
model = super().from_pretrained(
|
| 1834 |
pretrained_model_name_or_path,
|
| 1835 |
*model_args,
|
|
@@ -1842,8 +1885,18 @@ class Qwen3TTSForConditionalGeneration(Qwen3TTSPreTrainedModel, GenerationMixin)
|
|
| 1842 |
revision=revision,
|
| 1843 |
use_safetensors=use_safetensors,
|
| 1844 |
weights_only=weights_only,
|
|
|
|
| 1845 |
**kwargs,
|
| 1846 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1847 |
speech_tokenizer_path = cached_file(
|
| 1848 |
pretrained_model_name_or_path,
|
| 1849 |
"speech_tokenizer/config.json",
|
|
|
|
| 19 |
from dataclasses import dataclass
|
| 20 |
from typing import Callable, Optional
|
| 21 |
|
| 22 |
+
import huggingface_hub
|
| 23 |
import torch
|
| 24 |
+
from huggingface_hub import snapshot_download
|
| 25 |
from librosa.filters import mel as librosa_mel_fn
|
| 26 |
from torch import nn
|
| 27 |
from torch.nn import functional as F
|
|
|
|
| 29 |
from transformers.cache_utils import Cache, DynamicCache
|
| 30 |
from transformers.generation import GenerationMixin
|
| 31 |
from transformers.integrations import use_kernel_forward_from_hub
|
| 32 |
+
from transformers.masking_utils import (create_causal_mask,
|
| 33 |
+
create_sliding_window_causal_mask)
|
|
|
|
|
|
|
| 34 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 35 |
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 36 |
+
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
| 37 |
+
CausalLMOutputWithPast, ModelOutput)
|
| 38 |
+
from transformers.modeling_rope_utils import (ROPE_INIT_FUNCTIONS,
|
| 39 |
+
dynamic_rope_update)
|
| 40 |
+
from transformers.modeling_utils import (ALL_ATTENTION_FUNCTIONS,
|
| 41 |
+
PreTrainedModel)
|
|
|
|
| 42 |
from transformers.processing_utils import Unpack
|
| 43 |
from transformers.utils import can_return_tuple, logging
|
| 44 |
from transformers.utils.hub import cached_file
|
| 45 |
|
| 46 |
from ...inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
|
| 47 |
+
from .configuration_qwen3_tts import (Qwen3TTSConfig,
|
| 48 |
+
Qwen3TTSSpeakerEncoderConfig,
|
| 49 |
+
Qwen3TTSTalkerCodePredictorConfig,
|
| 50 |
+
Qwen3TTSTalkerConfig)
|
|
|
|
|
|
|
| 51 |
|
| 52 |
logger = logging.get_logger(__name__)
|
| 53 |
|
| 54 |
|
| 55 |
+
def download_weights_from_hf_specific(
|
| 56 |
+
model_name_or_path: str,
|
| 57 |
+
cache_dir: str | None,
|
| 58 |
+
allow_patterns: list[str],
|
| 59 |
+
revision: str | None = None,
|
| 60 |
+
ignore_patterns: str | list[str] | None = None,
|
| 61 |
+
) -> str:
|
| 62 |
+
"""Download model weights from Hugging Face Hub. Users can specify the
|
| 63 |
+
allow_patterns to download only the necessary weights.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
model_name_or_path (str): The model name or path.
|
| 67 |
+
cache_dir (Optional[str]): The cache directory to store the model
|
| 68 |
+
weights. If None, will use HF defaults.
|
| 69 |
+
allow_patterns (list[str]): The allowed patterns for the
|
| 70 |
+
weight files. Files matched by any of the patterns will be
|
| 71 |
+
downloaded.
|
| 72 |
+
revision (Optional[str]): The revision of the model.
|
| 73 |
+
ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
|
| 74 |
+
filter out the weight files. Files matched by any of the patterns
|
| 75 |
+
will be ignored.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
str: The path to the downloaded model weights.
|
| 79 |
+
"""
|
| 80 |
+
assert len(allow_patterns) > 0
|
| 81 |
+
local_only = huggingface_hub.constants.HF_HUB_OFFLINE
|
| 82 |
+
|
| 83 |
+
for allow_pattern in allow_patterns:
|
| 84 |
+
hf_folder = snapshot_download(
|
| 85 |
+
model_name_or_path,
|
| 86 |
+
allow_patterns=allow_pattern,
|
| 87 |
+
ignore_patterns=ignore_patterns,
|
| 88 |
+
cache_dir=cache_dir,
|
| 89 |
+
revision=revision,
|
| 90 |
+
local_files_only=local_only,
|
| 91 |
+
)
|
| 92 |
+
return hf_folder
|
| 93 |
+
|
| 94 |
+
|
| 95 |
class Res2NetBlock(torch.nn.Module):
|
| 96 |
def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
|
| 97 |
super().__init__()
|
|
|
|
| 470 |
supports_gradient_checkpointing = True
|
| 471 |
_no_split_modules = ["Qwen3TTSDecoderLayer"]
|
| 472 |
_skip_keys_device_placement = "past_key_values"
|
| 473 |
+
_supports_flash_attn = True
|
| 474 |
_supports_sdpa = True
|
| 475 |
_supports_cache_class = True
|
| 476 |
_supports_static_cache = False
|
|
|
|
| 501 |
supports_gradient_checkpointing = True
|
| 502 |
_no_split_modules = []
|
| 503 |
_skip_keys_device_placement = ["past_key_values"]
|
| 504 |
+
_supports_flash_attn = True
|
|
|
|
| 505 |
_supports_sdpa = True
|
| 506 |
_supports_flex_attn = True
|
| 507 |
_supports_cache_class = True
|
|
|
|
| 1214 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1215 |
)
|
| 1216 |
|
| 1217 |
+
inputs_embeds = self.small_to_mtp_projection(inputs_embeds)
|
| 1218 |
+
|
| 1219 |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 1220 |
outputs: BaseModelOutputWithPast = self.model(
|
| 1221 |
input_ids=None,
|
|
|
|
| 1868 |
weights_only=True,
|
| 1869 |
**kwargs,
|
| 1870 |
):
|
| 1871 |
+
# Hotfix to enable passing the correct attn implementation which is stored in the config but not in kwargs
|
| 1872 |
+
requested_attn_implementation = kwargs.pop("attn_implementation", None)
|
| 1873 |
+
if requested_attn_implementation is None and config and config._attn_implementation:
|
| 1874 |
+
requested_attn_implementation = config._attn_implementation
|
| 1875 |
+
|
| 1876 |
model = super().from_pretrained(
|
| 1877 |
pretrained_model_name_or_path,
|
| 1878 |
*model_args,
|
|
|
|
| 1885 |
revision=revision,
|
| 1886 |
use_safetensors=use_safetensors,
|
| 1887 |
weights_only=weights_only,
|
| 1888 |
+
attn_implementation=requested_attn_implementation,
|
| 1889 |
**kwargs,
|
| 1890 |
)
|
| 1891 |
+
if not local_files_only and not os.path.isdir(pretrained_model_name_or_path):
|
| 1892 |
+
download_cache_dir = kwargs.get("cache_dir", cache_dir)
|
| 1893 |
+
download_revision = kwargs.get("revision", revision)
|
| 1894 |
+
download_weights_from_hf_specific(
|
| 1895 |
+
pretrained_model_name_or_path,
|
| 1896 |
+
cache_dir=download_cache_dir,
|
| 1897 |
+
allow_patterns=["speech_tokenizer/*"],
|
| 1898 |
+
revision=download_revision,
|
| 1899 |
+
)
|
| 1900 |
speech_tokenizer_path = cached_file(
|
| 1901 |
pretrained_model_name_or_path,
|
| 1902 |
"speech_tokenizer/config.json",
|
qwen_tts/inference/qwen3_tts_model.py
CHANGED
|
@@ -286,7 +286,6 @@ class Qwen3TTSModel:
|
|
| 286 |
|
| 287 |
def _merge_generate_kwargs(
|
| 288 |
self,
|
| 289 |
-
non_streaming_mode: Optional[bool] = None,
|
| 290 |
do_sample: Optional[bool] = None,
|
| 291 |
top_k: Optional[int] = None,
|
| 292 |
top_p: Optional[float] = None,
|
|
@@ -308,7 +307,7 @@ class Qwen3TTSModel:
|
|
| 308 |
- Otherwise, fall back to the hard defaults.
|
| 309 |
|
| 310 |
Args:
|
| 311 |
-
|
| 312 |
subtalker_dosample, subtalker_top_k, subtalker_top_p, subtalker_temperature, max_new_tokens:
|
| 313 |
Common generation parameters.
|
| 314 |
**kwargs:
|
|
@@ -318,7 +317,6 @@ class Qwen3TTSModel:
|
|
| 318 |
Dict[str, Any]: Final kwargs to pass into model.generate().
|
| 319 |
"""
|
| 320 |
hard_defaults = dict(
|
| 321 |
-
non_streaming_mode=False,
|
| 322 |
do_sample=True,
|
| 323 |
top_k=50,
|
| 324 |
top_p=1.0,
|
|
@@ -340,7 +338,6 @@ class Qwen3TTSModel:
|
|
| 340 |
|
| 341 |
merged = dict(kwargs)
|
| 342 |
merged.update(
|
| 343 |
-
non_streaming_mode=pick("non_streaming_mode", non_streaming_mode),
|
| 344 |
do_sample=pick("do_sample", do_sample),
|
| 345 |
top_k=pick("top_k", top_k),
|
| 346 |
top_p=pick("top_p", top_p),
|
|
@@ -478,6 +475,7 @@ class Qwen3TTSModel:
|
|
| 478 |
ref_text: Optional[Union[str, List[Optional[str]]]] = None,
|
| 479 |
x_vector_only_mode: Union[bool, List[bool]] = False,
|
| 480 |
voice_clone_prompt: Optional[Union[Dict[str, Any], List[VoiceClonePromptItem]]] = None,
|
|
|
|
| 481 |
**kwargs,
|
| 482 |
) -> Tuple[List[np.ndarray], int]:
|
| 483 |
"""
|
|
@@ -607,6 +605,7 @@ class Qwen3TTSModel:
|
|
| 607 |
ref_ids=ref_ids,
|
| 608 |
voice_clone_prompt=voice_clone_prompt_dict,
|
| 609 |
languages=languages,
|
|
|
|
| 610 |
**gen_kwargs,
|
| 611 |
)
|
| 612 |
|
|
@@ -640,6 +639,7 @@ class Qwen3TTSModel:
|
|
| 640 |
text: Union[str, List[str]],
|
| 641 |
instruct: Union[str, List[str]],
|
| 642 |
language: Union[str, List[str]] = None,
|
|
|
|
| 643 |
**kwargs,
|
| 644 |
) -> Tuple[List[np.ndarray], int]:
|
| 645 |
"""
|
|
@@ -720,6 +720,7 @@ class Qwen3TTSModel:
|
|
| 720 |
input_ids=input_ids,
|
| 721 |
instruct_ids=instruct_ids,
|
| 722 |
languages=languages,
|
|
|
|
| 723 |
**gen_kwargs,
|
| 724 |
)
|
| 725 |
|
|
@@ -734,6 +735,7 @@ class Qwen3TTSModel:
|
|
| 734 |
speaker: Union[str, List[str]],
|
| 735 |
language: Union[str, List[str]] = None,
|
| 736 |
instruct: Optional[Union[str, List[str]]] = None,
|
|
|
|
| 737 |
**kwargs,
|
| 738 |
) -> Tuple[List[np.ndarray], int]:
|
| 739 |
"""
|
|
@@ -829,6 +831,7 @@ class Qwen3TTSModel:
|
|
| 829 |
instruct_ids=instruct_ids,
|
| 830 |
languages=languages,
|
| 831 |
speakers=speakers,
|
|
|
|
| 832 |
**gen_kwargs,
|
| 833 |
)
|
| 834 |
|
|
|
|
| 286 |
|
| 287 |
def _merge_generate_kwargs(
|
| 288 |
self,
|
|
|
|
| 289 |
do_sample: Optional[bool] = None,
|
| 290 |
top_k: Optional[int] = None,
|
| 291 |
top_p: Optional[float] = None,
|
|
|
|
| 307 |
- Otherwise, fall back to the hard defaults.
|
| 308 |
|
| 309 |
Args:
|
| 310 |
+
do_sample, top_k, top_p, temperature, repetition_penalty,
|
| 311 |
subtalker_dosample, subtalker_top_k, subtalker_top_p, subtalker_temperature, max_new_tokens:
|
| 312 |
Common generation parameters.
|
| 313 |
**kwargs:
|
|
|
|
| 317 |
Dict[str, Any]: Final kwargs to pass into model.generate().
|
| 318 |
"""
|
| 319 |
hard_defaults = dict(
|
|
|
|
| 320 |
do_sample=True,
|
| 321 |
top_k=50,
|
| 322 |
top_p=1.0,
|
|
|
|
| 338 |
|
| 339 |
merged = dict(kwargs)
|
| 340 |
merged.update(
|
|
|
|
| 341 |
do_sample=pick("do_sample", do_sample),
|
| 342 |
top_k=pick("top_k", top_k),
|
| 343 |
top_p=pick("top_p", top_p),
|
|
|
|
| 475 |
ref_text: Optional[Union[str, List[Optional[str]]]] = None,
|
| 476 |
x_vector_only_mode: Union[bool, List[bool]] = False,
|
| 477 |
voice_clone_prompt: Optional[Union[Dict[str, Any], List[VoiceClonePromptItem]]] = None,
|
| 478 |
+
non_streaming_mode: bool = False,
|
| 479 |
**kwargs,
|
| 480 |
) -> Tuple[List[np.ndarray], int]:
|
| 481 |
"""
|
|
|
|
| 605 |
ref_ids=ref_ids,
|
| 606 |
voice_clone_prompt=voice_clone_prompt_dict,
|
| 607 |
languages=languages,
|
| 608 |
+
non_streaming_mode=non_streaming_mode,
|
| 609 |
**gen_kwargs,
|
| 610 |
)
|
| 611 |
|
|
|
|
| 639 |
text: Union[str, List[str]],
|
| 640 |
instruct: Union[str, List[str]],
|
| 641 |
language: Union[str, List[str]] = None,
|
| 642 |
+
non_streaming_mode: bool = True,
|
| 643 |
**kwargs,
|
| 644 |
) -> Tuple[List[np.ndarray], int]:
|
| 645 |
"""
|
|
|
|
| 720 |
input_ids=input_ids,
|
| 721 |
instruct_ids=instruct_ids,
|
| 722 |
languages=languages,
|
| 723 |
+
non_streaming_mode=non_streaming_mode,
|
| 724 |
**gen_kwargs,
|
| 725 |
)
|
| 726 |
|
|
|
|
| 735 |
speaker: Union[str, List[str]],
|
| 736 |
language: Union[str, List[str]] = None,
|
| 737 |
instruct: Optional[Union[str, List[str]]] = None,
|
| 738 |
+
non_streaming_mode: bool = True,
|
| 739 |
**kwargs,
|
| 740 |
) -> Tuple[List[np.ndarray], int]:
|
| 741 |
"""
|
|
|
|
| 831 |
instruct_ids=instruct_ids,
|
| 832 |
languages=languages,
|
| 833 |
speakers=speakers,
|
| 834 |
+
non_streaming_mode=non_streaming_mode,
|
| 835 |
**gen_kwargs,
|
| 836 |
)
|
| 837 |
|
requirements.txt
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
# Qwen3-TTS Dependencies for HuggingFace Spaces
|
|
|
|
| 2 |
transformers==4.57.3
|
| 3 |
accelerate==1.12.0
|
| 4 |
einops
|
|
@@ -10,4 +11,5 @@ sox
|
|
| 10 |
onnxruntime
|
| 11 |
spaces
|
| 12 |
torch
|
| 13 |
-
numpy
|
|
|
|
|
|
| 1 |
# Qwen3-TTS Dependencies for HuggingFace Spaces
|
| 2 |
+
torch==2.8.0
|
| 3 |
transformers==4.57.3
|
| 4 |
accelerate==1.12.0
|
| 5 |
einops
|
|
|
|
| 11 |
onnxruntime
|
| 12 |
spaces
|
| 13 |
torch
|
| 14 |
+
numpy
|
| 15 |
+
kernels
|