simler commited on
Commit
58e69fe
·
verified ·
1 Parent(s): 45ca840

Upload 5 files

Browse files
Files changed (5) hide show
  1. Inference.py +115 -0
  2. Internal.py +403 -0
  3. Shared.py +14 -0
  4. TTSPlayer.py +242 -0
  5. app.py +3 -1
Inference.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ from typing import List, Optional
4
+ import threading
5
+
6
+ from ..Audio.ReferenceAudio import ReferenceAudio
7
+ from ..GetPhonesAndBert import get_phones_and_bert
8
+
9
+ MAX_T2S_LEN = 1000
10
+
11
+
12
+ class GENIE:
13
+ def __init__(self):
14
+ self.stop_event: threading.Event = threading.Event()
15
+
16
+ def tts(
17
+ self,
18
+ text: str,
19
+ prompt_audio: ReferenceAudio,
20
+ encoder: ort.InferenceSession,
21
+ first_stage_decoder: ort.InferenceSession,
22
+ stage_decoder: ort.InferenceSession,
23
+ vocoder: ort.InferenceSession,
24
+ prompt_encoder: Optional[ort.InferenceSession],
25
+ language: str = 'japanese',
26
+ text_language: str = None, # 新增:目标文本语言,默认使用参考音频语言
27
+ ) -> Optional[np.ndarray]:
28
+ # 如果未指定 text_language,则使用参考音频的语言
29
+ actual_text_language = text_language if text_language else language
30
+ text = '。' + text # 防止漏第一句。
31
+ text_seq, text_bert = get_phones_and_bert(text, language=actual_text_language)
32
+
33
+ semantic_tokens: np.ndarray = self.t2s_cpu(
34
+ ref_seq=prompt_audio.phonemes_seq,
35
+ ref_bert=prompt_audio.text_bert,
36
+ text_seq=text_seq,
37
+ text_bert=text_bert,
38
+ ssl_content=prompt_audio.ssl_content,
39
+ encoder=encoder,
40
+ first_stage_decoder=first_stage_decoder,
41
+ stage_decoder=stage_decoder,
42
+ )
43
+
44
+ eos_indices = np.where(semantic_tokens >= 1024) # 剔除不合法的元素,例如 EOS Token。
45
+ if len(eos_indices[0]) > 0:
46
+ first_eos_index = eos_indices[-1][0]
47
+ semantic_tokens = semantic_tokens[..., :first_eos_index]
48
+
49
+ if prompt_encoder is None:
50
+ return vocoder.run(None, {
51
+ "text_seq": text_seq,
52
+ "pred_semantic": semantic_tokens,
53
+ "ref_audio": prompt_audio.audio_32k
54
+ })[0]
55
+ else:
56
+ # V2ProPlus 新增。
57
+ prompt_audio.update_global_emb(prompt_encoder=prompt_encoder)
58
+ audio_chunk = vocoder.run(None, {
59
+ "text_seq": text_seq,
60
+ "pred_semantic": semantic_tokens,
61
+ "ge": prompt_audio.global_emb,
62
+ "ge_advanced": prompt_audio.global_emb_advanced,
63
+ })[0]
64
+ return audio_chunk
65
+
66
+ def t2s_cpu(
67
+ self,
68
+ ref_seq: np.ndarray,
69
+ ref_bert: np.ndarray,
70
+ text_seq: np.ndarray,
71
+ text_bert: np.ndarray,
72
+ ssl_content: np.ndarray,
73
+ encoder: ort.InferenceSession,
74
+ first_stage_decoder: ort.InferenceSession,
75
+ stage_decoder: ort.InferenceSession,
76
+ ) -> Optional[np.ndarray]:
77
+ """在CPU上运行T2S模型"""
78
+ # Encoder
79
+ x, prompts = encoder.run(
80
+ None,
81
+ {
82
+ "ref_seq": ref_seq,
83
+ "text_seq": text_seq,
84
+ "ref_bert": ref_bert,
85
+ "text_bert": text_bert,
86
+ "ssl_content": ssl_content,
87
+ },
88
+ )
89
+
90
+ # First Stage Decoder
91
+ y, y_emb, *present_key_values = first_stage_decoder.run(
92
+ None, {"x": x, "prompts": prompts}
93
+ )
94
+
95
+ # Stage Decoder
96
+ input_names: List[str] = [inp.name for inp in stage_decoder.get_inputs()]
97
+ idx: int = 0
98
+ for idx in range(0, 500):
99
+ if self.stop_event.is_set():
100
+ return None
101
+ input_feed = {
102
+ name: data
103
+ for name, data in zip(input_names, [y, y_emb, *present_key_values])
104
+ }
105
+ outputs = stage_decoder.run(None, input_feed)
106
+ y, y_emb, stop_condition_tensor, *present_key_values = outputs
107
+
108
+ if stop_condition_tensor:
109
+ break
110
+
111
+ y[0, -1] = 0
112
+ return np.expand_dims(y[:, -idx:], axis=0)
113
+
114
+
115
+ tts_client: GENIE = GENIE()
Internal.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 请严格遵循导入顺序。
2
+ # 1、环境变量。
3
+ import os
4
+ from os import PathLike
5
+
6
+ os.environ["HF_HUB_ENABLE_PROGRESS_BAR"] = "1"
7
+
8
+ # 2、Logging & Warnings。
9
+ import logging
10
+ import warnings
11
+
12
+ warnings.filterwarnings("ignore", category=UserWarning, module="jieba_fast._compat")
13
+ logging.basicConfig(level=logging.INFO, format="%(message)s", datefmt="[%X]")
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # 3、ONNX。
17
+ import onnxruntime
18
+
19
+ onnxruntime.set_default_logger_severity(3)
20
+
21
+ # 导入剩余库。
22
+
23
+ from pathlib import Path
24
+ import json
25
+ import asyncio
26
+ from typing import AsyncIterator, Optional, Union, Dict
27
+
28
+ from .Audio.ReferenceAudio import ReferenceAudio
29
+ from .Core.Resources import ensure_exists, Chinese_G2P_DIR, English_G2P_DIR
30
+ from .Core.TTSPlayer import tts_player
31
+ from .ModelManager import model_manager
32
+ from .Utils.Shared import context
33
+ from .Utils.Language import normalize_language
34
+ from .PredefinedCharacter import download_chara, CHARA_LANG, CHARA_ALIAS_MAP
35
+
36
+ # A module-level private dictionary to store reference audio configurations.
37
+ _reference_audios: Dict[str, dict] = {}
38
+ SUPPORTED_AUDIO_EXTS = {'.wav', '.flac', '.ogg', '.aiff', '.aif'}
39
+
40
+
41
+ def check_onnx_model_dir(onnx_model_dir: Union[str, os.PathLike]) -> None:
42
+ """
43
+ Checks if the directory contains the necessary ONNX model files for Genie TTS (v2 or v2ProPlus).
44
+ Raises a FileNotFoundError with detailed instructions if validation fails.
45
+ """
46
+ model_path = Path(onnx_model_dir)
47
+
48
+ # 1. Check if directory exists
49
+ if not model_path.exists() or not model_path.is_dir():
50
+ raise FileNotFoundError(f"The model directory '{onnx_model_dir}' does not exist or is not a directory.")
51
+
52
+ # 2. Define required files
53
+ # Base files required by both v2 and v2ProPlus
54
+ required_base_files = {
55
+ "t2s_encoder_fp32.bin",
56
+ "t2s_encoder_fp32.onnx",
57
+ "t2s_first_stage_decoder_fp32.onnx",
58
+ "t2s_shared_fp16.bin",
59
+ "t2s_stage_decoder_fp32.onnx",
60
+ "vits_fp16.bin",
61
+ "vits_fp32.onnx"
62
+ }
63
+
64
+ # 3. Get current files in directory
65
+ existing_files = set(f.name for f in model_path.iterdir() if f.is_file())
66
+
67
+ # 4. Validate
68
+ # We check if the base files exist. If base files are missing, the model is definitely unusable.
69
+ if not required_base_files.issubset(existing_files):
70
+ missing = required_base_files - existing_files
71
+
72
+ # Construct detailed error message
73
+ error_msg = (
74
+ f"\n\n[Genie Error] Invalid ONNX model directory: '{model_path}'\n"
75
+ "===============================================================\n"
76
+ f"Missing base files: {', '.join(missing)}\n"
77
+ "A valid model folder must contain at least the following files.\n"
78
+ "1. [v2 Base] (Required for all models):\n"
79
+ " - t2s_encoder_fp32.bin\n"
80
+ " - t2s_encoder_fp32.onnx\n"
81
+ " - t2s_first_stage_decoder_fp32.onnx\n"
82
+ " - t2s_shared_fp16.bin\n"
83
+ " - t2s_stage_decoder_fp32.onnx\n"
84
+ " - vits_fp16.bin\n"
85
+ " - vits_fp32.onnx\n"
86
+ "2. [v2ProPlus Additions] (Required for v2pp features):\n"
87
+ " - prompt_encoder_fp16.bin\n"
88
+ " - prompt_encoder_fp32.onnx\n"
89
+ "===============================================================\n"
90
+ )
91
+ raise FileNotFoundError(error_msg)
92
+
93
+
94
+ def load_character(
95
+ character_name: str,
96
+ onnx_model_dir: Union[str, PathLike],
97
+ language: str,
98
+ ) -> None:
99
+ """
100
+ Loads a character model from an ONNX model directory.
101
+
102
+ Args:
103
+ character_name (str): The name to assign to the loaded character.
104
+ onnx_model_dir (str | PathLike): The directory path containing the ONNX model files.
105
+ language (str): The language of the character model.
106
+ """
107
+ check_onnx_model_dir(onnx_model_dir)
108
+
109
+ language = normalize_language(language)
110
+ if language not in ['Japanese', 'English', 'Chinese']:
111
+ raise ValueError('Unknown language')
112
+
113
+ if language == 'Chinese':
114
+ ensure_exists(Chinese_G2P_DIR, "Chinese_G2P_DIR")
115
+ elif language == 'English':
116
+ ensure_exists(English_G2P_DIR, "English_G2P_DIR")
117
+
118
+ model_path: str = os.fspath(onnx_model_dir)
119
+ model_manager.load_character(
120
+ character_name=character_name,
121
+ model_dir=model_path,
122
+ language=language,
123
+ )
124
+
125
+
126
+ def unload_character(
127
+ character_name: str,
128
+ ) -> None:
129
+ """
130
+ Unloads a previously loaded character model to free up resources.
131
+
132
+ Args:
133
+ character_name (str): The name of the character to unload.
134
+ """
135
+ model_manager.remove_character(
136
+ character_name=character_name,
137
+ )
138
+
139
+
140
+ def set_reference_audio(
141
+ character_name: str,
142
+ audio_path: Union[str, PathLike],
143
+ audio_text: str,
144
+ language: str = None,
145
+ ) -> None:
146
+ """
147
+ Sets the reference audio for a character to be used for voice cloning.
148
+
149
+ This must be called for a character before using 'tts' or 'tts_async'.
150
+
151
+ Args:
152
+ character_name (str): The name of the character.
153
+ audio_path (str | PathLike): The file path to the reference audio (e.g., a WAV file).
154
+ audio_text (str): The transcript of the reference audio.
155
+ language (str): The language of the reference audio.
156
+ """
157
+ audio_path: str = os.fspath(audio_path)
158
+
159
+ # 检查文件后缀是否支持
160
+ ext = os.path.splitext(audio_path)[1].lower()
161
+ if ext not in SUPPORTED_AUDIO_EXTS:
162
+ logger.error(
163
+ f"Audio format '{ext}' is not supported. Only the following formats are supported: {SUPPORTED_AUDIO_EXTS}"
164
+ )
165
+ return
166
+
167
+ if language is None:
168
+ gsv_model = model_manager.get(character_name)
169
+ if gsv_model:
170
+ language = gsv_model.LANGUAGE
171
+ else:
172
+ raise ValueError('No language specified')
173
+ language = normalize_language(language)
174
+ if language not in ['Japanese', 'English', 'Chinese']:
175
+ raise ValueError('Unknown language')
176
+
177
+ _reference_audios[character_name] = {
178
+ 'audio_path': audio_path,
179
+ 'audio_text': audio_text,
180
+ 'language': language,
181
+ }
182
+ # print(_reference_audios[character_name])
183
+ context.current_prompt_audio = ReferenceAudio(
184
+ prompt_wav=audio_path,
185
+ prompt_text=audio_text,
186
+ language=language,
187
+ )
188
+
189
+
190
+ async def tts_async(
191
+ character_name: str,
192
+ text: str,
193
+ play: bool = False,
194
+ split_sentence: bool = False,
195
+ save_path: Union[str, PathLike, None] = None,
196
+ text_language: str = None, # 新增:目标文本语言,用于跨语言TTS
197
+ ) -> AsyncIterator[bytes]:
198
+ """
199
+ Asynchronously generates speech from text and yields audio chunks.
200
+
201
+ This function returns an async iterator that provides the audio data in
202
+ real-time as it's being generated.
203
+
204
+ Args:
205
+ character_name (str): The name of the character to use for synthesis.
206
+ text (str): The text to be synthesized into speech.
207
+ play (bool, optional): If True, plays the audio as it's generated. Defaults to False.
208
+ split_sentence (bool, optional): If True, splits the text into sentences for synthesis. Defaults to False.
209
+ save_path (str | PathLike | None, optional): If provided, saves the generated audio to this file path. Defaults to None.
210
+ text_language (str, optional): Language of the target text. If None, uses the reference audio language.
211
+
212
+ Yields:
213
+ bytes: A chunk of the generated audio data.
214
+
215
+ Raises:
216
+ ValueError: If 'set_reference_audio' has not been called for the character.
217
+ """
218
+ if character_name not in _reference_audios:
219
+ raise ValueError("Please call 'set_reference_audio' first to set the reference audio.")
220
+
221
+ if save_path:
222
+ save_path = os.fspath(save_path)
223
+ parent_dir = os.path.dirname(save_path)
224
+ if parent_dir:
225
+ os.makedirs(parent_dir, exist_ok=True)
226
+
227
+ # 1. 创建 asyncio 队列和获取当前事件循环
228
+ stream_queue: asyncio.Queue[Union[bytes, None]] = asyncio.Queue()
229
+ loop = asyncio.get_running_loop()
230
+
231
+ # 2. 定义回调函数,用于在线程和 asyncio 之间安全地传递数据
232
+ def tts_chunk_callback(c: Optional[bytes]):
233
+ """This callback is called from the TTS worker thread."""
234
+ loop.call_soon_threadsafe(stream_queue.put_nowait, c)
235
+
236
+ # 设置 TTS 上下文
237
+ context.current_speaker = character_name
238
+ context.current_prompt_audio = ReferenceAudio(
239
+ prompt_wav=_reference_audios[character_name]['audio_path'],
240
+ prompt_text=_reference_audios[character_name]['audio_text'],
241
+ language=_reference_audios[character_name]['language'],
242
+ )
243
+ # 设置目标文本语言(跨语言TTS)
244
+ context.current_text_language = normalize_language(text_language) if text_language else None
245
+
246
+ # 3. 使用新的回调接口启动 TTS 会话
247
+ tts_player.start_session(
248
+ play=play,
249
+ split=split_sentence,
250
+ save_path=save_path,
251
+ chunk_callback=tts_chunk_callback,
252
+ )
253
+
254
+ # 馈送文本并通知会话结束
255
+ tts_player.feed(text)
256
+ tts_player.end_session()
257
+
258
+ # 4. 从队列中异步读取数据并产生
259
+ while True:
260
+ chunk = await stream_queue.get()
261
+ if chunk is None:
262
+ break
263
+ yield chunk
264
+
265
+
266
+ def tts(
267
+ character_name: str,
268
+ text: str,
269
+ play: bool = False,
270
+ split_sentence: bool = True,
271
+ save_path: Union[str, PathLike, None] = None,
272
+ text_language: str = None, # 新增:目标文本语言,用于跨语言TTS
273
+ ) -> None:
274
+ """
275
+ Synchronously generates speech from text.
276
+
277
+ This is a blocking function that will not return until the entire TTS
278
+ process is complete.
279
+
280
+ Args:
281
+ character_name (str): The name of the character to use for synthesis.
282
+ text (str): The text to be synthesized into speech.
283
+ play (bool, optional): If True, plays the audio.
284
+ split_sentence (bool, optional): If True, splits the text into sentences for synthesis.
285
+ save_path (str | PathLike | None, optional): If provided, saves the generated audio to this file path. Defaults to None.
286
+ text_language (str, optional): Language of the target text. If None, uses the reference audio language.
287
+ """
288
+ if character_name not in _reference_audios:
289
+ logger.error("Please call 'set_reference_audio' first to set the reference audio.")
290
+ return
291
+
292
+ if save_path:
293
+ save_path = os.fspath(save_path)
294
+ parent_dir = os.path.dirname(save_path)
295
+ if parent_dir:
296
+ os.makedirs(parent_dir, exist_ok=True)
297
+
298
+ context.current_speaker = character_name
299
+ context.current_prompt_audio = ReferenceAudio(
300
+ prompt_wav=_reference_audios[character_name]['audio_path'],
301
+ prompt_text=_reference_audios[character_name]['audio_text'],
302
+ language=_reference_audios[character_name]['language'],
303
+ )
304
+ # 设置目标文本语言(跨语言TTS)
305
+ context.current_text_language = normalize_language(text_language) if text_language else None
306
+
307
+ tts_player.start_session(
308
+ play=play,
309
+ split=split_sentence,
310
+ save_path=save_path,
311
+ )
312
+ tts_player.feed(text)
313
+ tts_player.end_session()
314
+ tts_player.wait_for_tts_completion()
315
+
316
+
317
+ def wait_for_playback_done() -> None:
318
+ """
319
+ Wait until all TTS tasks have finished processing and playback has fully completed.
320
+ """
321
+ tts_player.wait_for_playback_done()
322
+
323
+
324
+ def stop() -> None:
325
+ """
326
+ Stops the currently playing text-to-speech audio.
327
+ """
328
+ tts_player.stop()
329
+
330
+
331
+ def convert_to_onnx(
332
+ torch_ckpt_path: Union[str, PathLike],
333
+ torch_pth_path: Union[str, PathLike],
334
+ output_dir: Union[str, PathLike],
335
+ ) -> None:
336
+ """
337
+ Converts PyTorch model checkpoints to the ONNX format.
338
+
339
+ This function requires PyTorch to be installed.
340
+
341
+ Args:
342
+ torch_ckpt_path (str | PathLike): The path to the T2S model (.ckpt) file.
343
+ torch_pth_path (str | PathLike): The path to the VITS model (.pth) file.
344
+ output_dir (str | PathLike): The directory where the ONNX models will be saved.
345
+ """
346
+ try:
347
+ import torch
348
+ except ImportError:
349
+ logger.error("❌ PyTorch is not installed. Please run `pip install torch` first.")
350
+ return
351
+
352
+ from .Converter.Converter import convert
353
+
354
+ torch_ckpt_path = os.fspath(torch_ckpt_path)
355
+ torch_pth_path = os.fspath(torch_pth_path)
356
+ output_dir = os.fspath(output_dir)
357
+
358
+ convert(
359
+ torch_pth_path=torch_pth_path,
360
+ torch_ckpt_path=torch_ckpt_path,
361
+ output_dir=output_dir,
362
+ )
363
+
364
+
365
+ def clear_reference_audio_cache() -> None:
366
+ """
367
+ Clears the cache of reference audio data.
368
+ """
369
+ ReferenceAudio.clear_cache()
370
+
371
+
372
+ def load_predefined_character(character_name: str) -> None:
373
+ """
374
+ Download and load a predefined character model for TTS inference.
375
+ """
376
+ character_name = character_name.lower().strip()
377
+ if character_name not in CHARA_ALIAS_MAP:
378
+ logger.error(f"No predefined character model found for {character_name}")
379
+ return
380
+ character_name = CHARA_ALIAS_MAP[character_name]
381
+
382
+ save_path = download_chara(character_name)
383
+ model_manager.load_character(
384
+ character_name=character_name,
385
+ model_dir=os.path.join(save_path, 'tts_models'),
386
+ language=CHARA_LANG[character_name],
387
+ )
388
+
389
+ with open(os.path.join(save_path, "prompt_wav.json"), "r", encoding="utf-8") as f:
390
+ prompt_wav_dict: Dict[str, Dict[str, str]] = json.load(f)
391
+
392
+ audio_text = prompt_wav_dict["Normal"]["text"]
393
+ audio_path = os.path.join(save_path, "prompt_wav", prompt_wav_dict["Normal"]["wav"])
394
+ _reference_audios[character_name] = {
395
+ 'audio_path': audio_path,
396
+ 'audio_text': audio_text,
397
+ 'language': CHARA_LANG[character_name],
398
+ }
399
+ context.current_prompt_audio = ReferenceAudio(
400
+ prompt_wav=audio_path,
401
+ prompt_text=audio_text,
402
+ language=CHARA_LANG[character_name],
403
+ )
Shared.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Optional
2
+
3
+ if TYPE_CHECKING:
4
+ from ..Audio.ReferenceAudio import ReferenceAudio
5
+
6
+
7
+ class Context:
8
+ def __init__(self):
9
+ self.current_speaker: str = ''
10
+ self.current_prompt_audio: Optional['ReferenceAudio'] = None
11
+ self.current_text_language: Optional[str] = None # 新增:目标文本语言(跨语言TTS)
12
+
13
+
14
+ context: Context = Context()
TTSPlayer.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 文件: .../Core/TTSPlayer.py
2
+
3
+ import queue
4
+ import os
5
+ import threading
6
+
7
+ import numpy as np
8
+ import wave
9
+ from typing import Optional, List, Callable
10
+ import logging
11
+
12
+ from ..Utils.TextSplitter import TextSplitter
13
+ from ..Core.Inference import tts_client
14
+ from ..ModelManager import model_manager
15
+ from ..Utils.Shared import context
16
+ from ..Utils.Utils import clear_queue
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ STREAM_END = 'STREAM_END' # 这是一个特殊的标记,表示文本流结束
21
+ AUDIO_STREAM_END = 'AUDIO_STREAM_END' # 新增:特殊的标记,表示音频流播放结束
22
+
23
+
24
+ class TTSPlayer:
25
+ def __init__(self, sample_rate: int = 32000):
26
+ self._text_splitter = TextSplitter()
27
+
28
+ self.sample_rate: int = sample_rate
29
+ self.channels: int = 1
30
+ self.bytes_per_sample: int = 2 # 16-bit audio
31
+
32
+ self._text_queue: queue.Queue = queue.Queue()
33
+ self._audio_queue: queue.Queue = queue.Queue()
34
+
35
+ self._stop_event: threading.Event = threading.Event()
36
+ self._tts_done_event: threading.Event = threading.Event()
37
+ self._playback_done_event: threading.Event = threading.Event() # 新增:用于标记播放完成
38
+ self._api_lock: threading.Lock = threading.Lock()
39
+
40
+ self._tts_worker: Optional[threading.Thread] = None
41
+ self._playback_worker: Optional[threading.Thread] = None
42
+
43
+ self._play: bool = False
44
+ self._current_save_path: Optional[str] = None
45
+ self._session_audio_chunks: List[np.ndarray] = []
46
+ self._split: bool = False
47
+
48
+ self._chunk_callback: Optional[Callable[[Optional[bytes]], None]] = None
49
+
50
+ @staticmethod
51
+ def _preprocess_for_playback(audio_float: np.ndarray) -> bytes:
52
+ audio_int16 = (audio_float.squeeze() * 32767).astype(np.int16)
53
+ return audio_int16.tobytes()
54
+
55
+ def _tts_worker_loop(self):
56
+ """从文本队列取句子,生成音频,并通过回调函数或音频队列分发。"""
57
+ while not self._stop_event.is_set():
58
+ try:
59
+ sentence = self._text_queue.get(timeout=1)
60
+ if sentence is None or self._stop_event.is_set():
61
+ break
62
+ except queue.Empty:
63
+ continue
64
+
65
+ try:
66
+ if sentence is STREAM_END:
67
+ if self._current_save_path and self._session_audio_chunks:
68
+ self._save_session_audio()
69
+
70
+ # 在TTS工作线程完成时,通过回调发送结束信号
71
+ if self._chunk_callback:
72
+ self._chunk_callback(None)
73
+
74
+ # 新增:如果开启了播放,通知音频队列流已结束
75
+ if self._play:
76
+ self._audio_queue.put(AUDIO_STREAM_END)
77
+
78
+ self._tts_done_event.set()
79
+ continue
80
+
81
+ gsv_model = model_manager.get(context.current_speaker)
82
+ if not gsv_model or not context.current_prompt_audio:
83
+ logger.error("Missing model or reference audio.")
84
+ continue
85
+
86
+ tts_client.stop_event.clear()
87
+ audio_chunk = tts_client.tts(
88
+ text=sentence,
89
+ prompt_audio=context.current_prompt_audio,
90
+ encoder=gsv_model.T2S_ENCODER,
91
+ first_stage_decoder=gsv_model.T2S_FIRST_STAGE_DECODER,
92
+ stage_decoder=gsv_model.T2S_STAGE_DECODER,
93
+ vocoder=gsv_model.VITS,
94
+ prompt_encoder=gsv_model.PROMPT_ENCODER,
95
+ language=gsv_model.LANGUAGE,
96
+ text_language=context.current_text_language, # 新增:跨语言TTS支持
97
+ )
98
+
99
+ if audio_chunk is not None:
100
+ if self._play:
101
+ self._audio_queue.put(audio_chunk)
102
+ if self._current_save_path:
103
+ self._session_audio_chunks.append(audio_chunk)
104
+
105
+ # 使用回调函数处理流式数据
106
+ if self._chunk_callback:
107
+ audio_data = self._preprocess_for_playback(audio_chunk)
108
+ self._chunk_callback(audio_data)
109
+
110
+ except Exception as e:
111
+ logger.error(f"A critical error occurred while processing the TTS task: {e}", exc_info=True)
112
+ # 发生错误时,也要确保发送结束信号
113
+ if self._chunk_callback:
114
+ self._chunk_callback(None)
115
+ self._tts_done_event.set()
116
+
117
+ def _playback_worker_loop(self):
118
+ try:
119
+ import sounddevice as sd
120
+ with sd.OutputStream(samplerate=self.sample_rate,
121
+ channels=self.channels,
122
+ dtype='float32') as stream:
123
+ while not self._stop_event.is_set():
124
+ try:
125
+ audio_chunk = self._audio_queue.get(timeout=1)
126
+ if audio_chunk is None:
127
+ break
128
+ if audio_chunk is AUDIO_STREAM_END:
129
+ self._playback_done_event.set()
130
+ continue
131
+ stream.write(audio_chunk.squeeze())
132
+ except queue.Empty:
133
+ continue
134
+ except Exception as e:
135
+ logger.error(f"Error during audio playback: {e}", exc_info=True)
136
+
137
+ except Exception as e:
138
+ logger.warning(f"Failed to initialize sounddevice: {e}. Audio playback will be skipped.")
139
+ # 如果音频设备初始化失败,即使不播放,也要消费队列中的结束信号,防止主线程死锁
140
+ while not self._stop_event.is_set():
141
+ try:
142
+ item = self._audio_queue.get(timeout=0.5)
143
+ if item is None:
144
+ break
145
+ if item is AUDIO_STREAM_END:
146
+ self._playback_done_event.set()
147
+ except queue.Empty:
148
+ continue
149
+
150
+ def _save_session_audio(self):
151
+ try:
152
+ full_audio = np.concatenate(self._session_audio_chunks, axis=0)
153
+ with wave.open(self._current_save_path, 'wb') as wf:
154
+ wf.setnchannels(self.channels)
155
+ wf.setsampwidth(self.bytes_per_sample)
156
+ wf.setframerate(self.sample_rate)
157
+ wf.writeframes(self._preprocess_for_playback(full_audio))
158
+ logger.info(f"Audio successfully saved to {os.path.abspath(self._current_save_path)}")
159
+ except Exception as e:
160
+ logger.error(f"Failed to save audio: {e}")
161
+ finally:
162
+ self._session_audio_chunks = []
163
+ self._current_save_path = None
164
+
165
+ def start_session(
166
+ self,
167
+ play: bool = False,
168
+ split: bool = False,
169
+ save_path: Optional[str] = None,
170
+ chunk_callback: Optional[Callable[[Optional[bytes]], None]] = None
171
+ ):
172
+ with self._api_lock:
173
+ self._tts_done_event.clear()
174
+ self._playback_done_event.clear() # 新增:重置播放完成事件
175
+ self._chunk_callback = chunk_callback
176
+ self._stop_event.clear()
177
+
178
+ if self._tts_worker is None or not self._tts_worker.is_alive():
179
+ self._tts_worker = threading.Thread(target=self._tts_worker_loop, daemon=True)
180
+ self._tts_worker.start()
181
+
182
+ if self._playback_worker is None or not self._playback_worker.is_alive():
183
+ self._playback_worker = threading.Thread(target=self._playback_worker_loop, daemon=True)
184
+ self._playback_worker.start()
185
+
186
+ clear_queue(self._text_queue)
187
+ clear_queue(self._audio_queue)
188
+
189
+ self._play = play
190
+ self._split = split
191
+ self._current_save_path = save_path
192
+ self._session_audio_chunks = []
193
+
194
+ def feed(self, text_chunk: str):
195
+ with self._api_lock:
196
+ if not text_chunk:
197
+ return
198
+ if self._split:
199
+ sentences = self._text_splitter.split(text_chunk.strip())
200
+ for sentence in sentences:
201
+ self._text_queue.put(sentence)
202
+ else:
203
+ self._text_queue.put(text_chunk)
204
+
205
+ def end_session(self):
206
+ with self._api_lock:
207
+ self._text_queue.put(STREAM_END)
208
+
209
+ def stop(self):
210
+ with self._api_lock:
211
+ if self._tts_worker is None and self._playback_worker is None:
212
+ return
213
+ if self._stop_event.is_set():
214
+ return
215
+ tts_client.stop_event.set()
216
+ self._stop_event.set()
217
+ self._tts_done_event.set()
218
+ self._text_queue.put(None)
219
+ self._audio_queue.put(None)
220
+ if self._tts_worker and self._tts_worker.is_alive():
221
+ self._tts_worker.join()
222
+ if self._playback_worker and self._playback_worker.is_alive():
223
+ self._playback_worker.join()
224
+ self._tts_worker = None
225
+ self._playback_worker = None
226
+
227
+ def wait_for_tts_completion(self):
228
+ if self._tts_done_event.is_set():
229
+ return
230
+ self._tts_done_event.wait()
231
+
232
+ def wait_for_playback_done(self):
233
+ # 1. 首先等待TTS生成全部完成
234
+ self.wait_for_tts_completion()
235
+
236
+ # 2. 如果开启了播放且没有被强制停止,则等待播放结束
237
+ if self._play and not self._stop_event.is_set():
238
+ if not self._playback_done_event.is_set():
239
+ self._playback_done_event.wait()
240
+
241
+
242
+ tts_player: TTSPlayer = TTSPlayer()
app.py CHANGED
@@ -114,10 +114,12 @@ async def dynamic_tts(
114
  character_name: str = Form("Base"),
115
  prompt_text: str = Form(None),
116
  prompt_lang: str = Form("zh"),
 
117
  use_default_ref: bool = Form(True)
118
  ):
119
  """
120
  通用 TTS 接口,支持切换已加载的角色
 
121
  """
122
  try:
123
  # 优先使用指定的角色,如果没有则尝试用 Base,如果都没有则报错
@@ -134,7 +136,7 @@ async def dynamic_tts(
134
  genie_tts.set_reference_audio(character_name, ref_info["path"], final_text, prompt_lang)
135
 
136
  out_path = f"/tmp/out_dyn_{int(time.time())}.wav"
137
- genie_tts.tts(character_name, text, save_path=out_path, play=False)
138
 
139
  return StreamingResponse(open(out_path, "rb"), media_type="audio/wav")
140
  except Exception as e:
 
114
  character_name: str = Form("Base"),
115
  prompt_text: str = Form(None),
116
  prompt_lang: str = Form("zh"),
117
+ text_lang: str = Form(None), # 新增:目标文本语言(跨语言TTS)
118
  use_default_ref: bool = Form(True)
119
  ):
120
  """
121
  通用 TTS 接口,支持切换已加载的角色
122
+ text_lang: 目标文本语言,如果和参考音频不同则可实现跨语言合成
123
  """
124
  try:
125
  # 优先使用指定的角色,如果没有则尝试用 Base,如果都没有则报错
 
136
  genie_tts.set_reference_audio(character_name, ref_info["path"], final_text, prompt_lang)
137
 
138
  out_path = f"/tmp/out_dyn_{int(time.time())}.wav"
139
+ genie_tts.tts(character_name, text, save_path=out_path, play=False, text_language=text_lang)
140
 
141
  return StreamingResponse(open(out_path, "rb"), media_type="audio/wav")
142
  except Exception as e: