liumaolin commited on
Commit
59603db
·
1 Parent(s): 89f7f05

Refactor ASR module: introduce modular structure with ASR interface, implement FunASR and Whisper clients, add registry, and consolidate utility functions for enhanced maintainability and extensibility.

Browse files
src/VoiceDialogue/services/speech/asr/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ASR Module
3
+
4
+ 提供自动语音识别(ASR)功能的完整解决方案,包括:
5
+ - ASR管理器和注册系统
6
+ - 多种ASR引擎支持
7
+ - 配置管理
8
+ - 运行时接口
9
+ """
10
+
11
+ from .models import (
12
+ ASRInterface,
13
+ )
14
+ from .manager import (
15
+ ASRManager,
16
+ ASRRegistryTables,
17
+ asr_manager,
18
+ asr_tables,
19
+ register_all_asr
20
+ )
21
+
22
+ __version__ = "1.0.0"
23
+
24
+ __all__ = [
25
+ # 管理器和注册表
26
+ 'ASRManager',
27
+ 'ASRRegistryTables',
28
+ 'asr_manager',
29
+ 'asr_tables',
30
+ 'register_all_asr',
31
+
32
+ # 配置模型
33
+
34
+ # 运行时接口
35
+ 'ASRInterface',
36
+ ]
37
+
38
+ # 模块初始化时自动注册所有ASR实现
39
+ # register_all_asr() 已在 asr_manager 模块中自动调用
src/VoiceDialogue/services/speech/asr/manager.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import logging
3
+ import re
4
+ from dataclasses import dataclass
5
+ from typing import Dict, Type, List, Literal, Optional
6
+
7
+ from .models import ASRInterface
8
+
9
+
10
+ @dataclass
11
+ class ASRRegistryTables:
12
+ """ASR注册表系统,用于管理不同的ASR实现"""
13
+
14
+ asr_classes: Dict[str, Type[ASRInterface]] = None
15
+
16
+ def __post_init__(self):
17
+ if self.asr_classes is None:
18
+ self.asr_classes = {}
19
+
20
+ def print(self, key: str = None) -> None:
21
+ """打印已注册的ASR类"""
22
+ print("\nASR Registry Tables: \n")
23
+ headers = ["register name", "class name", "class location", "supported languages"]
24
+
25
+ if self.asr_classes and (key is None or "asr_classes" in key):
26
+ print(f"----------- ** asr_classes ** --------------")
27
+ metas = []
28
+ for register_key, asr_class in self.asr_classes.items():
29
+ class_file = inspect.getfile(asr_class)
30
+ class_line = inspect.getsourcelines(asr_class)[1]
31
+ # 简化路径显示
32
+ pattern = r"^.+/VoiceDialogue/"
33
+ class_file = re.sub(pattern, "VoiceDialogue/", class_file)
34
+
35
+ # 获取支持的语言
36
+ try:
37
+ supported_langs = asr_class.supported_langs
38
+ supported_langs_str = ', '.join(supported_langs) if supported_langs else 'unknown'
39
+ except:
40
+ supported_langs_str = 'unknown'
41
+
42
+ meta_data = [
43
+ register_key,
44
+ asr_class.__name__,
45
+ f"{class_file}:{class_line}",
46
+ supported_langs_str,
47
+ ]
48
+ metas.append(meta_data)
49
+
50
+ metas.sort(key=lambda x: x[0])
51
+ data = [headers] + metas
52
+ col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
53
+
54
+ for row in data:
55
+ print(
56
+ "| "
57
+ + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
58
+ + " |"
59
+ )
60
+ print("\n")
61
+
62
+ def _get_asr_supported_languages(self, asr_key: str) -> List[str]:
63
+ """获取特定ASR引擎支持的语言列表"""
64
+ # 根据ASR类型返回支持的语言
65
+ language_mapping = {
66
+ 'funasr': ['zh', 'auto'],
67
+ 'whisper': ['en', 'zh', 'auto'],
68
+ }
69
+ return language_mapping.get(asr_key, ['auto'])
70
+
71
+ def register(self, register_table_key: str, key: str = None) -> callable:
72
+ """装饰器,用于注册ASR类"""
73
+
74
+ def decorator(target_class):
75
+ if not hasattr(self, register_table_key):
76
+ setattr(self, register_table_key, {})
77
+ logging.debug(f"New ASR registry table added: {register_table_key}")
78
+
79
+ registry = getattr(self, register_table_key)
80
+ registry_key = key if key is not None else target_class.__name__
81
+
82
+ if registry_key in registry:
83
+ logging.debug(
84
+ f"Key {registry_key} already exists in {register_table_key}, re-register"
85
+ )
86
+
87
+ registry[registry_key] = target_class
88
+ logging.info(f"Registered ASR class: {registry_key} -> {target_class.__name__}")
89
+ return target_class
90
+
91
+ return decorator
92
+
93
+
94
+ # 全局ASR注册表实例
95
+ asr_tables = ASRRegistryTables()
96
+
97
+
98
+ class ASRManager:
99
+ """ASR管理器,负责管理和创建ASR实例"""
100
+
101
+ def __init__(self):
102
+ self._asr_instances: Dict[str, ASRInterface] = {}
103
+ self._language_to_asr_mapping = {
104
+ 'zh': 'funasr', # 中文优先使用FunASR
105
+ 'en': 'whisper', # 英文优先使用Whisper
106
+ 'auto': 'whisper', # 自动检测默认使用Whisper
107
+ }
108
+
109
+ def create_asr(self, language: Literal['auto', 'zh', 'en']) -> ASRInterface:
110
+ """
111
+ 根据语言配置创建ASR实例
112
+
113
+ Args:
114
+ language: 语言类型
115
+
116
+ Returns:
117
+ ASRInterface: ASR实例
118
+
119
+ Raises:
120
+ ValueError: 如果ASR类型未注册或语言不支持
121
+ """
122
+ try:
123
+ # 根据语言选择合适的ASR引擎
124
+ asr_type = self._get_asr_type_for_language(language)
125
+
126
+ if asr_type not in asr_tables.asr_classes:
127
+ raise ValueError(f"ASR类型 '{asr_type}' 未注册")
128
+
129
+ asr_class = asr_tables.asr_classes[asr_type]
130
+ instance = asr_class()
131
+
132
+ logging.info(f"成功创建ASR实例: {asr_type} for language: {language}")
133
+ return instance
134
+
135
+ except Exception as e:
136
+ logging.error(f"创建ASR实例失败: {e}")
137
+ raise
138
+
139
+ def get_or_create_asr(self, language: Literal['auto', 'zh', 'en']) -> ASRInterface:
140
+ """
141
+ 获取或创建ASR实例(单例模式)
142
+
143
+ Args:
144
+ language: 语言类型
145
+
146
+ Returns:
147
+ ASRInterface: ASR实例
148
+ """
149
+ asr_type = self._get_asr_type_for_language(language)
150
+ instance_key = f"{asr_type}_{language}"
151
+
152
+ if instance_key not in self._asr_instances:
153
+ self._asr_instances[instance_key] = self.create_asr(language)
154
+
155
+ return self._asr_instances[instance_key]
156
+
157
+ def _get_asr_type_for_language(self, language: str) -> str:
158
+ """根据语言获取对应的ASR类型"""
159
+ asr_type = self._language_to_asr_mapping.get(language)
160
+ if not asr_type:
161
+ raise ValueError(f"不支持的语言类型: {language}")
162
+ return asr_type
163
+
164
+ def set_language_mapping(self, language: str, asr_type: str) -> None:
165
+ """
166
+ 设置语言到ASR引擎的映射关系
167
+
168
+ Args:
169
+ language: 语言代码
170
+ asr_type: ASR引擎类型
171
+ """
172
+ if asr_type not in asr_tables.asr_classes:
173
+ raise ValueError(f"ASR类型 '{asr_type}' 未注册")
174
+
175
+ self._language_to_asr_mapping[language] = asr_type
176
+ logging.info(f"更新语言映射: {language} -> {asr_type}")
177
+
178
+ def list_registered_asr(self) -> Dict[str, Type[ASRInterface]]:
179
+ """列出所有已注册的ASR类型"""
180
+ return asr_tables.asr_classes.copy()
181
+
182
+ def is_asr_registered(self, asr_type: str) -> bool:
183
+ """检查指定ASR类型是否已注册"""
184
+ return asr_type in asr_tables.asr_classes
185
+
186
+ def get_supported_languages(self) -> Dict[str, List[str]]:
187
+ """
188
+ 获取所有已注册ASR引擎支持的语言列表
189
+
190
+ Returns:
191
+ Dict[str, List[str]]: ASR引擎名称到支持语言列表的映射
192
+ """
193
+ supported_languages = {}
194
+
195
+ for asr_key in asr_tables.asr_classes.keys():
196
+ try:
197
+ languages = asr_tables._get_asr_supported_languages(asr_key)
198
+ supported_languages[asr_key] = languages
199
+ except Exception as e:
200
+ logging.warning(f"获取ASR引擎 '{asr_key}' 支持的语言失败: {e}")
201
+ supported_languages[asr_key] = ['unknown']
202
+
203
+ return supported_languages
204
+
205
+ def get_available_languages(self) -> List[str]:
206
+ """
207
+ 获取当前可用的所有语言列表
208
+
209
+ Returns:
210
+ List[str]: 可用的语言代码列表
211
+ """
212
+ all_languages = set()
213
+ supported_langs = self.get_supported_languages()
214
+
215
+ for asr_key, languages in supported_langs.items():
216
+ all_languages.update(languages)
217
+
218
+ # 移除unknown标记
219
+ all_languages.discard('unknown')
220
+ return sorted(list(all_languages))
221
+
222
+ def validate_language_support(self, language: str) -> bool:
223
+ """
224
+ 验证指定语言是否被支持
225
+
226
+ Args:
227
+ language: 语言代码
228
+
229
+ Returns:
230
+ bool: 是否支持该语言
231
+ """
232
+ available_languages = self.get_available_languages()
233
+ return language in available_languages
234
+
235
+ def get_optimal_asr_for_language(self, language: str) -> Optional[str]:
236
+ """
237
+ 为指定语言获取最优的ASR引擎
238
+
239
+ Args:
240
+ language: 语言代码
241
+
242
+ Returns:
243
+ Optional[str]: 最优的ASR引擎名称,如果没有支持的引擎则返回None
244
+ """
245
+ # 检查当前映射
246
+ if language in self._language_to_asr_mapping:
247
+ asr_type = self._language_to_asr_mapping[language]
248
+ if self.is_asr_registered(asr_type):
249
+ return asr_type
250
+
251
+ # 查找支持该语言的ASR引擎
252
+ supported_langs = self.get_supported_languages()
253
+ for asr_key, languages in supported_langs.items():
254
+ if language in languages:
255
+ return asr_key
256
+
257
+ return None
258
+
259
+ def cleanup(self) -> None:
260
+ """清理所有ASR实例"""
261
+ logging.info("清理ASR实例...")
262
+ self._asr_instances.clear()
263
+ logging.info("ASR实例清理完成")
264
+
265
+ def print_registry(self) -> None:
266
+ """打印注册表信息"""
267
+ asr_tables.print()
268
+
269
+ def get_asr_statistics(self) -> Dict:
270
+ """
271
+ 获取ASR管理器的统计信息
272
+
273
+ Returns:
274
+ Dict: 包含各种统计信息的字典
275
+ """
276
+ return {
277
+ 'registered_asr_count': len(asr_tables.asr_classes),
278
+ 'active_instances_count': len(self._asr_instances),
279
+ 'supported_languages': self.get_available_languages(),
280
+ 'language_mappings': self._language_to_asr_mapping.copy(),
281
+ 'registered_asr_types': list(asr_tables.asr_classes.keys())
282
+ }
283
+
284
+
285
+ # 全局ASR管理器实例
286
+ asr_manager = ASRManager()
287
+
288
+
289
+ def register_all_asr():
290
+ """自动发现并注册所有ASR实现"""
291
+ import importlib
292
+ from pathlib import Path
293
+
294
+ # 获取models目录路径
295
+ models_dir = Path(__file__).parent / "models"
296
+
297
+ # 扫描models目录中的Python文件
298
+ for py_file in models_dir.glob("*.py"):
299
+ if py_file.name in ["__init__.py", "base.py"]:
300
+ continue
301
+
302
+ module_name = py_file.stem
303
+ try:
304
+ # 动态导入模块
305
+ module = importlib.import_module(f".models.{module_name}",
306
+ package="VoiceDialogue.services.speech.asr")
307
+ logging.info(f"Successfully imported ASR module: {module_name}")
308
+ except ImportError as e:
309
+ logging.warning(f"Failed to import ASR module {module_name}: {e}")
310
+ except Exception as e:
311
+ logging.error(f"Unexpected error importing ASR module {module_name}: {e}")
312
+
313
+
314
+ # 在模块导入时自动注册所有ASR
315
+ register_all_asr()
src/VoiceDialogue/services/speech/asr/models/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import ASRInterface
2
+
3
+ __all__ = ['ASRInterface']
4
+
5
+ try:
6
+ from .funasr import FunASRClient
7
+
8
+ __all__.append('FunASRClient')
9
+ except ImportError as e:
10
+ import logging
11
+
12
+ logging.warning(f"Failed to import some FunASR implementations: {e}")
13
+
14
+ try:
15
+ from .whisper import WhisperCppClient
16
+
17
+ __all__.append('WhisperCppClient')
18
+ except ImportError as e:
19
+ import logging
20
+
21
+ logging.warning(f"Failed to import some Whisper implementations: {e}")
src/VoiceDialogue/services/speech/asr/models/base.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from enum import Enum
3
+
4
+ import librosa
5
+ import numpy as np
6
+
7
+ from config import paths
8
+
9
+
10
+ class ASRConfigType(Enum):
11
+ """ASR引擎类型枚举"""
12
+ FUNASR = 'funasr'
13
+ WHISPER_CPP = 'whisper_cpp'
14
+
15
+
16
+ class Language(Enum):
17
+ """支持的语言枚举"""
18
+ AUTO = 'auto'
19
+ CHINESE = 'zh'
20
+ ENGLISH = 'en'
21
+
22
+
23
+ class ASRInterface(ABC):
24
+ """ASR服务的抽象接口"""
25
+ supported_langs = []
26
+
27
+ def __init__(self):
28
+ warmup_audiofile = paths.RESOURCES_PATH / 'audio' / 'jfk.flac'
29
+ if warmup_audiofile.exists():
30
+ audiodata, _ = librosa.load(warmup_audiofile, sr=16000, mono=True)
31
+ else:
32
+ # 创建测试音频
33
+ audiodata = np.random.randn(16000).astype(np.float32) * 0.1 # 1秒的噪声
34
+ self.warmup_audiodata = audiodata
35
+
36
+ @abstractmethod
37
+ def setup(self, **kwargs) -> None:
38
+ """
39
+ 初始化ASR服务
40
+
41
+ Args:
42
+ **kwargs: 额外的初始化参数
43
+ """
44
+ pass
45
+
46
+ @abstractmethod
47
+ def warmup(self) -> None:
48
+ """预热ASR引擎"""
49
+ pass
50
+
51
+ @abstractmethod
52
+ def transcribe(self, audio_array: np.ndarray, language: str = None) -> str:
53
+ """
54
+ 将音频转换为文本
55
+
56
+ Args:
57
+ audio_array: 音频数据
58
+ language: 指定语言,如果为None则使用配置中的语言
59
+
60
+ Returns:
61
+ str: 识别结果文本
62
+ """
63
+ pass
src/VoiceDialogue/services/speech/asr/models/funasr.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import typing
3
+
4
+ import numpy as np
5
+ from funasr_onnx import SeacoParaformer, CT_Transformer
6
+
7
+ from config import paths
8
+ from .base import ASRInterface
9
+ from ..manager import asr_tables
10
+ from ..utils import ensure_minimum_audio_duration
11
+
12
+
13
+ @asr_tables.register('asr_classes', 'funasr')
14
+ class FunASRClient(ASRInterface):
15
+ """FunASR API客户端"""
16
+ supported_langs = ['zh']
17
+
18
+ def __init__(self):
19
+ super().__init__()
20
+ self.funasr_model: typing.Optional[SeacoParaformer] = None
21
+ self.punc_model: typing.Optional[CT_Transformer] = None
22
+
23
+ def setup(self, **kwargs) -> None:
24
+ # 设置模型缓存目录
25
+ models_dir = paths.MODELS_PATH / "asr"
26
+ asr_model_path = models_dir / "speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
27
+ punc_model_path = models_dir / "punc_ct-transformer_cn-en-common-vocab471067-large"
28
+ self.funasr_model = SeacoParaformer(asr_model_path, quantize=True)
29
+ self.punc_model = CT_Transformer(punc_model_path, quantize=True)
30
+
31
+ def warmup(self) -> None:
32
+ print('[INFO] Warming up FunASR model...')
33
+ try:
34
+ self.transcribe(self.warmup_audiodata)
35
+ print('[INFO] FunASR model warmed up.')
36
+ except Exception as e:
37
+ print(f'[WARNING] FunASR model warmup failed: {e}')
38
+
39
+ def _fix_spaced_uppercase(self, text: str) -> str:
40
+ """
41
+ 修复类似 " G N O M E " 这样的大写字母间有空格的字符串,将其替换为 "GNOME"
42
+ """
43
+ # 匹配大写字母之间的空格模式,至少2个大写字母
44
+ pattern = r'([A-Z])\s+([A-Z](?:\s+[A-Z])*)'
45
+
46
+ def replace_func(match):
47
+ # 移除所有空格
48
+ return match.group(0).replace(' ', '')
49
+
50
+ return re.sub(pattern, replace_func, text)
51
+
52
+ def transcribe(self, audio_array: np.ndarray, language="auto"):
53
+ audio_array = ensure_minimum_audio_duration(audio_array)
54
+
55
+ segments = self.funasr_model(wav_content=audio_array, hotwords='')
56
+
57
+ transcibed_texts = []
58
+ for segment in segments:
59
+ content = segment.get("preds", "")
60
+ content, _ = self.punc_model(content)
61
+ content = self._fix_spaced_uppercase(content)
62
+ transcibed_texts.append(content)
63
+ return " ".join(transcibed_texts)
src/VoiceDialogue/services/speech/asr/models/whisper.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+
3
+ import numpy as np
4
+ from pywhispercpp.model import Model
5
+
6
+ from config import paths
7
+ from .base import ASRInterface
8
+ from ..manager import asr_tables
9
+ from ..utils import ensure_minimum_audio_duration
10
+
11
+
12
+ @asr_tables.register('asr_classes', 'whisper')
13
+ class WhisperCppClient(ASRInterface):
14
+ """Whisper C++ API客户端"""
15
+ supported_langs = ['en', 'zh', 'auto']
16
+
17
+ def __init__(self):
18
+ super().__init__()
19
+ self.whisper: typing.Optional[Model] = None
20
+ self.language = "en"
21
+
22
+ def setup(self, **kwargs) -> None:
23
+ model = kwargs.get('model', 'medium')
24
+ if model == "medium":
25
+ model = "medium-q5_0"
26
+ else:
27
+ model = "large-v3-turbo-q5_0"
28
+
29
+ models_dir = paths.MODELS_PATH / "asr"
30
+ self.whisper = Model(model=model, models_dir=models_dir)
31
+
32
+ def warmup(self) -> None:
33
+ print('[INFO] Warming up Whisper model...')
34
+ try:
35
+ self.transcribe(self.warmup_audiodata)
36
+ print('[INFO] Whisper model warmed up.')
37
+ except Exception as e:
38
+ print(f'[WARNING] Whisper model warmup failed: {e}')
39
+
40
+ def transcribe(self, audio_array: np.ndarray, language="en"):
41
+ if language == "zh":
42
+ prompt = "以下是简体中文普通话的句子。"
43
+ else:
44
+ prompt = "The following is an English sentence."
45
+
46
+ audio_array = ensure_minimum_audio_duration(audio_array)
47
+
48
+ # print('............... language:', language)
49
+ segments = self.whisper.transcribe(
50
+ audio_array, language=language, initial_prompt=prompt, print_progress=False
51
+ )
52
+ text = []
53
+ for segment in segments:
54
+ content = segment.text
55
+ # if not content.endswith(()):
56
+ # content += ','
57
+ text.append(content)
58
+ text = " ".join(text)
59
+ return text
src/VoiceDialogue/services/speech/asr/utils.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ASR模块的工具函数
3
+ 包含音频预处理、格式转换等工具函数
4
+ """
5
+
6
+ import numpy as np
7
+
8
+
9
+ def ensure_minimum_audio_duration(
10
+ audio_array: np.ndarray, min_duration: float = 1.0, sample_rate: int = 16000
11
+ ) -> np.ndarray:
12
+ """
13
+ 确保音频数组满足最小时长要求,如果不足则用静音填充
14
+
15
+ Args:
16
+ audio_array: 输入音频数组
17
+ min_duration: 最小时长要求(秒),默认1秒
18
+ sample_rate: 采样率,默认16000Hz
19
+
20
+ Returns:
21
+ 处理后的音频数组
22
+ """
23
+ audio_duration = audio_array.shape[-1] / sample_rate
24
+
25
+ if audio_duration < min_duration:
26
+ padding_seconds = min_duration - audio_duration
27
+ audio_array = padding_silence(audio_array, padding_seconds, sample_rate)
28
+
29
+ return audio_array
30
+
31
+
32
+ def padding_silence(
33
+ audio_data: np.ndarray, duration_seconds: float, sample_rate: int = 16000
34
+ ) -> np.ndarray:
35
+ """
36
+ 为音频数据添加静音填充
37
+
38
+ Args:
39
+ audio_data: 原始音频数据
40
+ duration_seconds: 需要填充的时长(秒)
41
+ sample_rate: 采样率
42
+
43
+ Returns:
44
+ 填充后的音频数据
45
+ """
46
+ frequency = 440.0
47
+ duration = duration_seconds + 0.1
48
+ t = np.linspace(
49
+ 0, duration, int(sample_rate * duration), endpoint=False, dtype=audio_data.dtype
50
+ )
51
+ silence = 0.5 * np.sin(2 * np.pi * frequency * t)
52
+ audio_data = np.concatenate([audio_data, silence])
53
+ return audio_data
54
+
55
+
56
+ def validate_audio_array(audio_array: np.ndarray) -> bool:
57
+ """
58
+ 验证音频数组是否有效
59
+
60
+ Args:
61
+ audio_array: 音频数组
62
+
63
+ Returns:
64
+ bool: 是否为有效的音频数组
65
+ """
66
+ if audio_array is None:
67
+ return False
68
+
69
+ if not isinstance(audio_array, np.ndarray):
70
+ return False
71
+
72
+ if audio_array.size == 0:
73
+ return False
74
+
75
+ if len(audio_array.shape) > 2:
76
+ return False
77
+
78
+ return True
79
+
80
+
81
+ def normalize_audio(audio_array: np.ndarray, target_peak: float = 0.95) -> np.ndarray:
82
+ """
83
+ 标准化音频数组的音量
84
+
85
+ Args:
86
+ audio_array: 输入音频数组
87
+ target_peak: 目标峰值,默认0.95
88
+
89
+ Returns:
90
+ 标准化后的音频数组
91
+ """
92
+ if not validate_audio_array(audio_array):
93
+ raise ValueError("Invalid audio array")
94
+
95
+ # 获取当前峰值
96
+ current_peak = np.max(np.abs(audio_array))
97
+
98
+ if current_peak == 0:
99
+ return audio_array
100
+
101
+ # 计算缩放因子
102
+ scale_factor = target_peak / current_peak
103
+
104
+ # 应用缩放
105
+ normalized_audio = audio_array * scale_factor
106
+
107
+ return normalized_audio
108
+
109
+
110
+ def convert_sample_rate(
111
+ audio_array: np.ndarray,
112
+ source_rate: int,
113
+ target_rate: int
114
+ ) -> np.ndarray:
115
+ """
116
+ 转换音频采样率
117
+
118
+ Args:
119
+ audio_array: 输入音频数组
120
+ source_rate: 源采样率
121
+ target_rate: 目标采样率
122
+
123
+ Returns:
124
+ 转换后的音频数组
125
+ """
126
+ if source_rate == target_rate:
127
+ return audio_array
128
+
129
+ try:
130
+ import librosa
131
+ return librosa.resample(audio_array, orig_sr=source_rate, target_sr=target_rate)
132
+ except ImportError:
133
+ # 如果没有librosa,使用简单的重采样
134
+ ratio = target_rate / source_rate
135
+ new_length = int(len(audio_array) * ratio)
136
+ indices = np.linspace(0, len(audio_array) - 1, new_length)
137
+ return np.interp(indices, np.arange(len(audio_array)), audio_array)
138
+
139
+
140
+ def trim_silence(
141
+ audio_array: np.ndarray,
142
+ threshold: float = 0.01,
143
+ sample_rate: int = 16000
144
+ ) -> np.ndarray:
145
+ """
146
+ 修剪音频开头和结尾的静音部分
147
+
148
+ Args:
149
+ audio_array: 输入音频数组
150
+ threshold: 静音检测阈值
151
+ sample_rate: 采样率
152
+
153
+ Returns:
154
+ 修剪后的音频数组
155
+ """
156
+ if not validate_audio_array(audio_array):
157
+ return audio_array
158
+
159
+ # 计算音频的绝对值
160
+ audio_abs = np.abs(audio_array)
161
+
162
+ # 找到非静音部分的开始和结束
163
+ non_silent = audio_abs > threshold
164
+
165
+ if not np.any(non_silent):
166
+ # 如果全是静音,返回最小长度的音频
167
+ min_samples = int(0.1 * sample_rate) # 100ms
168
+ return audio_array[:min_samples] if len(audio_array) > min_samples else audio_array
169
+
170
+ # 找到第一个和最后一个非静音样本
171
+ start_idx = np.argmax(non_silent)
172
+ end_idx = len(non_silent) - np.argmax(non_silent[::-1])
173
+
174
+ return audio_array[start_idx:end_idx]
175
+
176
+
177
+ def get_audio_duration(audio_array: np.ndarray, sample_rate: int = 16000) -> float:
178
+ """
179
+ 获取音频时长(秒)
180
+
181
+ Args:
182
+ audio_array: 音频数组
183
+ sample_rate: 采样率
184
+
185
+ Returns:
186
+ 音频时长(秒)
187
+ """
188
+ if not validate_audio_array(audio_array):
189
+ return 0.0
190
+
191
+ return audio_array.shape[-1] / sample_rate
192
+
193
+
194
+ def create_silence(duration_seconds: float, sample_rate: int = 16000) -> np.ndarray:
195
+ """
196
+ 创建指定时长的静音
197
+
198
+ Args:
199
+ duration_seconds: 静音时长(秒)
200
+ sample_rate: 采样率
201
+
202
+ Returns:
203
+ 静音音频数组
204
+ """
205
+ num_samples = int(duration_seconds * sample_rate)
206
+ return np.zeros(num_samples, dtype=np.float32)
src/VoiceDialogue/services/speech/asr_service.py CHANGED
@@ -1,168 +1,14 @@
1
- import re
2
  import time
3
  import typing
4
  from queue import Queue
5
 
6
- import librosa
7
  import numpy as np
8
- from funasr_onnx import SeacoParaformer, CT_Transformer
9
- from pywhispercpp.model import Model
10
 
11
- from config import paths
12
  from models.voice_task import VoiceTask
13
  from services.core.base import BaseThread
14
  from services.core.constants import user_still_speaking_event, voice_state_manager, dropped_audio_cache
15
  from utils.cache import LRUCacheDict
16
-
17
-
18
- def ensure_minimum_audio_duration(
19
- audio_array: np.ndarray, min_duration: float = 1.0, sample_rate: int = 16000
20
- ) -> np.ndarray:
21
- """
22
- 确保音频数组满足最小时长要求,如果不足则用静音填充
23
-
24
- Args:
25
- audio_array: 输入音频数组
26
- min_duration: 最小时长要求(秒),默认1秒
27
- sample_rate: 采样率,默认16000Hz
28
-
29
- Returns:
30
- 处理后的音频数组
31
- """
32
- audio_duration = audio_array.shape[-1] / sample_rate
33
-
34
- if audio_duration < min_duration:
35
- padding_seconds = min_duration - audio_duration
36
- audio_array = padding_silence(audio_array, padding_seconds, sample_rate)
37
-
38
- return audio_array
39
-
40
-
41
- def padding_silence(
42
- audio_data: np.ndarray, duration_seconds: float, sample_rate: int = 16000
43
- ) -> np.ndarray:
44
- """
45
- 为音频数据添加静音填充
46
-
47
- Args:
48
- audio_data: 原始音频数据
49
- duration_seconds: 需要填充的时长(秒)
50
- sample_rate: 采样率
51
-
52
- Returns:
53
- 填充后的音频数据
54
- """
55
- frequency = 440.0
56
- duration = duration_seconds + 0.1
57
- t = np.linspace(
58
- 0, duration, int(sample_rate * duration), endpoint=False, dtype=audio_data.dtype
59
- )
60
- silence = 0.5 * np.sin(2 * np.pi * frequency * t)
61
- audio_data = np.concatenate([audio_data, silence])
62
- return audio_data
63
-
64
-
65
- class WhisperCppClient:
66
- """Whisper C++ API客户端"""
67
-
68
- def __init__(self, model: typing.Literal["medium", "large"] = "medium"):
69
- if model == "medium":
70
- model = "medium-q5_0"
71
- else:
72
- model = "large-v3-turbo-q5_0"
73
-
74
- models_dir = paths.MODELS_PATH / "asr"
75
- self.whisper = Model(model=model, models_dir=models_dir)
76
-
77
- def transcribe(self, audio_array: np.ndarray, language="en"):
78
- if language == "zh":
79
- prompt = "以下是简体中文普通话的句子。"
80
- else:
81
- prompt = "The following is an English sentence."
82
-
83
- audio_array = ensure_minimum_audio_duration(audio_array)
84
-
85
- # print('............... language:', language)
86
- segments = self.whisper.transcribe(
87
- audio_array, language=language, initial_prompt=prompt, print_progress=False
88
- )
89
- text = []
90
- for segment in segments:
91
- content = segment.text
92
- # if not content.endswith(()):
93
- # content += ','
94
- text.append(content)
95
- text = " ".join(text)
96
- return text
97
-
98
-
99
- class FunASRClient:
100
- """FunASR API客户端"""
101
-
102
- def __init__(self):
103
- # 设置模型缓存目录
104
- models_dir = paths.MODELS_PATH / "asr"
105
- asr_model_path = models_dir / "speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
106
- punc_model_path = models_dir / "punc_ct-transformer_cn-en-common-vocab471067-large"
107
- self.funasr_model = SeacoParaformer(asr_model_path, quantize=True)
108
- self.punc_model = CT_Transformer(punc_model_path, quantize=True)
109
-
110
- def _fix_spaced_uppercase(self, text: str) -> str:
111
- """
112
- 修复类似 " G N O M E " 这样的大写字母间有空格的字符串,将其替换为 "GNOME"
113
- """
114
- # 匹配大写字母之间的空格模式,至少2个大写字母
115
- pattern = r'([A-Z])\s+([A-Z](?:\s+[A-Z])*)'
116
-
117
- def replace_func(match):
118
- # 移除所有空格
119
- return match.group(0).replace(' ', '')
120
-
121
- return re.sub(pattern, replace_func, text)
122
-
123
- def transcribe(self, audio_array: np.ndarray, language="auto"):
124
- audio_array = ensure_minimum_audio_duration(audio_array)
125
-
126
- segments = self.funasr_model(wav_content=audio_array, hotwords='')
127
-
128
- transcibed_texts = []
129
- for segment in segments:
130
- content = segment.get("preds", "")
131
- content, _ = self.punc_model(content)
132
- content = self._fix_spaced_uppercase(content)
133
- transcibed_texts.append(content)
134
- return " ".join(transcibed_texts)
135
-
136
-
137
- class UnifiedASRClient:
138
- """统一的语音识别客户端,根据语言自动选择FunASR或Whisper"""
139
-
140
- def __init__(self, language: typing.Literal["auto", "zh", "en"] = "zh"):
141
- self.language = language
142
-
143
- if language == "zh":
144
- self.client = FunASRClient()
145
- else:
146
- self.client = WhisperCppClient()
147
-
148
- def warmup(self):
149
- """预热模型"""
150
- print('[INFO] 预热语音识别模型...')
151
- try:
152
- warmup_audiofile = paths.RESOURCES_PATH / 'audio' / 'jfk.flac'
153
- if warmup_audiofile.exists():
154
- data, sr = librosa.load(warmup_audiofile, sr=16000, mono=True)
155
- self.client.transcribe(data, language=self.language)
156
- else:
157
- # 创建测试音频
158
- test_audio = np.random.randn(16000).astype(np.float32) * 0.1 # 1秒的噪声
159
- self.client.transcribe(test_audio, language=self.language)
160
- print('[INFO] ASR模型预热完成')
161
- except Exception as e:
162
- print(f'[WARNING] ASR模型预热失败: {e}')
163
-
164
- def transcribe(self, audio_array: np.ndarray) -> str:
165
- return self.client.transcribe(audio_array, language=self.language)
166
 
167
 
168
  class ASRWorker(BaseThread):
@@ -179,7 +25,8 @@ class ASRWorker(BaseThread):
179
  self.cached_user_questions = LRUCacheDict(maxsize=10)
180
 
181
  def run(self):
182
- self.client = UnifiedASRClient(self.language)
 
183
  self.client.warmup()
184
 
185
  self.is_ready = True
 
 
1
  import time
2
  import typing
3
  from queue import Queue
4
 
 
5
  import numpy as np
 
 
6
 
 
7
  from models.voice_task import VoiceTask
8
  from services.core.base import BaseThread
9
  from services.core.constants import user_still_speaking_event, voice_state_manager, dropped_audio_cache
10
  from utils.cache import LRUCacheDict
11
+ from .asr import asr_manager
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  class ASRWorker(BaseThread):
 
25
  self.cached_user_questions = LRUCacheDict(maxsize=10)
26
 
27
  def run(self):
28
+ self.client = asr_manager.create_asr(self.language)
29
+ self.client.setup()
30
  self.client.warmup()
31
 
32
  self.is_ready = True