File size: 5,587 Bytes
ee4663a
ef0d09e
 
ee4663a
 
 
 
851495c
ee4663a
ef0d09e
 
 
 
 
 
 
 
f08ef5f
ef0d09e
 
 
 
 
 
851495c
ef0d09e
f08ef5f
ef0d09e
851495c
ef0d09e
 
 
 
 
 
 
 
 
 
 
 
 
f08ef5f
ef0d09e
 
 
 
 
851495c
ef0d09e
 
 
 
851495c
ef0d09e
 
 
 
 
 
 
851495c
ef0d09e
 
 
 
 
851495c
ef0d09e
 
 
 
851495c
ef0d09e
 
 
 
 
 
 
 
 
 
 
f08ef5f
ef0d09e
 
f08ef5f
ef0d09e
 
 
 
 
 
 
 
 
 
 
 
 
 
f08ef5f
ef0d09e
 
f08ef5f
ef0d09e
 
f08ef5f
ef0d09e
 
 
 
 
 
 
 
 
 
 
f08ef5f
ef0d09e
 
f08ef5f
ef0d09e
f08ef5f
ef0d09e
 
 
f08ef5f
ef0d09e
 
 
f08ef5f
ef0d09e
 
 
 
 
 
 
 
 
 
 
ee4663a
ef0d09e
 
f08ef5f
ef0d09e
 
 
 
f08ef5f
ef0d09e
 
9bfcd24
f08ef5f
9bfcd24
 
 
 
851495c
ef0d09e
851495c
ef0d09e
851495c
ef0d09e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import importlib.util
import inspect
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Type

from voice_dialogue.utils.logger import logger
from .models.base import BaseTTSConfig
from .runtime.interface import TTSInterface


@dataclass
class TTSRegistryTables:
    """TTS注册表系统,用于管理不同的TTS实现"""

    tts_classes: Dict[str, Type[TTSInterface]] = None

    def __post_init__(self):
        if self.tts_classes is None:
            self.tts_classes = {}

    def print(self, key: str = None) -> None:
        """打印已注册的TTS类"""
        logger.info("\nTTS Registry Tables: \n")
        headers = ["register name", "class name", "class location"]

        if self.tts_classes and (key is None or "tts_classes" in key):
            logger.info(f"-----------    ** tts_classes **    --------------")
            metas = []
            for register_key, tts_class in self.tts_classes.items():
                class_file = inspect.getfile(tts_class)
                class_line = inspect.getsourcelines(tts_class)[1]
                # 简化路径显示
                pattern = r"^.+/VoiceDialogue/"
                class_file = re.sub(pattern, "VoiceDialogue/", class_file)
                meta_data = [
                    register_key,
                    tts_class.__name__,
                    f"{class_file}:{class_line}",
                ]
                metas.append(meta_data)

            metas.sort(key=lambda x: x[0])
            data = [headers] + metas
            col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]

            for row in data:
                logger.info(
                    "| "
                    + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
                    + " |"
                )
        logger.info("\n")

    def register(self, register_table_key: str, key: str = None) -> callable:
        """装饰器,用于注册TTS类"""

        def decorator(target_class):
            if not hasattr(self, register_table_key):
                setattr(self, register_table_key, {})
                logger.debug(f"New TTS registry table added: {register_table_key}")

            registry = getattr(self, register_table_key)
            registry_key = key if key is not None else target_class.__name__

            if registry_key in registry:
                logger.debug(
                    f"Key {registry_key} already exists in {register_table_key}, re-register"
                )

            registry[registry_key] = target_class
            logger.info(f"Registered TTS class: {registry_key} -> {target_class.__name__}")
            return target_class

        return decorator


# 全局TTS注册表实例
tts_tables = TTSRegistryTables()


class TTSManager:
    """TTS管理器,负责管理和创建TTS实例"""

    def __init__(self):
        self._tts_instances: Dict[str, TTSInterface] = {}

    def create_tts(self, config: BaseTTSConfig) -> TTSInterface:
        """
        根据配置创建TTS实例
        
        Args:
            config: TTS配置对象
            
        Returns:
            TTSInterface: TTS实例
            
        Raises:
            ValueError: 如果TTS类型未注册
        """
        tts_type = config.tts_type.value

        if tts_type not in tts_tables.tts_classes:
            raise ValueError(f"未注册的TTS类型: {tts_type}. 可用类型: {list(tts_tables.tts_classes.keys())}")

        tts_class = tts_tables.tts_classes[tts_type]
        return tts_class(config)

    def get_or_create_tts(self, config: BaseTTSConfig) -> TTSInterface:
        """
        获取或创建TTS实例(单例模式)
        
        Args:
            config: TTS配置对象
            
        Returns:
            TTSInterface: TTS实例
        """
        instance_key = f"{config.tts_type.value}:{config.character_name}"

        if instance_key not in self._tts_instances:
            self._tts_instances[instance_key] = self.create_tts(config)

        return self._tts_instances[instance_key]

    def list_registered_tts(self) -> Dict[str, Type[TTSInterface]]:
        """列出所有已注册的TTS类"""
        return tts_tables.tts_classes.copy()

    def is_tts_registered(self, tts_type: str) -> bool:
        """检查指定TTS类型是否已注册"""
        return tts_type in tts_tables.tts_classes

    def print_registry(self):
        """打印注册表信息"""
        tts_tables.print()


# 全局TTS管理器实例
tts_manager = TTSManager()


def register_all_tts():
    """自动发现并注册runtime目录中的所有TTS实现"""

    # 获取runtime目录路径
    runtime_dir = Path(__file__).parent / "runtime"

    # 扫描runtime目录中的Python文件
    for py_file in runtime_dir.glob("*.py"):
        if py_file.name in ["__init__.py", "interface.py"]:
            continue

        module_name = py_file.stem
        try:
            spec = importlib.util.spec_from_file_location(
                module_name,
                py_file
            )
            module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(module)
            logger.info(f"Successfully imported TTS module: {module_name}")
        except ImportError as e:
            logger.warning(f"Failed to import TTS module {module_name}: {e}")
        except Exception as e:
            logger.error(f"Unexpected error importing TTS module {module_name}: {e}")


# 在模块导入时自动注册所有TTS
register_all_tts()