Nanny7's picture
feat: add streaming TTS endpoint /tts-stream
79f89ec
import onnxruntime as ort
import numpy as np
from typing import List, Optional
import threading
from ..Audio.ReferenceAudio import ReferenceAudio
from ..GetPhonesAndBert import get_phones_and_bert
MAX_T2S_LEN = 1000
def stretch_semantic_tokens(tokens: np.ndarray, speed: float) -> np.ndarray:
"""
语义 Token 插值(最近邻),用于实现语速调节。
借鉴自 AstraTTS 的 StretchSemanticTokens 算法。
Args:
tokens: 原始 semantic tokens [1, 1, T]
speed: 语速系数,>1 加速,<1 减速
Returns:
插值后的 tokens
"""
if tokens is None or tokens.size == 0:
return tokens
if abs(speed - 1.0) < 0.01:
return tokens
# 提取原始 token 序列(去除批次维度)
original = tokens.flatten()
original_len = len(original)
# 计算新长度
new_len = int(round(original_len / speed))
if new_len < 1:
new_len = 1
# 最近邻插值
result = np.zeros(new_len, dtype=original.dtype)
for i in range(new_len):
old_idx = int(i * speed)
if old_idx >= original_len:
old_idx = original_len - 1
result[i] = original[old_idx]
# 恢复原始形状 [1, 1, new_len]
return result.reshape(1, 1, -1)
class GENIE:
def __init__(self):
self.stop_event: threading.Event = threading.Event()
def tts(
self,
text: str,
prompt_audio: ReferenceAudio,
encoder: ort.InferenceSession,
first_stage_decoder: ort.InferenceSession,
stage_decoder: ort.InferenceSession,
vocoder: ort.InferenceSession,
prompt_encoder: Optional[ort.InferenceSession],
language: str = 'japanese',
text_language: str = None,
speed: float = 1.0, # 语速调节
) -> Optional[np.ndarray]:
# 如果未指定 text_language,则使用参考音频的语言
actual_text_language = text_language if text_language else language
text = '。' + text # 防止漏第一句。
text_seq, text_bert = get_phones_and_bert(text, language=actual_text_language)
semantic_tokens: np.ndarray = self.t2s_cpu(
ref_seq=prompt_audio.phonemes_seq,
ref_bert=prompt_audio.text_bert,
text_seq=text_seq,
text_bert=text_bert,
ssl_content=prompt_audio.ssl_content,
encoder=encoder,
first_stage_decoder=first_stage_decoder,
stage_decoder=stage_decoder,
)
eos_indices = np.where(semantic_tokens >= 1024) # 剔除不合法的元素,例如 EOS Token。
if len(eos_indices[0]) > 0:
first_eos_index = eos_indices[-1][0]
semantic_tokens = semantic_tokens[..., :first_eos_index]
# 🔥 语速调节:在 vocoder 前对 semantic tokens 进行插值
semantic_tokens = stretch_semantic_tokens(semantic_tokens, speed)
if prompt_encoder is None:
return vocoder.run(None, {
"text_seq": text_seq,
"pred_semantic": semantic_tokens,
"ref_audio": prompt_audio.audio_32k
})[0]
else:
# V2ProPlus 新增。
prompt_audio.update_global_emb(prompt_encoder=prompt_encoder)
audio_chunk = vocoder.run(None, {
"text_seq": text_seq,
"pred_semantic": semantic_tokens,
"ge": prompt_audio.global_emb,
"ge_advanced": prompt_audio.global_emb_advanced,
})[0]
return audio_chunk
def t2s_cpu(
self,
ref_seq: np.ndarray,
ref_bert: np.ndarray,
text_seq: np.ndarray,
text_bert: np.ndarray,
ssl_content: np.ndarray,
encoder: ort.InferenceSession,
first_stage_decoder: ort.InferenceSession,
stage_decoder: ort.InferenceSession,
) -> Optional[np.ndarray]:
"""在CPU上运行T2S模型"""
# Encoder
x, prompts = encoder.run(
None,
{
"ref_seq": ref_seq,
"text_seq": text_seq,
"ref_bert": ref_bert,
"text_bert": text_bert,
"ssl_content": ssl_content,
},
)
# First Stage Decoder
y, y_emb, *present_key_values = first_stage_decoder.run(
None, {"x": x, "prompts": prompts}
)
# Stage Decoder
input_names: List[str] = [inp.name for inp in stage_decoder.get_inputs()]
idx: int = 0
for idx in range(0, 500):
if self.stop_event.is_set():
return None
input_feed = {
name: data
for name, data in zip(input_names, [y, y_emb, *present_key_values])
}
outputs = stage_decoder.run(None, input_feed)
y, y_emb, stop_condition_tensor, *present_key_values = outputs
if stop_condition_tensor:
break
y[0, -1] = 0
return np.expand_dims(y[:, -idx:], axis=0)
tts_client: GENIE = GENIE()