liumaolin
commited on
Commit
·
59603db
1
Parent(s):
89f7f05
Refactor ASR module: introduce modular structure with ASR interface, implement FunASR and Whisper clients, add registry, and consolidate utility functions for enhanced maintainability and extensibility.
Browse files- src/VoiceDialogue/services/speech/asr/__init__.py +39 -0
- src/VoiceDialogue/services/speech/asr/manager.py +315 -0
- src/VoiceDialogue/services/speech/asr/models/__init__.py +21 -0
- src/VoiceDialogue/services/speech/asr/models/base.py +63 -0
- src/VoiceDialogue/services/speech/asr/models/funasr.py +63 -0
- src/VoiceDialogue/services/speech/asr/models/whisper.py +59 -0
- src/VoiceDialogue/services/speech/asr/utils.py +206 -0
- src/VoiceDialogue/services/speech/asr_service.py +3 -156
src/VoiceDialogue/services/speech/asr/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ASR Module
|
| 3 |
+
|
| 4 |
+
提供自动语音识别(ASR)功能的完整解决方案,包括:
|
| 5 |
+
- ASR管理器和注册系统
|
| 6 |
+
- 多种ASR引擎支持
|
| 7 |
+
- 配置管理
|
| 8 |
+
- 运行时接口
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from .models import (
|
| 12 |
+
ASRInterface,
|
| 13 |
+
)
|
| 14 |
+
from .manager import (
|
| 15 |
+
ASRManager,
|
| 16 |
+
ASRRegistryTables,
|
| 17 |
+
asr_manager,
|
| 18 |
+
asr_tables,
|
| 19 |
+
register_all_asr
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
__version__ = "1.0.0"
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
# 管理器和注册表
|
| 26 |
+
'ASRManager',
|
| 27 |
+
'ASRRegistryTables',
|
| 28 |
+
'asr_manager',
|
| 29 |
+
'asr_tables',
|
| 30 |
+
'register_all_asr',
|
| 31 |
+
|
| 32 |
+
# 配置模型
|
| 33 |
+
|
| 34 |
+
# 运行时接口
|
| 35 |
+
'ASRInterface',
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
# 模块初始化时自动注册所有ASR实现
|
| 39 |
+
# register_all_asr() 已在 asr_manager 模块中自动调用
|
src/VoiceDialogue/services/speech/asr/manager.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import logging
|
| 3 |
+
import re
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Dict, Type, List, Literal, Optional
|
| 6 |
+
|
| 7 |
+
from .models import ASRInterface
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class ASRRegistryTables:
|
| 12 |
+
"""ASR注册表系统,用于管理不同的ASR实现"""
|
| 13 |
+
|
| 14 |
+
asr_classes: Dict[str, Type[ASRInterface]] = None
|
| 15 |
+
|
| 16 |
+
def __post_init__(self):
|
| 17 |
+
if self.asr_classes is None:
|
| 18 |
+
self.asr_classes = {}
|
| 19 |
+
|
| 20 |
+
def print(self, key: str = None) -> None:
|
| 21 |
+
"""打印已注册的ASR类"""
|
| 22 |
+
print("\nASR Registry Tables: \n")
|
| 23 |
+
headers = ["register name", "class name", "class location", "supported languages"]
|
| 24 |
+
|
| 25 |
+
if self.asr_classes and (key is None or "asr_classes" in key):
|
| 26 |
+
print(f"----------- ** asr_classes ** --------------")
|
| 27 |
+
metas = []
|
| 28 |
+
for register_key, asr_class in self.asr_classes.items():
|
| 29 |
+
class_file = inspect.getfile(asr_class)
|
| 30 |
+
class_line = inspect.getsourcelines(asr_class)[1]
|
| 31 |
+
# 简化路径显示
|
| 32 |
+
pattern = r"^.+/VoiceDialogue/"
|
| 33 |
+
class_file = re.sub(pattern, "VoiceDialogue/", class_file)
|
| 34 |
+
|
| 35 |
+
# 获取支持的语言
|
| 36 |
+
try:
|
| 37 |
+
supported_langs = asr_class.supported_langs
|
| 38 |
+
supported_langs_str = ', '.join(supported_langs) if supported_langs else 'unknown'
|
| 39 |
+
except:
|
| 40 |
+
supported_langs_str = 'unknown'
|
| 41 |
+
|
| 42 |
+
meta_data = [
|
| 43 |
+
register_key,
|
| 44 |
+
asr_class.__name__,
|
| 45 |
+
f"{class_file}:{class_line}",
|
| 46 |
+
supported_langs_str,
|
| 47 |
+
]
|
| 48 |
+
metas.append(meta_data)
|
| 49 |
+
|
| 50 |
+
metas.sort(key=lambda x: x[0])
|
| 51 |
+
data = [headers] + metas
|
| 52 |
+
col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
|
| 53 |
+
|
| 54 |
+
for row in data:
|
| 55 |
+
print(
|
| 56 |
+
"| "
|
| 57 |
+
+ " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
|
| 58 |
+
+ " |"
|
| 59 |
+
)
|
| 60 |
+
print("\n")
|
| 61 |
+
|
| 62 |
+
def _get_asr_supported_languages(self, asr_key: str) -> List[str]:
|
| 63 |
+
"""获取特定ASR引擎支持的语言列表"""
|
| 64 |
+
# 根据ASR类型返回支持的语言
|
| 65 |
+
language_mapping = {
|
| 66 |
+
'funasr': ['zh', 'auto'],
|
| 67 |
+
'whisper': ['en', 'zh', 'auto'],
|
| 68 |
+
}
|
| 69 |
+
return language_mapping.get(asr_key, ['auto'])
|
| 70 |
+
|
| 71 |
+
def register(self, register_table_key: str, key: str = None) -> callable:
|
| 72 |
+
"""装饰器,用于注册ASR类"""
|
| 73 |
+
|
| 74 |
+
def decorator(target_class):
|
| 75 |
+
if not hasattr(self, register_table_key):
|
| 76 |
+
setattr(self, register_table_key, {})
|
| 77 |
+
logging.debug(f"New ASR registry table added: {register_table_key}")
|
| 78 |
+
|
| 79 |
+
registry = getattr(self, register_table_key)
|
| 80 |
+
registry_key = key if key is not None else target_class.__name__
|
| 81 |
+
|
| 82 |
+
if registry_key in registry:
|
| 83 |
+
logging.debug(
|
| 84 |
+
f"Key {registry_key} already exists in {register_table_key}, re-register"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
registry[registry_key] = target_class
|
| 88 |
+
logging.info(f"Registered ASR class: {registry_key} -> {target_class.__name__}")
|
| 89 |
+
return target_class
|
| 90 |
+
|
| 91 |
+
return decorator
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# 全局ASR注册表实例
|
| 95 |
+
asr_tables = ASRRegistryTables()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class ASRManager:
|
| 99 |
+
"""ASR管理器,负责管理和创建ASR实例"""
|
| 100 |
+
|
| 101 |
+
def __init__(self):
|
| 102 |
+
self._asr_instances: Dict[str, ASRInterface] = {}
|
| 103 |
+
self._language_to_asr_mapping = {
|
| 104 |
+
'zh': 'funasr', # 中文优先使用FunASR
|
| 105 |
+
'en': 'whisper', # 英文优先使用Whisper
|
| 106 |
+
'auto': 'whisper', # 自动检测默认使用Whisper
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
def create_asr(self, language: Literal['auto', 'zh', 'en']) -> ASRInterface:
|
| 110 |
+
"""
|
| 111 |
+
根据语言配置创建ASR实例
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
language: 语言类型
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
ASRInterface: ASR实例
|
| 118 |
+
|
| 119 |
+
Raises:
|
| 120 |
+
ValueError: 如果ASR类型未注册或语言不支持
|
| 121 |
+
"""
|
| 122 |
+
try:
|
| 123 |
+
# 根据语言选择合适的ASR引擎
|
| 124 |
+
asr_type = self._get_asr_type_for_language(language)
|
| 125 |
+
|
| 126 |
+
if asr_type not in asr_tables.asr_classes:
|
| 127 |
+
raise ValueError(f"ASR类型 '{asr_type}' 未注册")
|
| 128 |
+
|
| 129 |
+
asr_class = asr_tables.asr_classes[asr_type]
|
| 130 |
+
instance = asr_class()
|
| 131 |
+
|
| 132 |
+
logging.info(f"成功创建ASR实例: {asr_type} for language: {language}")
|
| 133 |
+
return instance
|
| 134 |
+
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logging.error(f"创建ASR实例失败: {e}")
|
| 137 |
+
raise
|
| 138 |
+
|
| 139 |
+
def get_or_create_asr(self, language: Literal['auto', 'zh', 'en']) -> ASRInterface:
|
| 140 |
+
"""
|
| 141 |
+
获取或创建ASR实例(单例模式)
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
language: 语言类型
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
ASRInterface: ASR实例
|
| 148 |
+
"""
|
| 149 |
+
asr_type = self._get_asr_type_for_language(language)
|
| 150 |
+
instance_key = f"{asr_type}_{language}"
|
| 151 |
+
|
| 152 |
+
if instance_key not in self._asr_instances:
|
| 153 |
+
self._asr_instances[instance_key] = self.create_asr(language)
|
| 154 |
+
|
| 155 |
+
return self._asr_instances[instance_key]
|
| 156 |
+
|
| 157 |
+
def _get_asr_type_for_language(self, language: str) -> str:
|
| 158 |
+
"""根据语言获取对应的ASR类型"""
|
| 159 |
+
asr_type = self._language_to_asr_mapping.get(language)
|
| 160 |
+
if not asr_type:
|
| 161 |
+
raise ValueError(f"不支持的语言类型: {language}")
|
| 162 |
+
return asr_type
|
| 163 |
+
|
| 164 |
+
def set_language_mapping(self, language: str, asr_type: str) -> None:
|
| 165 |
+
"""
|
| 166 |
+
设置语言到ASR引擎的映射关系
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
language: 语言代码
|
| 170 |
+
asr_type: ASR引擎类型
|
| 171 |
+
"""
|
| 172 |
+
if asr_type not in asr_tables.asr_classes:
|
| 173 |
+
raise ValueError(f"ASR类型 '{asr_type}' 未注册")
|
| 174 |
+
|
| 175 |
+
self._language_to_asr_mapping[language] = asr_type
|
| 176 |
+
logging.info(f"更新语言映射: {language} -> {asr_type}")
|
| 177 |
+
|
| 178 |
+
def list_registered_asr(self) -> Dict[str, Type[ASRInterface]]:
|
| 179 |
+
"""列出所有已注册的ASR类型"""
|
| 180 |
+
return asr_tables.asr_classes.copy()
|
| 181 |
+
|
| 182 |
+
def is_asr_registered(self, asr_type: str) -> bool:
|
| 183 |
+
"""检查指定ASR类型是否已注册"""
|
| 184 |
+
return asr_type in asr_tables.asr_classes
|
| 185 |
+
|
| 186 |
+
def get_supported_languages(self) -> Dict[str, List[str]]:
|
| 187 |
+
"""
|
| 188 |
+
获取所有已注册ASR引擎支持的语言列表
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
Dict[str, List[str]]: ASR引擎名称到支持语言列表的映射
|
| 192 |
+
"""
|
| 193 |
+
supported_languages = {}
|
| 194 |
+
|
| 195 |
+
for asr_key in asr_tables.asr_classes.keys():
|
| 196 |
+
try:
|
| 197 |
+
languages = asr_tables._get_asr_supported_languages(asr_key)
|
| 198 |
+
supported_languages[asr_key] = languages
|
| 199 |
+
except Exception as e:
|
| 200 |
+
logging.warning(f"获取ASR引擎 '{asr_key}' 支持的语言失败: {e}")
|
| 201 |
+
supported_languages[asr_key] = ['unknown']
|
| 202 |
+
|
| 203 |
+
return supported_languages
|
| 204 |
+
|
| 205 |
+
def get_available_languages(self) -> List[str]:
|
| 206 |
+
"""
|
| 207 |
+
获取当前可用的所有语言列表
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
List[str]: 可用的语言代码列表
|
| 211 |
+
"""
|
| 212 |
+
all_languages = set()
|
| 213 |
+
supported_langs = self.get_supported_languages()
|
| 214 |
+
|
| 215 |
+
for asr_key, languages in supported_langs.items():
|
| 216 |
+
all_languages.update(languages)
|
| 217 |
+
|
| 218 |
+
# 移除unknown标记
|
| 219 |
+
all_languages.discard('unknown')
|
| 220 |
+
return sorted(list(all_languages))
|
| 221 |
+
|
| 222 |
+
def validate_language_support(self, language: str) -> bool:
|
| 223 |
+
"""
|
| 224 |
+
验证指定语言是否被支持
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
language: 语言代码
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
bool: 是否支持该语言
|
| 231 |
+
"""
|
| 232 |
+
available_languages = self.get_available_languages()
|
| 233 |
+
return language in available_languages
|
| 234 |
+
|
| 235 |
+
def get_optimal_asr_for_language(self, language: str) -> Optional[str]:
|
| 236 |
+
"""
|
| 237 |
+
为指定语言获取最优的ASR引擎
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
language: 语言代码
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
Optional[str]: 最优的ASR引擎名称,如果没有支持的引擎则返回None
|
| 244 |
+
"""
|
| 245 |
+
# 检查当前映射
|
| 246 |
+
if language in self._language_to_asr_mapping:
|
| 247 |
+
asr_type = self._language_to_asr_mapping[language]
|
| 248 |
+
if self.is_asr_registered(asr_type):
|
| 249 |
+
return asr_type
|
| 250 |
+
|
| 251 |
+
# 查找支持该语言的ASR引擎
|
| 252 |
+
supported_langs = self.get_supported_languages()
|
| 253 |
+
for asr_key, languages in supported_langs.items():
|
| 254 |
+
if language in languages:
|
| 255 |
+
return asr_key
|
| 256 |
+
|
| 257 |
+
return None
|
| 258 |
+
|
| 259 |
+
def cleanup(self) -> None:
|
| 260 |
+
"""清理所有ASR实例"""
|
| 261 |
+
logging.info("清理ASR实例...")
|
| 262 |
+
self._asr_instances.clear()
|
| 263 |
+
logging.info("ASR实例清理完成")
|
| 264 |
+
|
| 265 |
+
def print_registry(self) -> None:
|
| 266 |
+
"""打印注册表信息"""
|
| 267 |
+
asr_tables.print()
|
| 268 |
+
|
| 269 |
+
def get_asr_statistics(self) -> Dict:
|
| 270 |
+
"""
|
| 271 |
+
获取ASR管理器的统计信息
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
Dict: 包含各种统计信息的字典
|
| 275 |
+
"""
|
| 276 |
+
return {
|
| 277 |
+
'registered_asr_count': len(asr_tables.asr_classes),
|
| 278 |
+
'active_instances_count': len(self._asr_instances),
|
| 279 |
+
'supported_languages': self.get_available_languages(),
|
| 280 |
+
'language_mappings': self._language_to_asr_mapping.copy(),
|
| 281 |
+
'registered_asr_types': list(asr_tables.asr_classes.keys())
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# 全局ASR管理器实例
|
| 286 |
+
asr_manager = ASRManager()
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def register_all_asr():
|
| 290 |
+
"""自动发现并注册所有ASR实现"""
|
| 291 |
+
import importlib
|
| 292 |
+
from pathlib import Path
|
| 293 |
+
|
| 294 |
+
# 获取models目录路径
|
| 295 |
+
models_dir = Path(__file__).parent / "models"
|
| 296 |
+
|
| 297 |
+
# 扫描models目录中的Python文件
|
| 298 |
+
for py_file in models_dir.glob("*.py"):
|
| 299 |
+
if py_file.name in ["__init__.py", "base.py"]:
|
| 300 |
+
continue
|
| 301 |
+
|
| 302 |
+
module_name = py_file.stem
|
| 303 |
+
try:
|
| 304 |
+
# 动态导入模块
|
| 305 |
+
module = importlib.import_module(f".models.{module_name}",
|
| 306 |
+
package="VoiceDialogue.services.speech.asr")
|
| 307 |
+
logging.info(f"Successfully imported ASR module: {module_name}")
|
| 308 |
+
except ImportError as e:
|
| 309 |
+
logging.warning(f"Failed to import ASR module {module_name}: {e}")
|
| 310 |
+
except Exception as e:
|
| 311 |
+
logging.error(f"Unexpected error importing ASR module {module_name}: {e}")
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# 在模块导入时自动注册所有ASR
|
| 315 |
+
register_all_asr()
|
src/VoiceDialogue/services/speech/asr/models/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import ASRInterface
|
| 2 |
+
|
| 3 |
+
__all__ = ['ASRInterface']
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from .funasr import FunASRClient
|
| 7 |
+
|
| 8 |
+
__all__.append('FunASRClient')
|
| 9 |
+
except ImportError as e:
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
logging.warning(f"Failed to import some FunASR implementations: {e}")
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from .whisper import WhisperCppClient
|
| 16 |
+
|
| 17 |
+
__all__.append('WhisperCppClient')
|
| 18 |
+
except ImportError as e:
|
| 19 |
+
import logging
|
| 20 |
+
|
| 21 |
+
logging.warning(f"Failed to import some Whisper implementations: {e}")
|
src/VoiceDialogue/services/speech/asr/models/base.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from enum import Enum
|
| 3 |
+
|
| 4 |
+
import librosa
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from config import paths
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ASRConfigType(Enum):
|
| 11 |
+
"""ASR引擎类型枚举"""
|
| 12 |
+
FUNASR = 'funasr'
|
| 13 |
+
WHISPER_CPP = 'whisper_cpp'
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Language(Enum):
|
| 17 |
+
"""支持的语言枚举"""
|
| 18 |
+
AUTO = 'auto'
|
| 19 |
+
CHINESE = 'zh'
|
| 20 |
+
ENGLISH = 'en'
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ASRInterface(ABC):
|
| 24 |
+
"""ASR服务的抽象接口"""
|
| 25 |
+
supported_langs = []
|
| 26 |
+
|
| 27 |
+
def __init__(self):
|
| 28 |
+
warmup_audiofile = paths.RESOURCES_PATH / 'audio' / 'jfk.flac'
|
| 29 |
+
if warmup_audiofile.exists():
|
| 30 |
+
audiodata, _ = librosa.load(warmup_audiofile, sr=16000, mono=True)
|
| 31 |
+
else:
|
| 32 |
+
# 创建测试音频
|
| 33 |
+
audiodata = np.random.randn(16000).astype(np.float32) * 0.1 # 1秒的噪声
|
| 34 |
+
self.warmup_audiodata = audiodata
|
| 35 |
+
|
| 36 |
+
@abstractmethod
|
| 37 |
+
def setup(self, **kwargs) -> None:
|
| 38 |
+
"""
|
| 39 |
+
初始化ASR服务
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
**kwargs: 额外的初始化参数
|
| 43 |
+
"""
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
@abstractmethod
|
| 47 |
+
def warmup(self) -> None:
|
| 48 |
+
"""预热ASR引擎"""
|
| 49 |
+
pass
|
| 50 |
+
|
| 51 |
+
@abstractmethod
|
| 52 |
+
def transcribe(self, audio_array: np.ndarray, language: str = None) -> str:
|
| 53 |
+
"""
|
| 54 |
+
将音频转换为文本
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
audio_array: 音频数据
|
| 58 |
+
language: 指定语言,如果为None则使用配置中的语言
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
str: 识别结果文本
|
| 62 |
+
"""
|
| 63 |
+
pass
|
src/VoiceDialogue/services/speech/asr/models/funasr.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import typing
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from funasr_onnx import SeacoParaformer, CT_Transformer
|
| 6 |
+
|
| 7 |
+
from config import paths
|
| 8 |
+
from .base import ASRInterface
|
| 9 |
+
from ..manager import asr_tables
|
| 10 |
+
from ..utils import ensure_minimum_audio_duration
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@asr_tables.register('asr_classes', 'funasr')
|
| 14 |
+
class FunASRClient(ASRInterface):
|
| 15 |
+
"""FunASR API客户端"""
|
| 16 |
+
supported_langs = ['zh']
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.funasr_model: typing.Optional[SeacoParaformer] = None
|
| 21 |
+
self.punc_model: typing.Optional[CT_Transformer] = None
|
| 22 |
+
|
| 23 |
+
def setup(self, **kwargs) -> None:
|
| 24 |
+
# 设置模型缓存目录
|
| 25 |
+
models_dir = paths.MODELS_PATH / "asr"
|
| 26 |
+
asr_model_path = models_dir / "speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
| 27 |
+
punc_model_path = models_dir / "punc_ct-transformer_cn-en-common-vocab471067-large"
|
| 28 |
+
self.funasr_model = SeacoParaformer(asr_model_path, quantize=True)
|
| 29 |
+
self.punc_model = CT_Transformer(punc_model_path, quantize=True)
|
| 30 |
+
|
| 31 |
+
def warmup(self) -> None:
|
| 32 |
+
print('[INFO] Warming up FunASR model...')
|
| 33 |
+
try:
|
| 34 |
+
self.transcribe(self.warmup_audiodata)
|
| 35 |
+
print('[INFO] FunASR model warmed up.')
|
| 36 |
+
except Exception as e:
|
| 37 |
+
print(f'[WARNING] FunASR model warmup failed: {e}')
|
| 38 |
+
|
| 39 |
+
def _fix_spaced_uppercase(self, text: str) -> str:
|
| 40 |
+
"""
|
| 41 |
+
修复类似 " G N O M E " 这样的大写字母间有空格的字符串,将其替换为 "GNOME"
|
| 42 |
+
"""
|
| 43 |
+
# 匹配大写字母之间的空格模式,至少2个大写字母
|
| 44 |
+
pattern = r'([A-Z])\s+([A-Z](?:\s+[A-Z])*)'
|
| 45 |
+
|
| 46 |
+
def replace_func(match):
|
| 47 |
+
# 移除所有空格
|
| 48 |
+
return match.group(0).replace(' ', '')
|
| 49 |
+
|
| 50 |
+
return re.sub(pattern, replace_func, text)
|
| 51 |
+
|
| 52 |
+
def transcribe(self, audio_array: np.ndarray, language="auto"):
|
| 53 |
+
audio_array = ensure_minimum_audio_duration(audio_array)
|
| 54 |
+
|
| 55 |
+
segments = self.funasr_model(wav_content=audio_array, hotwords='')
|
| 56 |
+
|
| 57 |
+
transcibed_texts = []
|
| 58 |
+
for segment in segments:
|
| 59 |
+
content = segment.get("preds", "")
|
| 60 |
+
content, _ = self.punc_model(content)
|
| 61 |
+
content = self._fix_spaced_uppercase(content)
|
| 62 |
+
transcibed_texts.append(content)
|
| 63 |
+
return " ".join(transcibed_texts)
|
src/VoiceDialogue/services/speech/asr/models/whisper.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import typing
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from pywhispercpp.model import Model
|
| 5 |
+
|
| 6 |
+
from config import paths
|
| 7 |
+
from .base import ASRInterface
|
| 8 |
+
from ..manager import asr_tables
|
| 9 |
+
from ..utils import ensure_minimum_audio_duration
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@asr_tables.register('asr_classes', 'whisper')
|
| 13 |
+
class WhisperCppClient(ASRInterface):
|
| 14 |
+
"""Whisper C++ API客户端"""
|
| 15 |
+
supported_langs = ['en', 'zh', 'auto']
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.whisper: typing.Optional[Model] = None
|
| 20 |
+
self.language = "en"
|
| 21 |
+
|
| 22 |
+
def setup(self, **kwargs) -> None:
|
| 23 |
+
model = kwargs.get('model', 'medium')
|
| 24 |
+
if model == "medium":
|
| 25 |
+
model = "medium-q5_0"
|
| 26 |
+
else:
|
| 27 |
+
model = "large-v3-turbo-q5_0"
|
| 28 |
+
|
| 29 |
+
models_dir = paths.MODELS_PATH / "asr"
|
| 30 |
+
self.whisper = Model(model=model, models_dir=models_dir)
|
| 31 |
+
|
| 32 |
+
def warmup(self) -> None:
|
| 33 |
+
print('[INFO] Warming up Whisper model...')
|
| 34 |
+
try:
|
| 35 |
+
self.transcribe(self.warmup_audiodata)
|
| 36 |
+
print('[INFO] Whisper model warmed up.')
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f'[WARNING] Whisper model warmup failed: {e}')
|
| 39 |
+
|
| 40 |
+
def transcribe(self, audio_array: np.ndarray, language="en"):
|
| 41 |
+
if language == "zh":
|
| 42 |
+
prompt = "以下是简体中文普通话的句子。"
|
| 43 |
+
else:
|
| 44 |
+
prompt = "The following is an English sentence."
|
| 45 |
+
|
| 46 |
+
audio_array = ensure_minimum_audio_duration(audio_array)
|
| 47 |
+
|
| 48 |
+
# print('............... language:', language)
|
| 49 |
+
segments = self.whisper.transcribe(
|
| 50 |
+
audio_array, language=language, initial_prompt=prompt, print_progress=False
|
| 51 |
+
)
|
| 52 |
+
text = []
|
| 53 |
+
for segment in segments:
|
| 54 |
+
content = segment.text
|
| 55 |
+
# if not content.endswith(()):
|
| 56 |
+
# content += ','
|
| 57 |
+
text.append(content)
|
| 58 |
+
text = " ".join(text)
|
| 59 |
+
return text
|
src/VoiceDialogue/services/speech/asr/utils.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ASR模块的工具函数
|
| 3 |
+
包含音频预处理、格式转换等工具函数
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def ensure_minimum_audio_duration(
|
| 10 |
+
audio_array: np.ndarray, min_duration: float = 1.0, sample_rate: int = 16000
|
| 11 |
+
) -> np.ndarray:
|
| 12 |
+
"""
|
| 13 |
+
确保音频数组满足最小时长要求,如果不足则用静音填充
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
audio_array: 输入音频数组
|
| 17 |
+
min_duration: 最小时长要求(秒),默认1秒
|
| 18 |
+
sample_rate: 采样率,默认16000Hz
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
处理后的音频数组
|
| 22 |
+
"""
|
| 23 |
+
audio_duration = audio_array.shape[-1] / sample_rate
|
| 24 |
+
|
| 25 |
+
if audio_duration < min_duration:
|
| 26 |
+
padding_seconds = min_duration - audio_duration
|
| 27 |
+
audio_array = padding_silence(audio_array, padding_seconds, sample_rate)
|
| 28 |
+
|
| 29 |
+
return audio_array
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def padding_silence(
|
| 33 |
+
audio_data: np.ndarray, duration_seconds: float, sample_rate: int = 16000
|
| 34 |
+
) -> np.ndarray:
|
| 35 |
+
"""
|
| 36 |
+
为音频数据添加静音填充
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
audio_data: 原始音频数据
|
| 40 |
+
duration_seconds: 需要填充的时长(秒)
|
| 41 |
+
sample_rate: 采样率
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
填充后的音频数据
|
| 45 |
+
"""
|
| 46 |
+
frequency = 440.0
|
| 47 |
+
duration = duration_seconds + 0.1
|
| 48 |
+
t = np.linspace(
|
| 49 |
+
0, duration, int(sample_rate * duration), endpoint=False, dtype=audio_data.dtype
|
| 50 |
+
)
|
| 51 |
+
silence = 0.5 * np.sin(2 * np.pi * frequency * t)
|
| 52 |
+
audio_data = np.concatenate([audio_data, silence])
|
| 53 |
+
return audio_data
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def validate_audio_array(audio_array: np.ndarray) -> bool:
|
| 57 |
+
"""
|
| 58 |
+
验证音频数组是否有效
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
audio_array: 音频数组
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
bool: 是否为有效的音频数组
|
| 65 |
+
"""
|
| 66 |
+
if audio_array is None:
|
| 67 |
+
return False
|
| 68 |
+
|
| 69 |
+
if not isinstance(audio_array, np.ndarray):
|
| 70 |
+
return False
|
| 71 |
+
|
| 72 |
+
if audio_array.size == 0:
|
| 73 |
+
return False
|
| 74 |
+
|
| 75 |
+
if len(audio_array.shape) > 2:
|
| 76 |
+
return False
|
| 77 |
+
|
| 78 |
+
return True
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def normalize_audio(audio_array: np.ndarray, target_peak: float = 0.95) -> np.ndarray:
|
| 82 |
+
"""
|
| 83 |
+
标准化音频数组的音量
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
audio_array: 输入音频数组
|
| 87 |
+
target_peak: 目标峰值,默认0.95
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
标准化后的音频数组
|
| 91 |
+
"""
|
| 92 |
+
if not validate_audio_array(audio_array):
|
| 93 |
+
raise ValueError("Invalid audio array")
|
| 94 |
+
|
| 95 |
+
# 获取当前峰值
|
| 96 |
+
current_peak = np.max(np.abs(audio_array))
|
| 97 |
+
|
| 98 |
+
if current_peak == 0:
|
| 99 |
+
return audio_array
|
| 100 |
+
|
| 101 |
+
# 计算缩放因子
|
| 102 |
+
scale_factor = target_peak / current_peak
|
| 103 |
+
|
| 104 |
+
# 应用缩放
|
| 105 |
+
normalized_audio = audio_array * scale_factor
|
| 106 |
+
|
| 107 |
+
return normalized_audio
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def convert_sample_rate(
|
| 111 |
+
audio_array: np.ndarray,
|
| 112 |
+
source_rate: int,
|
| 113 |
+
target_rate: int
|
| 114 |
+
) -> np.ndarray:
|
| 115 |
+
"""
|
| 116 |
+
转换音频采样率
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
audio_array: 输入音频数组
|
| 120 |
+
source_rate: 源采样率
|
| 121 |
+
target_rate: 目标采样率
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
转换后的音频数组
|
| 125 |
+
"""
|
| 126 |
+
if source_rate == target_rate:
|
| 127 |
+
return audio_array
|
| 128 |
+
|
| 129 |
+
try:
|
| 130 |
+
import librosa
|
| 131 |
+
return librosa.resample(audio_array, orig_sr=source_rate, target_sr=target_rate)
|
| 132 |
+
except ImportError:
|
| 133 |
+
# 如果没有librosa,使用简单的重采样
|
| 134 |
+
ratio = target_rate / source_rate
|
| 135 |
+
new_length = int(len(audio_array) * ratio)
|
| 136 |
+
indices = np.linspace(0, len(audio_array) - 1, new_length)
|
| 137 |
+
return np.interp(indices, np.arange(len(audio_array)), audio_array)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def trim_silence(
|
| 141 |
+
audio_array: np.ndarray,
|
| 142 |
+
threshold: float = 0.01,
|
| 143 |
+
sample_rate: int = 16000
|
| 144 |
+
) -> np.ndarray:
|
| 145 |
+
"""
|
| 146 |
+
修剪音频开头和结尾的静音部分
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
audio_array: 输入音频数组
|
| 150 |
+
threshold: 静音检测阈值
|
| 151 |
+
sample_rate: 采样率
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
修剪后的音频数组
|
| 155 |
+
"""
|
| 156 |
+
if not validate_audio_array(audio_array):
|
| 157 |
+
return audio_array
|
| 158 |
+
|
| 159 |
+
# 计算音频的绝对值
|
| 160 |
+
audio_abs = np.abs(audio_array)
|
| 161 |
+
|
| 162 |
+
# 找到非静音部分的开始和结束
|
| 163 |
+
non_silent = audio_abs > threshold
|
| 164 |
+
|
| 165 |
+
if not np.any(non_silent):
|
| 166 |
+
# 如果全是静音,返回最小长度的音频
|
| 167 |
+
min_samples = int(0.1 * sample_rate) # 100ms
|
| 168 |
+
return audio_array[:min_samples] if len(audio_array) > min_samples else audio_array
|
| 169 |
+
|
| 170 |
+
# 找到第一个和最后一个非静音样本
|
| 171 |
+
start_idx = np.argmax(non_silent)
|
| 172 |
+
end_idx = len(non_silent) - np.argmax(non_silent[::-1])
|
| 173 |
+
|
| 174 |
+
return audio_array[start_idx:end_idx]
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def get_audio_duration(audio_array: np.ndarray, sample_rate: int = 16000) -> float:
|
| 178 |
+
"""
|
| 179 |
+
获取音频时长(秒)
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
audio_array: 音频数组
|
| 183 |
+
sample_rate: 采样率
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
音频时长(秒)
|
| 187 |
+
"""
|
| 188 |
+
if not validate_audio_array(audio_array):
|
| 189 |
+
return 0.0
|
| 190 |
+
|
| 191 |
+
return audio_array.shape[-1] / sample_rate
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def create_silence(duration_seconds: float, sample_rate: int = 16000) -> np.ndarray:
|
| 195 |
+
"""
|
| 196 |
+
创建指定时长的静音
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
duration_seconds: 静音时长(秒)
|
| 200 |
+
sample_rate: 采样率
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
静音音频数组
|
| 204 |
+
"""
|
| 205 |
+
num_samples = int(duration_seconds * sample_rate)
|
| 206 |
+
return np.zeros(num_samples, dtype=np.float32)
|
src/VoiceDialogue/services/speech/asr_service.py
CHANGED
|
@@ -1,168 +1,14 @@
|
|
| 1 |
-
import re
|
| 2 |
import time
|
| 3 |
import typing
|
| 4 |
from queue import Queue
|
| 5 |
|
| 6 |
-
import librosa
|
| 7 |
import numpy as np
|
| 8 |
-
from funasr_onnx import SeacoParaformer, CT_Transformer
|
| 9 |
-
from pywhispercpp.model import Model
|
| 10 |
|
| 11 |
-
from config import paths
|
| 12 |
from models.voice_task import VoiceTask
|
| 13 |
from services.core.base import BaseThread
|
| 14 |
from services.core.constants import user_still_speaking_event, voice_state_manager, dropped_audio_cache
|
| 15 |
from utils.cache import LRUCacheDict
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def ensure_minimum_audio_duration(
|
| 19 |
-
audio_array: np.ndarray, min_duration: float = 1.0, sample_rate: int = 16000
|
| 20 |
-
) -> np.ndarray:
|
| 21 |
-
"""
|
| 22 |
-
确保音频数组满足最小时长要求,如果不足则用静音填充
|
| 23 |
-
|
| 24 |
-
Args:
|
| 25 |
-
audio_array: 输入音频数组
|
| 26 |
-
min_duration: 最小时长要求(秒),默认1秒
|
| 27 |
-
sample_rate: 采样率,默认16000Hz
|
| 28 |
-
|
| 29 |
-
Returns:
|
| 30 |
-
处理后的音频数组
|
| 31 |
-
"""
|
| 32 |
-
audio_duration = audio_array.shape[-1] / sample_rate
|
| 33 |
-
|
| 34 |
-
if audio_duration < min_duration:
|
| 35 |
-
padding_seconds = min_duration - audio_duration
|
| 36 |
-
audio_array = padding_silence(audio_array, padding_seconds, sample_rate)
|
| 37 |
-
|
| 38 |
-
return audio_array
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def padding_silence(
|
| 42 |
-
audio_data: np.ndarray, duration_seconds: float, sample_rate: int = 16000
|
| 43 |
-
) -> np.ndarray:
|
| 44 |
-
"""
|
| 45 |
-
为音频数据添加静音填充
|
| 46 |
-
|
| 47 |
-
Args:
|
| 48 |
-
audio_data: 原始音频数据
|
| 49 |
-
duration_seconds: 需要填充的时长(秒)
|
| 50 |
-
sample_rate: 采样率
|
| 51 |
-
|
| 52 |
-
Returns:
|
| 53 |
-
填充后的音频数据
|
| 54 |
-
"""
|
| 55 |
-
frequency = 440.0
|
| 56 |
-
duration = duration_seconds + 0.1
|
| 57 |
-
t = np.linspace(
|
| 58 |
-
0, duration, int(sample_rate * duration), endpoint=False, dtype=audio_data.dtype
|
| 59 |
-
)
|
| 60 |
-
silence = 0.5 * np.sin(2 * np.pi * frequency * t)
|
| 61 |
-
audio_data = np.concatenate([audio_data, silence])
|
| 62 |
-
return audio_data
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
class WhisperCppClient:
|
| 66 |
-
"""Whisper C++ API客户端"""
|
| 67 |
-
|
| 68 |
-
def __init__(self, model: typing.Literal["medium", "large"] = "medium"):
|
| 69 |
-
if model == "medium":
|
| 70 |
-
model = "medium-q5_0"
|
| 71 |
-
else:
|
| 72 |
-
model = "large-v3-turbo-q5_0"
|
| 73 |
-
|
| 74 |
-
models_dir = paths.MODELS_PATH / "asr"
|
| 75 |
-
self.whisper = Model(model=model, models_dir=models_dir)
|
| 76 |
-
|
| 77 |
-
def transcribe(self, audio_array: np.ndarray, language="en"):
|
| 78 |
-
if language == "zh":
|
| 79 |
-
prompt = "以下是简体中文普通话的句子。"
|
| 80 |
-
else:
|
| 81 |
-
prompt = "The following is an English sentence."
|
| 82 |
-
|
| 83 |
-
audio_array = ensure_minimum_audio_duration(audio_array)
|
| 84 |
-
|
| 85 |
-
# print('............... language:', language)
|
| 86 |
-
segments = self.whisper.transcribe(
|
| 87 |
-
audio_array, language=language, initial_prompt=prompt, print_progress=False
|
| 88 |
-
)
|
| 89 |
-
text = []
|
| 90 |
-
for segment in segments:
|
| 91 |
-
content = segment.text
|
| 92 |
-
# if not content.endswith(()):
|
| 93 |
-
# content += ','
|
| 94 |
-
text.append(content)
|
| 95 |
-
text = " ".join(text)
|
| 96 |
-
return text
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
class FunASRClient:
|
| 100 |
-
"""FunASR API客户端"""
|
| 101 |
-
|
| 102 |
-
def __init__(self):
|
| 103 |
-
# 设置模型缓存目录
|
| 104 |
-
models_dir = paths.MODELS_PATH / "asr"
|
| 105 |
-
asr_model_path = models_dir / "speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
| 106 |
-
punc_model_path = models_dir / "punc_ct-transformer_cn-en-common-vocab471067-large"
|
| 107 |
-
self.funasr_model = SeacoParaformer(asr_model_path, quantize=True)
|
| 108 |
-
self.punc_model = CT_Transformer(punc_model_path, quantize=True)
|
| 109 |
-
|
| 110 |
-
def _fix_spaced_uppercase(self, text: str) -> str:
|
| 111 |
-
"""
|
| 112 |
-
修复类似 " G N O M E " 这样的大写字母间有空格的字符串,将其替换为 "GNOME"
|
| 113 |
-
"""
|
| 114 |
-
# 匹配大写字母之间的空格模式,至少2个大写字母
|
| 115 |
-
pattern = r'([A-Z])\s+([A-Z](?:\s+[A-Z])*)'
|
| 116 |
-
|
| 117 |
-
def replace_func(match):
|
| 118 |
-
# 移除所有空格
|
| 119 |
-
return match.group(0).replace(' ', '')
|
| 120 |
-
|
| 121 |
-
return re.sub(pattern, replace_func, text)
|
| 122 |
-
|
| 123 |
-
def transcribe(self, audio_array: np.ndarray, language="auto"):
|
| 124 |
-
audio_array = ensure_minimum_audio_duration(audio_array)
|
| 125 |
-
|
| 126 |
-
segments = self.funasr_model(wav_content=audio_array, hotwords='')
|
| 127 |
-
|
| 128 |
-
transcibed_texts = []
|
| 129 |
-
for segment in segments:
|
| 130 |
-
content = segment.get("preds", "")
|
| 131 |
-
content, _ = self.punc_model(content)
|
| 132 |
-
content = self._fix_spaced_uppercase(content)
|
| 133 |
-
transcibed_texts.append(content)
|
| 134 |
-
return " ".join(transcibed_texts)
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
class UnifiedASRClient:
|
| 138 |
-
"""统一的语音识别客户端,根据语言自动选择FunASR或Whisper"""
|
| 139 |
-
|
| 140 |
-
def __init__(self, language: typing.Literal["auto", "zh", "en"] = "zh"):
|
| 141 |
-
self.language = language
|
| 142 |
-
|
| 143 |
-
if language == "zh":
|
| 144 |
-
self.client = FunASRClient()
|
| 145 |
-
else:
|
| 146 |
-
self.client = WhisperCppClient()
|
| 147 |
-
|
| 148 |
-
def warmup(self):
|
| 149 |
-
"""预热模型"""
|
| 150 |
-
print('[INFO] 预热语音识别模型...')
|
| 151 |
-
try:
|
| 152 |
-
warmup_audiofile = paths.RESOURCES_PATH / 'audio' / 'jfk.flac'
|
| 153 |
-
if warmup_audiofile.exists():
|
| 154 |
-
data, sr = librosa.load(warmup_audiofile, sr=16000, mono=True)
|
| 155 |
-
self.client.transcribe(data, language=self.language)
|
| 156 |
-
else:
|
| 157 |
-
# 创建测试音频
|
| 158 |
-
test_audio = np.random.randn(16000).astype(np.float32) * 0.1 # 1秒的噪声
|
| 159 |
-
self.client.transcribe(test_audio, language=self.language)
|
| 160 |
-
print('[INFO] ASR模型预热完成')
|
| 161 |
-
except Exception as e:
|
| 162 |
-
print(f'[WARNING] ASR模型预热失败: {e}')
|
| 163 |
-
|
| 164 |
-
def transcribe(self, audio_array: np.ndarray) -> str:
|
| 165 |
-
return self.client.transcribe(audio_array, language=self.language)
|
| 166 |
|
| 167 |
|
| 168 |
class ASRWorker(BaseThread):
|
|
@@ -179,7 +25,8 @@ class ASRWorker(BaseThread):
|
|
| 179 |
self.cached_user_questions = LRUCacheDict(maxsize=10)
|
| 180 |
|
| 181 |
def run(self):
|
| 182 |
-
self.client =
|
|
|
|
| 183 |
self.client.warmup()
|
| 184 |
|
| 185 |
self.is_ready = True
|
|
|
|
|
|
|
| 1 |
import time
|
| 2 |
import typing
|
| 3 |
from queue import Queue
|
| 4 |
|
|
|
|
| 5 |
import numpy as np
|
|
|
|
|
|
|
| 6 |
|
|
|
|
| 7 |
from models.voice_task import VoiceTask
|
| 8 |
from services.core.base import BaseThread
|
| 9 |
from services.core.constants import user_still_speaking_event, voice_state_manager, dropped_audio_cache
|
| 10 |
from utils.cache import LRUCacheDict
|
| 11 |
+
from .asr import asr_manager
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class ASRWorker(BaseThread):
|
|
|
|
| 25 |
self.cached_user_questions = LRUCacheDict(maxsize=10)
|
| 26 |
|
| 27 |
def run(self):
|
| 28 |
+
self.client = asr_manager.create_asr(self.language)
|
| 29 |
+
self.client.setup()
|
| 30 |
self.client.warmup()
|
| 31 |
|
| 32 |
self.is_ready = True
|