liumaolin
refactor(core): Architecturally decouple Audio, ASR, and TTS modules
60f8238
import importlib.util
import inspect
import re
from dataclasses import dataclass
from typing import Dict, Type, List, Literal, Optional
from voice_dialogue.utils.logger import logger
from .models import ASRInterface
@dataclass
class ASRRegistryTables:
"""ASR注册表系统,用于管理不同的ASR实现"""
asr_classes: Dict[str, Type[ASRInterface]] = None
def __post_init__(self):
if self.asr_classes is None:
self.asr_classes = {}
def print(self, key: str = None) -> None:
"""打印已注册的ASR类"""
logger.info("\nASR Registry Tables: \n")
headers = ["register name", "class name", "class location", "supported languages"]
if self.asr_classes and (key is None or "asr_classes" in key):
logger.info(f"----------- ** asr_classes ** --------------")
metas = []
for register_key, asr_class in self.asr_classes.items():
class_file = inspect.getfile(asr_class)
class_line = inspect.getsourcelines(asr_class)[1]
# 简化路径显示
pattern = r"^.+/VoiceDialogue/"
class_file = re.sub(pattern, "VoiceDialogue/", class_file)
# 获取支持的语言
try:
supported_langs = asr_class.supported_langs
supported_langs_str = ', '.join(supported_langs) if supported_langs else 'unknown'
except:
supported_langs_str = 'unknown'
meta_data = [
register_key,
asr_class.__name__,
f"{class_file}:{class_line}",
supported_langs_str,
]
metas.append(meta_data)
metas.sort(key=lambda x: x[0])
data = [headers] + metas
col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
for row in data:
logger.info(
"| "
+ " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
+ " |"
)
logger.info("\n")
def register(self, register_table_key: str, key: str = None) -> callable:
"""装饰器,用于注册ASR类"""
def decorator(target_class):
if not hasattr(self, register_table_key):
setattr(self, register_table_key, {})
logger.debug(f"New ASR registry table added: {register_table_key}")
registry = getattr(self, register_table_key)
registry_key = key if key is not None else target_class.__name__
if registry_key in registry:
logger.debug(
f"Key {registry_key} already exists in {register_table_key}, re-register"
)
registry[registry_key] = target_class
logger.info(f"Registered ASR class: {registry_key} -> {target_class.__name__}")
return target_class
return decorator
# 全局ASR注册表实例
asr_tables = ASRRegistryTables()
class ASRManager:
"""ASR管理器,负责管理和创建ASR实例"""
def __init__(self):
self._asr_instances: Dict[str, ASRInterface] = {}
self._language_to_asr_mapping = {
'zh': 'funasr', # 中文优先使用FunASR
'en': 'whisper', # 英文优先使用Whisper
# 'auto': 'whisper', # 自动检测默认使用Whisper
}
def create_asr(self, language: Literal['auto', 'zh', 'en']) -> ASRInterface:
"""
根据语言配置创建ASR实例
Args:
language: 语言类型
Returns:
ASRInterface: ASR实例
Raises:
ValueError: 如果ASR类型未注册或语言不支持
"""
try:
# 根据语言选择合适的ASR引擎
asr_type = self._get_asr_type_for_language(language)
if asr_type not in asr_tables.asr_classes:
raise ValueError(f"ASR类型 '{asr_type}' 未注册")
asr_class = asr_tables.asr_classes[asr_type]
instance = asr_class()
logger.info(f"成功创建ASR实例: {asr_type} for language: {language}")
return instance
except Exception as e:
logger.error(f"创建ASR实例失败: {e}")
raise
def get_or_create_asr(self, language: Literal['auto', 'zh', 'en']) -> ASRInterface:
"""
获取或创建ASR实例(单例模式)
Args:
language: 语言类型
Returns:
ASRInterface: ASR实例
"""
asr_type = self._get_asr_type_for_language(language)
instance_key = f"{asr_type}_{language}"
if instance_key not in self._asr_instances:
self._asr_instances[instance_key] = self.create_asr(language)
return self._asr_instances[instance_key]
def _get_asr_type_for_language(self, language: str) -> str:
"""根据语言获取对应的ASR类型"""
asr_type = self._language_to_asr_mapping.get(language)
if not asr_type:
raise ValueError(f"不支持的语言类型: {language}")
return asr_type
def set_language_mapping(self, language: str, asr_type: str) -> None:
"""
设置语言到ASR引擎的映射关系
Args:
language: 语言代码
asr_type: ASR引擎类型
"""
if asr_type not in asr_tables.asr_classes:
raise ValueError(f"ASR类型 '{asr_type}' 未注册")
self._language_to_asr_mapping[language] = asr_type
logger.info(f"更新语言映射: {language} -> {asr_type}")
def list_registered_asr(self) -> Dict[str, Type[ASRInterface]]:
"""列出所有已注册的ASR类型"""
return asr_tables.asr_classes.copy()
def is_asr_registered(self, asr_type: str) -> bool:
"""检查指定ASR类型是否已注册"""
return asr_type in asr_tables.asr_classes
def get_supported_languages(self) -> Dict[str, List[str]]:
"""
获取所有已注册ASR引擎支持的语言列表
Returns:
Dict[str, List[str]]: ASR引擎名称到支持语言列表的映射
"""
supported_languages = {}
for asr_key, asr_class in asr_tables.asr_classes.items():
try:
supported_languages[asr_key] = asr_class.supported_langs
# languages = asr_tables._get_asr_supported_languages(asr_key)
# supported_languages[asr_key] = languages
except Exception as e:
logger.warning(f"获取ASR引擎 '{asr_key}' 支持的语言失败: {e}")
supported_languages[asr_key] = ['unknown']
return supported_languages
def get_available_languages(self) -> List[str]:
"""
获取当前可用的所有语言列表
Returns:
List[str]: 可用的语言代码列表
"""
all_languages = set()
supported_langs = self.get_supported_languages()
for asr_key, languages in supported_langs.items():
all_languages.update(languages)
# 移除unknown标记
all_languages.discard('unknown')
return sorted(list(all_languages))
def validate_language_support(self, language: str) -> bool:
"""
验证指定语言是否被支持
Args:
language: 语言代码
Returns:
bool: 是否支持该语言
"""
available_languages = self.get_available_languages()
return language in available_languages
def get_optimal_asr_for_language(self, language: str) -> Optional[str]:
"""
为指定语言获取最优的ASR引擎
Args:
language: 语言代码
Returns:
Optional[str]: 最优的ASR引擎名称,如果没有支持的引擎则返回None
"""
# 检查当前映射
if language in self._language_to_asr_mapping:
asr_type = self._language_to_asr_mapping[language]
if self.is_asr_registered(asr_type):
return asr_type
# 查找支持该语言的ASR引擎
supported_langs = self.get_supported_languages()
for asr_key, languages in supported_langs.items():
if language in languages:
return asr_key
return None
def cleanup(self) -> None:
"""清理所有ASR实例"""
logger.info("清理ASR实例...")
self._asr_instances.clear()
logger.info("ASR实例清理完成")
def print_registry(self) -> None:
"""打印注册表信息"""
asr_tables.print()
def get_asr_statistics(self) -> Dict:
"""
获取ASR管理器的统计信息
Returns:
Dict: 包含各种统计信息的字典
"""
return {
'registered_asr_count': len(asr_tables.asr_classes),
'active_instances_count': len(self._asr_instances),
'supported_languages': self.get_available_languages(),
'language_mappings': self._language_to_asr_mapping.copy(),
'registered_asr_types': list(asr_tables.asr_classes.keys())
}
# 全局ASR管理器实例
asr_manager = ASRManager()
def register_all_asr():
"""自动发现并注册所有ASR实现"""
import importlib
from pathlib import Path
# 获取models目录路径
models_dir = Path(__file__).parent / "models"
# 扫描models目录中的Python文件
for py_file in models_dir.glob("*.py"):
if py_file.name in ["__init__.py", "base.py"]:
continue
module_name = py_file.stem
try:
# 动态导入模块
spec = importlib.util.spec_from_file_location(
module_name,
py_file
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
logger.info(f"Successfully imported ASR module: {module_name}")
except ImportError as e:
logger.warning(f"Failed to import ASR module {module_name}: {e}")
except Exception as e:
logger.error(f"Unexpected error importing ASR module {module_name}: {e}")
# 在模块导入时自动注册所有ASR
register_all_asr()