liumaolin
commited on
Commit
·
f08ef5f
1
Parent(s):
8acaad0
Remove trailing whitespace in `audio_generator/manager.py` and `asr/manager.py` for improved code cleanliness and consistency.
Browse files
src/VoiceDialogue/services/audio/audio_generator/manager.py
CHANGED
|
@@ -15,7 +15,7 @@ class TTSRegistryTables:
|
|
| 15 |
"""TTS注册表系统,用于管理不同的TTS实现"""
|
| 16 |
|
| 17 |
tts_classes: Dict[str, Type[TTSInterface]] = None
|
| 18 |
-
|
| 19 |
def __post_init__(self):
|
| 20 |
if self.tts_classes is None:
|
| 21 |
self.tts_classes = {}
|
|
@@ -24,7 +24,7 @@ class TTSRegistryTables:
|
|
| 24 |
"""打印已注册的TTS类"""
|
| 25 |
print("\nTTS Registry Tables: \n")
|
| 26 |
headers = ["register name", "class name", "class location"]
|
| 27 |
-
|
| 28 |
if self.tts_classes and (key is None or "tts_classes" in key):
|
| 29 |
print(f"----------- ** tts_classes ** --------------")
|
| 30 |
metas = []
|
|
@@ -40,7 +40,7 @@ class TTSRegistryTables:
|
|
| 40 |
f"{class_file}:{class_line}",
|
| 41 |
]
|
| 42 |
metas.append(meta_data)
|
| 43 |
-
|
| 44 |
metas.sort(key=lambda x: x[0])
|
| 45 |
data = [headers] + metas
|
| 46 |
col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
|
|
@@ -82,10 +82,10 @@ tts_tables = TTSRegistryTables()
|
|
| 82 |
|
| 83 |
class TTSManager:
|
| 84 |
"""TTS管理器,负责管理和创建TTS实例"""
|
| 85 |
-
|
| 86 |
def __init__(self):
|
| 87 |
self._tts_instances: Dict[str, TTSInterface] = {}
|
| 88 |
-
|
| 89 |
def create_tts(self, config: BaseTTSConfig) -> TTSInterface:
|
| 90 |
"""
|
| 91 |
根据配置创建TTS实例
|
|
@@ -100,13 +100,13 @@ class TTSManager:
|
|
| 100 |
ValueError: 如果TTS类型未注册
|
| 101 |
"""
|
| 102 |
tts_type = config.tts_type.value
|
| 103 |
-
|
| 104 |
if tts_type not in tts_tables.tts_classes:
|
| 105 |
raise ValueError(f"未注册的TTS类型: {tts_type}. 可用类型: {list(tts_tables.tts_classes.keys())}")
|
| 106 |
-
|
| 107 |
tts_class = tts_tables.tts_classes[tts_type]
|
| 108 |
return tts_class(config)
|
| 109 |
-
|
| 110 |
def get_or_create_tts(self, config: BaseTTSConfig) -> TTSInterface:
|
| 111 |
"""
|
| 112 |
获取或创建TTS实例(单例模式)
|
|
@@ -118,20 +118,20 @@ class TTSManager:
|
|
| 118 |
TTSInterface: TTS实例
|
| 119 |
"""
|
| 120 |
instance_key = f"{config.tts_type.value}:{config.character_name}"
|
| 121 |
-
|
| 122 |
if instance_key not in self._tts_instances:
|
| 123 |
self._tts_instances[instance_key] = self.create_tts(config)
|
| 124 |
-
|
| 125 |
return self._tts_instances[instance_key]
|
| 126 |
-
|
| 127 |
def list_registered_tts(self) -> Dict[str, Type[TTSInterface]]:
|
| 128 |
"""列出所有已注册的TTS类"""
|
| 129 |
return tts_tables.tts_classes.copy()
|
| 130 |
-
|
| 131 |
def is_tts_registered(self, tts_type: str) -> bool:
|
| 132 |
"""检查指定TTS类型是否已注册"""
|
| 133 |
return tts_type in tts_tables.tts_classes
|
| 134 |
-
|
| 135 |
def print_registry(self):
|
| 136 |
"""打印注册表信息"""
|
| 137 |
tts_tables.print()
|
|
@@ -146,16 +146,16 @@ def register_all_tts():
|
|
| 146 |
|
| 147 |
# 获取runtime目录路径
|
| 148 |
runtime_dir = Path(__file__).parent / "runtime"
|
| 149 |
-
|
| 150 |
# 扫描runtime目录中的Python文件
|
| 151 |
for py_file in runtime_dir.glob("*.py"):
|
| 152 |
if py_file.name in ["__init__.py", "interface.py"]:
|
| 153 |
continue
|
| 154 |
-
|
| 155 |
module_name = py_file.stem
|
| 156 |
try:
|
| 157 |
spec = importlib.util.spec_from_file_location(
|
| 158 |
-
|
| 159 |
py_file
|
| 160 |
)
|
| 161 |
module = importlib.util.module_from_spec(spec)
|
|
|
|
| 15 |
"""TTS注册表系统,用于管理不同的TTS实现"""
|
| 16 |
|
| 17 |
tts_classes: Dict[str, Type[TTSInterface]] = None
|
| 18 |
+
|
| 19 |
def __post_init__(self):
|
| 20 |
if self.tts_classes is None:
|
| 21 |
self.tts_classes = {}
|
|
|
|
| 24 |
"""打印已注册的TTS类"""
|
| 25 |
print("\nTTS Registry Tables: \n")
|
| 26 |
headers = ["register name", "class name", "class location"]
|
| 27 |
+
|
| 28 |
if self.tts_classes and (key is None or "tts_classes" in key):
|
| 29 |
print(f"----------- ** tts_classes ** --------------")
|
| 30 |
metas = []
|
|
|
|
| 40 |
f"{class_file}:{class_line}",
|
| 41 |
]
|
| 42 |
metas.append(meta_data)
|
| 43 |
+
|
| 44 |
metas.sort(key=lambda x: x[0])
|
| 45 |
data = [headers] + metas
|
| 46 |
col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
|
|
|
|
| 82 |
|
| 83 |
class TTSManager:
|
| 84 |
"""TTS管理器,负责管理和创建TTS实例"""
|
| 85 |
+
|
| 86 |
def __init__(self):
|
| 87 |
self._tts_instances: Dict[str, TTSInterface] = {}
|
| 88 |
+
|
| 89 |
def create_tts(self, config: BaseTTSConfig) -> TTSInterface:
|
| 90 |
"""
|
| 91 |
根据配置创建TTS实例
|
|
|
|
| 100 |
ValueError: 如果TTS类型未注册
|
| 101 |
"""
|
| 102 |
tts_type = config.tts_type.value
|
| 103 |
+
|
| 104 |
if tts_type not in tts_tables.tts_classes:
|
| 105 |
raise ValueError(f"未注册的TTS类型: {tts_type}. 可用类型: {list(tts_tables.tts_classes.keys())}")
|
| 106 |
+
|
| 107 |
tts_class = tts_tables.tts_classes[tts_type]
|
| 108 |
return tts_class(config)
|
| 109 |
+
|
| 110 |
def get_or_create_tts(self, config: BaseTTSConfig) -> TTSInterface:
|
| 111 |
"""
|
| 112 |
获取或创建TTS实例(单例模式)
|
|
|
|
| 118 |
TTSInterface: TTS实例
|
| 119 |
"""
|
| 120 |
instance_key = f"{config.tts_type.value}:{config.character_name}"
|
| 121 |
+
|
| 122 |
if instance_key not in self._tts_instances:
|
| 123 |
self._tts_instances[instance_key] = self.create_tts(config)
|
| 124 |
+
|
| 125 |
return self._tts_instances[instance_key]
|
| 126 |
+
|
| 127 |
def list_registered_tts(self) -> Dict[str, Type[TTSInterface]]:
|
| 128 |
"""列出所有已注册的TTS类"""
|
| 129 |
return tts_tables.tts_classes.copy()
|
| 130 |
+
|
| 131 |
def is_tts_registered(self, tts_type: str) -> bool:
|
| 132 |
"""检查指定TTS类型是否已注册"""
|
| 133 |
return tts_type in tts_tables.tts_classes
|
| 134 |
+
|
| 135 |
def print_registry(self):
|
| 136 |
"""打印注册表信息"""
|
| 137 |
tts_tables.print()
|
|
|
|
| 146 |
|
| 147 |
# 获取runtime目录路径
|
| 148 |
runtime_dir = Path(__file__).parent / "runtime"
|
| 149 |
+
|
| 150 |
# 扫描runtime目录中的Python文件
|
| 151 |
for py_file in runtime_dir.glob("*.py"):
|
| 152 |
if py_file.name in ["__init__.py", "interface.py"]:
|
| 153 |
continue
|
| 154 |
+
|
| 155 |
module_name = py_file.stem
|
| 156 |
try:
|
| 157 |
spec = importlib.util.spec_from_file_location(
|
| 158 |
+
module_name,
|
| 159 |
py_file
|
| 160 |
)
|
| 161 |
module = importlib.util.module_from_spec(spec)
|
src/VoiceDialogue/services/speech/asr/manager.py
CHANGED
|
@@ -296,7 +296,7 @@ def register_all_asr():
|
|
| 296 |
try:
|
| 297 |
# 动态导入模块
|
| 298 |
spec = importlib.util.spec_from_file_location(
|
| 299 |
-
|
| 300 |
py_file
|
| 301 |
)
|
| 302 |
module = importlib.util.module_from_spec(spec)
|
|
|
|
| 296 |
try:
|
| 297 |
# 动态导入模块
|
| 298 |
spec = importlib.util.spec_from_file_location(
|
| 299 |
+
module_name,
|
| 300 |
py_file
|
| 301 |
)
|
| 302 |
module = importlib.util.module_from_spec(spec)
|