simler commited on
Commit
d479a15
·
verified ·
1 Parent(s): 30ab98f

Upload 68 files

Browse files
genie_tts/Core/Inference.py CHANGED
@@ -1,112 +1,115 @@
1
- import onnxruntime as ort
2
- import numpy as np
3
- from typing import List, Optional
4
- import threading
5
-
6
- from ..Audio.ReferenceAudio import ReferenceAudio
7
- from ..GetPhonesAndBert import get_phones_and_bert
8
-
9
- MAX_T2S_LEN = 1000
10
-
11
-
12
- class GENIE:
13
- def __init__(self):
14
- self.stop_event: threading.Event = threading.Event()
15
-
16
- def tts(
17
- self,
18
- text: str,
19
- prompt_audio: ReferenceAudio,
20
- encoder: ort.InferenceSession,
21
- first_stage_decoder: ort.InferenceSession,
22
- stage_decoder: ort.InferenceSession,
23
- vocoder: ort.InferenceSession,
24
- prompt_encoder: Optional[ort.InferenceSession],
25
- language: str = 'japanese',
26
- ) -> Optional[np.ndarray]:
27
- text = '。' + text # 防止漏第一句。
28
- text_seq, text_bert = get_phones_and_bert(text, language=language)
29
-
30
- semantic_tokens: np.ndarray = self.t2s_cpu(
31
- ref_seq=prompt_audio.phonemes_seq,
32
- ref_bert=prompt_audio.text_bert,
33
- text_seq=text_seq,
34
- text_bert=text_bert,
35
- ssl_content=prompt_audio.ssl_content,
36
- encoder=encoder,
37
- first_stage_decoder=first_stage_decoder,
38
- stage_decoder=stage_decoder,
39
- )
40
-
41
- eos_indices = np.where(semantic_tokens >= 1024) # 剔除不合法的元素,例如 EOS Token。
42
- if len(eos_indices[0]) > 0:
43
- first_eos_index = eos_indices[-1][0]
44
- semantic_tokens = semantic_tokens[..., :first_eos_index]
45
-
46
- if prompt_encoder is None:
47
- return vocoder.run(None, {
48
- "text_seq": text_seq,
49
- "pred_semantic": semantic_tokens,
50
- "ref_audio": prompt_audio.audio_32k
51
- })[0]
52
- else:
53
- # V2ProPlus 新增。
54
- prompt_audio.update_global_emb(prompt_encoder=prompt_encoder)
55
- audio_chunk = vocoder.run(None, {
56
- "text_seq": text_seq,
57
- "pred_semantic": semantic_tokens,
58
- "ge": prompt_audio.global_emb,
59
- "ge_advanced": prompt_audio.global_emb_advanced,
60
- })[0]
61
- return audio_chunk
62
-
63
- def t2s_cpu(
64
- self,
65
- ref_seq: np.ndarray,
66
- ref_bert: np.ndarray,
67
- text_seq: np.ndarray,
68
- text_bert: np.ndarray,
69
- ssl_content: np.ndarray,
70
- encoder: ort.InferenceSession,
71
- first_stage_decoder: ort.InferenceSession,
72
- stage_decoder: ort.InferenceSession,
73
- ) -> Optional[np.ndarray]:
74
- """在CPU上运行T2S模型"""
75
- # Encoder
76
- x, prompts = encoder.run(
77
- None,
78
- {
79
- "ref_seq": ref_seq,
80
- "text_seq": text_seq,
81
- "ref_bert": ref_bert,
82
- "text_bert": text_bert,
83
- "ssl_content": ssl_content,
84
- },
85
- )
86
-
87
- # First Stage Decoder
88
- y, y_emb, *present_key_values = first_stage_decoder.run(
89
- None, {"x": x, "prompts": prompts}
90
- )
91
-
92
- # Stage Decoder
93
- input_names: List[str] = [inp.name for inp in stage_decoder.get_inputs()]
94
- idx: int = 0
95
- for idx in range(0, 500):
96
- if self.stop_event.is_set():
97
- return None
98
- input_feed = {
99
- name: data
100
- for name, data in zip(input_names, [y, y_emb, *present_key_values])
101
- }
102
- outputs = stage_decoder.run(None, input_feed)
103
- y, y_emb, stop_condition_tensor, *present_key_values = outputs
104
-
105
- if stop_condition_tensor:
106
- break
107
-
108
- y[0, -1] = 0
109
- return np.expand_dims(y[:, -idx:], axis=0)
110
-
111
-
112
- tts_client: GENIE = GENIE()
 
 
 
 
1
+ import onnxruntime as ort
2
+ import numpy as np
3
+ from typing import List, Optional
4
+ import threading
5
+
6
+ from ..Audio.ReferenceAudio import ReferenceAudio
7
+ from ..GetPhonesAndBert import get_phones_and_bert
8
+
9
+ MAX_T2S_LEN = 1000
10
+
11
+
12
+ class GENIE:
13
+ def __init__(self):
14
+ self.stop_event: threading.Event = threading.Event()
15
+
16
+ def tts(
17
+ self,
18
+ text: str,
19
+ prompt_audio: ReferenceAudio,
20
+ encoder: ort.InferenceSession,
21
+ first_stage_decoder: ort.InferenceSession,
22
+ stage_decoder: ort.InferenceSession,
23
+ vocoder: ort.InferenceSession,
24
+ prompt_encoder: Optional[ort.InferenceSession],
25
+ language: str = 'japanese',
26
+ text_language: str = None, # 新增:目标文本语言,默认使用参考音频语言
27
+ ) -> Optional[np.ndarray]:
28
+ # 如果未指定 text_language,则使用参考音频的语言
29
+ actual_text_language = text_language if text_language else language
30
+ text = '。' + text # 防止漏第一句。
31
+ text_seq, text_bert = get_phones_and_bert(text, language=actual_text_language)
32
+
33
+ semantic_tokens: np.ndarray = self.t2s_cpu(
34
+ ref_seq=prompt_audio.phonemes_seq,
35
+ ref_bert=prompt_audio.text_bert,
36
+ text_seq=text_seq,
37
+ text_bert=text_bert,
38
+ ssl_content=prompt_audio.ssl_content,
39
+ encoder=encoder,
40
+ first_stage_decoder=first_stage_decoder,
41
+ stage_decoder=stage_decoder,
42
+ )
43
+
44
+ eos_indices = np.where(semantic_tokens >= 1024) # 剔除不合法的元素,例如 EOS Token。
45
+ if len(eos_indices[0]) > 0:
46
+ first_eos_index = eos_indices[-1][0]
47
+ semantic_tokens = semantic_tokens[..., :first_eos_index]
48
+
49
+ if prompt_encoder is None:
50
+ return vocoder.run(None, {
51
+ "text_seq": text_seq,
52
+ "pred_semantic": semantic_tokens,
53
+ "ref_audio": prompt_audio.audio_32k
54
+ })[0]
55
+ else:
56
+ # V2ProPlus 新增。
57
+ prompt_audio.update_global_emb(prompt_encoder=prompt_encoder)
58
+ audio_chunk = vocoder.run(None, {
59
+ "text_seq": text_seq,
60
+ "pred_semantic": semantic_tokens,
61
+ "ge": prompt_audio.global_emb,
62
+ "ge_advanced": prompt_audio.global_emb_advanced,
63
+ })[0]
64
+ return audio_chunk
65
+
66
+ def t2s_cpu(
67
+ self,
68
+ ref_seq: np.ndarray,
69
+ ref_bert: np.ndarray,
70
+ text_seq: np.ndarray,
71
+ text_bert: np.ndarray,
72
+ ssl_content: np.ndarray,
73
+ encoder: ort.InferenceSession,
74
+ first_stage_decoder: ort.InferenceSession,
75
+ stage_decoder: ort.InferenceSession,
76
+ ) -> Optional[np.ndarray]:
77
+ """在CPU上运行T2S模型"""
78
+ # Encoder
79
+ x, prompts = encoder.run(
80
+ None,
81
+ {
82
+ "ref_seq": ref_seq,
83
+ "text_seq": text_seq,
84
+ "ref_bert": ref_bert,
85
+ "text_bert": text_bert,
86
+ "ssl_content": ssl_content,
87
+ },
88
+ )
89
+
90
+ # First Stage Decoder
91
+ y, y_emb, *present_key_values = first_stage_decoder.run(
92
+ None, {"x": x, "prompts": prompts}
93
+ )
94
+
95
+ # Stage Decoder
96
+ input_names: List[str] = [inp.name for inp in stage_decoder.get_inputs()]
97
+ idx: int = 0
98
+ for idx in range(0, 500):
99
+ if self.stop_event.is_set():
100
+ return None
101
+ input_feed = {
102
+ name: data
103
+ for name, data in zip(input_names, [y, y_emb, *present_key_values])
104
+ }
105
+ outputs = stage_decoder.run(None, input_feed)
106
+ y, y_emb, stop_condition_tensor, *present_key_values = outputs
107
+
108
+ if stop_condition_tensor:
109
+ break
110
+
111
+ y[0, -1] = 0
112
+ return np.expand_dims(y[:, -idx:], axis=0)
113
+
114
+
115
+ tts_client: GENIE = GENIE()
genie_tts/Core/TTSPlayer.py CHANGED
@@ -1,241 +1,242 @@
1
- # 文件: .../Core/TTSPlayer.py
2
-
3
- import queue
4
- import os
5
- import threading
6
-
7
- import numpy as np
8
- import wave
9
- from typing import Optional, List, Callable
10
- import logging
11
-
12
- from ..Utils.TextSplitter import TextSplitter
13
- from ..Core.Inference import tts_client
14
- from ..ModelManager import model_manager
15
- from ..Utils.Shared import context
16
- from ..Utils.Utils import clear_queue
17
-
18
- logger = logging.getLogger(__name__)
19
-
20
- STREAM_END = 'STREAM_END' # 这是一个特殊的标记,表示文本流结束
21
- AUDIO_STREAM_END = 'AUDIO_STREAM_END' # 新增:特殊的标记,表示音频流播放结束
22
-
23
-
24
- class TTSPlayer:
25
- def __init__(self, sample_rate: int = 32000):
26
- self._text_splitter = TextSplitter()
27
-
28
- self.sample_rate: int = sample_rate
29
- self.channels: int = 1
30
- self.bytes_per_sample: int = 2 # 16-bit audio
31
-
32
- self._text_queue: queue.Queue = queue.Queue()
33
- self._audio_queue: queue.Queue = queue.Queue()
34
-
35
- self._stop_event: threading.Event = threading.Event()
36
- self._tts_done_event: threading.Event = threading.Event()
37
- self._playback_done_event: threading.Event = threading.Event() # 新增:用于标记播放完成
38
- self._api_lock: threading.Lock = threading.Lock()
39
-
40
- self._tts_worker: Optional[threading.Thread] = None
41
- self._playback_worker: Optional[threading.Thread] = None
42
-
43
- self._play: bool = False
44
- self._current_save_path: Optional[str] = None
45
- self._session_audio_chunks: List[np.ndarray] = []
46
- self._split: bool = False
47
-
48
- self._chunk_callback: Optional[Callable[[Optional[bytes]], None]] = None
49
-
50
- @staticmethod
51
- def _preprocess_for_playback(audio_float: np.ndarray) -> bytes:
52
- audio_int16 = (audio_float.squeeze() * 32767).astype(np.int16)
53
- return audio_int16.tobytes()
54
-
55
- def _tts_worker_loop(self):
56
- """从文本队列取句子,生成音频,并通过回调函数或音频队列分发。"""
57
- while not self._stop_event.is_set():
58
- try:
59
- sentence = self._text_queue.get(timeout=1)
60
- if sentence is None or self._stop_event.is_set():
61
- break
62
- except queue.Empty:
63
- continue
64
-
65
- try:
66
- if sentence is STREAM_END:
67
- if self._current_save_path and self._session_audio_chunks:
68
- self._save_session_audio()
69
-
70
- # 在TTS工作线程完成时,通过回调发送结束信号
71
- if self._chunk_callback:
72
- self._chunk_callback(None)
73
-
74
- # 新增:如果开启了播放,通知音频队列流已结束
75
- if self._play:
76
- self._audio_queue.put(AUDIO_STREAM_END)
77
-
78
- self._tts_done_event.set()
79
- continue
80
-
81
- gsv_model = model_manager.get(context.current_speaker)
82
- if not gsv_model or not context.current_prompt_audio:
83
- logger.error("Missing model or reference audio.")
84
- continue
85
-
86
- tts_client.stop_event.clear()
87
- audio_chunk = tts_client.tts(
88
- text=sentence,
89
- prompt_audio=context.current_prompt_audio,
90
- encoder=gsv_model.T2S_ENCODER,
91
- first_stage_decoder=gsv_model.T2S_FIRST_STAGE_DECODER,
92
- stage_decoder=gsv_model.T2S_STAGE_DECODER,
93
- vocoder=gsv_model.VITS,
94
- prompt_encoder=gsv_model.PROMPT_ENCODER,
95
- language=gsv_model.LANGUAGE,
96
- )
97
-
98
- if audio_chunk is not None:
99
- if self._play:
100
- self._audio_queue.put(audio_chunk)
101
- if self._current_save_path:
102
- self._session_audio_chunks.append(audio_chunk)
103
-
104
- # 使用回调函数处理流式数据
105
- if self._chunk_callback:
106
- audio_data = self._preprocess_for_playback(audio_chunk)
107
- self._chunk_callback(audio_data)
108
-
109
- except Exception as e:
110
- logger.error(f"A critical error occurred while processing the TTS task: {e}", exc_info=True)
111
- # 发生错误时,也要确保发送结束信号
112
- if self._chunk_callback:
113
- self._chunk_callback(None)
114
- self._tts_done_event.set()
115
-
116
- def _playback_worker_loop(self):
117
- try:
118
- import sounddevice as sd
119
- with sd.OutputStream(samplerate=self.sample_rate,
120
- channels=self.channels,
121
- dtype='float32') as stream:
122
- while not self._stop_event.is_set():
123
- try:
124
- audio_chunk = self._audio_queue.get(timeout=1)
125
- if audio_chunk is None:
126
- break
127
- if audio_chunk is AUDIO_STREAM_END:
128
- self._playback_done_event.set()
129
- continue
130
- stream.write(audio_chunk.squeeze())
131
- except queue.Empty:
132
- continue
133
- except Exception as e:
134
- logger.error(f"Error during audio playback: {e}", exc_info=True)
135
-
136
- except Exception as e:
137
- logger.warning(f"Failed to initialize sounddevice: {e}. Audio playback will be skipped.")
138
- # 如果音频设备初始化失败,即使不播放,也要消费队列中的结束信号,防止主线程死锁
139
- while not self._stop_event.is_set():
140
- try:
141
- item = self._audio_queue.get(timeout=0.5)
142
- if item is None:
143
- break
144
- if item is AUDIO_STREAM_END:
145
- self._playback_done_event.set()
146
- except queue.Empty:
147
- continue
148
-
149
- def _save_session_audio(self):
150
- try:
151
- full_audio = np.concatenate(self._session_audio_chunks, axis=0)
152
- with wave.open(self._current_save_path, 'wb') as wf:
153
- wf.setnchannels(self.channels)
154
- wf.setsampwidth(self.bytes_per_sample)
155
- wf.setframerate(self.sample_rate)
156
- wf.writeframes(self._preprocess_for_playback(full_audio))
157
- logger.info(f"Audio successfully saved to {os.path.abspath(self._current_save_path)}")
158
- except Exception as e:
159
- logger.error(f"Failed to save audio: {e}")
160
- finally:
161
- self._session_audio_chunks = []
162
- self._current_save_path = None
163
-
164
- def start_session(
165
- self,
166
- play: bool = False,
167
- split: bool = False,
168
- save_path: Optional[str] = None,
169
- chunk_callback: Optional[Callable[[Optional[bytes]], None]] = None
170
- ):
171
- with self._api_lock:
172
- self._tts_done_event.clear()
173
- self._playback_done_event.clear() # 新增:重置播放完成事件
174
- self._chunk_callback = chunk_callback
175
- self._stop_event.clear()
176
-
177
- if self._tts_worker is None or not self._tts_worker.is_alive():
178
- self._tts_worker = threading.Thread(target=self._tts_worker_loop, daemon=True)
179
- self._tts_worker.start()
180
-
181
- if self._playback_worker is None or not self._playback_worker.is_alive():
182
- self._playback_worker = threading.Thread(target=self._playback_worker_loop, daemon=True)
183
- self._playback_worker.start()
184
-
185
- clear_queue(self._text_queue)
186
- clear_queue(self._audio_queue)
187
-
188
- self._play = play
189
- self._split = split
190
- self._current_save_path = save_path
191
- self._session_audio_chunks = []
192
-
193
- def feed(self, text_chunk: str):
194
- with self._api_lock:
195
- if not text_chunk:
196
- return
197
- if self._split:
198
- sentences = self._text_splitter.split(text_chunk.strip())
199
- for sentence in sentences:
200
- self._text_queue.put(sentence)
201
- else:
202
- self._text_queue.put(text_chunk)
203
-
204
- def end_session(self):
205
- with self._api_lock:
206
- self._text_queue.put(STREAM_END)
207
-
208
- def stop(self):
209
- with self._api_lock:
210
- if self._tts_worker is None and self._playback_worker is None:
211
- return
212
- if self._stop_event.is_set():
213
- return
214
- tts_client.stop_event.set()
215
- self._stop_event.set()
216
- self._tts_done_event.set()
217
- self._text_queue.put(None)
218
- self._audio_queue.put(None)
219
- if self._tts_worker and self._tts_worker.is_alive():
220
- self._tts_worker.join()
221
- if self._playback_worker and self._playback_worker.is_alive():
222
- self._playback_worker.join()
223
- self._tts_worker = None
224
- self._playback_worker = None
225
-
226
- def wait_for_tts_completion(self):
227
- if self._tts_done_event.is_set():
228
- return
229
- self._tts_done_event.wait()
230
-
231
- def wait_for_playback_done(self):
232
- # 1. 首先等待TTS生成全部完成
233
- self.wait_for_tts_completion()
234
-
235
- # 2. 如果开启了播放且没有被强制停止,则等待播放结束
236
- if self._play and not self._stop_event.is_set():
237
- if not self._playback_done_event.is_set():
238
- self._playback_done_event.wait()
239
-
240
-
241
- tts_player: TTSPlayer = TTSPlayer()
 
 
1
+ # 文件: .../Core/TTSPlayer.py
2
+
3
+ import queue
4
+ import os
5
+ import threading
6
+
7
+ import numpy as np
8
+ import wave
9
+ from typing import Optional, List, Callable
10
+ import logging
11
+
12
+ from ..Utils.TextSplitter import TextSplitter
13
+ from ..Core.Inference import tts_client
14
+ from ..ModelManager import model_manager
15
+ from ..Utils.Shared import context
16
+ from ..Utils.Utils import clear_queue
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ STREAM_END = 'STREAM_END' # 这是一个特殊的标记,表示文本流结束
21
+ AUDIO_STREAM_END = 'AUDIO_STREAM_END' # 新增:特殊的标记,表示音频流播放结束
22
+
23
+
24
+ class TTSPlayer:
25
+ def __init__(self, sample_rate: int = 32000):
26
+ self._text_splitter = TextSplitter()
27
+
28
+ self.sample_rate: int = sample_rate
29
+ self.channels: int = 1
30
+ self.bytes_per_sample: int = 2 # 16-bit audio
31
+
32
+ self._text_queue: queue.Queue = queue.Queue()
33
+ self._audio_queue: queue.Queue = queue.Queue()
34
+
35
+ self._stop_event: threading.Event = threading.Event()
36
+ self._tts_done_event: threading.Event = threading.Event()
37
+ self._playback_done_event: threading.Event = threading.Event() # 新增:用于标记播放完成
38
+ self._api_lock: threading.Lock = threading.Lock()
39
+
40
+ self._tts_worker: Optional[threading.Thread] = None
41
+ self._playback_worker: Optional[threading.Thread] = None
42
+
43
+ self._play: bool = False
44
+ self._current_save_path: Optional[str] = None
45
+ self._session_audio_chunks: List[np.ndarray] = []
46
+ self._split: bool = False
47
+
48
+ self._chunk_callback: Optional[Callable[[Optional[bytes]], None]] = None
49
+
50
+ @staticmethod
51
+ def _preprocess_for_playback(audio_float: np.ndarray) -> bytes:
52
+ audio_int16 = (audio_float.squeeze() * 32767).astype(np.int16)
53
+ return audio_int16.tobytes()
54
+
55
+ def _tts_worker_loop(self):
56
+ """从文本队列取句子,生成音频,并通过回调函数或音频队列分发。"""
57
+ while not self._stop_event.is_set():
58
+ try:
59
+ sentence = self._text_queue.get(timeout=1)
60
+ if sentence is None or self._stop_event.is_set():
61
+ break
62
+ except queue.Empty:
63
+ continue
64
+
65
+ try:
66
+ if sentence is STREAM_END:
67
+ if self._current_save_path and self._session_audio_chunks:
68
+ self._save_session_audio()
69
+
70
+ # 在TTS工作线程完成时,通过回调发送结束信号
71
+ if self._chunk_callback:
72
+ self._chunk_callback(None)
73
+
74
+ # 新增:如果开启了播放,通知音频队列流已结束
75
+ if self._play:
76
+ self._audio_queue.put(AUDIO_STREAM_END)
77
+
78
+ self._tts_done_event.set()
79
+ continue
80
+
81
+ gsv_model = model_manager.get(context.current_speaker)
82
+ if not gsv_model or not context.current_prompt_audio:
83
+ logger.error("Missing model or reference audio.")
84
+ continue
85
+
86
+ tts_client.stop_event.clear()
87
+ audio_chunk = tts_client.tts(
88
+ text=sentence,
89
+ prompt_audio=context.current_prompt_audio,
90
+ encoder=gsv_model.T2S_ENCODER,
91
+ first_stage_decoder=gsv_model.T2S_FIRST_STAGE_DECODER,
92
+ stage_decoder=gsv_model.T2S_STAGE_DECODER,
93
+ vocoder=gsv_model.VITS,
94
+ prompt_encoder=gsv_model.PROMPT_ENCODER,
95
+ language=gsv_model.LANGUAGE,
96
+ text_language=context.current_text_language, # 新增:跨语言TTS支持
97
+ )
98
+
99
+ if audio_chunk is not None:
100
+ if self._play:
101
+ self._audio_queue.put(audio_chunk)
102
+ if self._current_save_path:
103
+ self._session_audio_chunks.append(audio_chunk)
104
+
105
+ # 使用回调函数处理流式数据
106
+ if self._chunk_callback:
107
+ audio_data = self._preprocess_for_playback(audio_chunk)
108
+ self._chunk_callback(audio_data)
109
+
110
+ except Exception as e:
111
+ logger.error(f"A critical error occurred while processing the TTS task: {e}", exc_info=True)
112
+ # 发生错误时,也要确保发送结束信号
113
+ if self._chunk_callback:
114
+ self._chunk_callback(None)
115
+ self._tts_done_event.set()
116
+
117
+ def _playback_worker_loop(self):
118
+ try:
119
+ import sounddevice as sd
120
+ with sd.OutputStream(samplerate=self.sample_rate,
121
+ channels=self.channels,
122
+ dtype='float32') as stream:
123
+ while not self._stop_event.is_set():
124
+ try:
125
+ audio_chunk = self._audio_queue.get(timeout=1)
126
+ if audio_chunk is None:
127
+ break
128
+ if audio_chunk is AUDIO_STREAM_END:
129
+ self._playback_done_event.set()
130
+ continue
131
+ stream.write(audio_chunk.squeeze())
132
+ except queue.Empty:
133
+ continue
134
+ except Exception as e:
135
+ logger.error(f"Error during audio playback: {e}", exc_info=True)
136
+
137
+ except Exception as e:
138
+ logger.warning(f"Failed to initialize sounddevice: {e}. Audio playback will be skipped.")
139
+ # 如果音频设备初始化失败,即使不播放,也要消费队列中的结束信号,防止主线程死锁
140
+ while not self._stop_event.is_set():
141
+ try:
142
+ item = self._audio_queue.get(timeout=0.5)
143
+ if item is None:
144
+ break
145
+ if item is AUDIO_STREAM_END:
146
+ self._playback_done_event.set()
147
+ except queue.Empty:
148
+ continue
149
+
150
+ def _save_session_audio(self):
151
+ try:
152
+ full_audio = np.concatenate(self._session_audio_chunks, axis=0)
153
+ with wave.open(self._current_save_path, 'wb') as wf:
154
+ wf.setnchannels(self.channels)
155
+ wf.setsampwidth(self.bytes_per_sample)
156
+ wf.setframerate(self.sample_rate)
157
+ wf.writeframes(self._preprocess_for_playback(full_audio))
158
+ logger.info(f"Audio successfully saved to {os.path.abspath(self._current_save_path)}")
159
+ except Exception as e:
160
+ logger.error(f"Failed to save audio: {e}")
161
+ finally:
162
+ self._session_audio_chunks = []
163
+ self._current_save_path = None
164
+
165
+ def start_session(
166
+ self,
167
+ play: bool = False,
168
+ split: bool = False,
169
+ save_path: Optional[str] = None,
170
+ chunk_callback: Optional[Callable[[Optional[bytes]], None]] = None
171
+ ):
172
+ with self._api_lock:
173
+ self._tts_done_event.clear()
174
+ self._playback_done_event.clear() # 新增:重置播放完成事件
175
+ self._chunk_callback = chunk_callback
176
+ self._stop_event.clear()
177
+
178
+ if self._tts_worker is None or not self._tts_worker.is_alive():
179
+ self._tts_worker = threading.Thread(target=self._tts_worker_loop, daemon=True)
180
+ self._tts_worker.start()
181
+
182
+ if self._playback_worker is None or not self._playback_worker.is_alive():
183
+ self._playback_worker = threading.Thread(target=self._playback_worker_loop, daemon=True)
184
+ self._playback_worker.start()
185
+
186
+ clear_queue(self._text_queue)
187
+ clear_queue(self._audio_queue)
188
+
189
+ self._play = play
190
+ self._split = split
191
+ self._current_save_path = save_path
192
+ self._session_audio_chunks = []
193
+
194
+ def feed(self, text_chunk: str):
195
+ with self._api_lock:
196
+ if not text_chunk:
197
+ return
198
+ if self._split:
199
+ sentences = self._text_splitter.split(text_chunk.strip())
200
+ for sentence in sentences:
201
+ self._text_queue.put(sentence)
202
+ else:
203
+ self._text_queue.put(text_chunk)
204
+
205
+ def end_session(self):
206
+ with self._api_lock:
207
+ self._text_queue.put(STREAM_END)
208
+
209
+ def stop(self):
210
+ with self._api_lock:
211
+ if self._tts_worker is None and self._playback_worker is None:
212
+ return
213
+ if self._stop_event.is_set():
214
+ return
215
+ tts_client.stop_event.set()
216
+ self._stop_event.set()
217
+ self._tts_done_event.set()
218
+ self._text_queue.put(None)
219
+ self._audio_queue.put(None)
220
+ if self._tts_worker and self._tts_worker.is_alive():
221
+ self._tts_worker.join()
222
+ if self._playback_worker and self._playback_worker.is_alive():
223
+ self._playback_worker.join()
224
+ self._tts_worker = None
225
+ self._playback_worker = None
226
+
227
+ def wait_for_tts_completion(self):
228
+ if self._tts_done_event.is_set():
229
+ return
230
+ self._tts_done_event.wait()
231
+
232
+ def wait_for_playback_done(self):
233
+ # 1. 首先等待TTS生成全部完成
234
+ self.wait_for_tts_completion()
235
+
236
+ # 2. 如果开启了播放且没有被强制停止,则等待播放结束
237
+ if self._play and not self._stop_event.is_set():
238
+ if not self._playback_done_event.is_set():
239
+ self._playback_done_event.wait()
240
+
241
+
242
+ tts_player: TTSPlayer = TTSPlayer()
genie_tts/Internal.py CHANGED
@@ -1,395 +1,403 @@
1
- # 请严格遵循导入顺序。
2
- # 1、环境变量。
3
- import os
4
- from os import PathLike
5
-
6
- os.environ["HF_HUB_ENABLE_PROGRESS_BAR"] = "1"
7
-
8
- # 2、Logging & Warnings。
9
- import logging
10
- import warnings
11
-
12
- warnings.filterwarnings("ignore", category=UserWarning, module="jieba_fast._compat")
13
- logging.basicConfig(level=logging.INFO, format="%(message)s", datefmt="[%X]")
14
- logger = logging.getLogger(__name__)
15
-
16
- # 3、ONNX。
17
- import onnxruntime
18
-
19
- onnxruntime.set_default_logger_severity(3)
20
-
21
- # 导入剩余库。
22
-
23
- from pathlib import Path
24
- import json
25
- import asyncio
26
- from typing import AsyncIterator, Optional, Union, Dict
27
-
28
- from .Audio.ReferenceAudio import ReferenceAudio
29
- from .Core.Resources import ensure_exists, Chinese_G2P_DIR, English_G2P_DIR
30
- from .Core.TTSPlayer import tts_player
31
- from .ModelManager import model_manager
32
- from .Utils.Shared import context
33
- from .Utils.Language import normalize_language
34
- from .PredefinedCharacter import download_chara, CHARA_LANG, CHARA_ALIAS_MAP
35
-
36
- # A module-level private dictionary to store reference audio configurations.
37
- _reference_audios: Dict[str, dict] = {}
38
- SUPPORTED_AUDIO_EXTS = {'.wav', '.flac', '.ogg', '.aiff', '.aif'}
39
-
40
-
41
- def check_onnx_model_dir(onnx_model_dir: Union[str, os.PathLike]) -> None:
42
- """
43
- Checks if the directory contains the necessary ONNX model files for Genie TTS (v2 or v2ProPlus).
44
- Raises a FileNotFoundError with detailed instructions if validation fails.
45
- """
46
- model_path = Path(onnx_model_dir)
47
-
48
- # 1. Check if directory exists
49
- if not model_path.exists() or not model_path.is_dir():
50
- raise FileNotFoundError(f"The model directory '{onnx_model_dir}' does not exist or is not a directory.")
51
-
52
- # 2. Define required files
53
- # Base files required by both v2 and v2ProPlus
54
- required_base_files = {
55
- "t2s_encoder_fp32.bin",
56
- "t2s_encoder_fp32.onnx",
57
- "t2s_first_stage_decoder_fp32.onnx",
58
- "t2s_shared_fp16.bin",
59
- "t2s_stage_decoder_fp32.onnx",
60
- "vits_fp16.bin",
61
- "vits_fp32.onnx"
62
- }
63
-
64
- # 3. Get current files in directory
65
- existing_files = set(f.name for f in model_path.iterdir() if f.is_file())
66
-
67
- # 4. Validate
68
- # We check if the base files exist. If base files are missing, the model is definitely unusable.
69
- if not required_base_files.issubset(existing_files):
70
- missing = required_base_files - existing_files
71
-
72
- # Construct detailed error message
73
- error_msg = (
74
- f"\n\n[Genie Error] Invalid ONNX model directory: '{model_path}'\n"
75
- "===============================================================\n"
76
- f"Missing base files: {', '.join(missing)}\n"
77
- "A valid model folder must contain at least the following files.\n"
78
- "1. [v2 Base] (Required for all models):\n"
79
- " - t2s_encoder_fp32.bin\n"
80
- " - t2s_encoder_fp32.onnx\n"
81
- " - t2s_first_stage_decoder_fp32.onnx\n"
82
- " - t2s_shared_fp16.bin\n"
83
- " - t2s_stage_decoder_fp32.onnx\n"
84
- " - vits_fp16.bin\n"
85
- " - vits_fp32.onnx\n"
86
- "2. [v2ProPlus Additions] (Required for v2pp features):\n"
87
- " - prompt_encoder_fp16.bin\n"
88
- " - prompt_encoder_fp32.onnx\n"
89
- "===============================================================\n"
90
- )
91
- raise FileNotFoundError(error_msg)
92
-
93
-
94
- def load_character(
95
- character_name: str,
96
- onnx_model_dir: Union[str, PathLike],
97
- language: str,
98
- ) -> None:
99
- """
100
- Loads a character model from an ONNX model directory.
101
-
102
- Args:
103
- character_name (str): The name to assign to the loaded character.
104
- onnx_model_dir (str | PathLike): The directory path containing the ONNX model files.
105
- language (str): The language of the character model.
106
- """
107
- check_onnx_model_dir(onnx_model_dir)
108
-
109
- language = normalize_language(language)
110
- if language not in ['Japanese', 'English', 'Chinese']:
111
- raise ValueError('Unknown language')
112
-
113
- if language == 'Chinese':
114
- ensure_exists(Chinese_G2P_DIR, "Chinese_G2P_DIR")
115
- elif language == 'English':
116
- ensure_exists(English_G2P_DIR, "English_G2P_DIR")
117
-
118
- model_path: str = os.fspath(onnx_model_dir)
119
- model_manager.load_character(
120
- character_name=character_name,
121
- model_dir=model_path,
122
- language=language,
123
- )
124
-
125
-
126
- def unload_character(
127
- character_name: str,
128
- ) -> None:
129
- """
130
- Unloads a previously loaded character model to free up resources.
131
-
132
- Args:
133
- character_name (str): The name of the character to unload.
134
- """
135
- model_manager.remove_character(
136
- character_name=character_name,
137
- )
138
-
139
-
140
- def set_reference_audio(
141
- character_name: str,
142
- audio_path: Union[str, PathLike],
143
- audio_text: str,
144
- language: str = None,
145
- ) -> None:
146
- """
147
- Sets the reference audio for a character to be used for voice cloning.
148
-
149
- This must be called for a character before using 'tts' or 'tts_async'.
150
-
151
- Args:
152
- character_name (str): The name of the character.
153
- audio_path (str | PathLike): The file path to the reference audio (e.g., a WAV file).
154
- audio_text (str): The transcript of the reference audio.
155
- language (str): The language of the reference audio.
156
- """
157
- audio_path: str = os.fspath(audio_path)
158
-
159
- # 检查文件后缀是否支持
160
- ext = os.path.splitext(audio_path)[1].lower()
161
- if ext not in SUPPORTED_AUDIO_EXTS:
162
- logger.error(
163
- f"Audio format '{ext}' is not supported. Only the following formats are supported: {SUPPORTED_AUDIO_EXTS}"
164
- )
165
- return
166
-
167
- if language is None:
168
- gsv_model = model_manager.get(character_name)
169
- if gsv_model:
170
- language = gsv_model.LANGUAGE
171
- else:
172
- raise ValueError('No language specified')
173
- language = normalize_language(language)
174
- if language not in ['Japanese', 'English', 'Chinese']:
175
- raise ValueError('Unknown language')
176
-
177
- _reference_audios[character_name] = {
178
- 'audio_path': audio_path,
179
- 'audio_text': audio_text,
180
- 'language': language,
181
- }
182
- # print(_reference_audios[character_name])
183
- context.current_prompt_audio = ReferenceAudio(
184
- prompt_wav=audio_path,
185
- prompt_text=audio_text,
186
- language=language,
187
- )
188
-
189
-
190
- async def tts_async(
191
- character_name: str,
192
- text: str,
193
- play: bool = False,
194
- split_sentence: bool = False,
195
- save_path: Union[str, PathLike, None] = None,
196
- ) -> AsyncIterator[bytes]:
197
- """
198
- Asynchronously generates speech from text and yields audio chunks.
199
-
200
- This function returns an async iterator that provides the audio data in
201
- real-time as it's being generated.
202
-
203
- Args:
204
- character_name (str): The name of the character to use for synthesis.
205
- text (str): The text to be synthesized into speech.
206
- play (bool, optional): If True, plays the audio as it's generated. Defaults to False.
207
- split_sentence (bool, optional): If True, splits the text into sentences for synthesis. Defaults to False.
208
- save_path (str | PathLike | None, optional): If provided, saves the generated audio to this file path. Defaults to None.
209
-
210
- Yields:
211
- bytes: A chunk of the generated audio data.
212
-
213
- Raises:
214
- ValueError: If 'set_reference_audio' has not been called for the character.
215
- """
216
- if character_name not in _reference_audios:
217
- raise ValueError("Please call 'set_reference_audio' first to set the reference audio.")
218
-
219
- if save_path:
220
- save_path = os.fspath(save_path)
221
- parent_dir = os.path.dirname(save_path)
222
- if parent_dir:
223
- os.makedirs(parent_dir, exist_ok=True)
224
-
225
- # 1. 创建 asyncio 队列和获取当前事件循环
226
- stream_queue: asyncio.Queue[Union[bytes, None]] = asyncio.Queue()
227
- loop = asyncio.get_running_loop()
228
-
229
- # 2. 定义回调函数,用于在线程和 asyncio 之间安全地传递数据
230
- def tts_chunk_callback(c: Optional[bytes]):
231
- """This callback is called from the TTS worker thread."""
232
- loop.call_soon_threadsafe(stream_queue.put_nowait, c)
233
-
234
- # 设置 TTS 上下文
235
- context.current_speaker = character_name
236
- context.current_prompt_audio = ReferenceAudio(
237
- prompt_wav=_reference_audios[character_name]['audio_path'],
238
- prompt_text=_reference_audios[character_name]['audio_text'],
239
- language=_reference_audios[character_name]['language'],
240
- )
241
-
242
- # 3. 使用新的回调接口启动 TTS 会话
243
- tts_player.start_session(
244
- play=play,
245
- split=split_sentence,
246
- save_path=save_path,
247
- chunk_callback=tts_chunk_callback,
248
- )
249
-
250
- # 馈送文本并通知会话结束
251
- tts_player.feed(text)
252
- tts_player.end_session()
253
-
254
- # 4. 从队列中异步读取数据产生
255
- while True:
256
- chunk = await stream_queue.get()
257
- if chunk is None:
258
- break
259
- yield chunk
260
-
261
-
262
- def tts(
263
- character_name: str,
264
- text: str,
265
- play: bool = False,
266
- split_sentence: bool = True,
267
- save_path: Union[str, PathLike, None] = None,
268
- ) -> None:
269
- """
270
- Synchronously generates speech from text.
271
-
272
- This is a blocking function that will not return until the entire TTS
273
- process is complete.
274
-
275
- Args:
276
- character_name (str): The name of the character to use for synthesis.
277
- text (str): The text to be synthesized into speech.
278
- play (bool, optional): If True, plays the audio.
279
- split_sentence (bool, optional): If True, splits the text into sentences for synthesis.
280
- save_path (str | PathLike | None, optional): If provided, saves the generated audio to this file path. Defaults to None.
281
- """
282
- if character_name not in _reference_audios:
283
- logger.error("Please call 'set_reference_audio' first to set the reference audio.")
284
- return
285
-
286
- if save_path:
287
- save_path = os.fspath(save_path)
288
- parent_dir = os.path.dirname(save_path)
289
- if parent_dir:
290
- os.makedirs(parent_dir, exist_ok=True)
291
-
292
- context.current_speaker = character_name
293
- context.current_prompt_audio = ReferenceAudio(
294
- prompt_wav=_reference_audios[character_name]['audio_path'],
295
- prompt_text=_reference_audios[character_name]['audio_text'],
296
- language=_reference_audios[character_name]['language'],
297
- )
298
-
299
- tts_player.start_session(
300
- play=play,
301
- split=split_sentence,
302
- save_path=save_path,
303
- )
304
- tts_player.feed(text)
305
- tts_player.end_session()
306
- tts_player.wait_for_tts_completion()
307
-
308
-
309
- def wait_for_playback_done() -> None:
310
- """
311
- Wait until all TTS tasks have finished processing and playback has fully completed.
312
- """
313
- tts_player.wait_for_playback_done()
314
-
315
-
316
- def stop() -> None:
317
- """
318
- Stops the currently playing text-to-speech audio.
319
- """
320
- tts_player.stop()
321
-
322
-
323
- def convert_to_onnx(
324
- torch_ckpt_path: Union[str, PathLike],
325
- torch_pth_path: Union[str, PathLike],
326
- output_dir: Union[str, PathLike],
327
- ) -> None:
328
- """
329
- Converts PyTorch model checkpoints to the ONNX format.
330
-
331
- This function requires PyTorch to be installed.
332
-
333
- Args:
334
- torch_ckpt_path (str | PathLike): The path to the T2S model (.ckpt) file.
335
- torch_pth_path (str | PathLike): The path to the VITS model (.pth) file.
336
- output_dir (str | PathLike): The directory where the ONNX models will be saved.
337
- """
338
- try:
339
- import torch
340
- except ImportError:
341
- logger.error("❌ PyTorch is not installed. Please run `pip install torch` first.")
342
- return
343
-
344
- from .Converter.Converter import convert
345
-
346
- torch_ckpt_path = os.fspath(torch_ckpt_path)
347
- torch_pth_path = os.fspath(torch_pth_path)
348
- output_dir = os.fspath(output_dir)
349
-
350
- convert(
351
- torch_pth_path=torch_pth_path,
352
- torch_ckpt_path=torch_ckpt_path,
353
- output_dir=output_dir,
354
- )
355
-
356
-
357
- def clear_reference_audio_cache() -> None:
358
- """
359
- Clears the cache of reference audio data.
360
- """
361
- ReferenceAudio.clear_cache()
362
-
363
-
364
- def load_predefined_character(character_name: str) -> None:
365
- """
366
- Download and load a predefined character model for TTS inference.
367
- """
368
- character_name = character_name.lower().strip()
369
- if character_name not in CHARA_ALIAS_MAP:
370
- logger.error(f"No predefined character model found for {character_name}")
371
- return
372
- character_name = CHARA_ALIAS_MAP[character_name]
373
-
374
- save_path = download_chara(character_name)
375
- model_manager.load_character(
376
- character_name=character_name,
377
- model_dir=os.path.join(save_path, 'tts_models'),
378
- language=CHARA_LANG[character_name],
379
- )
380
-
381
- with open(os.path.join(save_path, "prompt_wav.json"), "r", encoding="utf-8") as f:
382
- prompt_wav_dict: Dict[str, Dict[str, str]] = json.load(f)
383
-
384
- audio_text = prompt_wav_dict["Normal"]["text"]
385
- audio_path = os.path.join(save_path, "prompt_wav", prompt_wav_dict["Normal"]["wav"])
386
- _reference_audios[character_name] = {
387
- 'audio_path': audio_path,
388
- 'audio_text': audio_text,
389
- 'language': CHARA_LANG[character_name],
390
- }
391
- context.current_prompt_audio = ReferenceAudio(
392
- prompt_wav=audio_path,
393
- prompt_text=audio_text,
394
- language=CHARA_LANG[character_name],
395
- )
 
 
 
 
 
 
 
 
 
1
+ # 请严格遵循导入顺序。
2
+ # 1、环境变量。
3
+ import os
4
+ from os import PathLike
5
+
6
+ os.environ["HF_HUB_ENABLE_PROGRESS_BAR"] = "1"
7
+
8
+ # 2、Logging & Warnings。
9
+ import logging
10
+ import warnings
11
+
12
+ warnings.filterwarnings("ignore", category=UserWarning, module="jieba_fast._compat")
13
+ logging.basicConfig(level=logging.INFO, format="%(message)s", datefmt="[%X]")
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # 3、ONNX。
17
+ import onnxruntime
18
+
19
+ onnxruntime.set_default_logger_severity(3)
20
+
21
+ # 导入剩余库。
22
+
23
+ from pathlib import Path
24
+ import json
25
+ import asyncio
26
+ from typing import AsyncIterator, Optional, Union, Dict
27
+
28
+ from .Audio.ReferenceAudio import ReferenceAudio
29
+ from .Core.Resources import ensure_exists, Chinese_G2P_DIR, English_G2P_DIR
30
+ from .Core.TTSPlayer import tts_player
31
+ from .ModelManager import model_manager
32
+ from .Utils.Shared import context
33
+ from .Utils.Language import normalize_language
34
+ from .PredefinedCharacter import download_chara, CHARA_LANG, CHARA_ALIAS_MAP
35
+
36
+ # A module-level private dictionary to store reference audio configurations.
37
+ _reference_audios: Dict[str, dict] = {}
38
+ SUPPORTED_AUDIO_EXTS = {'.wav', '.flac', '.ogg', '.aiff', '.aif'}
39
+
40
+
41
+ def check_onnx_model_dir(onnx_model_dir: Union[str, os.PathLike]) -> None:
42
+ """
43
+ Checks if the directory contains the necessary ONNX model files for Genie TTS (v2 or v2ProPlus).
44
+ Raises a FileNotFoundError with detailed instructions if validation fails.
45
+ """
46
+ model_path = Path(onnx_model_dir)
47
+
48
+ # 1. Check if directory exists
49
+ if not model_path.exists() or not model_path.is_dir():
50
+ raise FileNotFoundError(f"The model directory '{onnx_model_dir}' does not exist or is not a directory.")
51
+
52
+ # 2. Define required files
53
+ # Base files required by both v2 and v2ProPlus
54
+ required_base_files = {
55
+ "t2s_encoder_fp32.bin",
56
+ "t2s_encoder_fp32.onnx",
57
+ "t2s_first_stage_decoder_fp32.onnx",
58
+ "t2s_shared_fp16.bin",
59
+ "t2s_stage_decoder_fp32.onnx",
60
+ "vits_fp16.bin",
61
+ "vits_fp32.onnx"
62
+ }
63
+
64
+ # 3. Get current files in directory
65
+ existing_files = set(f.name for f in model_path.iterdir() if f.is_file())
66
+
67
+ # 4. Validate
68
+ # We check if the base files exist. If base files are missing, the model is definitely unusable.
69
+ if not required_base_files.issubset(existing_files):
70
+ missing = required_base_files - existing_files
71
+
72
+ # Construct detailed error message
73
+ error_msg = (
74
+ f"\n\n[Genie Error] Invalid ONNX model directory: '{model_path}'\n"
75
+ "===============================================================\n"
76
+ f"Missing base files: {', '.join(missing)}\n"
77
+ "A valid model folder must contain at least the following files.\n"
78
+ "1. [v2 Base] (Required for all models):\n"
79
+ " - t2s_encoder_fp32.bin\n"
80
+ " - t2s_encoder_fp32.onnx\n"
81
+ " - t2s_first_stage_decoder_fp32.onnx\n"
82
+ " - t2s_shared_fp16.bin\n"
83
+ " - t2s_stage_decoder_fp32.onnx\n"
84
+ " - vits_fp16.bin\n"
85
+ " - vits_fp32.onnx\n"
86
+ "2. [v2ProPlus Additions] (Required for v2pp features):\n"
87
+ " - prompt_encoder_fp16.bin\n"
88
+ " - prompt_encoder_fp32.onnx\n"
89
+ "===============================================================\n"
90
+ )
91
+ raise FileNotFoundError(error_msg)
92
+
93
+
94
+ def load_character(
95
+ character_name: str,
96
+ onnx_model_dir: Union[str, PathLike],
97
+ language: str,
98
+ ) -> None:
99
+ """
100
+ Loads a character model from an ONNX model directory.
101
+
102
+ Args:
103
+ character_name (str): The name to assign to the loaded character.
104
+ onnx_model_dir (str | PathLike): The directory path containing the ONNX model files.
105
+ language (str): The language of the character model.
106
+ """
107
+ check_onnx_model_dir(onnx_model_dir)
108
+
109
+ language = normalize_language(language)
110
+ if language not in ['Japanese', 'English', 'Chinese']:
111
+ raise ValueError('Unknown language')
112
+
113
+ if language == 'Chinese':
114
+ ensure_exists(Chinese_G2P_DIR, "Chinese_G2P_DIR")
115
+ elif language == 'English':
116
+ ensure_exists(English_G2P_DIR, "English_G2P_DIR")
117
+
118
+ model_path: str = os.fspath(onnx_model_dir)
119
+ model_manager.load_character(
120
+ character_name=character_name,
121
+ model_dir=model_path,
122
+ language=language,
123
+ )
124
+
125
+
126
+ def unload_character(
127
+ character_name: str,
128
+ ) -> None:
129
+ """
130
+ Unloads a previously loaded character model to free up resources.
131
+
132
+ Args:
133
+ character_name (str): The name of the character to unload.
134
+ """
135
+ model_manager.remove_character(
136
+ character_name=character_name,
137
+ )
138
+
139
+
140
+ def set_reference_audio(
141
+ character_name: str,
142
+ audio_path: Union[str, PathLike],
143
+ audio_text: str,
144
+ language: str = None,
145
+ ) -> None:
146
+ """
147
+ Sets the reference audio for a character to be used for voice cloning.
148
+
149
+ This must be called for a character before using 'tts' or 'tts_async'.
150
+
151
+ Args:
152
+ character_name (str): The name of the character.
153
+ audio_path (str | PathLike): The file path to the reference audio (e.g., a WAV file).
154
+ audio_text (str): The transcript of the reference audio.
155
+ language (str): The language of the reference audio.
156
+ """
157
+ audio_path: str = os.fspath(audio_path)
158
+
159
+ # 检查文件后缀是否支持
160
+ ext = os.path.splitext(audio_path)[1].lower()
161
+ if ext not in SUPPORTED_AUDIO_EXTS:
162
+ logger.error(
163
+ f"Audio format '{ext}' is not supported. Only the following formats are supported: {SUPPORTED_AUDIO_EXTS}"
164
+ )
165
+ return
166
+
167
+ if language is None:
168
+ gsv_model = model_manager.get(character_name)
169
+ if gsv_model:
170
+ language = gsv_model.LANGUAGE
171
+ else:
172
+ raise ValueError('No language specified')
173
+ language = normalize_language(language)
174
+ if language not in ['Japanese', 'English', 'Chinese']:
175
+ raise ValueError('Unknown language')
176
+
177
+ _reference_audios[character_name] = {
178
+ 'audio_path': audio_path,
179
+ 'audio_text': audio_text,
180
+ 'language': language,
181
+ }
182
+ # print(_reference_audios[character_name])
183
+ context.current_prompt_audio = ReferenceAudio(
184
+ prompt_wav=audio_path,
185
+ prompt_text=audio_text,
186
+ language=language,
187
+ )
188
+
189
+
190
+ async def tts_async(
191
+ character_name: str,
192
+ text: str,
193
+ play: bool = False,
194
+ split_sentence: bool = False,
195
+ save_path: Union[str, PathLike, None] = None,
196
+ text_language: str = None, # 新增:目标文本语言,用于跨语言TTS
197
+ ) -> AsyncIterator[bytes]:
198
+ """
199
+ Asynchronously generates speech from text and yields audio chunks.
200
+
201
+ This function returns an async iterator that provides the audio data in
202
+ real-time as it's being generated.
203
+
204
+ Args:
205
+ character_name (str): The name of the character to use for synthesis.
206
+ text (str): The text to be synthesized into speech.
207
+ play (bool, optional): If True, plays the audio as it's generated. Defaults to False.
208
+ split_sentence (bool, optional): If True, splits the text into sentences for synthesis. Defaults to False.
209
+ save_path (str | PathLike | None, optional): If provided, saves the generated audio to this file path. Defaults to None.
210
+ text_language (str, optional): Language of the target text. If None, uses the reference audio language.
211
+
212
+ Yields:
213
+ bytes: A chunk of the generated audio data.
214
+
215
+ Raises:
216
+ ValueError: If 'set_reference_audio' has not been called for the character.
217
+ """
218
+ if character_name not in _reference_audios:
219
+ raise ValueError("Please call 'set_reference_audio' first to set the reference audio.")
220
+
221
+ if save_path:
222
+ save_path = os.fspath(save_path)
223
+ parent_dir = os.path.dirname(save_path)
224
+ if parent_dir:
225
+ os.makedirs(parent_dir, exist_ok=True)
226
+
227
+ # 1. 创建 asyncio 队列和获取当前事件循环
228
+ stream_queue: asyncio.Queue[Union[bytes, None]] = asyncio.Queue()
229
+ loop = asyncio.get_running_loop()
230
+
231
+ # 2. 定义回调函数,用于在线程和 asyncio 之间安全地传递数据
232
+ def tts_chunk_callback(c: Optional[bytes]):
233
+ """This callback is called from the TTS worker thread."""
234
+ loop.call_soon_threadsafe(stream_queue.put_nowait, c)
235
+
236
+ # 设置 TTS 上下文
237
+ context.current_speaker = character_name
238
+ context.current_prompt_audio = ReferenceAudio(
239
+ prompt_wav=_reference_audios[character_name]['audio_path'],
240
+ prompt_text=_reference_audios[character_name]['audio_text'],
241
+ language=_reference_audios[character_name]['language'],
242
+ )
243
+ # 设置目标文本语言(跨语言TTS)
244
+ context.current_text_language = normalize_language(text_language) if text_language else None
245
+
246
+ # 3. 使用新的回调接口启动 TTS 会话
247
+ tts_player.start_session(
248
+ play=play,
249
+ split=split_sentence,
250
+ save_path=save_path,
251
+ chunk_callback=tts_chunk_callback,
252
+ )
253
+
254
+ # 馈送文本通知会话结束
255
+ tts_player.feed(text)
256
+ tts_player.end_session()
257
+
258
+ # 4. 从队列中异步读取数据并产生
259
+ while True:
260
+ chunk = await stream_queue.get()
261
+ if chunk is None:
262
+ break
263
+ yield chunk
264
+
265
+
266
+ def tts(
267
+ character_name: str,
268
+ text: str,
269
+ play: bool = False,
270
+ split_sentence: bool = True,
271
+ save_path: Union[str, PathLike, None] = None,
272
+ text_language: str = None, # 新增:目标文本语言,用于跨语言TTS
273
+ ) -> None:
274
+ """
275
+ Synchronously generates speech from text.
276
+
277
+ This is a blocking function that will not return until the entire TTS
278
+ process is complete.
279
+
280
+ Args:
281
+ character_name (str): The name of the character to use for synthesis.
282
+ text (str): The text to be synthesized into speech.
283
+ play (bool, optional): If True, plays the audio.
284
+ split_sentence (bool, optional): If True, splits the text into sentences for synthesis.
285
+ save_path (str | PathLike | None, optional): If provided, saves the generated audio to this file path. Defaults to None.
286
+ text_language (str, optional): Language of the target text. If None, uses the reference audio language.
287
+ """
288
+ if character_name not in _reference_audios:
289
+ logger.error("Please call 'set_reference_audio' first to set the reference audio.")
290
+ return
291
+
292
+ if save_path:
293
+ save_path = os.fspath(save_path)
294
+ parent_dir = os.path.dirname(save_path)
295
+ if parent_dir:
296
+ os.makedirs(parent_dir, exist_ok=True)
297
+
298
+ context.current_speaker = character_name
299
+ context.current_prompt_audio = ReferenceAudio(
300
+ prompt_wav=_reference_audios[character_name]['audio_path'],
301
+ prompt_text=_reference_audios[character_name]['audio_text'],
302
+ language=_reference_audios[character_name]['language'],
303
+ )
304
+ # 设置目标文本语言(跨语言TTS)
305
+ context.current_text_language = normalize_language(text_language) if text_language else None
306
+
307
+ tts_player.start_session(
308
+ play=play,
309
+ split=split_sentence,
310
+ save_path=save_path,
311
+ )
312
+ tts_player.feed(text)
313
+ tts_player.end_session()
314
+ tts_player.wait_for_tts_completion()
315
+
316
+
317
+ def wait_for_playback_done() -> None:
318
+ """
319
+ Wait until all TTS tasks have finished processing and playback has fully completed.
320
+ """
321
+ tts_player.wait_for_playback_done()
322
+
323
+
324
+ def stop() -> None:
325
+ """
326
+ Stops the currently playing text-to-speech audio.
327
+ """
328
+ tts_player.stop()
329
+
330
+
331
+ def convert_to_onnx(
332
+ torch_ckpt_path: Union[str, PathLike],
333
+ torch_pth_path: Union[str, PathLike],
334
+ output_dir: Union[str, PathLike],
335
+ ) -> None:
336
+ """
337
+ Converts PyTorch model checkpoints to the ONNX format.
338
+
339
+ This function requires PyTorch to be installed.
340
+
341
+ Args:
342
+ torch_ckpt_path (str | PathLike): The path to the T2S model (.ckpt) file.
343
+ torch_pth_path (str | PathLike): The path to the VITS model (.pth) file.
344
+ output_dir (str | PathLike): The directory where the ONNX models will be saved.
345
+ """
346
+ try:
347
+ import torch
348
+ except ImportError:
349
+ logger.error("❌ PyTorch is not installed. Please run `pip install torch` first.")
350
+ return
351
+
352
+ from .Converter.Converter import convert
353
+
354
+ torch_ckpt_path = os.fspath(torch_ckpt_path)
355
+ torch_pth_path = os.fspath(torch_pth_path)
356
+ output_dir = os.fspath(output_dir)
357
+
358
+ convert(
359
+ torch_pth_path=torch_pth_path,
360
+ torch_ckpt_path=torch_ckpt_path,
361
+ output_dir=output_dir,
362
+ )
363
+
364
+
365
+ def clear_reference_audio_cache() -> None:
366
+ """
367
+ Clears the cache of reference audio data.
368
+ """
369
+ ReferenceAudio.clear_cache()
370
+
371
+
372
+ def load_predefined_character(character_name: str) -> None:
373
+ """
374
+ Download and load a predefined character model for TTS inference.
375
+ """
376
+ character_name = character_name.lower().strip()
377
+ if character_name not in CHARA_ALIAS_MAP:
378
+ logger.error(f"No predefined character model found for {character_name}")
379
+ return
380
+ character_name = CHARA_ALIAS_MAP[character_name]
381
+
382
+ save_path = download_chara(character_name)
383
+ model_manager.load_character(
384
+ character_name=character_name,
385
+ model_dir=os.path.join(save_path, 'tts_models'),
386
+ language=CHARA_LANG[character_name],
387
+ )
388
+
389
+ with open(os.path.join(save_path, "prompt_wav.json"), "r", encoding="utf-8") as f:
390
+ prompt_wav_dict: Dict[str, Dict[str, str]] = json.load(f)
391
+
392
+ audio_text = prompt_wav_dict["Normal"]["text"]
393
+ audio_path = os.path.join(save_path, "prompt_wav", prompt_wav_dict["Normal"]["wav"])
394
+ _reference_audios[character_name] = {
395
+ 'audio_path': audio_path,
396
+ 'audio_text': audio_text,
397
+ 'language': CHARA_LANG[character_name],
398
+ }
399
+ context.current_prompt_audio = ReferenceAudio(
400
+ prompt_wav=audio_path,
401
+ prompt_text=audio_text,
402
+ language=CHARA_LANG[character_name],
403
+ )
genie_tts/Utils/Shared.py CHANGED
@@ -1,13 +1,14 @@
1
- from typing import TYPE_CHECKING, Optional
2
-
3
- if TYPE_CHECKING:
4
- from ..Audio.ReferenceAudio import ReferenceAudio
5
-
6
-
7
- class Context:
8
- def __init__(self):
9
- self.current_speaker: str = ''
10
- self.current_prompt_audio: Optional['ReferenceAudio'] = None
11
-
12
-
13
- context: Context = Context()
 
 
1
+ from typing import TYPE_CHECKING, Optional
2
+
3
+ if TYPE_CHECKING:
4
+ from ..Audio.ReferenceAudio import ReferenceAudio
5
+
6
+
7
+ class Context:
8
+ def __init__(self):
9
+ self.current_speaker: str = ''
10
+ self.current_prompt_audio: Optional['ReferenceAudio'] = None
11
+ self.current_text_language: Optional[str] = None # 新增:目标文本语言(跨语言TTS)
12
+
13
+
14
+ context: Context = Context()