import onnxruntime as ort import numpy as np from typing import List, Optional import threading from ..Audio.ReferenceAudio import ReferenceAudio from ..GetPhonesAndBert import get_phones_and_bert MAX_T2S_LEN = 1000 def stretch_semantic_tokens(tokens: np.ndarray, speed: float) -> np.ndarray: """ 语义 Token 插值(最近邻),用于实现语速调节。 借鉴自 AstraTTS 的 StretchSemanticTokens 算法。 Args: tokens: 原始 semantic tokens [1, 1, T] speed: 语速系数,>1 加速,<1 减速 Returns: 插值后的 tokens """ if tokens is None or tokens.size == 0: return tokens if abs(speed - 1.0) < 0.01: return tokens # 提取原始 token 序列(去除批次维度) original = tokens.flatten() original_len = len(original) # 计算新长度 new_len = int(round(original_len / speed)) if new_len < 1: new_len = 1 # 最近邻插值 result = np.zeros(new_len, dtype=original.dtype) for i in range(new_len): old_idx = int(i * speed) if old_idx >= original_len: old_idx = original_len - 1 result[i] = original[old_idx] # 恢复原始形状 [1, 1, new_len] return result.reshape(1, 1, -1) class GENIE: def __init__(self): self.stop_event: threading.Event = threading.Event() def tts( self, text: str, prompt_audio: ReferenceAudio, encoder: ort.InferenceSession, first_stage_decoder: ort.InferenceSession, stage_decoder: ort.InferenceSession, vocoder: ort.InferenceSession, prompt_encoder: Optional[ort.InferenceSession], language: str = 'japanese', text_language: str = None, speed: float = 1.0, # 语速调节 ) -> Optional[np.ndarray]: # 如果未指定 text_language,则使用参考音频的语言 actual_text_language = text_language if text_language else language text = '。' + text # 防止漏第一句。 text_seq, text_bert = get_phones_and_bert(text, language=actual_text_language) semantic_tokens: np.ndarray = self.t2s_cpu( ref_seq=prompt_audio.phonemes_seq, ref_bert=prompt_audio.text_bert, text_seq=text_seq, text_bert=text_bert, ssl_content=prompt_audio.ssl_content, encoder=encoder, first_stage_decoder=first_stage_decoder, stage_decoder=stage_decoder, ) eos_indices = np.where(semantic_tokens >= 1024) # 剔除不合法的元素,例如 EOS Token。 if len(eos_indices[0]) > 0: first_eos_index = eos_indices[-1][0] semantic_tokens = semantic_tokens[..., :first_eos_index] # 🔥 语速调节:在 vocoder 前对 semantic tokens 进行插值 semantic_tokens = stretch_semantic_tokens(semantic_tokens, speed) if prompt_encoder is None: return vocoder.run(None, { "text_seq": text_seq, "pred_semantic": semantic_tokens, "ref_audio": prompt_audio.audio_32k })[0] else: # V2ProPlus 新增。 prompt_audio.update_global_emb(prompt_encoder=prompt_encoder) audio_chunk = vocoder.run(None, { "text_seq": text_seq, "pred_semantic": semantic_tokens, "ge": prompt_audio.global_emb, "ge_advanced": prompt_audio.global_emb_advanced, })[0] return audio_chunk def t2s_cpu( self, ref_seq: np.ndarray, ref_bert: np.ndarray, text_seq: np.ndarray, text_bert: np.ndarray, ssl_content: np.ndarray, encoder: ort.InferenceSession, first_stage_decoder: ort.InferenceSession, stage_decoder: ort.InferenceSession, ) -> Optional[np.ndarray]: """在CPU上运行T2S模型,带重试机制防止 EOS 过早终止""" # 动态阈值:最小期望 tokens 数量(参考 AstraTTS) min_expected_tokens = max(8, text_seq.shape[-1] * 2) max_retries = 5 # Encoder 只需运行一次 x, prompts = encoder.run( None, { "ref_seq": ref_seq, "text_seq": text_seq, "ref_bert": ref_bert, "text_bert": text_bert, "ssl_content": ssl_content, }, ) input_names: List[str] = [inp.name for inp in stage_decoder.get_inputs()] best_y = None best_idx = 0 for retry in range(max_retries): if self.stop_event.is_set(): return None # First Stage Decoder(每次重试都重新运行以获取新的随机采样状态) y, y_emb, *present_key_values = first_stage_decoder.run( None, {"x": x, "prompts": prompts} ) # Stage Decoder Loop idx: int = 0 for idx in range(0, 500): if self.stop_event.is_set(): return None input_feed = { name: data for name, data in zip(input_names, [y, y_emb, *present_key_values]) } outputs = stage_decoder.run(None, input_feed) y, y_emb, stop_condition_tensor, *present_key_values = outputs if stop_condition_tensor: break # 保存最佳结果(tokens 数量最多的) if idx > best_idx: best_idx = idx best_y = y.copy() # 验证生成数量是否达到预期 if idx >= min_expected_tokens: break # 成功,退出重试循环 # 否则继续重试 # 使用最佳结果 if best_y is None: best_y = y best_idx = idx best_y[0, -1] = 0 return np.expand_dims(best_y[:, -best_idx:], axis=0) tts_client: GENIE = GENIE()