File size: 10,585 Bytes
8acaad0
59603db
 
 
 
 
851495c
59603db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851495c
59603db
 
 
851495c
59603db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851495c
59603db
 
 
 
851495c
59603db
 
 
 
 
 
 
851495c
59603db
 
 
 
 
851495c
59603db
 
 
 
851495c
59603db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8acaad0
59603db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851495c
59603db
 
 
851495c
59603db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851495c
59603db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8acaad0
59603db
8acaad0
 
 
59603db
851495c
59603db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851495c
59603db
851495c
59603db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8acaad0
f08ef5f
8acaad0
 
 
 
851495c
59603db
851495c
59603db
851495c
59603db
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import importlib.util
import inspect
import re
from dataclasses import dataclass
from typing import Dict, Type, List, Literal, Optional

from voice_dialogue.utils.logger import logger
from .models import ASRInterface


@dataclass
class ASRRegistryTables:
    """ASR注册表系统,用于管理不同的ASR实现"""

    asr_classes: Dict[str, Type[ASRInterface]] = None

    def __post_init__(self):
        if self.asr_classes is None:
            self.asr_classes = {}

    def print(self, key: str = None) -> None:
        """打印已注册的ASR类"""
        logger.info("\nASR Registry Tables: \n")
        headers = ["register name", "class name", "class location", "supported languages"]

        if self.asr_classes and (key is None or "asr_classes" in key):
            logger.info(f"-----------    ** asr_classes **    --------------")
            metas = []
            for register_key, asr_class in self.asr_classes.items():
                class_file = inspect.getfile(asr_class)
                class_line = inspect.getsourcelines(asr_class)[1]
                # 简化路径显示
                pattern = r"^.+/VoiceDialogue/"
                class_file = re.sub(pattern, "VoiceDialogue/", class_file)

                # 获取支持的语言
                try:
                    supported_langs = asr_class.supported_langs
                    supported_langs_str = ', '.join(supported_langs) if supported_langs else 'unknown'
                except:
                    supported_langs_str = 'unknown'

                meta_data = [
                    register_key,
                    asr_class.__name__,
                    f"{class_file}:{class_line}",
                    supported_langs_str,
                ]
                metas.append(meta_data)

            metas.sort(key=lambda x: x[0])
            data = [headers] + metas
            col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]

            for row in data:
                logger.info(
                    "| "
                    + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
                    + " |"
                )
        logger.info("\n")

    def register(self, register_table_key: str, key: str = None) -> callable:
        """装饰器,用于注册ASR类"""

        def decorator(target_class):
            if not hasattr(self, register_table_key):
                setattr(self, register_table_key, {})
                logger.debug(f"New ASR registry table added: {register_table_key}")

            registry = getattr(self, register_table_key)
            registry_key = key if key is not None else target_class.__name__

            if registry_key in registry:
                logger.debug(
                    f"Key {registry_key} already exists in {register_table_key}, re-register"
                )

            registry[registry_key] = target_class
            logger.info(f"Registered ASR class: {registry_key} -> {target_class.__name__}")
            return target_class

        return decorator


# 全局ASR注册表实例
asr_tables = ASRRegistryTables()


class ASRManager:
    """ASR管理器,负责管理和创建ASR实例"""

    def __init__(self):
        self._asr_instances: Dict[str, ASRInterface] = {}
        self._language_to_asr_mapping = {
            'zh': 'funasr',  # 中文优先使用FunASR
            'en': 'whisper',  # 英文优先使用Whisper
            # 'auto': 'whisper',  # 自动检测默认使用Whisper
        }

    def create_asr(self, language: Literal['auto', 'zh', 'en']) -> ASRInterface:
        """
        根据语言配置创建ASR实例
        
        Args:
            language: 语言类型
            
        Returns:
            ASRInterface: ASR实例
            
        Raises:
            ValueError: 如果ASR类型未注册或语言不支持
        """
        try:
            # 根据语言选择合适的ASR引擎
            asr_type = self._get_asr_type_for_language(language)

            if asr_type not in asr_tables.asr_classes:
                raise ValueError(f"ASR类型 '{asr_type}' 未注册")

            asr_class = asr_tables.asr_classes[asr_type]
            instance = asr_class()

            logger.info(f"成功创建ASR实例: {asr_type} for language: {language}")
            return instance

        except Exception as e:
            logger.error(f"创建ASR实例失败: {e}")
            raise

    def get_or_create_asr(self, language: Literal['auto', 'zh', 'en']) -> ASRInterface:
        """
        获取或创建ASR实例(单例模式)
        
        Args:
            language: 语言类型
            
        Returns:
            ASRInterface: ASR实例
        """
        asr_type = self._get_asr_type_for_language(language)
        instance_key = f"{asr_type}_{language}"

        if instance_key not in self._asr_instances:
            self._asr_instances[instance_key] = self.create_asr(language)

        return self._asr_instances[instance_key]

    def _get_asr_type_for_language(self, language: str) -> str:
        """根据语言获取对应的ASR类型"""
        asr_type = self._language_to_asr_mapping.get(language)
        if not asr_type:
            raise ValueError(f"不支持的语言类型: {language}")
        return asr_type

    def set_language_mapping(self, language: str, asr_type: str) -> None:
        """
        设置语言到ASR引擎的映射关系
        
        Args:
            language: 语言代码
            asr_type: ASR引擎类型
        """
        if asr_type not in asr_tables.asr_classes:
            raise ValueError(f"ASR类型 '{asr_type}' 未注册")

        self._language_to_asr_mapping[language] = asr_type
        logger.info(f"更新语言映射: {language} -> {asr_type}")

    def list_registered_asr(self) -> Dict[str, Type[ASRInterface]]:
        """列出所有已注册的ASR类型"""
        return asr_tables.asr_classes.copy()

    def is_asr_registered(self, asr_type: str) -> bool:
        """检查指定ASR类型是否已注册"""
        return asr_type in asr_tables.asr_classes

    def get_supported_languages(self) -> Dict[str, List[str]]:
        """
        获取所有已注册ASR引擎支持的语言列表
        
        Returns:
            Dict[str, List[str]]: ASR引擎名称到支持语言列表的映射
        """
        supported_languages = {}

        for asr_key, asr_class in asr_tables.asr_classes.items():
            try:
                supported_languages[asr_key] = asr_class.supported_langs
                # languages = asr_tables._get_asr_supported_languages(asr_key)
                # supported_languages[asr_key] = languages
            except Exception as e:
                logger.warning(f"获取ASR引擎 '{asr_key}' 支持的语言失败: {e}")
                supported_languages[asr_key] = ['unknown']

        return supported_languages

    def get_available_languages(self) -> List[str]:
        """
        获取当前可用的所有语言列表
        
        Returns:
            List[str]: 可用的语言代码列表
        """
        all_languages = set()
        supported_langs = self.get_supported_languages()

        for asr_key, languages in supported_langs.items():
            all_languages.update(languages)

        # 移除unknown标记
        all_languages.discard('unknown')
        return sorted(list(all_languages))

    def validate_language_support(self, language: str) -> bool:
        """
        验证指定语言是否被支持
        
        Args:
            language: 语言代码
            
        Returns:
            bool: 是否支持该语言
        """
        available_languages = self.get_available_languages()
        return language in available_languages

    def get_optimal_asr_for_language(self, language: str) -> Optional[str]:
        """
        为指定语言获取最优的ASR引擎
        
        Args:
            language: 语言代码
            
        Returns:
            Optional[str]: 最优的ASR引擎名称,如果没有支持的引擎则返回None
        """
        # 检查当前映射
        if language in self._language_to_asr_mapping:
            asr_type = self._language_to_asr_mapping[language]
            if self.is_asr_registered(asr_type):
                return asr_type

        # 查找支持该语言的ASR引擎
        supported_langs = self.get_supported_languages()
        for asr_key, languages in supported_langs.items():
            if language in languages:
                return asr_key

        return None

    def cleanup(self) -> None:
        """清理所有ASR实例"""
        logger.info("清理ASR实例...")
        self._asr_instances.clear()
        logger.info("ASR实例清理完成")

    def print_registry(self) -> None:
        """打印注册表信息"""
        asr_tables.print()

    def get_asr_statistics(self) -> Dict:
        """
        获取ASR管理器的统计信息
        
        Returns:
            Dict: 包含各种统计信息的字典
        """
        return {
            'registered_asr_count': len(asr_tables.asr_classes),
            'active_instances_count': len(self._asr_instances),
            'supported_languages': self.get_available_languages(),
            'language_mappings': self._language_to_asr_mapping.copy(),
            'registered_asr_types': list(asr_tables.asr_classes.keys())
        }


# 全局ASR管理器实例
asr_manager = ASRManager()


def register_all_asr():
    """自动发现并注册所有ASR实现"""
    import importlib
    from pathlib import Path

    # 获取models目录路径
    models_dir = Path(__file__).parent / "models"

    # 扫描models目录中的Python文件
    for py_file in models_dir.glob("*.py"):
        if py_file.name in ["__init__.py", "base.py"]:
            continue

        module_name = py_file.stem
        try:
            # 动态导入模块
            spec = importlib.util.spec_from_file_location(
                module_name,
                py_file
            )
            module = importlib.util.module_from_spec(spec)
            spec.loader.exec_module(module)
            logger.info(f"Successfully imported ASR module: {module_name}")
        except ImportError as e:
            logger.warning(f"Failed to import ASR module {module_name}: {e}")
        except Exception as e:
            logger.error(f"Unexpected error importing ASR module {module_name}: {e}")


# 在模块导入时自动注册所有ASR
register_all_asr()