File size: 15,821 Bytes
3309e4b
 
 
 
 
 
 
 
 
 
cc6e7cb
3309e4b
ddefb98
6438a49
a9d9345
2cf0f4a
27ddaa6
 
3309e4b
 
 
 
27ddaa6
3309e4b
 
 
27ddaa6
 
 
3309e4b
27ddaa6
 
3309e4b
 
 
27ddaa6
6438a49
 
 
 
 
 
 
 
3309e4b
 
 
 
27ddaa6
fb11c6d
 
 
 
27ddaa6
fb11c6d
 
 
 
3309e4b
a9d9345
 
 
 
 
3309e4b
 
 
 
27ddaa6
3309e4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb11c6d
6438a49
3309e4b
 
 
 
 
 
 
 
 
 
27ddaa6
fb11c6d
 
 
 
 
 
3309e4b
 
 
 
 
6438a49
27ddaa6
6438a49
3309e4b
 
 
ddefb98
a9ab89b
cc6e7cb
 
 
 
 
 
 
607ce9b
 
 
 
ddefb98
 
 
 
 
 
5706ae6
27ddaa6
5706ae6
ddefb98
 
 
 
27ddaa6
fb11c6d
6438a49
a8ab3f2
 
 
471c527
 
e0707a4
 
85916ca
 
 
 
 
 
 
 
27ddaa6
 
 
 
 
 
 
e0707a4
85916ca
 
27ddaa6
 
 
 
 
 
 
 
 
 
cc6e7cb
08f4100
27ddaa6
471c527
08f4100
cc6e7cb
 
 
 
 
 
 
 
 
ddefb98
 
cc6e7cb
ddefb98
3309e4b
27ddaa6
3309e4b
 
 
 
 
0f1cb09
 
3309e4b
 
27ddaa6
3309e4b
27ddaa6
 
 
3309e4b
 
 
 
27ddaa6
 
3309e4b
 
27ddaa6
0f1cb09
3309e4b
 
27ddaa6
3309e4b
27ddaa6
 
 
ddefb98
 
 
 
27ddaa6
0f1cb09
ddefb98
 
27ddaa6
3309e4b
a9d9345
27ddaa6
a9d9345
 
 
 
 
 
27ddaa6
 
 
 
 
 
 
a9d9345
 
27ddaa6
a9d9345
3309e4b
27ddaa6
 
 
3309e4b
a9d9345
 
27ddaa6
 
15ec850
27ddaa6
 
 
 
3309e4b
 
27ddaa6
 
 
 
 
 
 
 
15ec850
3309e4b
ddefb98
 
27ddaa6
 
 
 
ddefb98
27ddaa6
 
15ec850
27ddaa6
ddefb98
27ddaa6
 
 
 
ddefb98
27ddaa6
 
 
15ec850
4e8eacf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de62521
4e8eacf
 
 
 
3309e4b
 
 
27ddaa6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
# coding=utf-8
import os
import sys
import logging
import spaces
import gradio as gr
import numpy as np
import torch
from huggingface_hub import snapshot_download, login
from qwen_tts import Qwen3TTSModel
from qwen_tts.inference.qwen3_tts_model import VoiceClonePromptItem
import functools
import uuid
import random
import whisper
import librosa
from opencc import OpenCC

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("Qwen3-TTS-Demo")

# 初始化简繁转换器
cc = OpenCC('t2s')

HF_TOKEN = os.environ.get('HF_TOKEN')
if HF_TOKEN:
    login(token=HF_TOKEN)

MODEL_SIZES = ["0.6B", "1.7B"]
LANGUAGES = ["Auto", "Chinese", "English", "Japanese", "Korean", "French", "German", "Spanish", "Portuguese", "Russian"]

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def get_model_path(model_type: str, model_size: str) -> str:
    return snapshot_download(f"Qwen/Qwen3-TTS-12Hz-{model_size}-{model_type}")

@functools.lru_cache(maxsize=1)
def load_model(model_type, model_size):
    path = get_model_path(model_type, model_size)
    return Qwen3TTSModel.from_pretrained(
        path,
        device_map="cuda",
        dtype=torch.bfloat16,
        token=HF_TOKEN,
        attn_implementation="kernels-community/flash-attn3"
    )

@functools.lru_cache(maxsize=1)
def load_whisper_model(model_name="large-v3"):
    model = whisper.load_model(model_name, device="cuda" if torch.cuda.is_available() else "cpu")
    return model

def _normalize_audio(wav, eps=1e-12, clip=True):
    x = np.asarray(wav)
    if np.issubdtype(x.dtype, np.integer):
        info = np.iinfo(x.dtype)
        y = x.astype(np.float32) / max(abs(info.min), info.max)
    elif np.issubdtype(x.dtype, np.floating):
        y = x.astype(np.float32)
        m = np.max(np.abs(y)) if y.size else 0.0
        if m > 1.0 + 1e-6:
            y = y / (m + eps)
    else:
        raise TypeError(f"Unsupported dtype: {x.dtype}")
    if clip:
        y = np.clip(y, -1.0, 1.0)
    if y.ndim > 1:
        y = np.mean(y, axis=-1).astype(np.float32)
    return y

def _audio_to_tuple(audio):
    if audio is None:
        return None
    if isinstance(audio, tuple) and len(audio) == 2 and isinstance(audio[0], int):
        sr, wav = audio
        wav = _normalize_audio(wav)
        return wav, int(sr)
    if isinstance(audio, dict) and "sampling_rate" in audio and "data" in audio:
        sr = int(audio["sampling_rate"])
        wav = _normalize_audio(audio["data"])
        return wav, sr
    return None

@spaces.GPU
def infer_voice_design(part, language, voice_description):
    voice_design_model = load_model("VoiceDesign","1.7B")
    seed_everything(42)
    wavs, sr = voice_design_model.generate_voice_design(
        text=part,
        language=language,
        instruct=voice_description.strip(),
        non_streaming_mode=True,
        max_new_tokens=2048,
    )
    return wavs[0], sr

@spaces.GPU
def infer_voice_clone(part, language, audio_tuple, ref_text, use_xvector_only):
    tts = load_model("Base", "0.6B")
    voice_clone_prompt = tts.create_voice_clone_prompt(
        ref_audio=audio_tuple,
        ref_text=ref_text.strip() if ref_text else None,
        x_vector_only_mode=use_xvector_only
    )
    wavs, sr = tts.generate_voice_clone(
        text=part,
        language=language,
        voice_clone_prompt=voice_clone_prompt,
        max_new_tokens=2048,
        seed=42, 
        temperature=0.3,
        top_p=0.85
    )
    return wavs[0], sr

@spaces.GPU
def infer_voice_clone_from_prompt(part, language, prompt_file_path):
    loaded_data = torch.load(prompt_file_path, map_location='cuda', weights_only=False)
    if isinstance(loaded_data, list) and len(loaded_data) > 0 and isinstance(loaded_data[0], VoiceClonePromptItem):
        voice_clone_prompt = loaded_data
    elif isinstance(loaded_data, list) and len(loaded_data) > 0 and isinstance(loaded_data[0], dict):
        voice_clone_prompt = [VoiceClonePromptItem(**item) for item in loaded_data]
    else:
         voice_clone_prompt = loaded_data
    if isinstance(voice_clone_prompt, list):
        for item in voice_clone_prompt:
            if item.ref_code is not None and item.ref_code.ndim == 3:
                item.ref_code = item.ref_code.squeeze(0)
    tts = load_model("Base", "0.6B")
    wavs, sr = tts.generate_voice_clone(
        text=part,
        language=language,
        voice_clone_prompt=voice_clone_prompt,
        max_new_tokens=2048,
        seed=42, 
        temperature=0.3,
        top_p=0.85
    )
    return wavs[0], sr

@spaces.GPU
def extract_voice_clone_prompt(ref_audio, ref_text, use_xvector_only):
    tts = load_model("Base", "0.6B")
    seed_everything(42)
    audio_tuple = _audio_to_tuple(ref_audio)
    if audio_tuple is None:
        return None, "错误:需要参考音频。"
    r_text = ref_text
    uxo = use_xvector_only
    if not r_text or (isinstance(r_text, str) and not r_text.strip()):
        whisper_size = "base"
        try:
            whisper_model = load_whisper_model(whisper_size)
            audio_data, sr = audio_tuple
            if sr != 16000:
                whisper_audio = librosa.resample(audio_data, orig_sr=sr, target_sr=16000)
            else:
                whisper_audio = audio_data
            result = whisper_model.transcribe(whisper_audio)
            
            res_val = result.get("text", "")
            if isinstance(res_val, list) and len(res_val) > 0:
                res_val = res_val[0]
            if not isinstance(res_val, str):
                res_val = str(res_val)
            r_text = cc.convert(res_val.strip())
            uxo = False
        except Exception as e:
            logger.error(f"Whisper 识别失败: {str(e)}", exc_info=True)
            uxo = True
            # return None, f"错误:语音识别失败且未提供参考文本。{str(e)}"
    
    r_text_str = ""
    if isinstance(r_text, str):
        r_text_str = r_text.strip()
    elif isinstance(r_text, list) and len(r_text) > 0 and isinstance(r_text[0], str):
        r_text_str = r_text[0].strip()
        
    logger.info(f"语音识别成功 :{r_text_str}")
    voice_clone_prompt_items = tts.create_voice_clone_prompt(
        ref_audio=audio_tuple,
        ref_text=r_text_str if r_text_str else None,
        x_vector_only_mode=uxo
    )
    prompt_data = []
    for item in voice_clone_prompt_items:
        prompt_data.append({
            "ref_code": item.ref_code,
            "ref_spk_embedding": item.ref_spk_embedding,
            "x_vector_only_mode": item.x_vector_only_mode,
            "icl_mode": item.icl_mode,
            "ref_text": item.ref_text
        })
    file_id = str(uuid.uuid4())[:8]
    file_path = f"voice_clone_prompt_{file_id}.pt"
    torch.save(prompt_data, file_path)
    return file_path

def generate_voice_design(text, language, voice_description):
    if not text or not text.strip():
        return None, "错误:文本不能为空。"
    if not voice_description or not voice_description.strip():
        return None, "错误:语音描述不能为空。"
    try:
        wav, sr = infer_voice_design(text.strip(), language, voice_description)
        return (sr, wav), "语音设计生成成功!"
    except Exception as e:
        logger.error(f"Voice Design 生成失败: {str(e)}", exc_info=True)
        return None, f"错误: {e}"

def generate_voice_clone(ref_audio, ref_text, target_text, language, use_xvector_only):
    t_text = target_text.strip() if isinstance(target_text, str) else ""
    if not t_text:
        return None, "错误:目标文本不能为空。"
    audio_tuple = _audio_to_tuple(ref_audio)
    if audio_tuple is None:
        return None, "错误:需要参考音频。"
    r_text = ref_text.strip() if isinstance(ref_text, str) else ""
    if not use_xvector_only and not r_text:
        return None, "错误:未启用 '仅使用 x-vector' 时需要参考文本。"
    try:
        wav, sr = infer_voice_clone(t_text, language, audio_tuple, r_text, use_xvector_only)
        return (sr, wav), "语音克隆生成成功!"
    except Exception as e:
        logger.error(f"Voice Clone 生成失败: {str(e)}", exc_info=True)
        return None, f"错误: {e}"

def generate_voice_clone_from_prompt_file(prompt_file_path, target_text, language):
    t_text = target_text.strip() if isinstance(target_text, str) else ""
    if not t_text:
        return None, "错误:目标文本不能为空。"
    if not prompt_file_path:
        return None, "错误:需要提供音频特征文件。"
    try:
        wav, sr = infer_voice_clone_from_prompt(t_text, language, prompt_file_path)
        return (sr, wav), "语音克隆生成成功(使用特征文件)!"
    except Exception as e:
        logger.error(f"Voice Clone 生成失败: {str(e)}", exc_info=True)
        return None, f"错误: {e}"

@spaces.GPU
def infer_whisper_audio(audio_path, model_size="base"):
    if not audio_path:
        return "错误:请上传音频文件或进行录音。"
    try:
        model = load_whisper_model(model_size)
        result = model.transcribe(audio_path)
        
        res_val = result.get("text", "")
        if isinstance(res_val, list) and len(res_val) > 0:
            res_val = res_val[0]
        if not isinstance(res_val, str):
            res_val = str(res_val)
        
        return cc.convert(res_val.strip())
    except Exception as e:
        logger.error(f"Whisper 识别失败: {str(e)}", exc_info=True)
        return f"识别出错: {e}"

def build_ui():
    theme = gr.themes.Soft(font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"])
    with gr.Blocks(theme=theme, title="Qwen3-TTS Demo") as demo:
        gr.Markdown("# Qwen3-TTS Demo")
        with gr.Tabs():
            with gr.Tab("ASR (Whisper)"):
                with gr.Row():
                    with gr.Column():
                        asr_audio_input = gr.Audio(label="输入音频", type="filepath", sources=["microphone", "upload"])
                        asr_model_size = gr.Dropdown(label="Whisper 模型大小", choices=["base", "small", "medium", "large-v3"], value="base")
                        asr_btn = gr.Button("开始识别", variant="primary")
                    with gr.Column():
                        asr_text_output = gr.Textbox(label="识别结果", lines=10, show_copy_button=True)
                asr_btn.click(infer_whisper_audio, inputs=[asr_audio_input, asr_model_size], outputs=[asr_text_output])
            with gr.Tab("Voice Design"):
                with gr.Row():
                    with gr.Column():
                        design_text = gr.Textbox(label="目标文本", lines=4, value="It's in the top drawer... wait, it's empty?")
                        design_language = gr.Dropdown(label="语言", choices=LANGUAGES, value="Auto")
                        design_instruct = gr.Textbox(label="语音描述", lines=3, value="Speak in an incredulous tone.")
                        design_btn = gr.Button("开始生成", variant="primary")
                    with gr.Column():
                        design_audio_out = gr.Audio(label="生成音频", type="numpy")
                        design_status = gr.Textbox(label="状态", interactive=False)
                design_btn.click(generate_voice_design, inputs=[design_text, design_language, design_instruct], outputs=[design_audio_out, design_status],api_name="generate_voice_design")
            with gr.Tab("Voice Clone (Base)"):
                gr.Markdown("### 1. 提取音频特征")
                with gr.Row():
                    with gr.Column():
                        extract_ref_audio = gr.Audio(label="参考音频", type="numpy")
                        extract_ref_text = gr.Textbox(label="参考文本", lines=2)
                        extract_xvector = gr.Checkbox(label="仅使用 x-vector", value=False)
                        extract_btn = gr.Button("提取音频特征", variant="primary")
                    with gr.Column():
                        extract_file_out = gr.File(label="特征文件 (.pt)")
                extract_btn.click(extract_voice_clone_prompt, inputs=[extract_ref_audio, extract_ref_text, extract_xvector], outputs=[extract_file_out],api_name="extract_voice_clone_prompt")
                gr.Markdown("### 2. 使用特征文件生成")
                with gr.Row():
                    with gr.Column():
                        prompt_file = gr.File(label="特征文件 (.pt)")
                        prompt_target_text = gr.Textbox(label="目标文本", lines=4)
                        prompt_language = gr.Dropdown(label="语言", choices=LANGUAGES, value="Auto")
                        prompt_btn = gr.Button("使用特征文件生成", variant="primary")
                    with gr.Column():
                        prompt_audio_out = gr.Audio(label="生成音频", type="numpy")
                        prompt_status = gr.Textbox(label="状态", interactive=False)
                prompt_btn.click(generate_voice_clone_from_prompt_file, inputs=[prompt_file, prompt_target_text, prompt_language], outputs=[prompt_audio_out, prompt_status],api_name="generate_voice_clone_from_prompt")
                gr.Markdown("---")

                # Section 3: Traditional Voice Clone (Original)
                gr.Markdown("### 3. 传统音色克隆(直接使用参考音频)")
                gr.Markdown("直接上传参考音频生成语音(每次都需要提取特征)。")
                with gr.Row():
                    with gr.Column(scale=2):
                        clone_ref_audio = gr.Audio(
                            label="参考音频",
                            type="numpy",
                        )
                        clone_ref_text = gr.Textbox(
                            label="参考文本",
                            lines=2,
                            placeholder="输入参考音频中的确切文字...",
                        )
                        clone_xvector = gr.Checkbox(
                            label="仅使用 x-vector",
                            value=False,
                        )

                    with gr.Column(scale=2):
                        clone_target_text = gr.Textbox(
                            label="目标文本",
                            lines=4,
                            placeholder="输入要让克隆音色说话的文字...",
                        )
                        with gr.Row():
                            clone_language = gr.Dropdown(
                                label="语言",
                                choices=LANGUAGES,
                                value="Auto",
                                interactive=True,
                            )
                        clone_btn = gr.Button("克隆并生成", variant="primary")

                with gr.Row():
                    clone_audio_out = gr.Audio(label="生成的音频", type="numpy")
                    clone_status = gr.Textbox(label="状态", lines=2, interactive=False)

                clone_btn.click(
                    generate_voice_clone,
                    inputs=[clone_ref_audio, clone_ref_text, clone_target_text, clone_language, clone_xvector],
                    outputs=[clone_audio_out, clone_status],
                    api_name="generate_voice_clone"
                )

    return demo

if __name__ == "__main__":
    build_ui().launch()