liumaolin commited on
Commit
4e2e3d8
·
1 Parent(s): d41c6db

Integrate `SileroVAD` into `SpeechMonitor` for optional voice activity detection. Add `_detect_speech()` method and update queue handling logic. Implement `SileroVAD` as a singleton for efficient model management.

Browse files
src/voice_dialogue/services/audio/vad.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ from silero_vad import load_silero_vad
6
+
7
+ from voice_dialogue.utils.logger import logger
8
+
9
+
10
+ class SileroVAD:
11
+ """
12
+ 一个线程安全的、基于单例模式的Silero VAD模型包装器。
13
+
14
+ 该类在首次实例化时加载 Silero VAD 模型,并提供一个方法来检测音频帧中的语音活动。
15
+ 设计为单例可以避免在应用中重复加载这个较为消耗资源模型。
16
+ """
17
+ _instance: Optional['SileroVAD'] = None
18
+ _model = None
19
+
20
+ def __new__(cls, *args, **kwargs):
21
+ if cls._instance is None:
22
+ cls._instance = super().__new__(cls)
23
+ return cls._instance
24
+
25
+ def __init__(self, threshold: float = 0.7):
26
+ """
27
+ 初始化 Silero VAD 模型。模型只会在首次创建实例时加载。
28
+
29
+ Args:
30
+ threshold (float): 用于判定语音活动的置信度阈值 (范围 0.0 到 1.0)。
31
+ """
32
+ if self._model is None:
33
+ logger.info("正在首次初始化 Silero VAD 模型...")
34
+ try:
35
+ self._model = load_silero_vad()
36
+ self._model.reset_states()
37
+ self.threshold = threshold
38
+ logger.info("Silero VAD 模型初始化成功。")
39
+ except Exception as e:
40
+ logger.error(f"初始化 Silero VAD 模型失败: {e}", exc_info=True)
41
+ # 如果失败,重置实例,以便下次可以重试
42
+ SileroVAD._instance = None
43
+ raise
44
+
45
+ def is_voice_active(self, audio_frame: np.ndarray, sample_rate: int = 16000) -> bool:
46
+ """
47
+ 检测给定的音频帧中是否包含语音活动。
48
+
49
+ Args:
50
+ audio_frame (np.ndarray): 一个一维的 float32 numpy 数组,代表音频数据。
51
+ 其数值范围应为 [-1.0, 1.0]。
52
+ 对于16kHz采样率,帧大小必须为 [512, 1024, 1536] 之一。
53
+ sample_rate (int): 音频的采样率,必须是 8000 或 16000。
54
+
55
+ Returns:
56
+ bool: 如果检测到语音活动,返回 True,否则返回 False。
57
+ """
58
+ if self._model is None:
59
+ logger.error("VAD 模型未初始化,无法执行检测。")
60
+ return False
61
+
62
+ if not isinstance(audio_frame, np.ndarray):
63
+ logger.warning("VAD 检测的输入必须是一个 numpy 数组。")
64
+ return False
65
+
66
+ # Silero VAD 模型要求 float32 类型
67
+ if audio_frame.dtype != np.float32:
68
+ audio_frame = audio_frame.astype(np.float32)
69
+
70
+ window_size = 512 if sample_rate == 16000 else 256
71
+
72
+ audio_tensor = torch.from_numpy(audio_frame)
73
+
74
+ try:
75
+ probs = []
76
+ for i in range(0, len(audio_tensor), window_size):
77
+ audio_slice = audio_tensor[i:i + window_size]
78
+ if len(audio_slice) < window_size:
79
+ audio_slice = audio_tensor[-window_size:]
80
+
81
+ # 模型会返回一个包含语音可能性的张量
82
+ prob = self._model(audio_slice, sample_rate).item()
83
+ probs.append(prob)
84
+
85
+ return np.max(probs) > self.threshold
86
+ except Exception as e:
87
+ logger.error(f"VAD 检测过程中发生错误: {e}")
88
+ return False
src/voice_dialogue/services/speech/monitor.py CHANGED
@@ -19,6 +19,7 @@ from voice_dialogue.core.constants import (
19
  )
20
  from voice_dialogue.core.enums import AudioState
21
  from voice_dialogue.models.voice_task import VoiceTask
 
22
  from voice_dialogue.utils.logger import logger
23
 
24
 
@@ -49,6 +50,7 @@ class SpeechStateMonitor(BaseThread):
49
  self, group=None, target=None, name=None, args=(), kwargs=None, *, daemon=None,
50
  audio_frame_queue: Queue,
51
  user_voice_queue: Queue,
 
52
  ):
53
  """
54
  初始化语音状态监控器
@@ -56,12 +58,18 @@ class SpeechStateMonitor(BaseThread):
56
  Args:
57
  audio_frame_queue: 音频帧队列
58
  user_voice_queue: 用户语音队列
 
59
  """
60
  super().__init__(group, target, name, args, kwargs, daemon=daemon)
61
 
62
  self.audio_frame_queue = audio_frame_queue
63
  self.user_voice_queue = user_voice_queue
64
  self.sample_rate = 16000
 
 
 
 
 
65
 
66
  # 配置参数
67
  self.config = SpeechMonitorConfig()
@@ -104,11 +112,19 @@ class SpeechStateMonitor(BaseThread):
104
  """将 int16 格式的音频字节数据转换为 [-1.0, 1.0] 范围的 numpy 浮点数组。"""
105
  return np.frombuffer(data, dtype=np.int16).astype(np.float32) / np.iinfo(np.int16).max
106
 
 
 
 
107
  def _get_audio_frame_from_queue(self):
108
  """从队列获取音频帧"""
109
  try:
110
- data, is_voice_active = self.audio_frame_queue.get(block=False, timeout=self.config.QUEUE_TIMEOUT)
111
- audio_frame = self._normalize_audio_frame(data)
 
 
 
 
 
112
  return audio_frame, is_voice_active
113
  except Empty:
114
  return None, None
 
19
  )
20
  from voice_dialogue.core.enums import AudioState
21
  from voice_dialogue.models.voice_task import VoiceTask
22
+ from voice_dialogue.services.audio.vad import SileroVAD
23
  from voice_dialogue.utils.logger import logger
24
 
25
 
 
50
  self, group=None, target=None, name=None, args=(), kwargs=None, *, daemon=None,
51
  audio_frame_queue: Queue,
52
  user_voice_queue: Queue,
53
+ enable_vad: bool = False,
54
  ):
55
  """
56
  初始化语音状态监控器
 
58
  Args:
59
  audio_frame_queue: 音频帧队列
60
  user_voice_queue: 用户语音队列
61
+ enable_vad: 是否启用语音活动检测
62
  """
63
  super().__init__(group, target, name, args, kwargs, daemon=daemon)
64
 
65
  self.audio_frame_queue = audio_frame_queue
66
  self.user_voice_queue = user_voice_queue
67
  self.sample_rate = 16000
68
+ self._enable_vad = enable_vad
69
+
70
+ self._vad_instance = None
71
+ if self._enable_vad:
72
+ self._vad_instance = SileroVAD()
73
 
74
  # 配置参数
75
  self.config = SpeechMonitorConfig()
 
112
  """将 int16 格式的音频字节数据转换为 [-1.0, 1.0] 范围的 numpy 浮点数组。"""
113
  return np.frombuffer(data, dtype=np.int16).astype(np.float32) / np.iinfo(np.int16).max
114
 
115
+ def _detect_speech(self, audio_frame: np.ndarray) -> bool:
116
+ return self._vad_instance.is_voice_active(audio_frame, self.sample_rate)
117
+
118
  def _get_audio_frame_from_queue(self):
119
  """从队列获取音频帧"""
120
  try:
121
+ if self._enable_vad:
122
+ data = self.audio_frame_queue.get(block=False, timeout=self.config.QUEUE_TIMEOUT)
123
+ audio_frame = self._normalize_audio_frame(data)
124
+ is_voice_active = self._detect_speech(audio_frame)
125
+ else:
126
+ data, is_voice_active = self.audio_frame_queue.get(block=False, timeout=self.config.QUEUE_TIMEOUT)
127
+ audio_frame = self._normalize_audio_frame(data)
128
  return audio_frame, is_voice_active
129
  except Empty:
130
  return None, None