hzeng412 Claude Fable 5 commited on
Commit
417ffa1
·
1 Parent(s): 3c3f610

ASR: add Qwen3-ASR engine, default zh/en to it on this branch

Browse files

- New QwenASRClient (Qwen/Qwen3-ASR-1.7B via qwen-asr, transformers backend, MPS bf16)
- Language mapping zh/en -> qwen; VOICE_DIALOGUE_ASR=legacy restores funasr/whisper
- ASRService now passes session language to transcribe()

Requires: uv pip install qwen-asr (upgrades transformers to 4.57)

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

src/voice_dialogue/asr/manager.py CHANGED
@@ -1,5 +1,6 @@
1
  import importlib.util
2
  import inspect
 
3
  import re
4
  from dataclasses import dataclass
5
  from typing import Dict, Type, List, Literal, Optional
@@ -92,11 +93,17 @@ class ASRManager:
92
 
93
  def __init__(self):
94
  self._asr_instances: Dict[str, ASRInterface] = {}
95
- self._language_to_asr_mapping = {
96
- 'zh': 'funasr', # 中文优先使用FunASR
97
- 'en': 'whisper', # 英文优先使用Whisper
98
- # 'auto': 'whisper', # 自动检测默认使用Whisper
99
- }
 
 
 
 
 
 
100
 
101
  def create_asr(self, language: Literal['auto', 'zh', 'en']) -> ASRInterface:
102
  """
 
1
  import importlib.util
2
  import inspect
3
+ import os
4
  import re
5
  from dataclasses import dataclass
6
  from typing import Dict, Type, List, Literal, Optional
 
93
 
94
  def __init__(self):
95
  self._asr_instances: Dict[str, ASRInterface] = {}
96
+ # 本分支默认使用 Qwen3-ASR;设置 VOICE_DIALOGUE_ASR=legacy 可切回原引擎做 A/B 对比
97
+ if os.environ.get('VOICE_DIALOGUE_ASR', 'qwen') == 'legacy':
98
+ self._language_to_asr_mapping = {
99
+ 'zh': 'funasr', # 中文优先使用FunASR
100
+ 'en': 'whisper', # 英文优先使用Whisper
101
+ }
102
+ else:
103
+ self._language_to_asr_mapping = {
104
+ 'zh': 'qwen',
105
+ 'en': 'qwen',
106
+ }
107
 
108
  def create_asr(self, language: Literal['auto', 'zh', 'en']) -> ASRInterface:
109
  """
src/voice_dialogue/asr/models/__init__.py CHANGED
@@ -19,3 +19,12 @@ except ImportError as e:
19
  from voice_dialogue.utils.logger import logger
20
 
21
  logger.warning(f"Failed to import some Whisper implementations: {e}")
 
 
 
 
 
 
 
 
 
 
19
  from voice_dialogue.utils.logger import logger
20
 
21
  logger.warning(f"Failed to import some Whisper implementations: {e}")
22
+
23
+ try:
24
+ from .qwen import QwenASRClient
25
+
26
+ __all__.append('QwenASRClient')
27
+ except ImportError as e:
28
+ from voice_dialogue.utils.logger import logger
29
+
30
+ logger.warning(f"Failed to import some Qwen ASR implementations: {e}")
src/voice_dialogue/asr/models/qwen.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import typing
3
+
4
+ import numpy as np
5
+ import torch
6
+ from qwen_asr import Qwen3ASRModel
7
+
8
+ from voice_dialogue.asr.manager import asr_tables
9
+ from voice_dialogue.asr.models.base import ASRInterface
10
+ from voice_dialogue.asr.utils import ensure_minimum_audio_duration
11
+ from voice_dialogue.utils.logger import logger
12
+
13
+ # Qwen3-ASR 的 language 参数使用语言全名
14
+ LANGUAGE_NAME_MAPPING = {
15
+ 'zh': 'Chinese',
16
+ 'en': 'English',
17
+ }
18
+
19
+ DEFAULT_MODEL = os.environ.get('QWEN_ASR_MODEL', 'Qwen/Qwen3-ASR-1.7B')
20
+
21
+ TARGET_SAMPLE_RATE = 16000
22
+
23
+
24
+ @asr_tables.register('asr_classes', 'qwen')
25
+ class QwenASRClient(ASRInterface):
26
+ """Qwen3-ASR 客户端(transformers 后端,macOS 上使用 MPS 加速)"""
27
+ supported_langs = ['zh', 'en']
28
+
29
+ def __init__(self):
30
+ super().__init__()
31
+ self.model: typing.Optional[Qwen3ASRModel] = None
32
+
33
+ def setup(self, **kwargs) -> None:
34
+ model_name = kwargs.get('model', DEFAULT_MODEL)
35
+
36
+ if torch.backends.mps.is_available():
37
+ device_map, dtype = 'mps', torch.bfloat16
38
+ elif torch.cuda.is_available():
39
+ device_map, dtype = 'cuda:0', torch.bfloat16
40
+ else:
41
+ device_map, dtype = 'cpu', torch.float32
42
+
43
+ logger.info(f'[INFO] Loading Qwen3-ASR model: {model_name} (device={device_map}, dtype={dtype})')
44
+ self.model = Qwen3ASRModel.from_pretrained(
45
+ model_name,
46
+ dtype=dtype,
47
+ device_map=device_map,
48
+ max_inference_batch_size=1,
49
+ max_new_tokens=256,
50
+ )
51
+
52
+ def warmup(self) -> None:
53
+ logger.info('[INFO] Warming up Qwen3-ASR model...')
54
+ try:
55
+ self.transcribe(self.warmup_audiodata)
56
+ logger.info('[INFO] Qwen3-ASR model warmed up.')
57
+ except Exception as e:
58
+ logger.warning(f'[WARNING] Qwen3-ASR model warmup failed: {e}')
59
+
60
+ def transcribe(self, audio_array: np.ndarray, language: str = None) -> str:
61
+ audio_array = ensure_minimum_audio_duration(audio_array)
62
+
63
+ # 未指定语言时交给模型自动检测(Qwen3-ASR 自带语种识别)
64
+ qwen_language = LANGUAGE_NAME_MAPPING.get(language)
65
+
66
+ results = self.model.transcribe(
67
+ audio=(audio_array, TARGET_SAMPLE_RATE),
68
+ language=qwen_language,
69
+ )
70
+ return ' '.join(result.text for result in results).strip()
src/voice_dialogue/services/asr_service.py CHANGED
@@ -42,7 +42,7 @@ class ASRService(BaseThread, PerformanceLogMixin):
42
  voice_task.whisper_start_time = time.time()
43
 
44
  user_voice: np.array = voice_task.user_voice
45
- transcribed_text = self.client.transcribe(user_voice)
46
  if not transcribed_text.strip():
47
  voice_state_manager.reset_task_id()
48
  continue
 
42
  voice_task.whisper_start_time = time.time()
43
 
44
  user_voice: np.array = voice_task.user_voice
45
+ transcribed_text = self.client.transcribe(user_voice, language=self.language)
46
  if not transcribed_text.strip():
47
  voice_state_manager.reset_task_id()
48
  continue