Genie-TTS-testing / genie_tts /Audio /ReferenceAudio.py
antigravity
fix: add sentence interval padding and increase ref audio padding for cross-lang TTS
05123cf
from ..Utils.Utils import LRUCacheDict
from ..GetPhonesAndBert import get_phones_and_bert
from ..Audio.Audio import load_audio
from ..ModelManager import model_manager
from onnxruntime import InferenceSession
import os
import numpy as np
import soxr
from typing import Optional, Dict
class ReferenceAudio:
_prompt_cache: Dict[str, 'ReferenceAudio'] = LRUCacheDict(
capacity=int(os.getenv('Max_Cached_Reference_Audio', '10')))
def __new__(cls, prompt_wav: str, prompt_text: str, language: str):
if prompt_wav in cls._prompt_cache:
instance = cls._prompt_cache[prompt_wav]
if instance.text != prompt_text: # 如果文本与缓存内记录的不同,则更新。
instance.set_text(prompt_text, language=language)
return instance
instance = super().__new__(cls)
cls._prompt_cache[prompt_wav] = instance
return instance
def __init__(self, prompt_wav: str, prompt_text: str, language: str):
if hasattr(self, '_initialized'):
return
# 文本相关。
self.text: str = prompt_text
self.phonemes_seq: Optional[np.ndarray] = None
self.text_bert: Optional[np.ndarray] = None
self.set_text(prompt_text, language=language)
# 音频相关。
self.audio_32k: Optional[np.ndarray] = load_audio(
audio_path=prompt_wav,
target_sampling_rate=32000
)
self.audio_16k: np.ndarray = soxr.resample(self.audio_32k, 32000, 16000, quality='hq')
# 修复:增加静音填充到 0.5 秒,帮助跨语言 TTS 更好区分参考音频与目标内容边界
zero_padding_16k = np.zeros(int(16000 * 0.5), dtype=self.audio_16k.dtype)
audio_16k_padded = np.concatenate([self.audio_16k, zero_padding_16k])
self.audio_32k = np.expand_dims(self.audio_32k, axis=0)
self.audio_16k = np.expand_dims(self.audio_16k, axis=0) # 增加 Batch_Size 维度
if not model_manager.cn_hubert:
model_manager.load_cn_hubert()
# 使用添加了静音填充的音频提取 SSL 特征
self.ssl_content: Optional[np.ndarray] = model_manager.cn_hubert.run(
None, {'input_values': np.expand_dims(audio_16k_padded, axis=0)}
)[0]
self.global_emb: Optional[np.ndarray] = None
self.global_emb_advanced: Optional[np.ndarray] = None
self._initialized = True
def set_text(self, prompt_text: str, language: str) -> None:
self.text = prompt_text
self.phonemes_seq, self.text_bert = get_phones_and_bert(prompt_text, language=language)
@classmethod
def clear_cache(cls) -> None:
"""清空 ReferenceAudio 的缓存"""
cls._prompt_cache.clear()
def update_global_emb(self, prompt_encoder: InferenceSession) -> None:
if self.global_emb is not None:
return
if model_manager.load_sv_model():
sv_emb = model_manager.speaker_verification_model.run(None, {'waveform': self.audio_16k})[0]
self.global_emb, self.global_emb_advanced = prompt_encoder.run(None, {
'ref_audio': self.audio_32k,
'sv_emb': sv_emb,
})