from ..Utils.Utils import LRUCacheDict from ..GetPhonesAndBert import get_phones_and_bert from ..Audio.Audio import load_audio from ..ModelManager import model_manager from onnxruntime import InferenceSession import os import numpy as np import soxr from typing import Optional, Dict class ReferenceAudio: _prompt_cache: Dict[str, 'ReferenceAudio'] = LRUCacheDict( capacity=int(os.getenv('Max_Cached_Reference_Audio', '10'))) def __new__(cls, prompt_wav: str, prompt_text: str, language: str): if prompt_wav in cls._prompt_cache: instance = cls._prompt_cache[prompt_wav] if instance.text != prompt_text: # 如果文本与缓存内记录的不同,则更新。 instance.set_text(prompt_text, language=language) return instance instance = super().__new__(cls) cls._prompt_cache[prompt_wav] = instance return instance def __init__(self, prompt_wav: str, prompt_text: str, language: str): if hasattr(self, '_initialized'): return # 文本相关。 self.text: str = prompt_text self.phonemes_seq: Optional[np.ndarray] = None self.text_bert: Optional[np.ndarray] = None self.set_text(prompt_text, language=language) # 音频相关。 self.audio_32k: Optional[np.ndarray] = load_audio( audio_path=prompt_wav, target_sampling_rate=32000 ) self.audio_16k: np.ndarray = soxr.resample(self.audio_32k, 32000, 16000, quality='hq') # 修复:增加静音填充到 0.5 秒,帮助跨语言 TTS 更好区分参考音频与目标内容边界 zero_padding_16k = np.zeros(int(16000 * 0.5), dtype=self.audio_16k.dtype) audio_16k_padded = np.concatenate([self.audio_16k, zero_padding_16k]) self.audio_32k = np.expand_dims(self.audio_32k, axis=0) self.audio_16k = np.expand_dims(self.audio_16k, axis=0) # 增加 Batch_Size 维度 if not model_manager.cn_hubert: model_manager.load_cn_hubert() # 使用添加了静音填充的音频提取 SSL 特征 self.ssl_content: Optional[np.ndarray] = model_manager.cn_hubert.run( None, {'input_values': np.expand_dims(audio_16k_padded, axis=0)} )[0] self.global_emb: Optional[np.ndarray] = None self.global_emb_advanced: Optional[np.ndarray] = None self._initialized = True def set_text(self, prompt_text: str, language: str) -> None: self.text = prompt_text self.phonemes_seq, self.text_bert = get_phones_and_bert(prompt_text, language=language) @classmethod def clear_cache(cls) -> None: """清空 ReferenceAudio 的缓存""" cls._prompt_cache.clear() def update_global_emb(self, prompt_encoder: InferenceSession) -> None: if self.global_emb is not None: return if model_manager.load_sv_model(): sv_emb = model_manager.speaker_verification_model.run(None, {'waveform': self.audio_16k})[0] self.global_emb, self.global_emb_advanced = prompt_encoder.run(None, { 'ref_audio': self.audio_32k, 'sv_emb': sv_emb, })