liumaolin
refactor(core): Architecturally decouple Audio, ASR, and TTS modules
60f8238
import importlib.util
import inspect
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Type
from voice_dialogue.utils.logger import logger
from .models.base import BaseTTSConfig
from .runtime.interface import TTSInterface
@dataclass
class TTSRegistryTables:
"""TTS注册表系统,用于管理不同的TTS实现"""
tts_classes: Dict[str, Type[TTSInterface]] = None
def __post_init__(self):
if self.tts_classes is None:
self.tts_classes = {}
def print(self, key: str = None) -> None:
"""打印已注册的TTS类"""
logger.info("\nTTS Registry Tables: \n")
headers = ["register name", "class name", "class location"]
if self.tts_classes and (key is None or "tts_classes" in key):
logger.info(f"----------- ** tts_classes ** --------------")
metas = []
for register_key, tts_class in self.tts_classes.items():
class_file = inspect.getfile(tts_class)
class_line = inspect.getsourcelines(tts_class)[1]
# 简化路径显示
pattern = r"^.+/VoiceDialogue/"
class_file = re.sub(pattern, "VoiceDialogue/", class_file)
meta_data = [
register_key,
tts_class.__name__,
f"{class_file}:{class_line}",
]
metas.append(meta_data)
metas.sort(key=lambda x: x[0])
data = [headers] + metas
col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
for row in data:
logger.info(
"| "
+ " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
+ " |"
)
logger.info("\n")
def register(self, register_table_key: str, key: str = None) -> callable:
"""装饰器,用于注册TTS类"""
def decorator(target_class):
if not hasattr(self, register_table_key):
setattr(self, register_table_key, {})
logger.debug(f"New TTS registry table added: {register_table_key}")
registry = getattr(self, register_table_key)
registry_key = key if key is not None else target_class.__name__
if registry_key in registry:
logger.debug(
f"Key {registry_key} already exists in {register_table_key}, re-register"
)
registry[registry_key] = target_class
logger.info(f"Registered TTS class: {registry_key} -> {target_class.__name__}")
return target_class
return decorator
# 全局TTS注册表实例
tts_tables = TTSRegistryTables()
class TTSManager:
"""TTS管理器,负责管理和创建TTS实例"""
def __init__(self):
self._tts_instances: Dict[str, TTSInterface] = {}
def create_tts(self, config: BaseTTSConfig) -> TTSInterface:
"""
根据配置创建TTS实例
Args:
config: TTS配置对象
Returns:
TTSInterface: TTS实例
Raises:
ValueError: 如果TTS类型未注册
"""
tts_type = config.tts_type.value
if tts_type not in tts_tables.tts_classes:
raise ValueError(f"未注册的TTS类型: {tts_type}. 可用类型: {list(tts_tables.tts_classes.keys())}")
tts_class = tts_tables.tts_classes[tts_type]
return tts_class(config)
def get_or_create_tts(self, config: BaseTTSConfig) -> TTSInterface:
"""
获取或创建TTS实例(单例模式)
Args:
config: TTS配置对象
Returns:
TTSInterface: TTS实例
"""
instance_key = f"{config.tts_type.value}:{config.character_name}"
if instance_key not in self._tts_instances:
self._tts_instances[instance_key] = self.create_tts(config)
return self._tts_instances[instance_key]
def list_registered_tts(self) -> Dict[str, Type[TTSInterface]]:
"""列出所有已注册的TTS类"""
return tts_tables.tts_classes.copy()
def is_tts_registered(self, tts_type: str) -> bool:
"""检查指定TTS类型是否已注册"""
return tts_type in tts_tables.tts_classes
def print_registry(self):
"""打印注册表信息"""
tts_tables.print()
# 全局TTS管理器实例
tts_manager = TTSManager()
def register_all_tts():
"""自动发现并注册runtime目录中的所有TTS实现"""
# 获取runtime目录路径
runtime_dir = Path(__file__).parent / "runtime"
# 扫描runtime目录中的Python文件
for py_file in runtime_dir.glob("*.py"):
if py_file.name in ["__init__.py", "interface.py"]:
continue
module_name = py_file.stem
try:
spec = importlib.util.spec_from_file_location(
module_name,
py_file
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
logger.info(f"Successfully imported TTS module: {module_name}")
except ImportError as e:
logger.warning(f"Failed to import TTS module {module_name}: {e}")
except Exception as e:
logger.error(f"Unexpected error importing TTS module {module_name}: {e}")
# 在模块导入时自动注册所有TTS
register_all_tts()