Spaces:
Running
Running
| import onnxruntime as ort | |
| import numpy as np | |
| from typing import List, Optional | |
| import threading | |
| from ..Audio.ReferenceAudio import ReferenceAudio | |
| from ..GetPhonesAndBert import get_phones_and_bert | |
| MAX_T2S_LEN = 1000 | |
| def stretch_semantic_tokens(tokens: np.ndarray, speed: float) -> np.ndarray: | |
| """ | |
| 语义 Token 插值(最近邻),用于实现语速调节。 | |
| 借鉴自 AstraTTS 的 StretchSemanticTokens 算法。 | |
| Args: | |
| tokens: 原始 semantic tokens [1, 1, T] | |
| speed: 语速系数,>1 加速,<1 减速 | |
| Returns: | |
| 插值后的 tokens | |
| """ | |
| if tokens is None or tokens.size == 0: | |
| return tokens | |
| if abs(speed - 1.0) < 0.01: | |
| return tokens | |
| # 提取原始 token 序列(去除批次维度) | |
| original = tokens.flatten() | |
| original_len = len(original) | |
| # 计算新长度 | |
| new_len = int(round(original_len / speed)) | |
| if new_len < 1: | |
| new_len = 1 | |
| # 最近邻插值 | |
| result = np.zeros(new_len, dtype=original.dtype) | |
| for i in range(new_len): | |
| old_idx = int(i * speed) | |
| if old_idx >= original_len: | |
| old_idx = original_len - 1 | |
| result[i] = original[old_idx] | |
| # 恢复原始形状 [1, 1, new_len] | |
| return result.reshape(1, 1, -1) | |
| class GENIE: | |
| def __init__(self): | |
| self.stop_event: threading.Event = threading.Event() | |
| def tts( | |
| self, | |
| text: str, | |
| prompt_audio: ReferenceAudio, | |
| encoder: ort.InferenceSession, | |
| first_stage_decoder: ort.InferenceSession, | |
| stage_decoder: ort.InferenceSession, | |
| vocoder: ort.InferenceSession, | |
| prompt_encoder: Optional[ort.InferenceSession], | |
| language: str = 'japanese', | |
| text_language: str = None, | |
| speed: float = 1.0, # 语速调节 | |
| ) -> Optional[np.ndarray]: | |
| # 如果未指定 text_language,则使用参考音频的语言 | |
| actual_text_language = text_language if text_language else language | |
| text = '。' + text # 防止漏第一句。 | |
| text_seq, text_bert = get_phones_and_bert(text, language=actual_text_language) | |
| semantic_tokens: np.ndarray = self.t2s_cpu( | |
| ref_seq=prompt_audio.phonemes_seq, | |
| ref_bert=prompt_audio.text_bert, | |
| text_seq=text_seq, | |
| text_bert=text_bert, | |
| ssl_content=prompt_audio.ssl_content, | |
| encoder=encoder, | |
| first_stage_decoder=first_stage_decoder, | |
| stage_decoder=stage_decoder, | |
| ) | |
| eos_indices = np.where(semantic_tokens >= 1024) # 剔除不合法的元素,例如 EOS Token。 | |
| if len(eos_indices[0]) > 0: | |
| first_eos_index = eos_indices[-1][0] | |
| semantic_tokens = semantic_tokens[..., :first_eos_index] | |
| # 🔥 语速调节:在 vocoder 前对 semantic tokens 进行插值 | |
| semantic_tokens = stretch_semantic_tokens(semantic_tokens, speed) | |
| if prompt_encoder is None: | |
| return vocoder.run(None, { | |
| "text_seq": text_seq, | |
| "pred_semantic": semantic_tokens, | |
| "ref_audio": prompt_audio.audio_32k | |
| })[0] | |
| else: | |
| # V2ProPlus 新增。 | |
| prompt_audio.update_global_emb(prompt_encoder=prompt_encoder) | |
| audio_chunk = vocoder.run(None, { | |
| "text_seq": text_seq, | |
| "pred_semantic": semantic_tokens, | |
| "ge": prompt_audio.global_emb, | |
| "ge_advanced": prompt_audio.global_emb_advanced, | |
| })[0] | |
| return audio_chunk | |
| def t2s_cpu( | |
| self, | |
| ref_seq: np.ndarray, | |
| ref_bert: np.ndarray, | |
| text_seq: np.ndarray, | |
| text_bert: np.ndarray, | |
| ssl_content: np.ndarray, | |
| encoder: ort.InferenceSession, | |
| first_stage_decoder: ort.InferenceSession, | |
| stage_decoder: ort.InferenceSession, | |
| ) -> Optional[np.ndarray]: | |
| """在CPU上运行T2S模型,带重试机制防止 EOS 过早终止""" | |
| # 动态阈值:最小期望 tokens 数量(参考 AstraTTS) | |
| min_expected_tokens = max(8, text_seq.shape[-1] * 2) | |
| max_retries = 5 | |
| # Encoder 只需运行一次 | |
| x, prompts = encoder.run( | |
| None, | |
| { | |
| "ref_seq": ref_seq, | |
| "text_seq": text_seq, | |
| "ref_bert": ref_bert, | |
| "text_bert": text_bert, | |
| "ssl_content": ssl_content, | |
| }, | |
| ) | |
| input_names: List[str] = [inp.name for inp in stage_decoder.get_inputs()] | |
| best_y = None | |
| best_idx = 0 | |
| for retry in range(max_retries): | |
| if self.stop_event.is_set(): | |
| return None | |
| # First Stage Decoder(每次重试都重新运行以获取新的随机采样状态) | |
| y, y_emb, *present_key_values = first_stage_decoder.run( | |
| None, {"x": x, "prompts": prompts} | |
| ) | |
| # Stage Decoder Loop | |
| idx: int = 0 | |
| for idx in range(0, 500): | |
| if self.stop_event.is_set(): | |
| return None | |
| input_feed = { | |
| name: data | |
| for name, data in zip(input_names, [y, y_emb, *present_key_values]) | |
| } | |
| outputs = stage_decoder.run(None, input_feed) | |
| y, y_emb, stop_condition_tensor, *present_key_values = outputs | |
| if stop_condition_tensor: | |
| break | |
| # 保存最佳结果(tokens 数量最多的) | |
| if idx > best_idx: | |
| best_idx = idx | |
| best_y = y.copy() | |
| # 验证生成数量是否达到预期 | |
| if idx >= min_expected_tokens: | |
| break # 成功,退出重试循环 | |
| # 否则继续重试 | |
| # 使用最佳结果 | |
| if best_y is None: | |
| best_y = y | |
| best_idx = idx | |
| best_y[0, -1] = 0 | |
| return np.expand_dims(best_y[:, -best_idx:], axis=0) | |
| tts_client: GENIE = GENIE() | |