| | import importlib.util |
| | import inspect |
| | import re |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Dict, Type |
| |
|
| | from voice_dialogue.utils.logger import logger |
| | from .models.base import BaseTTSConfig |
| | from .runtime.interface import TTSInterface |
| |
|
| |
|
| | @dataclass |
| | class TTSRegistryTables: |
| | """TTS注册表系统,用于管理不同的TTS实现""" |
| |
|
| | tts_classes: Dict[str, Type[TTSInterface]] = None |
| |
|
| | def __post_init__(self): |
| | if self.tts_classes is None: |
| | self.tts_classes = {} |
| |
|
| | def print(self, key: str = None) -> None: |
| | """打印已注册的TTS类""" |
| | logger.info("\nTTS Registry Tables: \n") |
| | headers = ["register name", "class name", "class location"] |
| |
|
| | if self.tts_classes and (key is None or "tts_classes" in key): |
| | logger.info(f"----------- ** tts_classes ** --------------") |
| | metas = [] |
| | for register_key, tts_class in self.tts_classes.items(): |
| | class_file = inspect.getfile(tts_class) |
| | class_line = inspect.getsourcelines(tts_class)[1] |
| | |
| | pattern = r"^.+/VoiceDialogue/" |
| | class_file = re.sub(pattern, "VoiceDialogue/", class_file) |
| | meta_data = [ |
| | register_key, |
| | tts_class.__name__, |
| | f"{class_file}:{class_line}", |
| | ] |
| | metas.append(meta_data) |
| |
|
| | metas.sort(key=lambda x: x[0]) |
| | data = [headers] + metas |
| | col_widths = [max(len(str(item)) for item in col) for col in zip(*data)] |
| |
|
| | for row in data: |
| | logger.info( |
| | "| " |
| | + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths)) |
| | + " |" |
| | ) |
| | logger.info("\n") |
| |
|
| | def register(self, register_table_key: str, key: str = None) -> callable: |
| | """装饰器,用于注册TTS类""" |
| |
|
| | def decorator(target_class): |
| | if not hasattr(self, register_table_key): |
| | setattr(self, register_table_key, {}) |
| | logger.debug(f"New TTS registry table added: {register_table_key}") |
| |
|
| | registry = getattr(self, register_table_key) |
| | registry_key = key if key is not None else target_class.__name__ |
| |
|
| | if registry_key in registry: |
| | logger.debug( |
| | f"Key {registry_key} already exists in {register_table_key}, re-register" |
| | ) |
| |
|
| | registry[registry_key] = target_class |
| | logger.info(f"Registered TTS class: {registry_key} -> {target_class.__name__}") |
| | return target_class |
| |
|
| | return decorator |
| |
|
| |
|
| | |
| | tts_tables = TTSRegistryTables() |
| |
|
| |
|
| | class TTSManager: |
| | """TTS管理器,负责管理和创建TTS实例""" |
| |
|
| | def __init__(self): |
| | self._tts_instances: Dict[str, TTSInterface] = {} |
| |
|
| | def create_tts(self, config: BaseTTSConfig) -> TTSInterface: |
| | """ |
| | 根据配置创建TTS实例 |
| | |
| | Args: |
| | config: TTS配置对象 |
| | |
| | Returns: |
| | TTSInterface: TTS实例 |
| | |
| | Raises: |
| | ValueError: 如果TTS类型未注册 |
| | """ |
| | tts_type = config.tts_type.value |
| |
|
| | if tts_type not in tts_tables.tts_classes: |
| | raise ValueError(f"未注册的TTS类型: {tts_type}. 可用类型: {list(tts_tables.tts_classes.keys())}") |
| |
|
| | tts_class = tts_tables.tts_classes[tts_type] |
| | return tts_class(config) |
| |
|
| | def get_or_create_tts(self, config: BaseTTSConfig) -> TTSInterface: |
| | """ |
| | 获取或创建TTS实例(单例模式) |
| | |
| | Args: |
| | config: TTS配置对象 |
| | |
| | Returns: |
| | TTSInterface: TTS实例 |
| | """ |
| | instance_key = f"{config.tts_type.value}:{config.character_name}" |
| |
|
| | if instance_key not in self._tts_instances: |
| | self._tts_instances[instance_key] = self.create_tts(config) |
| |
|
| | return self._tts_instances[instance_key] |
| |
|
| | def list_registered_tts(self) -> Dict[str, Type[TTSInterface]]: |
| | """列出所有已注册的TTS类""" |
| | return tts_tables.tts_classes.copy() |
| |
|
| | def is_tts_registered(self, tts_type: str) -> bool: |
| | """检查指定TTS类型是否已注册""" |
| | return tts_type in tts_tables.tts_classes |
| |
|
| | def print_registry(self): |
| | """打印注册表信息""" |
| | tts_tables.print() |
| |
|
| |
|
| | |
| | tts_manager = TTSManager() |
| |
|
| |
|
| | def register_all_tts(): |
| | """自动发现并注册runtime目录中的所有TTS实现""" |
| |
|
| | |
| | runtime_dir = Path(__file__).parent / "runtime" |
| |
|
| | |
| | for py_file in runtime_dir.glob("*.py"): |
| | if py_file.name in ["__init__.py", "interface.py"]: |
| | continue |
| |
|
| | module_name = py_file.stem |
| | try: |
| | spec = importlib.util.spec_from_file_location( |
| | module_name, |
| | py_file |
| | ) |
| | module = importlib.util.module_from_spec(spec) |
| | spec.loader.exec_module(module) |
| | logger.info(f"Successfully imported TTS module: {module_name}") |
| | except ImportError as e: |
| | logger.warning(f"Failed to import TTS module {module_name}: {e}") |
| | except Exception as e: |
| | logger.error(f"Unexpected error importing TTS module {module_name}: {e}") |
| |
|
| |
|
| | |
| | register_all_tts() |
| |
|