File size: 5,587 Bytes
ee4663a ef0d09e ee4663a 851495c ee4663a ef0d09e f08ef5f ef0d09e 851495c ef0d09e f08ef5f ef0d09e 851495c ef0d09e f08ef5f ef0d09e 851495c ef0d09e 851495c ef0d09e 851495c ef0d09e 851495c ef0d09e 851495c ef0d09e f08ef5f ef0d09e f08ef5f ef0d09e f08ef5f ef0d09e f08ef5f ef0d09e f08ef5f ef0d09e f08ef5f ef0d09e f08ef5f ef0d09e f08ef5f ef0d09e f08ef5f ef0d09e f08ef5f ef0d09e ee4663a ef0d09e f08ef5f ef0d09e f08ef5f ef0d09e 9bfcd24 f08ef5f 9bfcd24 851495c ef0d09e 851495c ef0d09e 851495c ef0d09e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import importlib.util
import inspect
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Type
from voice_dialogue.utils.logger import logger
from .models.base import BaseTTSConfig
from .runtime.interface import TTSInterface
@dataclass
class TTSRegistryTables:
"""TTS注册表系统,用于管理不同的TTS实现"""
tts_classes: Dict[str, Type[TTSInterface]] = None
def __post_init__(self):
if self.tts_classes is None:
self.tts_classes = {}
def print(self, key: str = None) -> None:
"""打印已注册的TTS类"""
logger.info("\nTTS Registry Tables: \n")
headers = ["register name", "class name", "class location"]
if self.tts_classes and (key is None or "tts_classes" in key):
logger.info(f"----------- ** tts_classes ** --------------")
metas = []
for register_key, tts_class in self.tts_classes.items():
class_file = inspect.getfile(tts_class)
class_line = inspect.getsourcelines(tts_class)[1]
# 简化路径显示
pattern = r"^.+/VoiceDialogue/"
class_file = re.sub(pattern, "VoiceDialogue/", class_file)
meta_data = [
register_key,
tts_class.__name__,
f"{class_file}:{class_line}",
]
metas.append(meta_data)
metas.sort(key=lambda x: x[0])
data = [headers] + metas
col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
for row in data:
logger.info(
"| "
+ " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
+ " |"
)
logger.info("\n")
def register(self, register_table_key: str, key: str = None) -> callable:
"""装饰器,用于注册TTS类"""
def decorator(target_class):
if not hasattr(self, register_table_key):
setattr(self, register_table_key, {})
logger.debug(f"New TTS registry table added: {register_table_key}")
registry = getattr(self, register_table_key)
registry_key = key if key is not None else target_class.__name__
if registry_key in registry:
logger.debug(
f"Key {registry_key} already exists in {register_table_key}, re-register"
)
registry[registry_key] = target_class
logger.info(f"Registered TTS class: {registry_key} -> {target_class.__name__}")
return target_class
return decorator
# 全局TTS注册表实例
tts_tables = TTSRegistryTables()
class TTSManager:
"""TTS管理器,负责管理和创建TTS实例"""
def __init__(self):
self._tts_instances: Dict[str, TTSInterface] = {}
def create_tts(self, config: BaseTTSConfig) -> TTSInterface:
"""
根据配置创建TTS实例
Args:
config: TTS配置对象
Returns:
TTSInterface: TTS实例
Raises:
ValueError: 如果TTS类型未注册
"""
tts_type = config.tts_type.value
if tts_type not in tts_tables.tts_classes:
raise ValueError(f"未注册的TTS类型: {tts_type}. 可用类型: {list(tts_tables.tts_classes.keys())}")
tts_class = tts_tables.tts_classes[tts_type]
return tts_class(config)
def get_or_create_tts(self, config: BaseTTSConfig) -> TTSInterface:
"""
获取或创建TTS实例(单例模式)
Args:
config: TTS配置对象
Returns:
TTSInterface: TTS实例
"""
instance_key = f"{config.tts_type.value}:{config.character_name}"
if instance_key not in self._tts_instances:
self._tts_instances[instance_key] = self.create_tts(config)
return self._tts_instances[instance_key]
def list_registered_tts(self) -> Dict[str, Type[TTSInterface]]:
"""列出所有已注册的TTS类"""
return tts_tables.tts_classes.copy()
def is_tts_registered(self, tts_type: str) -> bool:
"""检查指定TTS类型是否已注册"""
return tts_type in tts_tables.tts_classes
def print_registry(self):
"""打印注册表信息"""
tts_tables.print()
# 全局TTS管理器实例
tts_manager = TTSManager()
def register_all_tts():
"""自动发现并注册runtime目录中的所有TTS实现"""
# 获取runtime目录路径
runtime_dir = Path(__file__).parent / "runtime"
# 扫描runtime目录中的Python文件
for py_file in runtime_dir.glob("*.py"):
if py_file.name in ["__init__.py", "interface.py"]:
continue
module_name = py_file.stem
try:
spec = importlib.util.spec_from_file_location(
module_name,
py_file
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
logger.info(f"Successfully imported TTS module: {module_name}")
except ImportError as e:
logger.warning(f"Failed to import TTS module {module_name}: {e}")
except Exception as e:
logger.error(f"Unexpected error importing TTS module {module_name}: {e}")
# 在模块导入时自动注册所有TTS
register_all_tts()
|