liumaolin
refactor(core): Architecturally decouple Audio, ASR, and TTS modules
60f8238
raw
history blame
5.59 kB
import importlib.util
import inspect
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Type
from voice_dialogue.utils.logger import logger
from .models.base import BaseTTSConfig
from .runtime.interface import TTSInterface
@dataclass
class TTSRegistryTables:
"""TTS注册表系统,用于管理不同的TTS实现"""
tts_classes: Dict[str, Type[TTSInterface]] = None
def __post_init__(self):
if self.tts_classes is None:
self.tts_classes = {}
def print(self, key: str = None) -> None:
"""打印已注册的TTS类"""
logger.info("\nTTS Registry Tables: \n")
headers = ["register name", "class name", "class location"]
if self.tts_classes and (key is None or "tts_classes" in key):
logger.info(f"----------- ** tts_classes ** --------------")
metas = []
for register_key, tts_class in self.tts_classes.items():
class_file = inspect.getfile(tts_class)
class_line = inspect.getsourcelines(tts_class)[1]
# 简化路径显示
pattern = r"^.+/VoiceDialogue/"
class_file = re.sub(pattern, "VoiceDialogue/", class_file)
meta_data = [
register_key,
tts_class.__name__,
f"{class_file}:{class_line}",
]
metas.append(meta_data)
metas.sort(key=lambda x: x[0])
data = [headers] + metas
col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
for row in data:
logger.info(
"| "
+ " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
+ " |"
)
logger.info("\n")
def register(self, register_table_key: str, key: str = None) -> callable:
"""装饰器,用于注册TTS类"""
def decorator(target_class):
if not hasattr(self, register_table_key):
setattr(self, register_table_key, {})
logger.debug(f"New TTS registry table added: {register_table_key}")
registry = getattr(self, register_table_key)
registry_key = key if key is not None else target_class.__name__
if registry_key in registry:
logger.debug(
f"Key {registry_key} already exists in {register_table_key}, re-register"
)
registry[registry_key] = target_class
logger.info(f"Registered TTS class: {registry_key} -> {target_class.__name__}")
return target_class
return decorator
# 全局TTS注册表实例
tts_tables = TTSRegistryTables()
class TTSManager:
"""TTS管理器,负责管理和创建TTS实例"""
def __init__(self):
self._tts_instances: Dict[str, TTSInterface] = {}
def create_tts(self, config: BaseTTSConfig) -> TTSInterface:
"""
根据配置创建TTS实例
Args:
config: TTS配置对象
Returns:
TTSInterface: TTS实例
Raises:
ValueError: 如果TTS类型未注册
"""
tts_type = config.tts_type.value
if tts_type not in tts_tables.tts_classes:
raise ValueError(f"未注册的TTS类型: {tts_type}. 可用类型: {list(tts_tables.tts_classes.keys())}")
tts_class = tts_tables.tts_classes[tts_type]
return tts_class(config)
def get_or_create_tts(self, config: BaseTTSConfig) -> TTSInterface:
"""
获取或创建TTS实例(单例模式)
Args:
config: TTS配置对象
Returns:
TTSInterface: TTS实例
"""
instance_key = f"{config.tts_type.value}:{config.character_name}"
if instance_key not in self._tts_instances:
self._tts_instances[instance_key] = self.create_tts(config)
return self._tts_instances[instance_key]
def list_registered_tts(self) -> Dict[str, Type[TTSInterface]]:
"""列出所有已注册的TTS类"""
return tts_tables.tts_classes.copy()
def is_tts_registered(self, tts_type: str) -> bool:
"""检查指定TTS类型是否已注册"""
return tts_type in tts_tables.tts_classes
def print_registry(self):
"""打印注册表信息"""
tts_tables.print()
# 全局TTS管理器实例
tts_manager = TTSManager()
def register_all_tts():
"""自动发现并注册runtime目录中的所有TTS实现"""
# 获取runtime目录路径
runtime_dir = Path(__file__).parent / "runtime"
# 扫描runtime目录中的Python文件
for py_file in runtime_dir.glob("*.py"):
if py_file.name in ["__init__.py", "interface.py"]:
continue
module_name = py_file.stem
try:
spec = importlib.util.spec_from_file_location(
module_name,
py_file
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
logger.info(f"Successfully imported TTS module: {module_name}")
except ImportError as e:
logger.warning(f"Failed to import TTS module {module_name}: {e}")
except Exception as e:
logger.error(f"Unexpected error importing TTS module {module_name}: {e}")
# 在模块导入时自动注册所有TTS
register_all_tts()