File size: 10,585 Bytes
8acaad0 59603db 851495c 59603db 851495c 59603db 851495c 59603db 851495c 59603db 851495c 59603db 851495c 59603db 851495c 59603db 851495c 59603db 8acaad0 59603db 851495c 59603db 851495c 59603db 851495c 59603db 8acaad0 59603db 8acaad0 59603db 851495c 59603db 851495c 59603db 851495c 59603db 8acaad0 f08ef5f 8acaad0 851495c 59603db 851495c 59603db 851495c 59603db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 |
import importlib.util
import inspect
import re
from dataclasses import dataclass
from typing import Dict, Type, List, Literal, Optional
from voice_dialogue.utils.logger import logger
from .models import ASRInterface
@dataclass
class ASRRegistryTables:
"""ASR注册表系统,用于管理不同的ASR实现"""
asr_classes: Dict[str, Type[ASRInterface]] = None
def __post_init__(self):
if self.asr_classes is None:
self.asr_classes = {}
def print(self, key: str = None) -> None:
"""打印已注册的ASR类"""
logger.info("\nASR Registry Tables: \n")
headers = ["register name", "class name", "class location", "supported languages"]
if self.asr_classes and (key is None or "asr_classes" in key):
logger.info(f"----------- ** asr_classes ** --------------")
metas = []
for register_key, asr_class in self.asr_classes.items():
class_file = inspect.getfile(asr_class)
class_line = inspect.getsourcelines(asr_class)[1]
# 简化路径显示
pattern = r"^.+/VoiceDialogue/"
class_file = re.sub(pattern, "VoiceDialogue/", class_file)
# 获取支持的语言
try:
supported_langs = asr_class.supported_langs
supported_langs_str = ', '.join(supported_langs) if supported_langs else 'unknown'
except:
supported_langs_str = 'unknown'
meta_data = [
register_key,
asr_class.__name__,
f"{class_file}:{class_line}",
supported_langs_str,
]
metas.append(meta_data)
metas.sort(key=lambda x: x[0])
data = [headers] + metas
col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
for row in data:
logger.info(
"| "
+ " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
+ " |"
)
logger.info("\n")
def register(self, register_table_key: str, key: str = None) -> callable:
"""装饰器,用于注册ASR类"""
def decorator(target_class):
if not hasattr(self, register_table_key):
setattr(self, register_table_key, {})
logger.debug(f"New ASR registry table added: {register_table_key}")
registry = getattr(self, register_table_key)
registry_key = key if key is not None else target_class.__name__
if registry_key in registry:
logger.debug(
f"Key {registry_key} already exists in {register_table_key}, re-register"
)
registry[registry_key] = target_class
logger.info(f"Registered ASR class: {registry_key} -> {target_class.__name__}")
return target_class
return decorator
# 全局ASR注册表实例
asr_tables = ASRRegistryTables()
class ASRManager:
"""ASR管理器,负责管理和创建ASR实例"""
def __init__(self):
self._asr_instances: Dict[str, ASRInterface] = {}
self._language_to_asr_mapping = {
'zh': 'funasr', # 中文优先使用FunASR
'en': 'whisper', # 英文优先使用Whisper
# 'auto': 'whisper', # 自动检测默认使用Whisper
}
def create_asr(self, language: Literal['auto', 'zh', 'en']) -> ASRInterface:
"""
根据语言配置创建ASR实例
Args:
language: 语言类型
Returns:
ASRInterface: ASR实例
Raises:
ValueError: 如果ASR类型未注册或语言不支持
"""
try:
# 根据语言选择合适的ASR引擎
asr_type = self._get_asr_type_for_language(language)
if asr_type not in asr_tables.asr_classes:
raise ValueError(f"ASR类型 '{asr_type}' 未注册")
asr_class = asr_tables.asr_classes[asr_type]
instance = asr_class()
logger.info(f"成功创建ASR实例: {asr_type} for language: {language}")
return instance
except Exception as e:
logger.error(f"创建ASR实例失败: {e}")
raise
def get_or_create_asr(self, language: Literal['auto', 'zh', 'en']) -> ASRInterface:
"""
获取或创建ASR实例(单例模式)
Args:
language: 语言类型
Returns:
ASRInterface: ASR实例
"""
asr_type = self._get_asr_type_for_language(language)
instance_key = f"{asr_type}_{language}"
if instance_key not in self._asr_instances:
self._asr_instances[instance_key] = self.create_asr(language)
return self._asr_instances[instance_key]
def _get_asr_type_for_language(self, language: str) -> str:
"""根据语言获取对应的ASR类型"""
asr_type = self._language_to_asr_mapping.get(language)
if not asr_type:
raise ValueError(f"不支持的语言类型: {language}")
return asr_type
def set_language_mapping(self, language: str, asr_type: str) -> None:
"""
设置语言到ASR引擎的映射关系
Args:
language: 语言代码
asr_type: ASR引擎类型
"""
if asr_type not in asr_tables.asr_classes:
raise ValueError(f"ASR类型 '{asr_type}' 未注册")
self._language_to_asr_mapping[language] = asr_type
logger.info(f"更新语言映射: {language} -> {asr_type}")
def list_registered_asr(self) -> Dict[str, Type[ASRInterface]]:
"""列出所有已注册的ASR类型"""
return asr_tables.asr_classes.copy()
def is_asr_registered(self, asr_type: str) -> bool:
"""检查指定ASR类型是否已注册"""
return asr_type in asr_tables.asr_classes
def get_supported_languages(self) -> Dict[str, List[str]]:
"""
获取所有已注册ASR引擎支持的语言列表
Returns:
Dict[str, List[str]]: ASR引擎名称到支持语言列表的映射
"""
supported_languages = {}
for asr_key, asr_class in asr_tables.asr_classes.items():
try:
supported_languages[asr_key] = asr_class.supported_langs
# languages = asr_tables._get_asr_supported_languages(asr_key)
# supported_languages[asr_key] = languages
except Exception as e:
logger.warning(f"获取ASR引擎 '{asr_key}' 支持的语言失败: {e}")
supported_languages[asr_key] = ['unknown']
return supported_languages
def get_available_languages(self) -> List[str]:
"""
获取当前可用的所有语言列表
Returns:
List[str]: 可用的语言代码列表
"""
all_languages = set()
supported_langs = self.get_supported_languages()
for asr_key, languages in supported_langs.items():
all_languages.update(languages)
# 移除unknown标记
all_languages.discard('unknown')
return sorted(list(all_languages))
def validate_language_support(self, language: str) -> bool:
"""
验证指定语言是否被支持
Args:
language: 语言代码
Returns:
bool: 是否支持该语言
"""
available_languages = self.get_available_languages()
return language in available_languages
def get_optimal_asr_for_language(self, language: str) -> Optional[str]:
"""
为指定语言获取最优的ASR引擎
Args:
language: 语言代码
Returns:
Optional[str]: 最优的ASR引擎名称,如果没有支持的引擎则返回None
"""
# 检查当前映射
if language in self._language_to_asr_mapping:
asr_type = self._language_to_asr_mapping[language]
if self.is_asr_registered(asr_type):
return asr_type
# 查找支持该语言的ASR引擎
supported_langs = self.get_supported_languages()
for asr_key, languages in supported_langs.items():
if language in languages:
return asr_key
return None
def cleanup(self) -> None:
"""清理所有ASR实例"""
logger.info("清理ASR实例...")
self._asr_instances.clear()
logger.info("ASR实例清理完成")
def print_registry(self) -> None:
"""打印注册表信息"""
asr_tables.print()
def get_asr_statistics(self) -> Dict:
"""
获取ASR管理器的统计信息
Returns:
Dict: 包含各种统计信息的字典
"""
return {
'registered_asr_count': len(asr_tables.asr_classes),
'active_instances_count': len(self._asr_instances),
'supported_languages': self.get_available_languages(),
'language_mappings': self._language_to_asr_mapping.copy(),
'registered_asr_types': list(asr_tables.asr_classes.keys())
}
# 全局ASR管理器实例
asr_manager = ASRManager()
def register_all_asr():
"""自动发现并注册所有ASR实现"""
import importlib
from pathlib import Path
# 获取models目录路径
models_dir = Path(__file__).parent / "models"
# 扫描models目录中的Python文件
for py_file in models_dir.glob("*.py"):
if py_file.name in ["__init__.py", "base.py"]:
continue
module_name = py_file.stem
try:
# 动态导入模块
spec = importlib.util.spec_from_file_location(
module_name,
py_file
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
logger.info(f"Successfully imported ASR module: {module_name}")
except ImportError as e:
logger.warning(f"Failed to import ASR module {module_name}: {e}")
except Exception as e:
logger.error(f"Unexpected error importing ASR module {module_name}: {e}")
# 在模块导入时自动注册所有ASR
register_all_asr()
|