|
|
import importlib.util |
|
|
import inspect |
|
|
import re |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Dict, Type |
|
|
|
|
|
from voice_dialogue.utils.logger import logger |
|
|
from .models.base import BaseTTSConfig |
|
|
from .runtime.interface import TTSInterface |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TTSRegistryTables: |
|
|
"""TTS注册表系统,用于管理不同的TTS实现""" |
|
|
|
|
|
tts_classes: Dict[str, Type[TTSInterface]] = None |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.tts_classes is None: |
|
|
self.tts_classes = {} |
|
|
|
|
|
def print(self, key: str = None) -> None: |
|
|
"""打印已注册的TTS类""" |
|
|
logger.info("\nTTS Registry Tables: \n") |
|
|
headers = ["register name", "class name", "class location"] |
|
|
|
|
|
if self.tts_classes and (key is None or "tts_classes" in key): |
|
|
logger.info(f"----------- ** tts_classes ** --------------") |
|
|
metas = [] |
|
|
for register_key, tts_class in self.tts_classes.items(): |
|
|
class_file = inspect.getfile(tts_class) |
|
|
class_line = inspect.getsourcelines(tts_class)[1] |
|
|
|
|
|
pattern = r"^.+/VoiceDialogue/" |
|
|
class_file = re.sub(pattern, "VoiceDialogue/", class_file) |
|
|
meta_data = [ |
|
|
register_key, |
|
|
tts_class.__name__, |
|
|
f"{class_file}:{class_line}", |
|
|
] |
|
|
metas.append(meta_data) |
|
|
|
|
|
metas.sort(key=lambda x: x[0]) |
|
|
data = [headers] + metas |
|
|
col_widths = [max(len(str(item)) for item in col) for col in zip(*data)] |
|
|
|
|
|
for row in data: |
|
|
logger.info( |
|
|
"| " |
|
|
+ " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths)) |
|
|
+ " |" |
|
|
) |
|
|
logger.info("\n") |
|
|
|
|
|
def register(self, register_table_key: str, key: str = None) -> callable: |
|
|
"""装饰器,用于注册TTS类""" |
|
|
|
|
|
def decorator(target_class): |
|
|
if not hasattr(self, register_table_key): |
|
|
setattr(self, register_table_key, {}) |
|
|
logger.debug(f"New TTS registry table added: {register_table_key}") |
|
|
|
|
|
registry = getattr(self, register_table_key) |
|
|
registry_key = key if key is not None else target_class.__name__ |
|
|
|
|
|
if registry_key in registry: |
|
|
logger.debug( |
|
|
f"Key {registry_key} already exists in {register_table_key}, re-register" |
|
|
) |
|
|
|
|
|
registry[registry_key] = target_class |
|
|
logger.info(f"Registered TTS class: {registry_key} -> {target_class.__name__}") |
|
|
return target_class |
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
|
|
|
|
tts_tables = TTSRegistryTables() |
|
|
|
|
|
|
|
|
class TTSManager: |
|
|
"""TTS管理器,负责管理和创建TTS实例""" |
|
|
|
|
|
def __init__(self): |
|
|
self._tts_instances: Dict[str, TTSInterface] = {} |
|
|
|
|
|
def create_tts(self, config: BaseTTSConfig) -> TTSInterface: |
|
|
""" |
|
|
根据配置创建TTS实例 |
|
|
|
|
|
Args: |
|
|
config: TTS配置对象 |
|
|
|
|
|
Returns: |
|
|
TTSInterface: TTS实例 |
|
|
|
|
|
Raises: |
|
|
ValueError: 如果TTS类型未注册 |
|
|
""" |
|
|
tts_type = config.tts_type.value |
|
|
|
|
|
if tts_type not in tts_tables.tts_classes: |
|
|
raise ValueError(f"未注册的TTS类型: {tts_type}. 可用类型: {list(tts_tables.tts_classes.keys())}") |
|
|
|
|
|
tts_class = tts_tables.tts_classes[tts_type] |
|
|
return tts_class(config) |
|
|
|
|
|
def get_or_create_tts(self, config: BaseTTSConfig) -> TTSInterface: |
|
|
""" |
|
|
获取或创建TTS实例(单例模式) |
|
|
|
|
|
Args: |
|
|
config: TTS配置对象 |
|
|
|
|
|
Returns: |
|
|
TTSInterface: TTS实例 |
|
|
""" |
|
|
instance_key = f"{config.tts_type.value}:{config.character_name}" |
|
|
|
|
|
if instance_key not in self._tts_instances: |
|
|
self._tts_instances[instance_key] = self.create_tts(config) |
|
|
|
|
|
return self._tts_instances[instance_key] |
|
|
|
|
|
def list_registered_tts(self) -> Dict[str, Type[TTSInterface]]: |
|
|
"""列出所有已注册的TTS类""" |
|
|
return tts_tables.tts_classes.copy() |
|
|
|
|
|
def is_tts_registered(self, tts_type: str) -> bool: |
|
|
"""检查指定TTS类型是否已注册""" |
|
|
return tts_type in tts_tables.tts_classes |
|
|
|
|
|
def print_registry(self): |
|
|
"""打印注册表信息""" |
|
|
tts_tables.print() |
|
|
|
|
|
|
|
|
|
|
|
tts_manager = TTSManager() |
|
|
|
|
|
|
|
|
def register_all_tts(): |
|
|
"""自动发现并注册runtime目录中的所有TTS实现""" |
|
|
|
|
|
|
|
|
runtime_dir = Path(__file__).parent / "runtime" |
|
|
|
|
|
|
|
|
for py_file in runtime_dir.glob("*.py"): |
|
|
if py_file.name in ["__init__.py", "interface.py"]: |
|
|
continue |
|
|
|
|
|
module_name = py_file.stem |
|
|
try: |
|
|
spec = importlib.util.spec_from_file_location( |
|
|
module_name, |
|
|
py_file |
|
|
) |
|
|
module = importlib.util.module_from_spec(spec) |
|
|
spec.loader.exec_module(module) |
|
|
logger.info(f"Successfully imported TTS module: {module_name}") |
|
|
except ImportError as e: |
|
|
logger.warning(f"Failed to import TTS module {module_name}: {e}") |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected error importing TTS module {module_name}: {e}") |
|
|
|
|
|
|
|
|
|
|
|
register_all_tts() |
|
|
|