Sbboss commited on
Commit
2bb7b65
·
1 Parent(s): ad3c9dd

applied VAD

Browse files
requirements.txt CHANGED
@@ -7,6 +7,7 @@ httpx>=0.27.0
7
  structlog>=24.0.0
8
  azure-identity>=1.15.0
9
  azure-ai-projects==1.0.0b10
 
10
  streamlit>=1.35.0
11
  pytest>=8.0.0
12
  pytest-asyncio>=0.23.0
 
7
  structlog>=24.0.0
8
  azure-identity>=1.15.0
9
  azure-ai-projects==1.0.0b10
10
+ silero-vad-lite>=0.2.0
11
  streamlit>=1.35.0
12
  pytest>=8.0.0
13
  pytest-asyncio>=0.23.0
src/app/api/routes.py CHANGED
@@ -14,6 +14,7 @@ from ..core.errors import SpeechError, ValidationError
14
  from ..core.logging import get_logger
15
  from ..services.pipeline import VoicePipeline
16
  from ..services.stt import SpeechToTextService
 
17
  from ..utils.audio import encode_base64
18
 
19
  router = APIRouter()
@@ -109,6 +110,91 @@ async def voice_stream(websocket: WebSocket) -> None:
109
  frames_sent: int | None = None
110
  avg_rms: float | None = None
111
  llm_provider: str | None = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  try:
114
  while True:
@@ -123,6 +209,10 @@ async def voice_stream(websocket: WebSocket) -> None:
123
  if stt_session is not None:
124
  stt_session.write(chunk)
125
  buffer.extend(chunk)
 
 
 
 
126
  if len(buffer) > MAX_FILE_SIZE_BYTES:
127
  raise ValidationError(
128
  code="file_too_large", message="Stream exceeds 15MB limit."
@@ -156,6 +246,7 @@ async def voice_stream(websocket: WebSocket) -> None:
156
  stt_session = SpeechToTextService().start_streaming(
157
  end_silence_ms=1200, initial_silence_ms=5000
158
  )
 
159
  continue
160
 
161
  if event == "stop":
@@ -183,81 +274,12 @@ async def voice_stream(websocket: WebSocket) -> None:
183
  "LLM provider must be 'foundry_agent' or 'azure_openai'."
184
  ),
185
  )
186
- try:
187
- stt_result = await anyio.to_thread.run_sync(
188
- stt_session.finish
189
- )
190
- except SpeechError as exc:
191
- if exc.code in {"stt_empty", "stt_no_match"}:
192
- try:
193
- stt_result = await anyio.to_thread.run_sync(
194
- SpeechToTextService().transcribe,
195
- bytes(buffer),
196
- None,
197
- content_type,
198
- )
199
- except SpeechError as exc_fallback:
200
- if exc_fallback.code in {"stt_empty", "stt_no_match"}:
201
- await websocket.send_json(
202
- {
203
- "event": "result",
204
- "transcript": "",
205
- "reply_text": NO_MATCH_REPLY,
206
- "audio_format": "wav",
207
- "reply_audio_base64": None,
208
- "timings_ms": {"stt": 0, "llm": 0, "tts": 0, "total": 0},
209
- }
210
- )
211
- buffer.clear()
212
- break
213
- raise
214
- else:
215
- raise
216
- await websocket.send_json(
217
- {"event": "transcript", "transcript": stt_result.transcript}
218
- )
219
- pipeline = VoicePipeline()
220
- result = await pipeline.run(
221
- audio_bytes=bytes(buffer),
222
- filename=None,
223
- content_type=content_type,
224
- prompt=prompt,
225
- return_audio=return_audio,
226
- transcript_override=stt_result.transcript,
227
- language_override=stt_result.language,
228
- llm_provider=llm_provider,
229
- )
230
- response_body = {
231
- "event": "result",
232
- "transcript": result.transcript,
233
- "reply_text": result.reply_text,
234
- "audio_format": "wav",
235
- "reply_audio_base64": None,
236
- "timings_ms": result.timings_ms,
237
- }
238
- log.info(
239
- "voice_stream_complete",
240
- bytes_received=len(buffer),
241
- timings_ms=result.timings_ms,
242
- return_audio=return_audio,
243
- content_type=content_type,
244
- frames_sent=frames_sent,
245
- avg_rms=avg_rms,
246
- )
247
- await websocket.send_json(response_body)
248
- if result.reply_audio and return_audio:
249
- await websocket.send_bytes(result.reply_audio)
250
- buffer.clear()
251
  break
252
 
253
  if event == "segment_end":
254
  if not buffer:
255
  continue
256
- if stt_session is None:
257
- raise ValidationError(
258
- code="stt_not_started",
259
- message="STT session not started.",
260
- )
261
  prompt = payload.get("prompt", prompt)
262
  return_audio = payload.get("return_audio", return_audio)
263
  llm_provider = payload.get("llm_provider", llm_provider)
@@ -273,77 +295,21 @@ async def voice_stream(websocket: WebSocket) -> None:
273
  "LLM provider must be 'foundry_agent' or 'azure_openai'."
274
  ),
275
  )
276
- try:
277
- stt_result = await anyio.to_thread.run_sync(
278
- stt_session.finish
 
 
 
 
 
 
 
279
  )
280
- except SpeechError as exc:
281
- if exc.code in {"stt_empty", "stt_no_match"}:
282
- try:
283
- stt_result = await anyio.to_thread.run_sync(
284
- SpeechToTextService().transcribe,
285
- bytes(buffer),
286
- None,
287
- content_type,
288
- )
289
- except SpeechError as exc_fallback:
290
- if exc_fallback.code in {"stt_empty", "stt_no_match"}:
291
- await websocket.send_json(
292
- {
293
- "event": "result",
294
- "transcript": "",
295
- "reply_text": NO_MATCH_REPLY,
296
- "audio_format": "wav",
297
- "reply_audio_base64": None,
298
- "timings_ms": {"stt": 0, "llm": 0, "tts": 0, "total": 0},
299
- }
300
- )
301
- buffer.clear()
302
- stt_session = SpeechToTextService().start_streaming(
303
- end_silence_ms=1200, initial_silence_ms=5000
304
- )
305
- continue
306
- raise
307
- else:
308
- raise
309
- await websocket.send_json(
310
- {"event": "transcript", "transcript": stt_result.transcript}
311
- )
312
- pipeline = VoicePipeline()
313
- result = await pipeline.run(
314
- audio_bytes=bytes(buffer),
315
- filename=None,
316
- content_type=content_type,
317
- prompt=prompt,
318
- return_audio=return_audio,
319
- transcript_override=stt_result.transcript,
320
- language_override=stt_result.language,
321
- llm_provider=llm_provider,
322
- )
323
- response_body = {
324
- "event": "result",
325
- "transcript": result.transcript,
326
- "reply_text": result.reply_text,
327
- "audio_format": "wav",
328
- "reply_audio_base64": None,
329
- "timings_ms": result.timings_ms,
330
- }
331
- log.info(
332
- "voice_stream_complete",
333
- bytes_received=len(buffer),
334
- timings_ms=result.timings_ms,
335
- return_audio=return_audio,
336
- content_type=content_type,
337
- frames_sent=frames_sent,
338
- avg_rms=avg_rms,
339
- )
340
- await websocket.send_json(response_body)
341
- if result.reply_audio and return_audio:
342
- await websocket.send_bytes(result.reply_audio)
343
- buffer.clear()
344
- stt_session = SpeechToTextService().start_streaming(
345
- end_silence_ms=1200, initial_silence_ms=5000
346
- )
347
  continue
348
 
349
  raise ValidationError(
 
14
  from ..core.logging import get_logger
15
  from ..services.pipeline import VoicePipeline
16
  from ..services.stt import SpeechToTextService
17
+ from ..services.vad import SileroVADStream
18
  from ..utils.audio import encode_base64
19
 
20
  router = APIRouter()
 
110
  frames_sent: int | None = None
111
  avg_rms: float | None = None
112
  llm_provider: str | None = None
113
+ vad_stream: SileroVADStream | None = None
114
+ segment_processing = False
115
+
116
+ async def _finalize_segment() -> None:
117
+ nonlocal stt_session, segment_processing, vad_stream
118
+ if stt_session is None:
119
+ raise ValidationError(
120
+ code="stt_not_started", message="STT session not started."
121
+ )
122
+ if not buffer:
123
+ return
124
+ segment_processing = True
125
+ try:
126
+ stt_result = await anyio.to_thread.run_sync(stt_session.finish)
127
+ except SpeechError as exc:
128
+ if exc.code in {"stt_empty", "stt_no_match"}:
129
+ try:
130
+ stt_result = await anyio.to_thread.run_sync(
131
+ SpeechToTextService().transcribe,
132
+ bytes(buffer),
133
+ None,
134
+ content_type,
135
+ )
136
+ except SpeechError as exc_fallback:
137
+ if exc_fallback.code in {"stt_empty", "stt_no_match"}:
138
+ await websocket.send_json(
139
+ {
140
+ "event": "result",
141
+ "transcript": "",
142
+ "reply_text": NO_MATCH_REPLY,
143
+ "audio_format": "wav",
144
+ "reply_audio_base64": None,
145
+ "timings_ms": {"stt": 0, "llm": 0, "tts": 0, "total": 0},
146
+ }
147
+ )
148
+ buffer.clear()
149
+ stt_session = SpeechToTextService().start_streaming(
150
+ end_silence_ms=1200, initial_silence_ms=5000
151
+ )
152
+ vad_stream = SileroVADStream()
153
+ return
154
+ raise
155
+ else:
156
+ raise
157
+
158
+ await websocket.send_json(
159
+ {"event": "transcript", "transcript": stt_result.transcript}
160
+ )
161
+ pipeline = VoicePipeline()
162
+ result = await pipeline.run(
163
+ audio_bytes=bytes(buffer),
164
+ filename=None,
165
+ content_type=content_type,
166
+ prompt=prompt,
167
+ return_audio=return_audio,
168
+ transcript_override=stt_result.transcript,
169
+ language_override=stt_result.language,
170
+ llm_provider=llm_provider,
171
+ )
172
+ response_body = {
173
+ "event": "result",
174
+ "transcript": result.transcript,
175
+ "reply_text": result.reply_text,
176
+ "audio_format": "wav",
177
+ "reply_audio_base64": None,
178
+ "timings_ms": result.timings_ms,
179
+ }
180
+ log.info(
181
+ "voice_stream_complete",
182
+ bytes_received=len(buffer),
183
+ timings_ms=result.timings_ms,
184
+ return_audio=return_audio,
185
+ content_type=content_type,
186
+ frames_sent=frames_sent,
187
+ avg_rms=avg_rms,
188
+ )
189
+ await websocket.send_json(response_body)
190
+ if result.reply_audio and return_audio:
191
+ await websocket.send_bytes(result.reply_audio)
192
+ buffer.clear()
193
+ stt_session = SpeechToTextService().start_streaming(
194
+ end_silence_ms=1200, initial_silence_ms=5000
195
+ )
196
+ vad_stream = SileroVADStream()
197
+ segment_processing = False
198
 
199
  try:
200
  while True:
 
209
  if stt_session is not None:
210
  stt_session.write(chunk)
211
  buffer.extend(chunk)
212
+ if vad_stream is not None and not segment_processing:
213
+ decision = vad_stream.update(chunk)
214
+ if decision.speech_ended:
215
+ await _finalize_segment()
216
  if len(buffer) > MAX_FILE_SIZE_BYTES:
217
  raise ValidationError(
218
  code="file_too_large", message="Stream exceeds 15MB limit."
 
246
  stt_session = SpeechToTextService().start_streaming(
247
  end_silence_ms=1200, initial_silence_ms=5000
248
  )
249
+ vad_stream = SileroVADStream()
250
  continue
251
 
252
  if event == "stop":
 
274
  "LLM provider must be 'foundry_agent' or 'azure_openai'."
275
  ),
276
  )
277
+ await _finalize_segment()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  break
279
 
280
  if event == "segment_end":
281
  if not buffer:
282
  continue
 
 
 
 
 
283
  prompt = payload.get("prompt", prompt)
284
  return_audio = payload.get("return_audio", return_audio)
285
  llm_provider = payload.get("llm_provider", llm_provider)
 
295
  "LLM provider must be 'foundry_agent' or 'azure_openai'."
296
  ),
297
  )
298
+ if vad_stream is not None and not vad_stream.has_speech():
299
+ await websocket.send_json(
300
+ {
301
+ "event": "result",
302
+ "transcript": "",
303
+ "reply_text": NO_MATCH_REPLY,
304
+ "audio_format": "wav",
305
+ "reply_audio_base64": None,
306
+ "timings_ms": {"stt": 0, "llm": 0, "tts": 0, "total": 0},
307
+ }
308
  )
309
+ buffer.clear()
310
+ vad_stream.reset()
311
+ continue
312
+ await _finalize_segment()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  continue
314
 
315
  raise ValidationError(
src/app/services/vad.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Voice activity detection using Silero VAD (ONNX)."""
2
+ from __future__ import annotations
3
+
4
+ from array import array
5
+ from dataclasses import dataclass
6
+
7
+ from silero_vad_lite import SileroVAD
8
+
9
+
10
+ @dataclass
11
+ class VADDecision:
12
+ speech_started: bool
13
+ speech_ended: bool
14
+ speech_ms: int
15
+ silence_ms: int
16
+
17
+
18
+ class SileroVADStream:
19
+ """Streaming VAD state machine for 16kHz mono PCM."""
20
+
21
+ def __init__(
22
+ self,
23
+ sample_rate: int = 16000,
24
+ speech_threshold: float = 0.8,
25
+ min_speech_ms: int = 600,
26
+ end_silence_ms: int = 1400,
27
+ min_speech_frames: int = 2,
28
+ min_silence_frames: int = 3,
29
+ prob_smoothing: float = 0.5,
30
+ ) -> None:
31
+ self._sample_rate = sample_rate
32
+ self._frame_samples = 512 # 32ms @ 16kHz
33
+ self._frame_bytes = self._frame_samples * 2 # int16
34
+ self._vad = SileroVAD(sample_rate=sample_rate)
35
+ self._speech_threshold = speech_threshold
36
+ self._min_speech_ms = min_speech_ms
37
+ self._end_silence_ms = end_silence_ms
38
+ self._min_speech_frames = min_speech_frames
39
+ self._min_silence_frames = min_silence_frames
40
+ self._prob_smoothing = prob_smoothing
41
+
42
+ self._buffer = bytearray()
43
+ self._in_speech = False
44
+ self._speech_ms = 0
45
+ self._silence_ms = 0
46
+ self._speech_frames = 0
47
+ self._silence_frames = 0
48
+ self._prob_ema = 0.0
49
+
50
+ def reset(self) -> None:
51
+ self._buffer.clear()
52
+ self._in_speech = False
53
+ self._speech_ms = 0
54
+ self._silence_ms = 0
55
+ self._speech_frames = 0
56
+ self._silence_frames = 0
57
+ self._prob_ema = 0.0
58
+
59
+ def has_speech(self) -> bool:
60
+ return self._speech_ms >= self._min_speech_ms
61
+
62
+ def update(self, pcm_bytes: bytes) -> VADDecision:
63
+ """Feed PCM bytes and return VAD decision for the latest frames."""
64
+ self._buffer.extend(pcm_bytes)
65
+ speech_started = False
66
+ speech_ended = False
67
+
68
+ while len(self._buffer) >= self._frame_bytes:
69
+ frame = self._buffer[: self._frame_bytes]
70
+ del self._buffer[: self._frame_bytes]
71
+
72
+ samples = array("h", frame)
73
+ float32 = [s / 32768.0 for s in samples]
74
+ prob = self._vad.process(float32)
75
+ self._prob_ema = (
76
+ self._prob_ema * self._prob_smoothing
77
+ + prob * (1.0 - self._prob_smoothing)
78
+ )
79
+
80
+ if self._prob_ema >= self._speech_threshold:
81
+ self._speech_frames += 1
82
+ self._silence_frames = 0
83
+ if not self._in_speech and self._speech_frames >= self._min_speech_frames:
84
+ speech_started = True
85
+ self._in_speech = True
86
+ self._speech_ms = 0
87
+ if self._in_speech:
88
+ self._speech_ms += 32
89
+ self._silence_ms = 0
90
+ else:
91
+ self._silence_frames += 1
92
+ self._speech_frames = 0
93
+ if self._in_speech:
94
+ self._silence_ms += 32
95
+ if (
96
+ self._speech_ms >= self._min_speech_ms
97
+ and self._silence_ms >= self._end_silence_ms
98
+ and self._silence_frames >= self._min_silence_frames
99
+ ):
100
+ speech_ended = True
101
+ self._in_speech = False
102
+ self._silence_ms = 0
103
+
104
+ return VADDecision(
105
+ speech_started=speech_started,
106
+ speech_ended=speech_ended,
107
+ speech_ms=self._speech_ms,
108
+ silence_ms=self._silence_ms,
109
+ )
ui/streamlit_app.py CHANGED
@@ -939,21 +939,7 @@ html = """
939
  performance.now() - lastVoiceAt > SILENCE_MS
940
  ) {
941
  segmentInFlight = true;
942
- sendEnabled = false;
943
  setState('thinking');
944
- const avgRms = rmsCount ? rmsSum / rmsCount : 0;
945
- ws.send(JSON.stringify({
946
- event: 'segment_end',
947
- prompt: 'Answer briefly.',
948
- frames_sent: framesSent,
949
- avg_rms: avgRms,
950
- llm_provider: llmProvider
951
- }));
952
- framesSent = 0;
953
- rmsSum = 0;
954
- rmsCount = 0;
955
- hadVoice = false;
956
- lastVoiceAt = performance.now();
957
  }
958
  };
959
  source.connect(processor);
@@ -993,6 +979,9 @@ html = """
993
  sendEnabled = !isMuted;
994
  hadVoice = false;
995
  lastVoiceAt = performance.now();
 
 
 
996
  if (data.transcript) {
997
  const last = messages[messages.length - 1];
998
  if (!last || last.role !== 'user' || last.text !== data.transcript) {
@@ -1019,6 +1008,9 @@ html = """
1019
  sendEnabled = !isMuted;
1020
  hadVoice = false;
1021
  lastVoiceAt = performance.now();
 
 
 
1022
  if (isMuted && ws) ws.close();
1023
  }
1024
  };
 
939
  performance.now() - lastVoiceAt > SILENCE_MS
940
  ) {
941
  segmentInFlight = true;
 
942
  setState('thinking');
 
 
 
 
 
 
 
 
 
 
 
 
 
943
  }
944
  };
945
  source.connect(processor);
 
979
  sendEnabled = !isMuted;
980
  hadVoice = false;
981
  lastVoiceAt = performance.now();
982
+ framesSent = 0;
983
+ rmsSum = 0;
984
+ rmsCount = 0;
985
  if (data.transcript) {
986
  const last = messages[messages.length - 1];
987
  if (!last || last.role !== 'user' || last.text !== data.transcript) {
 
1008
  sendEnabled = !isMuted;
1009
  hadVoice = false;
1010
  lastVoiceAt = performance.now();
1011
+ framesSent = 0;
1012
+ rmsSum = 0;
1013
+ rmsCount = 0;
1014
  if (isMuted && ws) ws.close();
1015
  }
1016
  };