Xin Zhang commited on
Commit
b67c020
·
1 Parent(s): e03f21e

[fix]: update vad threshold.

Browse files
main.py CHANGED
@@ -57,16 +57,17 @@ async def root():
57
  async def translate(websocket: WebSocket):
58
  query_parameters_dict = websocket.query_params
59
  from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to')
60
-
61
  client = WhisperTranscriptionService(
62
  websocket,
63
  pipe,
64
- language="en",
 
65
  client_uid=f"{uuid1()}",
66
  )
67
 
68
 
69
- if from_lang and to_lang:
70
  client.set_language(from_lang, to_lang)
71
  logger.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
72
  await websocket.accept()
@@ -75,7 +76,7 @@ async def translate(websocket: WebSocket):
75
  frame_data = await get_audio_from_websocket(websocket)
76
  client.add_frames(frame_data)
77
  except WebSocketDisconnect:
78
- return
79
 
80
  if __name__ == '__main__':
81
  freeze_support()
 
57
  async def translate(websocket: WebSocket):
58
  query_parameters_dict = websocket.query_params
59
  from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to')
60
+
61
  client = WhisperTranscriptionService(
62
  websocket,
63
  pipe,
64
+ language=from_lang,
65
+ dst_lang=to_lang,
66
  client_uid=f"{uuid1()}",
67
  )
68
 
69
 
70
+ if from_lang and to_lang and client:
71
  client.set_language(from_lang, to_lang)
72
  logger.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
73
  await websocket.accept()
 
76
  frame_data = await get_audio_from_websocket(websocket)
77
  client.add_frames(frame_data)
78
  except WebSocketDisconnect:
79
+ return
80
 
81
  if __name__ == '__main__':
82
  freeze_support()
transcribe/helpers/vadprocessor.py CHANGED
@@ -278,12 +278,12 @@ class VadProcessor:
278
  cache_s=0.15,
279
  sr=16000
280
  ):
281
- self.prob_thres = prob_threshold
282
  self.cache_s = cache_s
283
  self.sr = sr
284
  self.silence_s = silence_s
285
 
286
- self.vad = VadV2(self.prob_thres, self.sr, self.silence_s * 1000, self.cache_s * 1000, max_speech_duration_s=15)
287
 
288
 
289
  def process_audio(self, audio_buffer: np.ndarray):
 
278
  cache_s=0.15,
279
  sr=16000
280
  ):
281
+ self.prob_threshold = prob_threshold
282
  self.cache_s = cache_s
283
  self.sr = sr
284
  self.silence_s = silence_s
285
 
286
+ self.vad = VadV2(self.prob_threshold, self.sr, self.silence_s * 1000, self.cache_s * 1000, max_speech_duration_s=15)
287
 
288
 
289
  def process_audio(self, audio_buffer: np.ndarray):
transcribe/whisper_llm_serve.py CHANGED
@@ -30,7 +30,8 @@ class WhisperTranscriptionService:
30
  DISCONNECT = "DISCONNECT"
31
 
32
  def __init__(self, websocket, pipe: TranslatePipes, language=None, dst_lang=None, client_uid=None):
33
-
 
34
  self.source_language = language # 源语言
35
  self.target_language = dst_lang # 目标翻译语言
36
  self.client_uid = client_uid
@@ -40,7 +41,7 @@ class WhisperTranscriptionService:
40
 
41
  # 音频处理相关
42
  self.sample_rate = 16000
43
-
44
  self.lock = threading.Lock()
45
  self._frame_queue = queue.Queue()
46
  self._vad_frame_queue = queue.Queue()
@@ -49,7 +50,7 @@ class WhisperTranscriptionService:
49
  self.text_separator = self._get_text_separator(language)
50
  self.loop = asyncio.get_event_loop()
51
  # 发送就绪状态
52
-
53
  self._transcrible_analysis = None
54
  # 启动处理线程
55
  self._translate_thread_stop = threading.Event()
@@ -57,7 +58,10 @@ class WhisperTranscriptionService:
57
 
58
  self.translate_thread = self._start_thread(self._transcription_processing_loop)
59
  self.frame_processing_thread = self._start_thread(self._frame_processing_loop)
60
- self._vad = VadProcessor()
 
 
 
61
  self.row_number = 0
62
  # for test
63
  self._transcrible_time_cost = 0.
@@ -66,9 +70,9 @@ class WhisperTranscriptionService:
66
  self._test_task_stop = threading.Event()
67
  self._test_queue = queue.Queue()
68
  self._test_thread = self._start_thread(self.test_data_loop)
69
-
70
  # self._c = 0
71
-
72
  def test_data_loop(self):
73
  writer = TestDataWriter()
74
  while not self._test_task_stop.is_set():
@@ -179,7 +183,7 @@ class WhisperTranscriptionService:
179
  if audio_buffer is None or len(audio_buffer) < int(self.sample_rate):
180
  time.sleep(0.2)
181
  continue
182
-
183
  logger.debug(f"audio buffer size: {len(audio_buffer) / self.sample_rate:.2f}s")
184
  # try:
185
  meta_item = self._transcribe_audio(audio_buffer)
 
30
  DISCONNECT = "DISCONNECT"
31
 
32
  def __init__(self, websocket, pipe: TranslatePipes, language=None, dst_lang=None, client_uid=None):
33
+ print('>>>>>>>>>>>>>>>> init service >>>>>>>>>>>>>>>>>>>>>>')
34
+ print('src_lang:', language)
35
  self.source_language = language # 源语言
36
  self.target_language = dst_lang # 目标翻译语言
37
  self.client_uid = client_uid
 
41
 
42
  # 音频处理相关
43
  self.sample_rate = 16000
44
+
45
  self.lock = threading.Lock()
46
  self._frame_queue = queue.Queue()
47
  self._vad_frame_queue = queue.Queue()
 
50
  self.text_separator = self._get_text_separator(language)
51
  self.loop = asyncio.get_event_loop()
52
  # 发送就绪状态
53
+
54
  self._transcrible_analysis = None
55
  # 启动处理线程
56
  self._translate_thread_stop = threading.Event()
 
58
 
59
  self.translate_thread = self._start_thread(self._transcription_processing_loop)
60
  self.frame_processing_thread = self._start_thread(self._frame_processing_loop)
61
+ if language == "zh":
62
+ self._vad = VadProcessor(prob_threshold=0.8, silence_s=0.2, cache_s=0.15)
63
+ else:
64
+ self._vad = VadProcessor(prob_threshold=0.7, silence_s=0.2, cache_s=0.15)
65
  self.row_number = 0
66
  # for test
67
  self._transcrible_time_cost = 0.
 
70
  self._test_task_stop = threading.Event()
71
  self._test_queue = queue.Queue()
72
  self._test_thread = self._start_thread(self.test_data_loop)
73
+
74
  # self._c = 0
75
+
76
  def test_data_loop(self):
77
  writer = TestDataWriter()
78
  while not self._test_task_stop.is_set():
 
183
  if audio_buffer is None or len(audio_buffer) < int(self.sample_rate):
184
  time.sleep(0.2)
185
  continue
186
+
187
  logger.debug(f"audio buffer size: {len(audio_buffer) / self.sample_rate:.2f}s")
188
  # try:
189
  meta_item = self._transcribe_audio(audio_buffer)