Xin Zhang commited on
Commit
391654f
·
2 Parent(s): a79d676 b67c020

Merge branch 'vad'

Browse files

* vad:
[fix]: update vad threshold.
Update model and processor files
fix segments missing error

main.py CHANGED
@@ -57,16 +57,17 @@ async def root():
57
  async def translate(websocket: WebSocket):
58
  query_parameters_dict = websocket.query_params
59
  from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to')
60
-
61
  client = WhisperTranscriptionService(
62
  websocket,
63
  pipe,
64
- language="en",
 
65
  client_uid=f"{uuid1()}",
66
  )
67
 
68
 
69
- if from_lang and to_lang:
70
  client.set_language(from_lang, to_lang)
71
  logger.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
72
  await websocket.accept()
@@ -75,7 +76,7 @@ async def translate(websocket: WebSocket):
75
  frame_data = await get_audio_from_websocket(websocket)
76
  client.add_frames(frame_data)
77
  except WebSocketDisconnect:
78
- return
79
 
80
  if __name__ == '__main__':
81
  freeze_support()
 
57
  async def translate(websocket: WebSocket):
58
  query_parameters_dict = websocket.query_params
59
  from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to')
60
+
61
  client = WhisperTranscriptionService(
62
  websocket,
63
  pipe,
64
+ language=from_lang,
65
+ dst_lang=to_lang,
66
  client_uid=f"{uuid1()}",
67
  )
68
 
69
 
70
+ if from_lang and to_lang and client:
71
  client.set_language(from_lang, to_lang)
72
  logger.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}")
73
  await websocket.accept()
 
76
  frame_data = await get_audio_from_websocket(websocket)
77
  client.add_frames(frame_data)
78
  except WebSocketDisconnect:
79
+ return
80
 
81
  if __name__ == '__main__':
82
  freeze_support()
moyoyo_asr_models/ggml-small.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3f6ef171491de375b741059400ba9a0aead023122b7a7db731b4943f9baa0f97
3
  size 487601984
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:951596a31b1c96a01b7a2b1bc511f665d900c679126134f6ec18db5ec4a485fe
3
  size 487601984
run_client.py DELETED
@@ -1,17 +0,0 @@
1
-
2
- from transcribe.client import TranscriptionClient
3
-
4
- client = TranscriptionClient(
5
- "localhost",
6
- 9090,
7
- lang="zh",
8
- dst_lang="en",
9
- save_output_recording=False, # Only used for microphone input, False by Default
10
- output_recording_filename="./output_recording.wav", # Only used for microphone input
11
- max_clients=4,
12
- max_connection_time=600,
13
- mute_audio_playback=False, # Only used for file input, False by Default
14
- )
15
-
16
- if __name__ == '__main__':
17
- client()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_server.py DELETED
@@ -1,31 +0,0 @@
1
- import argparse
2
- import os
3
-
4
- if __name__ == "__main__":
5
- parser = argparse.ArgumentParser()
6
- parser.add_argument('--port', '-p',
7
- type=int,
8
- default=9090,
9
- help="Websocket port to run the server on.")
10
- parser.add_argument('--backend', '-b',
11
- type=str,
12
- default='pywhispercpp',
13
- help='Backends from ["pywhispercpp"]')
14
-
15
- parser.add_argument('--omp_num_threads', '-omp',
16
- type=int,
17
- default=1,
18
- help="Number of threads to use for OpenMP")
19
-
20
- args = parser.parse_args()
21
-
22
- if "OMP_NUM_THREADS" not in os.environ:
23
- os.environ["OMP_NUM_THREADS"] = str(args.omp_num_threads)
24
-
25
- from transcribe.transcription import TranscriptionServer
26
- server = TranscriptionServer()
27
- server.run(
28
- "0.0.0.0",
29
- port=args.port,
30
- backend=args.backend,
31
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transcribe/helpers/vadprocessor.py CHANGED
@@ -137,7 +137,7 @@ class VADIteratorOnnx:
137
  return_seconds: bool (default - False)
138
  whether return timestamps in seconds (default - samples)
139
  """
140
-
141
  window_size_samples = 512 if self.sampling_rate == 16000 else 256
142
  x = x[:window_size_samples]
143
  if len(x) < window_size_samples:
@@ -156,7 +156,7 @@ class VADIteratorOnnx:
156
  speech_start = max(0, self.current_sample - window_size_samples)
157
  self.start = speech_start
158
  return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
159
-
160
  if (speech_prob >= self.threshold) and self.current_sample - self.start >= self.max_speech_samples:
161
  if self.temp_end:
162
  self.temp_end = 0
@@ -175,7 +175,7 @@ class VADIteratorOnnx:
175
  return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
176
 
177
  return None
178
-
179
 
180
 
181
  class VadV2:
@@ -267,24 +267,24 @@ class VadV2:
267
 
268
  return result
269
  return None
270
-
271
 
272
-
 
273
  class VadProcessor:
274
  def __init__(
275
  self,
276
  prob_threshold=0.5,
277
- silence_s=0.3,
278
- cache_s=0.25,
279
  sr=16000
280
  ):
281
- self.prob_thres = prob_threshold
282
  self.cache_s = cache_s
283
  self.sr = sr
284
  self.silence_s = silence_s
285
 
286
- self.vad = VadV2(self.prob_thres, self.sr, self.silence_s * 1000, self.cache_s * 1000, max_speech_duration_s=15)
287
-
288
 
289
  def process_audio(self, audio_buffer: np.ndarray):
290
  audio = np.array([], np.float32)
 
137
  return_seconds: bool (default - False)
138
  whether return timestamps in seconds (default - samples)
139
  """
140
+
141
  window_size_samples = 512 if self.sampling_rate == 16000 else 256
142
  x = x[:window_size_samples]
143
  if len(x) < window_size_samples:
 
156
  speech_start = max(0, self.current_sample - window_size_samples)
157
  self.start = speech_start
158
  return {'start': int(speech_start) if not return_seconds else round(speech_start / self.sampling_rate, 1)}
159
+
160
  if (speech_prob >= self.threshold) and self.current_sample - self.start >= self.max_speech_samples:
161
  if self.temp_end:
162
  self.temp_end = 0
 
175
  return {'end': int(speech_end) if not return_seconds else round(speech_end / self.sampling_rate, 1)}
176
 
177
  return None
178
+
179
 
180
 
181
  class VadV2:
 
267
 
268
  return result
269
  return None
 
270
 
271
+
272
+
273
  class VadProcessor:
274
  def __init__(
275
  self,
276
  prob_threshold=0.5,
277
+ silence_s=0.2,
278
+ cache_s=0.15,
279
  sr=16000
280
  ):
281
+ self.prob_threshold = prob_threshold
282
  self.cache_s = cache_s
283
  self.sr = sr
284
  self.silence_s = silence_s
285
 
286
+ self.vad = VadV2(self.prob_threshold, self.sr, self.silence_s * 1000, self.cache_s * 1000, max_speech_duration_s=15)
287
+
288
 
289
  def process_audio(self, audio_buffer: np.ndarray):
290
  audio = np.array([], np.float32)
transcribe/vad.py DELETED
@@ -1,164 +0,0 @@
1
- import os
2
- import subprocess
3
- import warnings
4
-
5
- import numpy as np
6
- import onnxruntime
7
- import torch
8
- import logging
9
- from config import VAD_MODEL_PATH
10
-
11
- class VoiceActivityDetection():
12
-
13
- def __init__(self, force_onnx_cpu=True):
14
- # path = self.download()
15
- path = VAD_MODEL_PATH
16
- if not os.path.exists(path):
17
- raise FileNotFoundError(f"Model file not found at {path}. Please download the model.")
18
-
19
- opts = onnxruntime.SessionOptions()
20
- opts.log_severity_level = 3
21
-
22
- opts.inter_op_num_threads = 1
23
- opts.intra_op_num_threads = 1
24
-
25
- if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
26
- self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
27
- else:
28
- self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts)
29
-
30
- self.reset_states()
31
- if '16k' in path:
32
- warnings.warn('This model support only 16000 sampling rate!')
33
- self.sample_rates = [16000]
34
- else:
35
- self.sample_rates = [8000, 16000]
36
-
37
- def _validate_input(self, x, sr: int):
38
- if x.dim() == 1:
39
- x = x.unsqueeze(0)
40
- if x.dim() > 2:
41
- raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
42
-
43
- if sr != 16000 and (sr % 16000 == 0):
44
- step = sr // 16000
45
- x = x[:, ::step]
46
- sr = 16000
47
-
48
- if sr not in self.sample_rates:
49
- raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
50
- if sr / x.shape[1] > 31.25:
51
- raise ValueError("Input audio chunk is too short")
52
-
53
- return x, sr
54
-
55
- def reset_states(self, batch_size=1):
56
- self._state = torch.zeros((2, batch_size, 128)).float()
57
- self._context = torch.zeros(0)
58
- self._last_sr = 0
59
- self._last_batch_size = 0
60
-
61
- def __call__(self, x, sr: int):
62
-
63
- x, sr = self._validate_input(x, sr)
64
- num_samples = 512 if sr == 16000 else 256
65
-
66
- if x.shape[-1] != num_samples:
67
- raise ValueError(
68
- f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
69
-
70
- batch_size = x.shape[0]
71
- context_size = 64 if sr == 16000 else 32
72
-
73
- if not self._last_batch_size:
74
- self.reset_states(batch_size)
75
- if (self._last_sr) and (self._last_sr != sr):
76
- self.reset_states(batch_size)
77
- if (self._last_batch_size) and (self._last_batch_size != batch_size):
78
- self.reset_states(batch_size)
79
-
80
- if not len(self._context):
81
- self._context = torch.zeros(batch_size, context_size)
82
-
83
- x = torch.cat([self._context, x], dim=1)
84
- if sr in [8000, 16000]:
85
- ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
86
- ort_outs = self.session.run(None, ort_inputs)
87
- out, state = ort_outs
88
- self._state = torch.from_numpy(state)
89
- else:
90
- raise ValueError()
91
-
92
- self._context = x[..., -context_size:]
93
- self._last_sr = sr
94
- self._last_batch_size = batch_size
95
-
96
- out = torch.from_numpy(out)
97
- return out
98
-
99
- def audio_forward(self, x, sr: int):
100
- outs = []
101
- x, sr = self._validate_input(x, sr)
102
- self.reset_states()
103
- num_samples = 512 if sr == 16000 else 256
104
-
105
- if x.shape[1] % num_samples:
106
- pad_num = num_samples - (x.shape[1] % num_samples)
107
- x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
108
-
109
- for i in range(0, x.shape[1], num_samples):
110
- wavs_batch = x[:, i:i + num_samples]
111
- out_chunk = self.__call__(wavs_batch, sr)
112
- outs.append(out_chunk)
113
-
114
- stacked = torch.cat(outs, dim=1)
115
- return stacked.cpu()
116
-
117
- @staticmethod
118
- def download(model_url="https://github.com/snakers4/silero-vad/raw/v5.0/files/silero_vad.onnx"):
119
- target_dir = os.path.expanduser("~/.cache/silero-vad/")
120
-
121
- # Ensure the target directory exists
122
- os.makedirs(target_dir, exist_ok=True)
123
-
124
- # Define the target file path
125
- model_filename = os.path.join(target_dir, "silero_vad.onnx")
126
-
127
- # Check if the model file already exists
128
- if not os.path.exists(model_filename):
129
- # If it doesn't exist, download the model using wget
130
- try:
131
- # subprocess.run(["wget", "-O", model_filename, model_url], check=True)
132
- subprocess.run(["curl", "-sL", "-o", model_filename, model_url], check=True)
133
- except subprocess.CalledProcessError:
134
- print("Failed to download the model using wget.")
135
- return model_filename
136
-
137
-
138
- class VoiceActivityDetector:
139
- def __init__(self, threshold=0.5, frame_rate=16000):
140
- """
141
- Initializes the VoiceActivityDetector with a voice activity detection model and a threshold.
142
-
143
- Args:
144
- threshold (float, optional): The probability threshold for detecting voice activity. Defaults to 0.5.
145
- """
146
- self.model = VoiceActivityDetection()
147
- self.threshold = threshold
148
- self.frame_rate = frame_rate
149
-
150
- def __call__(self, audio_frame):
151
- """
152
- Determines if the given audio frame contains speech by comparing the detected speech probability against
153
- the threshold.
154
-
155
- Args:
156
- audio_frame (np.ndarray): The audio frame to be analyzed for voice activity. It is expected to be a
157
- NumPy array of audio samples.
158
-
159
- Returns:
160
- bool: True if the speech probability exceeds the threshold, indicating the presence of voice activity;
161
- False otherwise.
162
- """
163
- speech_probs = self.model.audio_forward(torch.from_numpy(audio_frame.copy()), self.frame_rate)[0]
164
- return torch.any(speech_probs > self.threshold).item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transcribe/whisper_llm_serve.py CHANGED
@@ -30,7 +30,8 @@ class WhisperTranscriptionService:
30
  DISCONNECT = "DISCONNECT"
31
 
32
  def __init__(self, websocket, pipe: TranslatePipes, language=None, dst_lang=None, client_uid=None):
33
-
 
34
  self.source_language = language # 源语言
35
  self.target_language = dst_lang # 目标翻译语言
36
  self.client_uid = client_uid
@@ -40,7 +41,7 @@ class WhisperTranscriptionService:
40
 
41
  # 音频处理相关
42
  self.sample_rate = 16000
43
-
44
  self.lock = threading.Lock()
45
  self._frame_queue = queue.Queue()
46
  self._vad_frame_queue = queue.Queue()
@@ -49,7 +50,7 @@ class WhisperTranscriptionService:
49
  self.text_separator = self._get_text_separator(language)
50
  self.loop = asyncio.get_event_loop()
51
  # 发送就绪状态
52
-
53
  self._transcrible_analysis = None
54
  # 启动处理线程
55
  self._translate_thread_stop = threading.Event()
@@ -57,7 +58,10 @@ class WhisperTranscriptionService:
57
 
58
  self.translate_thread = self._start_thread(self._transcription_processing_loop)
59
  self.frame_processing_thread = self._start_thread(self._frame_processing_loop)
60
- self._vad = VadProcessor()
 
 
 
61
  self.row_number = 0
62
  # for test
63
  self._transcrible_time_cost = 0.
@@ -66,9 +70,9 @@ class WhisperTranscriptionService:
66
  self._test_task_stop = threading.Event()
67
  self._test_queue = queue.Queue()
68
  self._test_thread = self._start_thread(self.test_data_loop)
69
-
70
  # self._c = 0
71
-
72
  def test_data_loop(self):
73
  writer = TestDataWriter()
74
  while not self._test_task_stop.is_set():
@@ -99,7 +103,7 @@ class WhisperTranscriptionService:
99
  """设置源语言和目标语言"""
100
  self.source_language = source_lang
101
  self.target_language = target_lang
102
- # self.text_separator = self._get_text_separator(source_lang)
103
  # self._transcrible_analysis = TranscriptStabilityAnalyzer(self.source_language, self.text_separator)
104
 
105
  def add_frames(self, frame_np: np.ndarray) -> None:
@@ -179,7 +183,7 @@ class WhisperTranscriptionService:
179
  if audio_buffer is None or len(audio_buffer) < int(self.sample_rate):
180
  time.sleep(0.2)
181
  continue
182
-
183
  logger.debug(f"audio buffer size: {len(audio_buffer) / self.sample_rate:.2f}s")
184
  # try:
185
  meta_item = self._transcribe_audio(audio_buffer)
@@ -197,13 +201,13 @@ class WhisperTranscriptionService:
197
  # logger.error(f"Error processing audio: {e}")
198
 
199
  def _process_transcription_results_2(self, segments: List[TranscriptToken],):
200
- seg = segments[0]
201
  item = TransResult(
202
  seg_id=self.row_number,
203
- context=seg.text,
204
  from_=self.source_language,
205
  to=self.target_language,
206
- tran_content=self._translate_text_large(seg.text),
207
  partial=False
208
  )
209
  self.row_number += 1
 
30
  DISCONNECT = "DISCONNECT"
31
 
32
  def __init__(self, websocket, pipe: TranslatePipes, language=None, dst_lang=None, client_uid=None):
33
+ print('>>>>>>>>>>>>>>>> init service >>>>>>>>>>>>>>>>>>>>>>')
34
+ print('src_lang:', language)
35
  self.source_language = language # 源语言
36
  self.target_language = dst_lang # 目标翻译语言
37
  self.client_uid = client_uid
 
41
 
42
  # 音频处理相关
43
  self.sample_rate = 16000
44
+
45
  self.lock = threading.Lock()
46
  self._frame_queue = queue.Queue()
47
  self._vad_frame_queue = queue.Queue()
 
50
  self.text_separator = self._get_text_separator(language)
51
  self.loop = asyncio.get_event_loop()
52
  # 发送就绪状态
53
+
54
  self._transcrible_analysis = None
55
  # 启动处理线程
56
  self._translate_thread_stop = threading.Event()
 
58
 
59
  self.translate_thread = self._start_thread(self._transcription_processing_loop)
60
  self.frame_processing_thread = self._start_thread(self._frame_processing_loop)
61
+ if language == "zh":
62
+ self._vad = VadProcessor(prob_threshold=0.8, silence_s=0.2, cache_s=0.15)
63
+ else:
64
+ self._vad = VadProcessor(prob_threshold=0.7, silence_s=0.2, cache_s=0.15)
65
  self.row_number = 0
66
  # for test
67
  self._transcrible_time_cost = 0.
 
70
  self._test_task_stop = threading.Event()
71
  self._test_queue = queue.Queue()
72
  self._test_thread = self._start_thread(self.test_data_loop)
73
+
74
  # self._c = 0
75
+
76
  def test_data_loop(self):
77
  writer = TestDataWriter()
78
  while not self._test_task_stop.is_set():
 
103
  """设置源语言和目标语言"""
104
  self.source_language = source_lang
105
  self.target_language = target_lang
106
+ self.text_separator = self._get_text_separator(source_lang)
107
  # self._transcrible_analysis = TranscriptStabilityAnalyzer(self.source_language, self.text_separator)
108
 
109
  def add_frames(self, frame_np: np.ndarray) -> None:
 
183
  if audio_buffer is None or len(audio_buffer) < int(self.sample_rate):
184
  time.sleep(0.2)
185
  continue
186
+
187
  logger.debug(f"audio buffer size: {len(audio_buffer) / self.sample_rate:.2f}s")
188
  # try:
189
  meta_item = self._transcribe_audio(audio_buffer)
 
201
  # logger.error(f"Error processing audio: {e}")
202
 
203
  def _process_transcription_results_2(self, segments: List[TranscriptToken],):
204
+ seg_text = self.text_separator.join(seg.text for seg in segments)
205
  item = TransResult(
206
  seg_id=self.row_number,
207
+ context=seg_text,
208
  from_=self.source_language,
209
  to=self.target_language,
210
+ tran_content=self._translate_text_large(seg_text),
211
  partial=False
212
  )
213
  self.row_number += 1
transcribe/whispercpp_serve.py DELETED
@@ -1,383 +0,0 @@
1
-
2
- import json
3
- import logging
4
- import pathlib
5
- import threading
6
- import time
7
- import config
8
- import librosa
9
- import numpy as np
10
- import soundfile
11
- from pywhispercpp.model import Model
12
-
13
- logging.basicConfig(level=logging.INFO)
14
-
15
- class ServeClientBase(object):
16
- RATE = 16000
17
- SERVER_READY = "SERVER_READY"
18
- DISCONNECT = "DISCONNECT"
19
-
20
- def __init__(self, client_uid, websocket):
21
- self.client_uid = client_uid
22
- self.websocket = websocket
23
- self.frames = b""
24
- self.timestamp_offset = 0.0
25
- self.frames_np = None
26
- self.frames_offset = 0.0
27
- self.text = []
28
- self.current_out = ''
29
- self.prev_out = ''
30
- self.t_start = None
31
- self.exit = False
32
- self.same_output_count = 0
33
- self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
34
- self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
35
- self.transcript = []
36
- self.send_last_n_segments = 10
37
-
38
- # text formatting
39
- self.pick_previous_segments = 2
40
-
41
- # threading
42
- self.lock = threading.Lock()
43
-
44
- def speech_to_text(self):
45
- raise NotImplementedError
46
-
47
- def transcribe_audio(self):
48
- raise NotImplementedError
49
-
50
- def handle_transcription_output(self):
51
- raise NotImplementedError
52
-
53
- def add_frames(self, frame_np):
54
- """
55
- Add audio frames to the ongoing audio stream buffer.
56
-
57
- This method is responsible for maintaining the audio stream buffer, allowing the continuous addition
58
- of audio frames as they are received. It also ensures that the buffer does not exceed a specified size
59
- to prevent excessive memory usage.
60
-
61
- If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds
62
- of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided
63
- audio frame. The audio stream buffer is used for real-time processing of audio data for transcription.
64
-
65
- Args:
66
- frame_np (numpy.ndarray): The audio frame data as a NumPy array.
67
-
68
- """
69
- self.lock.acquire()
70
- if self.frames_np is not None and self.frames_np.shape[0] > 45 * self.RATE:
71
- self.frames_offset += 30.0
72
- self.frames_np = self.frames_np[int(30 * self.RATE):]
73
- # check timestamp offset(should be >= self.frame_offset)
74
- # this basically means that there is no speech as timestamp offset hasnt updated
75
- # and is less than frame_offset
76
- if self.timestamp_offset < self.frames_offset:
77
- self.timestamp_offset = self.frames_offset
78
- if self.frames_np is None:
79
- self.frames_np = frame_np.copy()
80
- else:
81
- self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
82
- self.lock.release()
83
-
84
- def clip_audio_if_no_valid_segment(self):
85
- """
86
- Update the timestamp offset based on audio buffer status.
87
- Clip audio if the current chunk exceeds 30 seconds, this basically implies that
88
- no valid segment for the last 30 seconds from whisper
89
- """
90
- with self.lock:
91
- if self.frames_np[int((self.timestamp_offset - self.frames_offset) * self.RATE):].shape[0] > 25 * self.RATE:
92
- duration = self.frames_np.shape[0] / self.RATE
93
- self.timestamp_offset = self.frames_offset + duration - 5
94
-
95
- def get_audio_chunk_for_processing(self):
96
- """
97
- Retrieves the next chunk of audio data for processing based on the current offsets.
98
-
99
- Calculates which part of the audio data should be processed next, based on
100
- the difference between the current timestamp offset and the frame's offset, scaled by
101
- the audio sample rate (RATE). It then returns this chunk of audio data along with its
102
- duration in seconds.
103
-
104
- Returns:
105
- tuple: A tuple containing:
106
- - input_bytes (np.ndarray): The next chunk of audio data to be processed.
107
- - duration (float): The duration of the audio chunk in seconds.
108
- """
109
- with self.lock:
110
- samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE)
111
- input_bytes = self.frames_np[int(samples_take):].copy()
112
- duration = input_bytes.shape[0] / self.RATE
113
- return input_bytes, duration
114
-
115
- def prepare_segments(self, last_segment=None):
116
- """
117
- Prepares the segments of transcribed text to be sent to the client.
118
-
119
- This method compiles the recent segments of transcribed text, ensuring that only the
120
- specified number of the most recent segments are included. It also appends the most
121
- recent segment of text if provided (which is considered incomplete because of the possibility
122
- of the last word being truncated in the audio chunk).
123
-
124
- Args:
125
- last_segment (str, optional): The most recent segment of transcribed text to be added
126
- to the list of segments. Defaults to None.
127
-
128
- Returns:
129
- list: A list of transcribed text segments to be sent to the client.
130
- """
131
- segments = []
132
- if len(self.transcript) >= self.send_last_n_segments:
133
- segments = self.transcript[-self.send_last_n_segments:].copy()
134
- else:
135
- segments = self.transcript.copy()
136
- if last_segment is not None:
137
- segments = segments + [last_segment]
138
- logging.info(f"{segments}")
139
- return segments
140
-
141
- def get_audio_chunk_duration(self, input_bytes):
142
- """
143
- Calculates the duration of the provided audio chunk.
144
-
145
- Args:
146
- input_bytes (numpy.ndarray): The audio chunk for which to calculate the duration.
147
-
148
- Returns:
149
- float: The duration of the audio chunk in seconds.
150
- """
151
- return input_bytes.shape[0] / self.RATE
152
-
153
- def send_transcription_to_client(self, segments):
154
- """
155
- Sends the specified transcription segments to the client over the websocket connection.
156
-
157
- This method formats the transcription segments into a JSON object and attempts to send
158
- this object to the client. If an error occurs during the send operation, it logs the error.
159
-
160
- Returns:
161
- segments (list): A list of transcription segments to be sent to the client.
162
- """
163
- try:
164
- self.websocket.send(
165
- json.dumps({
166
- "uid": self.client_uid,
167
- "segments": segments,
168
- })
169
- )
170
- except Exception as e:
171
- logging.error(f"[ERROR]: Sending data to client: {e}")
172
-
173
- def disconnect(self):
174
- """
175
- Notify the client of disconnection and send a disconnect message.
176
-
177
- This method sends a disconnect message to the client via the WebSocket connection to notify them
178
- that the transcription service is disconnecting gracefully.
179
-
180
- """
181
- self.websocket.send(json.dumps({
182
- "uid": self.client_uid,
183
- "message": self.DISCONNECT
184
- }))
185
-
186
- def cleanup(self):
187
- """
188
- Perform cleanup tasks before exiting the transcription service.
189
-
190
- This method performs necessary cleanup tasks, including stopping the transcription thread, marking
191
- the exit flag to indicate the transcription thread should exit gracefully, and destroying resources
192
- associated with the transcription process.
193
-
194
- """
195
- logging.info("Cleaning up.")
196
- self.exit = True
197
-
198
-
199
- class ServeClientWhisperCPP(ServeClientBase):
200
- SINGLE_MODEL = None
201
- SINGLE_MODEL_LOCK = threading.Lock()
202
-
203
- def __init__(self, websocket, language=None, client_uid=None,
204
- single_model=False):
205
- """
206
- Initialize a ServeClient instance.
207
- The Whisper model is initialized based on the client's language and device availability.
208
- The transcription thread is started upon initialization. A "SERVER_READY" message is sent
209
- to the client to indicate that the server is ready.
210
-
211
- Args:
212
- websocket (WebSocket): The WebSocket connection for the client.
213
- language (str, optional): The language for transcription. Defaults to None.
214
- client_uid (str, optional): A unique identifier for the client. Defaults to None.
215
- single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
216
-
217
- """
218
- super().__init__(client_uid, websocket)
219
- self.language = language
220
- self.eos = False
221
-
222
- if single_model:
223
- if ServeClientWhisperCPP.SINGLE_MODEL is None:
224
- self.create_model()
225
- ServeClientWhisperCPP.SINGLE_MODEL = self.transcriber
226
- else:
227
- self.transcriber = ServeClientWhisperCPP.SINGLE_MODEL
228
- else:
229
- self.create_model()
230
-
231
- # threading
232
- logging.info('Create a thread to process audio.')
233
- self.trans_thread = threading.Thread(target=self.speech_to_text)
234
- self.trans_thread.start()
235
-
236
- self.websocket.send(json.dumps({
237
- "uid": self.client_uid,
238
- "message": self.SERVER_READY,
239
- "backend": "pywhispercpp"
240
- }))
241
-
242
- def create_model(self, warmup=True):
243
- """
244
- Instantiates a new model, sets it as the transcriber and does warmup if desired.
245
- """
246
-
247
- self.transcriber = Model(model=config.WHISPER_MODEL, models_dir=config.MODEL_DIR)
248
- if warmup:
249
- self.warmup()
250
-
251
- def warmup(self, warmup_steps=1):
252
- """
253
- Warmup TensorRT since first few inferences are slow.
254
-
255
- Args:
256
- warmup_steps (int): Number of steps to warm up the model for.
257
- """
258
- logging.info("[INFO:] Warming up whisper.cpp engine..")
259
- mel, _, = soundfile.read("assets/jfk.flac")
260
- for i in range(warmup_steps):
261
- self.transcriber.transcribe(mel, print_progress=False)
262
-
263
- def set_eos(self, eos):
264
- """
265
- Sets the End of Speech (EOS) flag.
266
-
267
- Args:
268
- eos (bool): The value to set for the EOS flag.
269
- """
270
- self.lock.acquire()
271
- self.eos = eos
272
- self.lock.release()
273
-
274
- def handle_transcription_output(self, last_segment, duration):
275
- """
276
- Handle the transcription output, updating the transcript and sending data to the client.
277
-
278
- Args:
279
- last_segment (str): The last segment from the whisper output which is considered to be incomplete because
280
- of the possibility of word being truncated.
281
- duration (float): Duration of the transcribed audio chunk.
282
- """
283
- segments = self.prepare_segments({"text": last_segment})
284
- self.send_transcription_to_client(segments)
285
- if self.eos:
286
- self.update_timestamp_offset(last_segment, duration)
287
-
288
- def transcribe_audio(self, input_bytes):
289
- """
290
- Transcribe the audio chunk and send the results to the client.
291
-
292
- Args:
293
- input_bytes (np.array): The audio chunk to transcribe.
294
- """
295
- if ServeClientWhisperCPP.SINGLE_MODEL:
296
- ServeClientWhisperCPP.SINGLE_MODEL_LOCK.acquire()
297
- logging.info(f"[pywhispercpp:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}")
298
- mel = input_bytes
299
- duration = librosa.get_duration(y=input_bytes, sr=self.RATE)
300
-
301
- if self.language == "zh":
302
- prompt = '以下是简体中文普通话的句子。'
303
- else:
304
- prompt = ''
305
-
306
- segments = self.transcriber.transcribe(
307
- mel,
308
- language=self.language,
309
- initial_prompt=prompt,
310
- token_timestamps=True,
311
- # max_len=max_len,
312
- print_progress=False
313
- )
314
- text = []
315
- for segment in segments:
316
- content = segment.text
317
- text.append(content)
318
- last_segment = ' '.join(text)
319
-
320
- logging.info(f"[pywhispercpp:] Last segment: {last_segment}")
321
-
322
- if ServeClientWhisperCPP.SINGLE_MODEL:
323
- ServeClientWhisperCPP.SINGLE_MODEL_LOCK.release()
324
- if last_segment:
325
- self.handle_transcription_output(last_segment, duration)
326
-
327
- def update_timestamp_offset(self, last_segment, duration):
328
- """
329
- Update timestamp offset and transcript.
330
-
331
- Args:
332
- last_segment (str): Last transcribed audio from the whisper model.
333
- duration (float): Duration of the last audio chunk.
334
- """
335
- if not len(self.transcript):
336
- self.transcript.append({"text": last_segment + " "})
337
- elif self.transcript[-1]["text"].strip() != last_segment:
338
- self.transcript.append({"text": last_segment + " "})
339
-
340
- logging.info(f'Transcript list context: {self.transcript}')
341
-
342
- with self.lock:
343
- self.timestamp_offset += duration
344
-
345
- def speech_to_text(self):
346
- """
347
- Process an audio stream in an infinite loop, continuously transcribing the speech.
348
-
349
- This method continuously receives audio frames, performs real-time transcription, and sends
350
- transcribed segments to the client via a WebSocket connection.
351
-
352
- If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
353
- It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
354
- are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech
355
- (no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if
356
- there is no speech for a specified duration to indicate a pause.
357
-
358
- Raises:
359
- Exception: If there is an issue with audio processing or WebSocket communication.
360
-
361
- """
362
- while True:
363
- if self.exit:
364
- logging.info("Exiting speech to text thread")
365
- break
366
-
367
- if self.frames_np is None:
368
- time.sleep(0.02) # wait for any audio to arrive
369
- continue
370
-
371
- self.clip_audio_if_no_valid_segment()
372
-
373
- input_bytes, duration = self.get_audio_chunk_for_processing()
374
- if duration < 1:
375
- continue
376
-
377
- try:
378
- input_sample = input_bytes.copy()
379
- logging.info(f"[pywhispercpp:] Processing audio with duration: {duration}")
380
- self.transcribe_audio(input_sample)
381
-
382
- except Exception as e:
383
- logging.error(f"[ERROR]: {e}")