COldish-Ayako commited on
Commit
fcd58ee
·
unverified ·
1 Parent(s): ae2f86a

Upload 19 files

Browse files
.gitattributes CHANGED
@@ -33,4 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
36
  *.icns filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/jfk.flac filter=lfs diff=lfs merge=lfs -text
37
  *.icns filter=lfs diff=lfs merge=lfs -text
assets/jfk.flac ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63a4b1e4c1dc655ac70961ffbf518acd249df237e5a0152faae9a4a836949715
3
+ size 1152693
moyoyo_asr_models/ggml-medium-encoder.mlmodelc/analytics/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:adbe456375e7eb3407732a426ecb65bbda86860e4aa801f3a696b70b8a533cdd
3
+ size 207
moyoyo_asr_models/ggml-medium-encoder.mlmodelc/coremldata.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05fe28591b40616fa0c34ad7b853133623f5300923ec812acb11459c411acf3b
3
+ size 149
moyoyo_asr_models/ggml-medium-encoder.mlmodelc/metadata.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "metadataOutputVersion" : "3.0",
4
+ "storagePrecision" : "Float16",
5
+ "outputSchema" : [
6
+ {
7
+ "hasShapeFlexibility" : "0",
8
+ "isOptional" : "0",
9
+ "dataType" : "Float32",
10
+ "formattedType" : "MultiArray (Float32)",
11
+ "shortDescription" : "",
12
+ "shape" : "[]",
13
+ "name" : "output",
14
+ "type" : "MultiArray"
15
+ }
16
+ ],
17
+ "modelParameters" : [
18
+
19
+ ],
20
+ "specificationVersion" : 6,
21
+ "mlProgramOperationTypeHistogram" : {
22
+ "Linear" : 144,
23
+ "Matmul" : 48,
24
+ "Cast" : 2,
25
+ "Conv" : 2,
26
+ "Softmax" : 24,
27
+ "Add" : 49,
28
+ "LayerNorm" : 49,
29
+ "Mul" : 48,
30
+ "Transpose" : 97,
31
+ "Gelu" : 26,
32
+ "Reshape" : 96
33
+ },
34
+ "computePrecision" : "Mixed (Float16, Float32, Int32)",
35
+ "isUpdatable" : "0",
36
+ "availability" : {
37
+ "macOS" : "12.0",
38
+ "tvOS" : "15.0",
39
+ "watchOS" : "8.0",
40
+ "iOS" : "15.0",
41
+ "macCatalyst" : "15.0"
42
+ },
43
+ "modelType" : {
44
+ "name" : "MLModelType_mlProgram"
45
+ },
46
+ "userDefinedMetadata" : {
47
+
48
+ },
49
+ "inputSchema" : [
50
+ {
51
+ "hasShapeFlexibility" : "0",
52
+ "isOptional" : "0",
53
+ "dataType" : "Float32",
54
+ "formattedType" : "MultiArray (Float32 1 × 80 × 3000)",
55
+ "shortDescription" : "",
56
+ "shape" : "[1, 80, 3000]",
57
+ "name" : "logmel_data",
58
+ "type" : "MultiArray"
59
+ }
60
+ ],
61
+ "generatedClassName" : "coreml_encoder_medium",
62
+ "method" : "predict"
63
+ }
64
+ ]
moyoyo_asr_models/ggml-medium-encoder.mlmodelc/model.mil ADDED
The diff for this file is too large to render. See raw diff
 
moyoyo_asr_models/ggml-medium-encoder.mlmodelc/weights/weight.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a188b0e4e3109f28f38f1f47ea2497ffe623923419df8e1ae12cb5f809a1815
3
+ size 614507008
moyoyo_asr_models/ggml-medium-q5_0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:19fea4b380c3a618ec4723c3eef2eb785ffba0d0538cf43f8f235e7b3b34220f
3
+ size 539212467
run_client.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transcribe.client import TranscriptionClient
2
+
3
+ client = TranscriptionClient(
4
+ "localhost",
5
+ 9000,
6
+ lang="zh",
7
+ save_output_recording=False, # Only used for microphone input, False by Default
8
+ output_recording_filename="./output_recording.wav", # Only used for microphone input
9
+ max_clients=4,
10
+ max_connection_time=600,
11
+ mute_audio_playback=False, # Only used for file input, False by Default
12
+ )
13
+
14
+ if __name__ == '__main__':
15
+ client()
run_server.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ if __name__ == "__main__":
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument('--port', '-p',
7
+ type=int,
8
+ default=9090,
9
+ help="Websocket port to run the server on.")
10
+ parser.add_argument('--backend', '-b',
11
+ type=str,
12
+ default='pywhispercpp',
13
+ help='Backends from ["pywhispercpp"]')
14
+
15
+ parser.add_argument('--omp_num_threads', '-omp',
16
+ type=int,
17
+ default=1,
18
+ help="Number of threads to use for OpenMP")
19
+
20
+ args = parser.parse_args()
21
+
22
+ if "OMP_NUM_THREADS" not in os.environ:
23
+ os.environ["OMP_NUM_THREADS"] = str(args.omp_num_threads)
24
+
25
+ from transcribe.server import TranscriptionServer
26
+ server = TranscriptionServer()
27
+ server.run(
28
+ "0.0.0.0",
29
+ port=args.port,
30
+ backend=args.backend,
31
+ )
transcribe/__init__.py ADDED
File without changes
transcribe/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (183 Bytes). View file
 
transcribe/__pycache__/client.cpython-311.pyc ADDED
Binary file (39 kB). View file
 
transcribe/__pycache__/server.cpython-311.pyc ADDED
Binary file (36 kB). View file
 
transcribe/__pycache__/utils.cpython-311.pyc ADDED
Binary file (4.64 kB). View file
 
transcribe/__pycache__/vad.cpython-311.pyc ADDED
Binary file (9.36 kB). View file
 
transcribe/client.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import shutil
4
+ import threading
5
+ import time
6
+ import uuid
7
+ import wave
8
+
9
+ import av
10
+ import numpy as np
11
+ import pyaudio
12
+ import websocket
13
+
14
+ import transcribe.utils as utils
15
+
16
+
17
+ class Client:
18
+ """
19
+ Handles communication with a server using WebSocket.
20
+ """
21
+ INSTANCES = {}
22
+ END_OF_AUDIO = "END_OF_AUDIO"
23
+
24
+ def __init__(
25
+ self,
26
+ host=None,
27
+ port=None,
28
+ lang=None,
29
+ log_transcription=True,
30
+ max_clients=4,
31
+ max_connection_time=600,
32
+ ):
33
+ """
34
+ Initializes a Client instance for audio recording and streaming to a server.
35
+
36
+ If host and port are not provided, the WebSocket connection will not be established.
37
+ the audio recording starts immediately upon initialization.
38
+
39
+ Args:
40
+ host (str): The hostname or IP address of the server.
41
+ port (int): The port number for the WebSocket server.
42
+ lang (str, optional): The selected language for transcription. Default is None.
43
+ log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
44
+ max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
45
+ max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
46
+ """
47
+ self.recording = False
48
+ self.uid = str(uuid.uuid4())
49
+ self.waiting = False
50
+ self.last_response_received = None
51
+ self.disconnect_if_no_response_for = 15
52
+ self.language = lang
53
+ self.server_error = False
54
+ self.last_segment = None
55
+ self.last_received_segment = None
56
+ self.log_transcription = log_transcription
57
+ self.max_clients = max_clients
58
+ self.max_connection_time = max_connection_time
59
+
60
+
61
+ self.audio_bytes = None
62
+
63
+ if host is not None and port is not None:
64
+ socket_url = f"ws://{host}:{port}"
65
+ self.client_socket = websocket.WebSocketApp(
66
+ socket_url,
67
+ on_open=lambda ws: self.on_open(ws),
68
+ on_message=lambda ws, message: self.on_message(ws, message),
69
+ on_error=lambda ws, error: self.on_error(ws, error),
70
+ on_close=lambda ws, close_status_code, close_msg: self.on_close(
71
+ ws, close_status_code, close_msg
72
+ ),
73
+ )
74
+ else:
75
+ print("[ERROR]: No host or port specified.")
76
+ return
77
+
78
+ Client.INSTANCES[self.uid] = self
79
+
80
+ # start websocket client in a thread
81
+ self.ws_thread = threading.Thread(target=self.client_socket.run_forever)
82
+ self.ws_thread.daemon = True
83
+ self.ws_thread.start()
84
+
85
+ self.transcript = []
86
+ print("[INFO]: * recording")
87
+
88
+ def handle_status_messages(self, message_data):
89
+ """Handles server status messages."""
90
+ status = message_data["status"]
91
+ if status == "WAIT":
92
+ self.waiting = True
93
+ print(f"[INFO]: Server is full. Estimated wait time {round(message_data['message'])} minutes.")
94
+ elif status == "ERROR":
95
+ print(f"Message from Server: {message_data['message']}")
96
+ self.server_error = True
97
+ elif status == "WARNING":
98
+ print(f"Message from Server: {message_data['message']}")
99
+
100
+ def process_segments(self, segments):
101
+ """Processes transcript segments."""
102
+ text = []
103
+ for i, seg in enumerate(segments):
104
+ if not text or text[-1] != seg["text"]:
105
+ text.append(seg["text"])
106
+ if i == len(segments) - 1 and not seg.get("completed", False):
107
+ self.last_segment = seg
108
+
109
+ # update last received segment and last valid response time
110
+ if self.last_received_segment is None or self.last_received_segment != segments[-1]["text"]:
111
+ self.last_response_received = time.time()
112
+ self.last_received_segment = segments[-1]["text"]
113
+
114
+ if self.log_transcription:
115
+ # Truncate to last 3 entries for brevity.
116
+ text = text[-3:]
117
+ utils.clear_screen()
118
+ utils.print_transcript(text)
119
+
120
+ def on_message(self, ws, message):
121
+ """
122
+ Callback function called when a message is received from the server.
123
+
124
+ It updates various attributes of the client based on the received message, including
125
+ recording status, language detection, and server messages. If a disconnect message
126
+ is received, it sets the recording status to False.
127
+
128
+ Args:
129
+ ws (websocket.WebSocketApp): The WebSocket client instance.
130
+ message (str): The received message from the server.
131
+
132
+ """
133
+ message = json.loads(message)
134
+
135
+ if self.uid != message.get("uid"):
136
+ print("[ERROR]: invalid client uid")
137
+ return
138
+
139
+ if "status" in message.keys():
140
+ self.handle_status_messages(message)
141
+ return
142
+
143
+ if "message" in message.keys() and message["message"] == "DISCONNECT":
144
+ print("[INFO]: Server disconnected due to overtime.")
145
+ self.recording = False
146
+
147
+ if "message" in message.keys() and message["message"] == "SERVER_READY":
148
+ self.last_response_received = time.time()
149
+ self.recording = True
150
+ self.server_backend = message["backend"]
151
+ print(f"[INFO]: Server Running with backend {self.server_backend}")
152
+ return
153
+
154
+ if "language" in message.keys():
155
+ self.language = message.get("language")
156
+ lang_prob = message.get("language_prob")
157
+ print(
158
+ f"[INFO]: Server detected language {self.language} with probability {lang_prob}"
159
+ )
160
+ return
161
+
162
+ if "segments" in message.keys():
163
+ self.process_segments(message["segments"])
164
+
165
+ def on_error(self, ws, error):
166
+ print(f"[ERROR] WebSocket Error: {error}")
167
+ self.server_error = True
168
+ self.error_message = error
169
+
170
+ def on_close(self, ws, close_status_code, close_msg):
171
+ print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}")
172
+ self.recording = False
173
+ self.waiting = False
174
+
175
+ def on_open(self, ws):
176
+ """
177
+ Callback function called when the WebSocket connection is successfully opened.
178
+
179
+ Sends an initial configuration message to the server, including client UID,
180
+ language selection, and task type.
181
+
182
+ Args:
183
+ ws (websocket.WebSocketApp): The WebSocket client instance.
184
+
185
+ """
186
+ print("[INFO]: Opened connection")
187
+ ws.send(
188
+ json.dumps(
189
+ {
190
+ "uid": self.uid,
191
+ "language": self.language,
192
+ "max_clients": self.max_clients,
193
+ "max_connection_time": self.max_connection_time,
194
+ }
195
+ )
196
+ )
197
+
198
+ def send_packet_to_server(self, message):
199
+ """
200
+ Send an audio packet to the server using WebSocket.
201
+
202
+ Args:
203
+ message (bytes): The audio data packet in bytes to be sent to the server.
204
+
205
+ """
206
+ try:
207
+ self.client_socket.send(message, websocket.ABNF.OPCODE_BINARY)
208
+ except Exception as e:
209
+ print(e)
210
+
211
+ def close_websocket(self):
212
+ """
213
+ Close the WebSocket connection and join the WebSocket thread.
214
+
215
+ First attempts to close the WebSocket connection using `self.client_socket.close()`. After
216
+ closing the connection, it joins the WebSocket thread to ensure proper termination.
217
+
218
+ """
219
+ try:
220
+ self.client_socket.close()
221
+ except Exception as e:
222
+ print("[ERROR]: Error closing WebSocket:", e)
223
+
224
+ try:
225
+ self.ws_thread.join()
226
+ except Exception as e:
227
+ print("[ERROR:] Error joining WebSocket thread:", e)
228
+
229
+ def get_client_socket(self):
230
+ """
231
+ Get the WebSocket client socket instance.
232
+
233
+ Returns:
234
+ WebSocketApp: The WebSocket client socket instance currently in use by the client.
235
+ """
236
+ return self.client_socket
237
+
238
+ def wait_before_disconnect(self):
239
+ """Waits a bit before disconnecting in order to process pending responses."""
240
+ assert self.last_response_received
241
+ while time.time() - self.last_response_received < self.disconnect_if_no_response_for:
242
+ continue
243
+
244
+
245
+ class TranscriptionTeeClient:
246
+ """
247
+ Client for handling audio recording, streaming, and transcription tasks via one or more
248
+ WebSocket connections.
249
+
250
+ Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
251
+ to send audio data for transcription to one or more servers, and receive transcribed text segments.
252
+ Args:
253
+ clients (list): one or more previously initialized Client instances
254
+
255
+ Attributes:
256
+ clients (list): the underlying Client instances responsible for handling WebSocket connections.
257
+ """
258
+
259
+ def __init__(self, clients, save_output_recording=False, output_recording_filename="./output_recording.wav",
260
+ mute_audio_playback=False):
261
+ self.clients = clients
262
+ if not self.clients:
263
+ raise Exception("At least one client is required.")
264
+ self.chunk = 4096
265
+ self.format = pyaudio.paInt16
266
+ self.channels = 1
267
+ self.rate = 16000
268
+ self.record_seconds = 60000
269
+ self.save_output_recording = save_output_recording
270
+ self.output_recording_filename = output_recording_filename
271
+ self.mute_audio_playback = mute_audio_playback
272
+ self.frames = b""
273
+ self.p = pyaudio.PyAudio()
274
+ try:
275
+ self.stream = self.p.open(
276
+ format=self.format,
277
+ channels=self.channels,
278
+ rate=self.rate,
279
+ input=True,
280
+ frames_per_buffer=self.chunk,
281
+ )
282
+ except OSError as error:
283
+ print(f"[WARN]: Unable to access microphone. {error}")
284
+ self.stream = None
285
+
286
+ def __call__(self, audio=None, rtsp_url=None, hls_url=None, save_file=None):
287
+ """
288
+ Start the transcription process.
289
+
290
+ Initiates the transcription process by connecting to the server via a WebSocket. It waits for the server
291
+ to be ready to receive audio data and then sends audio for transcription. If an audio file is provided, it
292
+ will be played and streamed to the server; otherwise, it will perform live recording.
293
+
294
+ Args:
295
+ audio (str, optional): Path to an audio file for transcription. Default is None, which triggers live recording.
296
+
297
+ """
298
+ assert sum(
299
+ source is not None for source in [audio, rtsp_url, hls_url]
300
+ ) <= 1, 'You must provide only one selected source'
301
+
302
+ print("[INFO]: Waiting for server ready ...")
303
+ for client in self.clients:
304
+ while not client.recording:
305
+ if client.waiting or client.server_error:
306
+ self.close_all_clients()
307
+ return
308
+
309
+ print("[INFO]: Server Ready!")
310
+ if hls_url is not None:
311
+ self.process_hls_stream(hls_url, save_file)
312
+ elif audio is not None:
313
+ resampled_file = utils.resample(audio)
314
+ self.play_file(resampled_file)
315
+ elif rtsp_url is not None:
316
+ self.process_rtsp_stream(rtsp_url)
317
+ else:
318
+ self.record()
319
+
320
+ def close_all_clients(self):
321
+ """Closes all client websockets."""
322
+ for client in self.clients:
323
+ client.close_websocket()
324
+
325
+ def multicast_packet(self, packet, unconditional=False):
326
+ """
327
+ Sends an identical packet via all clients.
328
+
329
+ Args:
330
+ packet (bytes): The audio data packet in bytes to be sent.
331
+ unconditional (bool, optional): If true, send regardless of whether clients are recording. Default is False.
332
+ """
333
+ for client in self.clients:
334
+ if (unconditional or client.recording):
335
+ client.send_packet_to_server(packet)
336
+
337
+ def play_file(self, filename):
338
+ """
339
+ Play an audio file and send it to the server for processing.
340
+
341
+ Reads an audio file, plays it through the audio output, and simultaneously sends
342
+ the audio data to the server for processing. It uses PyAudio to create an audio
343
+ stream for playback. The audio data is read from the file in chunks, converted to
344
+ floating-point format, and sent to the server using WebSocket communication.
345
+ This method is typically used when you want to process pre-recorded audio and send it
346
+ to the server in real-time.
347
+
348
+ Args:
349
+ filename (str): The path to the audio file to be played and sent to the server.
350
+ """
351
+
352
+ # read audio and create pyaudio stream
353
+ with wave.open(filename, "rb") as wavfile:
354
+ self.stream = self.p.open(
355
+ format=self.p.get_format_from_width(wavfile.getsampwidth()),
356
+ channels=wavfile.getnchannels(),
357
+ rate=wavfile.getframerate(),
358
+ input=True,
359
+ output=True,
360
+ frames_per_buffer=self.chunk,
361
+ )
362
+ chunk_duration = self.chunk / float(wavfile.getframerate())
363
+ try:
364
+ while any(client.recording for client in self.clients):
365
+ data = wavfile.readframes(self.chunk)
366
+ if data == b"":
367
+ break
368
+
369
+ audio_array = self.bytes_to_float_array(data)
370
+ self.multicast_packet(audio_array.tobytes())
371
+ if self.mute_audio_playback:
372
+ time.sleep(chunk_duration)
373
+ else:
374
+ self.stream.write(data)
375
+
376
+ wavfile.close()
377
+
378
+ for client in self.clients:
379
+ client.wait_before_disconnect()
380
+ self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
381
+ self.stream.close()
382
+ self.close_all_clients()
383
+
384
+ except KeyboardInterrupt:
385
+ wavfile.close()
386
+ self.stream.stop_stream()
387
+ self.stream.close()
388
+ self.p.terminate()
389
+ self.close_all_clients()
390
+ print("[INFO]: Keyboard interrupt.")
391
+
392
+ def process_rtsp_stream(self, rtsp_url):
393
+ """
394
+ Connect to an RTSP source, process the audio stream, and send it for transcription.
395
+
396
+ Args:
397
+ rtsp_url (str): The URL of the RTSP stream source.
398
+ """
399
+ print("[INFO]: Connecting to RTSP stream...")
400
+ try:
401
+ container = av.open(rtsp_url, format="rtsp", options={"rtsp_transport": "tcp"})
402
+ self.process_av_stream(container, stream_type="RTSP")
403
+ except Exception as e:
404
+ print(f"[ERROR]: Failed to process RTSP stream: {e}")
405
+ finally:
406
+ for client in self.clients:
407
+ client.wait_before_disconnect()
408
+ self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
409
+ self.close_all_clients()
410
+ print("[INFO]: RTSP stream processing finished.")
411
+
412
+ def process_hls_stream(self, hls_url, save_file=None):
413
+ """
414
+ Connect to an HLS source, process the audio stream, and send it for transcription.
415
+
416
+ Args:
417
+ hls_url (str): The URL of the HLS stream source.
418
+ save_file (str, optional): Local path to save the network stream.
419
+ """
420
+ print("[INFO]: Connecting to HLS stream...")
421
+ try:
422
+ container = av.open(hls_url, format="hls")
423
+ self.process_av_stream(container, stream_type="HLS", save_file=save_file)
424
+ except Exception as e:
425
+ print(f"[ERROR]: Failed to process HLS stream: {e}")
426
+ finally:
427
+ for client in self.clients:
428
+ client.wait_before_disconnect()
429
+ self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
430
+ self.close_all_clients()
431
+ print("[INFO]: HLS stream processing finished.")
432
+
433
+ def process_av_stream(self, container, stream_type, save_file=None):
434
+ """
435
+ Process an AV container stream and send audio packets to the server.
436
+
437
+ Args:
438
+ container (av.container.InputContainer): The input container to process.
439
+ stream_type (str): The type of stream being processed ("RTSP" or "HLS").
440
+ save_file (str, optional): Local path to save the stream. Default is None.
441
+ """
442
+ audio_stream = next((s for s in container.streams if s.type == "audio"), None)
443
+ if not audio_stream:
444
+ print(f"[ERROR]: No audio stream found in {stream_type} source.")
445
+ return
446
+
447
+ output_container = None
448
+ if save_file:
449
+ output_container = av.open(save_file, mode="w")
450
+ output_audio_stream = output_container.add_stream(codec_name="pcm_s16le", rate=self.rate)
451
+
452
+ try:
453
+ for packet in container.demux(audio_stream):
454
+ for frame in packet.decode():
455
+ audio_data = frame.to_ndarray().tobytes()
456
+ self.multicast_packet(audio_data)
457
+
458
+ if save_file:
459
+ output_container.mux(frame)
460
+ except Exception as e:
461
+ print(f"[ERROR]: Error during {stream_type} stream processing: {e}")
462
+ finally:
463
+ # Wait for server to send any leftover transcription.
464
+ time.sleep(5)
465
+ self.multicast_packet(Client.END_OF_AUDIO.encode('utf-8'), True)
466
+ if output_container:
467
+ output_container.close()
468
+ container.close()
469
+
470
+ def save_chunk(self, n_audio_file):
471
+ """
472
+ Saves the current audio frames to a WAV file in a separate thread.
473
+
474
+ Args:
475
+ n_audio_file (int): The index of the audio file which determines the filename.
476
+ This helps in maintaining the order and uniqueness of each chunk.
477
+ """
478
+ t = threading.Thread(
479
+ target=self.write_audio_frames_to_file,
480
+ args=(self.frames[:], f"chunks/{n_audio_file}.wav",),
481
+ )
482
+ t.start()
483
+
484
+ def finalize_recording(self, n_audio_file):
485
+ """
486
+ Finalizes the recording process by saving any remaining audio frames,
487
+ closing the audio stream, and terminating the process.
488
+
489
+ Args:
490
+ n_audio_file (int): The file index to be used if there are remaining audio frames to be saved.
491
+ This index is incremented before use if the last chunk is saved.
492
+ """
493
+ if self.save_output_recording and len(self.frames):
494
+ self.write_audio_frames_to_file(
495
+ self.frames[:], f"chunks/{n_audio_file}.wav"
496
+ )
497
+ n_audio_file += 1
498
+ self.stream.stop_stream()
499
+ self.stream.close()
500
+ self.p.terminate()
501
+ self.close_all_clients()
502
+ if self.save_output_recording:
503
+ self.write_output_recording(n_audio_file)
504
+
505
+ def record(self):
506
+ """
507
+ Record audio data from the input stream and save it to a WAV file.
508
+
509
+ Continuously records audio data from the input stream, sends it to the server via a WebSocket
510
+ connection, and simultaneously saves it to multiple WAV files in chunks. It stops recording when
511
+ the `RECORD_SECONDS` duration is reached or when the `RECORDING` flag is set to `False`.
512
+
513
+ Audio data is saved in chunks to the "chunks" directory. Each chunk is saved as a separate WAV file.
514
+ The recording will continue until the specified duration is reached or until the `RECORDING` flag is set to `False`.
515
+ The recording process can be interrupted by sending a KeyboardInterrupt (e.g., pressing Ctrl+C). After recording,
516
+ the method combines all the saved audio chunks into the specified `out_file`.
517
+ """
518
+ n_audio_file = 0
519
+ if self.save_output_recording:
520
+ if os.path.exists("chunks"):
521
+ shutil.rmtree("chunks")
522
+ os.makedirs("chunks")
523
+ try:
524
+ for _ in range(0, int(self.rate / self.chunk * self.record_seconds)):
525
+ if not any(client.recording for client in self.clients):
526
+ break
527
+ data = self.stream.read(self.chunk, exception_on_overflow=False)
528
+ self.frames += data
529
+
530
+ audio_array = self.bytes_to_float_array(data)
531
+
532
+ self.multicast_packet(audio_array.tobytes())
533
+
534
+ # save frames if more than a minute
535
+ if len(self.frames) > 60 * self.rate:
536
+ if self.save_output_recording:
537
+ self.save_chunk(n_audio_file)
538
+ n_audio_file += 1
539
+ self.frames = b""
540
+
541
+ except KeyboardInterrupt:
542
+ self.finalize_recording(n_audio_file)
543
+
544
+ def write_audio_frames_to_file(self, frames, file_name):
545
+ """
546
+ Write audio frames to a WAV file.
547
+
548
+ The WAV file is created or overwritten with the specified name. The audio frames should be
549
+ in the correct format and match the specified channel, sample width, and sample rate.
550
+
551
+ Args:
552
+ frames (bytes): The audio frames to be written to the file.
553
+ file_name (str): The name of the WAV file to which the frames will be written.
554
+
555
+ """
556
+ with wave.open(file_name, "wb") as wavfile:
557
+ wavfile: wave.Wave_write
558
+ wavfile.setnchannels(self.channels)
559
+ wavfile.setsampwidth(2)
560
+ wavfile.setframerate(self.rate)
561
+ wavfile.writeframes(frames)
562
+
563
+ def write_output_recording(self, n_audio_file):
564
+ """
565
+ Combine and save recorded audio chunks into a single WAV file.
566
+
567
+ The individual audio chunk files are expected to be located in the "chunks" directory. Reads each chunk
568
+ file, appends its audio data to the final recording, and then deletes the chunk file. After combining
569
+ and saving, the final recording is stored in the specified `out_file`.
570
+
571
+
572
+ Args:
573
+ n_audio_file (int): The number of audio chunk files to combine.
574
+ out_file (str): The name of the output WAV file to save the final recording.
575
+
576
+ """
577
+ input_files = [
578
+ f"chunks/{i}.wav"
579
+ for i in range(n_audio_file)
580
+ if os.path.exists(f"chunks/{i}.wav")
581
+ ]
582
+ with wave.open(self.output_recording_filename, "wb") as wavfile:
583
+ wavfile: wave.Wave_write
584
+ wavfile.setnchannels(self.channels)
585
+ wavfile.setsampwidth(2)
586
+ wavfile.setframerate(self.rate)
587
+ for in_file in input_files:
588
+ with wave.open(in_file, "rb") as wav_in:
589
+ while True:
590
+ data = wav_in.readframes(self.chunk)
591
+ if data == b"":
592
+ break
593
+ wavfile.writeframes(data)
594
+ # remove this file
595
+ os.remove(in_file)
596
+ wavfile.close()
597
+ # clean up temporary directory to store chunks
598
+ if os.path.exists("chunks"):
599
+ shutil.rmtree("chunks")
600
+
601
+ @staticmethod
602
+ def bytes_to_float_array(audio_bytes):
603
+ """
604
+ Convert audio data from bytes to a NumPy float array.
605
+
606
+ It assumes that the audio data is in 16-bit PCM format. The audio data is normalized to
607
+ have values between -1 and 1.
608
+
609
+ Args:
610
+ audio_bytes (bytes): Audio data in bytes.
611
+
612
+ Returns:
613
+ np.ndarray: A NumPy array containing the audio data as float values normalized between -1 and 1.
614
+ """
615
+ raw_data = np.frombuffer(buffer=audio_bytes, dtype=np.int16)
616
+ return raw_data.astype(np.float32) / 32768.0
617
+
618
+
619
+ class TranscriptionClient(TranscriptionTeeClient):
620
+ """
621
+ Client for handling audio transcription tasks via a single WebSocket connection.
622
+
623
+ Acts as a high-level client for audio transcription tasks using a WebSocket connection. It can be used
624
+ to send audio data for transcription to a server and receive transcribed text segments.
625
+
626
+ Args:
627
+ host (str): The hostname or IP address of the server.
628
+ port (int): The port number to connect to on the server.
629
+ lang (str, optional): The primary language for transcription. Default is None, which defaults to English ('en').
630
+ save_output_recording (bool, optional): Whether to save the microphone recording. Default is False.
631
+ output_recording_filename (str, optional): Path to save the output recording WAV file. Default is "./output_recording.wav".
632
+ output_transcription_path (str, optional): File path to save the output transcription (SRT file). Default is "./output.srt".
633
+ log_transcription (bool, optional): Whether to log transcription output to the console. Default is True.
634
+ max_clients (int, optional): Maximum number of client connections allowed. Default is 4.
635
+ max_connection_time (int, optional): Maximum allowed connection time in seconds. Default is 600.
636
+ mute_audio_playback (bool, optional): If True, mutes audio playback during file playback. Default is False.
637
+
638
+ Attributes:
639
+ client (Client): An instance of the underlying Client class responsible for handling the WebSocket connection.
640
+
641
+ Example:
642
+ To create a TranscriptionClient and start transcription on microphone audio:
643
+ ```python
644
+ transcription_client = TranscriptionClient(host="localhost", port=9090)
645
+ transcription_client()
646
+ ```
647
+ """
648
+
649
+ def __init__(
650
+ self,
651
+ host,
652
+ port,
653
+ lang=None,
654
+ save_output_recording=False,
655
+ output_recording_filename="./output_recording.wav",
656
+ log_transcription=True,
657
+ max_clients=4,
658
+ max_connection_time=600,
659
+ mute_audio_playback=False,
660
+ ):
661
+ self.client = Client(
662
+ host, port, lang, log_transcription=log_transcription, max_clients=max_clients,
663
+ max_connection_time=max_connection_time
664
+ )
665
+
666
+ if save_output_recording and not output_recording_filename.endswith(".wav"):
667
+ raise ValueError(f"Please provide a valid `output_recording_filename`: {output_recording_filename}")
668
+
669
+ TranscriptionTeeClient.__init__(
670
+ self,
671
+ [self.client],
672
+ save_output_recording=save_output_recording,
673
+ output_recording_filename=output_recording_filename,
674
+ mute_audio_playback=mute_audio_playback
675
+ )
transcribe/server.py ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import json
3
+ import logging
4
+ import pathlib
5
+ import threading
6
+ import time
7
+ from enum import Enum
8
+ from typing import List, Optional
9
+
10
+ import librosa
11
+ import numpy as np
12
+ import soundfile
13
+ from pywhispercpp.model import Model
14
+ from websockets.exceptions import ConnectionClosed
15
+ from websockets.sync.server import serve
16
+
17
+ from transcribe.vad import VoiceActivityDetector
18
+
19
+ logging.basicConfig(level=logging.INFO)
20
+
21
+
22
+ class ClientManager:
23
+ def __init__(self, max_clients=4, max_connection_time=600):
24
+ """
25
+ Initializes the ClientManager with specified limits on client connections and connection durations.
26
+
27
+ Args:
28
+ max_clients (int, optional): The maximum number of simultaneous client connections allowed. Defaults to 4.
29
+ max_connection_time (int, optional): The maximum duration (in seconds) a client can stay connected. Defaults
30
+ to 600 seconds (10 minutes).
31
+ """
32
+ self.clients = {}
33
+ self.start_times = {}
34
+ self.max_clients = max_clients
35
+ self.max_connection_time = max_connection_time
36
+
37
+ def add_client(self, websocket, client):
38
+ """
39
+ Adds a client and their connection start time to the tracking dictionaries.
40
+
41
+ Args:
42
+ websocket: The websocket associated with the client to add.
43
+ client: The client object to be added and tracked.
44
+ """
45
+ self.clients[websocket] = client
46
+ self.start_times[websocket] = time.time()
47
+
48
+ def get_client(self, websocket):
49
+ """
50
+ Retrieves a client associated with the given websocket.
51
+
52
+ Args:
53
+ websocket: The websocket associated with the client to retrieve.
54
+
55
+ Returns:
56
+ The client object if found, False otherwise.
57
+ """
58
+ if websocket in self.clients:
59
+ return self.clients[websocket]
60
+ return False
61
+
62
+ def remove_client(self, websocket):
63
+ """
64
+ Removes a client and their connection start time from the tracking dictionaries. Performs cleanup on the
65
+ client if necessary.
66
+
67
+ Args:
68
+ websocket: The websocket associated with the client to be removed.
69
+ """
70
+ client = self.clients.pop(websocket, None)
71
+ if client:
72
+ client.cleanup()
73
+ self.start_times.pop(websocket, None)
74
+
75
+ def get_wait_time(self):
76
+ """
77
+ Calculates the estimated wait time for new clients based on the remaining connection times of current clients.
78
+
79
+ Returns:
80
+ The estimated wait time in minutes for new clients to connect. Returns 0 if there are available slots.
81
+ """
82
+ wait_time = None
83
+ for start_time in self.start_times.values():
84
+ current_client_time_remaining = self.max_connection_time - (time.time() - start_time)
85
+ if wait_time is None or current_client_time_remaining < wait_time:
86
+ wait_time = current_client_time_remaining
87
+ return wait_time / 60 if wait_time is not None else 0
88
+
89
+ def is_server_full(self, websocket, options):
90
+ """
91
+ Checks if the server is at its maximum client capacity and sends a wait message to the client if necessary.
92
+
93
+ Args:
94
+ websocket: The websocket of the client attempting to connect.
95
+ options: A dictionary of options that may include the client's unique identifier.
96
+
97
+ Returns:
98
+ True if the server is full, False otherwise.
99
+ """
100
+ if len(self.clients) >= self.max_clients:
101
+ wait_time = self.get_wait_time()
102
+ response = {"uid": options["uid"], "status": "WAIT", "message": wait_time}
103
+ websocket.send(json.dumps(response))
104
+ return True
105
+ return False
106
+
107
+ def is_client_timeout(self, websocket):
108
+ """
109
+ Checks if a client has exceeded the maximum allowed connection time and disconnects them if so, issuing a warning.
110
+
111
+ Args:
112
+ websocket: The websocket associated with the client to check.
113
+
114
+ Returns:
115
+ True if the client's connection time has exceeded the maximum limit, False otherwise.
116
+ """
117
+ elapsed_time = time.time() - self.start_times[websocket]
118
+ if elapsed_time >= self.max_connection_time:
119
+ self.clients[websocket].disconnect()
120
+ logging.warning(f"Client with uid '{self.clients[websocket].client_uid}' disconnected due to overtime.")
121
+ return True
122
+ return False
123
+
124
+
125
+ class BackendType(Enum):
126
+ PYWHISPERCPP = "pywhispercpp"
127
+
128
+ @staticmethod
129
+ def valid_types() -> List[str]:
130
+ return [backend_type.value for backend_type in BackendType]
131
+
132
+ @staticmethod
133
+ def is_valid(backend: str) -> bool:
134
+ return backend in BackendType.valid_types()
135
+
136
+ def is_pywhispercpp(self) -> bool:
137
+ return self == BackendType.PYWHISPERCPP
138
+
139
+
140
+ class TranscriptionServer:
141
+ RATE = 16000
142
+
143
+ def __init__(self):
144
+ self.client_manager = None
145
+ self.no_voice_activity_chunks = 0
146
+ self.single_model = False
147
+
148
+ def initialize_client(
149
+ self, websocket, options
150
+ ):
151
+ client: Optional[ServeClientBase] = None
152
+
153
+ if self.backend.is_pywhispercpp():
154
+ client = ServeClientWhisperCPP(
155
+ websocket,
156
+ language=options["language"],
157
+ client_uid=options["uid"],
158
+ single_model=self.single_model,
159
+ )
160
+ logging.info("Running pywhispercpp backend.")
161
+
162
+ if client is None:
163
+ raise ValueError(f"Backend type {self.backend.value} not recognised or not handled.")
164
+
165
+ self.client_manager.add_client(websocket, client)
166
+
167
+ def get_audio_from_websocket(self, websocket):
168
+ """
169
+ Receives audio buffer from websocket and creates a numpy array out of it.
170
+
171
+ Args:
172
+ websocket: The websocket to receive audio from.
173
+
174
+ Returns:
175
+ A numpy array containing the audio.
176
+ """
177
+ frame_data = websocket.recv()
178
+ if frame_data == b"END_OF_AUDIO":
179
+ return False
180
+ return np.frombuffer(frame_data, dtype=np.float32)
181
+
182
+ def handle_new_connection(self, websocket):
183
+ try:
184
+ logging.info("New client connected")
185
+ options = websocket.recv()
186
+ options = json.loads(options)
187
+
188
+ if self.client_manager is None:
189
+ max_clients = options.get('max_clients', 4)
190
+ max_connection_time = options.get('max_connection_time', 600)
191
+ self.client_manager = ClientManager(max_clients, max_connection_time)
192
+
193
+ if self.client_manager.is_server_full(websocket, options):
194
+ websocket.close()
195
+ return False # Indicates that the connection should not continue
196
+
197
+ if self.backend.is_pywhispercpp():
198
+ self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE)
199
+
200
+ self.initialize_client(websocket, options)
201
+
202
+ return True
203
+ except json.JSONDecodeError:
204
+ logging.error("Failed to decode JSON from client")
205
+ return False
206
+ except ConnectionClosed:
207
+ logging.info("Connection closed by client")
208
+ return False
209
+ except Exception as e:
210
+ logging.error(f"Error during new connection initialization: {str(e)}")
211
+ return False
212
+
213
+ def process_audio_frames(self, websocket):
214
+ frame_np = self.get_audio_from_websocket(websocket)
215
+ client = self.client_manager.get_client(websocket)
216
+
217
+ # TODO Vad has some problem, it will be blocking process loop
218
+ # if frame_np is False:
219
+ # if self.backend.is_pywhispercpp():
220
+ # client.set_eos(True)
221
+ # return False
222
+
223
+ # if self.backend.is_pywhispercpp():
224
+ # voice_active = self.voice_activity(websocket, frame_np)
225
+ # if voice_active:
226
+ # self.no_voice_activity_chunks = 0
227
+ # client.set_eos(False)
228
+ # if self.use_vad and not voice_active:
229
+ # return True
230
+
231
+ client.add_frames(frame_np)
232
+ return True
233
+
234
+ def recv_audio(self,
235
+ websocket,
236
+ backend: BackendType = BackendType.PYWHISPERCPP):
237
+
238
+ self.backend = backend
239
+ if not self.handle_new_connection(websocket):
240
+ return
241
+
242
+ try:
243
+ while not self.client_manager.is_client_timeout(websocket):
244
+ if not self.process_audio_frames(websocket):
245
+ break
246
+ except ConnectionClosed:
247
+ logging.info("Connection closed by client")
248
+ except Exception as e:
249
+ logging.error(f"Unexpected error: {str(e)}")
250
+ finally:
251
+ if self.client_manager.get_client(websocket):
252
+ self.cleanup(websocket)
253
+ websocket.close()
254
+ del websocket
255
+
256
+ def run(self,
257
+ host,
258
+ port=9090,
259
+ backend="pywhispercpp"):
260
+ """
261
+ Run the transcription server.
262
+
263
+ Args:
264
+ host (str): The host address to bind the server.
265
+ port (int): The port number to bind the server.
266
+ """
267
+
268
+ if not BackendType.is_valid(backend):
269
+ raise ValueError(f"{backend} is not a valid backend type. Choose backend from {BackendType.valid_types()}")
270
+
271
+ with serve(
272
+ functools.partial(
273
+ self.recv_audio,
274
+ backend=BackendType(backend),
275
+ ),
276
+ host,
277
+ port
278
+ ) as server:
279
+ server.serve_forever()
280
+
281
+ def voice_activity(self, websocket, frame_np):
282
+ """
283
+ Evaluates the voice activity in a given audio frame and manages the state of voice activity detection.
284
+
285
+ This method uses the configured voice activity detection (VAD) model to assess whether the given audio frame
286
+ contains speech. If the VAD model detects no voice activity for more than three consecutive frames,
287
+ it sets an end-of-speech (EOS) flag for the associated client. This method aims to efficiently manage
288
+ speech detection to improve subsequent processing steps.
289
+
290
+ Args:
291
+ websocket: The websocket associated with the current client. Used to retrieve the client object
292
+ from the client manager for state management.
293
+ frame_np (numpy.ndarray): The audio frame to be analyzed. This should be a NumPy array containing
294
+ the audio data for the current frame.
295
+
296
+ Returns:
297
+ bool: True if voice activity is detected in the current frame, False otherwise. When returning False
298
+ after detecting no voice activity for more than three consecutive frames, it also triggers the
299
+ end-of-speech (EOS) flag for the client.
300
+ """
301
+ if not self.vad_detector(frame_np):
302
+ self.no_voice_activity_chunks += 1
303
+ if self.no_voice_activity_chunks > 3:
304
+ client = self.client_manager.get_client(websocket)
305
+ if not client.eos:
306
+ client.set_eos(True)
307
+ time.sleep(0.1) # Sleep 100m; wait some voice activity.
308
+ return False
309
+ return True
310
+
311
+ def cleanup(self, websocket):
312
+ """
313
+ Cleans up resources associated with a given client's websocket.
314
+
315
+ Args:
316
+ websocket: The websocket associated with the client to be cleaned up.
317
+ """
318
+ if self.client_manager.get_client(websocket):
319
+ self.client_manager.remove_client(websocket)
320
+
321
+
322
+ class ServeClientBase(object):
323
+ RATE = 16000
324
+ SERVER_READY = "SERVER_READY"
325
+ DISCONNECT = "DISCONNECT"
326
+
327
+ def __init__(self, client_uid, websocket):
328
+ self.client_uid = client_uid
329
+ self.websocket = websocket
330
+ self.frames = b""
331
+ self.timestamp_offset = 0.0
332
+ self.frames_np = None
333
+ self.frames_offset = 0.0
334
+ self.text = []
335
+ self.current_out = ''
336
+ self.prev_out = ''
337
+ self.t_start = None
338
+ self.exit = False
339
+ self.same_output_count = 0
340
+ self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds
341
+ self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds
342
+ self.transcript = []
343
+ self.send_last_n_segments = 10
344
+
345
+ # text formatting
346
+ self.pick_previous_segments = 2
347
+
348
+ # threading
349
+ self.lock = threading.Lock()
350
+
351
+ def speech_to_text(self):
352
+ raise NotImplementedError
353
+
354
+ def transcribe_audio(self):
355
+ raise NotImplementedError
356
+
357
+ def handle_transcription_output(self):
358
+ raise NotImplementedError
359
+
360
+ def add_frames(self, frame_np):
361
+ """
362
+ Add audio frames to the ongoing audio stream buffer.
363
+
364
+ This method is responsible for maintaining the audio stream buffer, allowing the continuous addition
365
+ of audio frames as they are received. It also ensures that the buffer does not exceed a specified size
366
+ to prevent excessive memory usage.
367
+
368
+ If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds
369
+ of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided
370
+ audio frame. The audio stream buffer is used for real-time processing of audio data for transcription.
371
+
372
+ Args:
373
+ frame_np (numpy.ndarray): The audio frame data as a NumPy array.
374
+
375
+ """
376
+ self.lock.acquire()
377
+ if self.frames_np is not None and self.frames_np.shape[0] > 45 * self.RATE:
378
+ self.frames_offset += 30.0
379
+ self.frames_np = self.frames_np[int(30 * self.RATE):]
380
+ # check timestamp offset(should be >= self.frame_offset)
381
+ # this basically means that there is no speech as timestamp offset hasnt updated
382
+ # and is less than frame_offset
383
+ if self.timestamp_offset < self.frames_offset:
384
+ self.timestamp_offset = self.frames_offset
385
+ if self.frames_np is None:
386
+ self.frames_np = frame_np.copy()
387
+ else:
388
+ self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0)
389
+ self.lock.release()
390
+
391
+ def clip_audio_if_no_valid_segment(self):
392
+ """
393
+ Update the timestamp offset based on audio buffer status.
394
+ Clip audio if the current chunk exceeds 30 seconds, this basically implies that
395
+ no valid segment for the last 30 seconds from whisper
396
+ """
397
+ with self.lock:
398
+ if self.frames_np[int((self.timestamp_offset - self.frames_offset) * self.RATE):].shape[0] > 25 * self.RATE:
399
+ duration = self.frames_np.shape[0] / self.RATE
400
+ self.timestamp_offset = self.frames_offset + duration - 5
401
+
402
+ def get_audio_chunk_for_processing(self):
403
+ """
404
+ Retrieves the next chunk of audio data for processing based on the current offsets.
405
+
406
+ Calculates which part of the audio data should be processed next, based on
407
+ the difference between the current timestamp offset and the frame's offset, scaled by
408
+ the audio sample rate (RATE). It then returns this chunk of audio data along with its
409
+ duration in seconds.
410
+
411
+ Returns:
412
+ tuple: A tuple containing:
413
+ - input_bytes (np.ndarray): The next chunk of audio data to be processed.
414
+ - duration (float): The duration of the audio chunk in seconds.
415
+ """
416
+ with self.lock:
417
+ samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE)
418
+ input_bytes = self.frames_np[int(samples_take):].copy()
419
+ duration = input_bytes.shape[0] / self.RATE
420
+ return input_bytes, duration
421
+
422
+ def prepare_segments(self, last_segment=None):
423
+ """
424
+ Prepares the segments of transcribed text to be sent to the client.
425
+
426
+ This method compiles the recent segments of transcribed text, ensuring that only the
427
+ specified number of the most recent segments are included. It also appends the most
428
+ recent segment of text if provided (which is considered incomplete because of the possibility
429
+ of the last word being truncated in the audio chunk).
430
+
431
+ Args:
432
+ last_segment (str, optional): The most recent segment of transcribed text to be added
433
+ to the list of segments. Defaults to None.
434
+
435
+ Returns:
436
+ list: A list of transcribed text segments to be sent to the client.
437
+ """
438
+ segments = []
439
+ if len(self.transcript) >= self.send_last_n_segments:
440
+ segments = self.transcript[-self.send_last_n_segments:].copy()
441
+ else:
442
+ segments = self.transcript.copy()
443
+ if last_segment is not None:
444
+ segments = segments + [last_segment]
445
+ return segments
446
+
447
+ def get_audio_chunk_duration(self, input_bytes):
448
+ """
449
+ Calculates the duration of the provided audio chunk.
450
+
451
+ Args:
452
+ input_bytes (numpy.ndarray): The audio chunk for which to calculate the duration.
453
+
454
+ Returns:
455
+ float: The duration of the audio chunk in seconds.
456
+ """
457
+ return input_bytes.shape[0] / self.RATE
458
+
459
+ def send_transcription_to_client(self, segments):
460
+ """
461
+ Sends the specified transcription segments to the client over the websocket connection.
462
+
463
+ This method formats the transcription segments into a JSON object and attempts to send
464
+ this object to the client. If an error occurs during the send operation, it logs the error.
465
+
466
+ Returns:
467
+ segments (list): A list of transcription segments to be sent to the client.
468
+ """
469
+ try:
470
+ self.websocket.send(
471
+ json.dumps({
472
+ "uid": self.client_uid,
473
+ "segments": segments,
474
+ })
475
+ )
476
+ except Exception as e:
477
+ logging.error(f"[ERROR]: Sending data to client: {e}")
478
+
479
+ def disconnect(self):
480
+ """
481
+ Notify the client of disconnection and send a disconnect message.
482
+
483
+ This method sends a disconnect message to the client via the WebSocket connection to notify them
484
+ that the transcription service is disconnecting gracefully.
485
+
486
+ """
487
+ self.websocket.send(json.dumps({
488
+ "uid": self.client_uid,
489
+ "message": self.DISCONNECT
490
+ }))
491
+
492
+ def cleanup(self):
493
+ """
494
+ Perform cleanup tasks before exiting the transcription service.
495
+
496
+ This method performs necessary cleanup tasks, including stopping the transcription thread, marking
497
+ the exit flag to indicate the transcription thread should exit gracefully, and destroying resources
498
+ associated with the transcription process.
499
+
500
+ """
501
+ logging.info("Cleaning up.")
502
+ self.exit = True
503
+
504
+
505
+ class ServeClientWhisperCPP(ServeClientBase):
506
+ SINGLE_MODEL = None
507
+ SINGLE_MODEL_LOCK = threading.Lock()
508
+
509
+ def __init__(self, websocket, language=None, client_uid=None,
510
+ single_model=False):
511
+ """
512
+ Initialize a ServeClient instance.
513
+ The Whisper model is initialized based on the client's language and device availability.
514
+ The transcription thread is started upon initialization. A "SERVER_READY" message is sent
515
+ to the client to indicate that the server is ready.
516
+
517
+ Args:
518
+ websocket (WebSocket): The WebSocket connection for the client.
519
+ language (str, optional): The language for transcription. Defaults to None.
520
+ client_uid (str, optional): A unique identifier for the client. Defaults to None.
521
+ single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False.
522
+
523
+ """
524
+ super().__init__(client_uid, websocket)
525
+ self.language = language
526
+ self.eos = False
527
+
528
+ if single_model:
529
+ if ServeClientWhisperCPP.SINGLE_MODEL is None:
530
+ self.create_model()
531
+ ServeClientWhisperCPP.SINGLE_MODEL = self.transcriber
532
+ else:
533
+ self.transcriber = ServeClientWhisperCPP.SINGLE_MODEL
534
+ else:
535
+ self.create_model()
536
+
537
+ # threading
538
+ logging.info('Create a thread to process audio.')
539
+ self.trans_thread = threading.Thread(target=self.speech_to_text)
540
+ self.trans_thread.start()
541
+
542
+ self.websocket.send(json.dumps({
543
+ "uid": self.client_uid,
544
+ "message": self.SERVER_READY,
545
+ "backend": "pywhispercpp"
546
+ }))
547
+
548
+ def create_model(self, warmup=True):
549
+ """
550
+ Instantiates a new model, sets it as the transcriber and does warmup if desired.
551
+ """
552
+ model = 'medium-q5_0'
553
+ here = pathlib.Path(__file__)
554
+ models_dir = f'{here.parent.parent / "moyoyo_asr_models"}'
555
+ self.transcriber = Model(model=model, models_dir=models_dir)
556
+ if warmup:
557
+ self.warmup()
558
+
559
+ def warmup(self, warmup_steps=1):
560
+ """
561
+ Warmup TensorRT since first few inferences are slow.
562
+
563
+ Args:
564
+ warmup_steps (int): Number of steps to warm up the model for.
565
+ """
566
+ logging.info("[INFO:] Warming up whisper.cpp engine..")
567
+ mel, _, = soundfile.read("assets/jfk.flac")
568
+ for i in range(warmup_steps):
569
+ self.transcriber.transcribe(mel, print_progress=False)
570
+
571
+ def set_eos(self, eos):
572
+ """
573
+ Sets the End of Speech (EOS) flag.
574
+
575
+ Args:
576
+ eos (bool): The value to set for the EOS flag.
577
+ """
578
+ self.lock.acquire()
579
+ self.eos = eos
580
+ self.lock.release()
581
+
582
+ def handle_transcription_output(self, last_segment, duration):
583
+ """
584
+ Handle the transcription output, updating the transcript and sending data to the client.
585
+
586
+ Args:
587
+ last_segment (str): The last segment from the whisper output which is considered to be incomplete because
588
+ of the possibility of word being truncated.
589
+ duration (float): Duration of the transcribed audio chunk.
590
+ """
591
+ segments = self.prepare_segments({"text": last_segment})
592
+ self.send_transcription_to_client(segments)
593
+ if self.eos:
594
+ self.update_timestamp_offset(last_segment, duration)
595
+
596
+ def transcribe_audio(self, input_bytes):
597
+ """
598
+ Transcribe the audio chunk and send the results to the client.
599
+
600
+ Args:
601
+ input_bytes (np.array): The audio chunk to transcribe.
602
+ """
603
+ if ServeClientWhisperCPP.SINGLE_MODEL:
604
+ ServeClientWhisperCPP.SINGLE_MODEL_LOCK.acquire()
605
+ logging.info(f"[pywhispercpp:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}")
606
+ mel = input_bytes
607
+ duration = librosa.get_duration(y=input_bytes, sr=self.RATE)
608
+
609
+ if self.language == "zh":
610
+ prompt = '以下是简体中文普通话的句子。'
611
+ else:
612
+ prompt = 'The following is an English sentence.'
613
+
614
+ segments = self.transcriber.transcribe(mel, language='zh', initial_prompt=prompt, print_progress=False)
615
+ text = []
616
+ for segment in segments:
617
+ content = segment.text
618
+ text.append(content)
619
+ last_segment = ' '.join(text)
620
+
621
+ logging.info(f"[pywhispercpp:] Last segment: {last_segment}")
622
+
623
+ if ServeClientWhisperCPP.SINGLE_MODEL:
624
+ ServeClientWhisperCPP.SINGLE_MODEL_LOCK.release()
625
+ if last_segment:
626
+ self.handle_transcription_output(last_segment, duration)
627
+
628
+ def update_timestamp_offset(self, last_segment, duration):
629
+ """
630
+ Update timestamp offset and transcript.
631
+
632
+ Args:
633
+ last_segment (str): Last transcribed audio from the whisper model.
634
+ duration (float): Duration of the last audio chunk.
635
+ """
636
+ if not len(self.transcript):
637
+ self.transcript.append({"text": last_segment + " "})
638
+ elif self.transcript[-1]["text"].strip() != last_segment:
639
+ self.transcript.append({"text": last_segment + " "})
640
+
641
+ logging.info(f'Transcript list context: {self.transcript}')
642
+
643
+ with self.lock:
644
+ self.timestamp_offset += duration
645
+
646
+ def speech_to_text(self):
647
+ """
648
+ Process an audio stream in an infinite loop, continuously transcribing the speech.
649
+
650
+ This method continuously receives audio frames, performs real-time transcription, and sends
651
+ transcribed segments to the client via a WebSocket connection.
652
+
653
+ If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction.
654
+ It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments
655
+ are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech
656
+ (no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if
657
+ there is no speech for a specified duration to indicate a pause.
658
+
659
+ Raises:
660
+ Exception: If there is an issue with audio processing or WebSocket communication.
661
+
662
+ """
663
+ while True:
664
+ if self.exit:
665
+ logging.info("Exiting speech to text thread")
666
+ break
667
+
668
+ if self.frames_np is None:
669
+ time.sleep(0.02) # wait for any audio to arrive
670
+ continue
671
+
672
+ self.clip_audio_if_no_valid_segment()
673
+
674
+ input_bytes, duration = self.get_audio_chunk_for_processing()
675
+ if duration < 1:
676
+ continue
677
+
678
+ try:
679
+ input_sample = input_bytes.copy()
680
+ logging.info(f"[pywhispercpp:] Processing audio with duration: {duration}")
681
+ self.transcribe_audio(input_sample)
682
+
683
+ except Exception as e:
684
+ logging.error(f"[ERROR]: {e}")
transcribe/utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import textwrap
3
+ from pathlib import Path
4
+
5
+ import av
6
+
7
+
8
+ def clear_screen():
9
+ """Clears the console screen."""
10
+ os.system("cls" if os.name == "nt" else "clear")
11
+
12
+
13
+ def print_transcript(text):
14
+ """Prints formatted transcript text."""
15
+ wrapper = textwrap.TextWrapper(width=60)
16
+ for line in wrapper.wrap(text="".join(text)):
17
+ print(line)
18
+
19
+
20
+ def format_time(s):
21
+ """Convert seconds (float) to SRT time format."""
22
+ hours = int(s // 3600)
23
+ minutes = int((s % 3600) // 60)
24
+ seconds = int(s % 60)
25
+ milliseconds = int((s - int(s)) * 1000)
26
+ return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"
27
+
28
+
29
+ def create_srt_file(segments, resampled_file):
30
+ with open(resampled_file, 'w', encoding='utf-8') as srt_file:
31
+ segment_number = 1
32
+ for segment in segments:
33
+ start_time = format_time(float(segment['start']))
34
+ end_time = format_time(float(segment['end']))
35
+ text = segment['text']
36
+
37
+ srt_file.write(f"{segment_number}\n")
38
+ srt_file.write(f"{start_time} --> {end_time}\n")
39
+ srt_file.write(f"{text}\n\n")
40
+
41
+ segment_number += 1
42
+
43
+
44
+ def resample(file: str, sr: int = 16000):
45
+ """
46
+ Resample the audio file to 16kHz.
47
+
48
+ Args:
49
+ file (str): The audio file to open
50
+ sr (int): The sample rate to resample the audio if necessary
51
+
52
+ Returns:
53
+ resampled_file (str): The resampled audio file
54
+ """
55
+ container = av.open(file)
56
+ stream = next(s for s in container.streams if s.type == 'audio')
57
+
58
+ resampler = av.AudioResampler(
59
+ format='s16',
60
+ layout='mono',
61
+ rate=sr,
62
+ )
63
+
64
+ resampled_file = Path(file).stem + "_resampled.wav"
65
+ output_container = av.open(resampled_file, mode='w')
66
+ output_stream = output_container.add_stream('pcm_s16le', rate=sr)
67
+ output_stream.layout = 'mono'
68
+
69
+ for frame in container.decode(audio=0):
70
+ frame.pts = None
71
+ resampled_frames = resampler.resample(frame)
72
+ if resampled_frames is not None:
73
+ for resampled_frame in resampled_frames:
74
+ for packet in output_stream.encode(resampled_frame):
75
+ output_container.mux(packet)
76
+
77
+ for packet in output_stream.encode(None):
78
+ output_container.mux(packet)
79
+
80
+ output_container.close()
81
+ return resampled_file
transcribe/vad.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import warnings
4
+
5
+ import numpy as np
6
+ import onnxruntime
7
+ import torch
8
+
9
+
10
+ class VoiceActivityDetection():
11
+
12
+ def __init__(self, force_onnx_cpu=True):
13
+ path = self.download()
14
+
15
+ opts = onnxruntime.SessionOptions()
16
+ opts.log_severity_level = 3
17
+
18
+ opts.inter_op_num_threads = 1
19
+ opts.intra_op_num_threads = 1
20
+
21
+ if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers():
22
+ self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts)
23
+ else:
24
+ self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts)
25
+
26
+ self.reset_states()
27
+ if '16k' in path:
28
+ warnings.warn('This model support only 16000 sampling rate!')
29
+ self.sample_rates = [16000]
30
+ else:
31
+ self.sample_rates = [8000, 16000]
32
+
33
+ def _validate_input(self, x, sr: int):
34
+ if x.dim() == 1:
35
+ x = x.unsqueeze(0)
36
+ if x.dim() > 2:
37
+ raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
38
+
39
+ if sr != 16000 and (sr % 16000 == 0):
40
+ step = sr // 16000
41
+ x = x[:, ::step]
42
+ sr = 16000
43
+
44
+ if sr not in self.sample_rates:
45
+ raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")
46
+ if sr / x.shape[1] > 31.25:
47
+ raise ValueError("Input audio chunk is too short")
48
+
49
+ return x, sr
50
+
51
+ def reset_states(self, batch_size=1):
52
+ self._state = torch.zeros((2, batch_size, 128)).float()
53
+ self._context = torch.zeros(0)
54
+ self._last_sr = 0
55
+ self._last_batch_size = 0
56
+
57
+ def __call__(self, x, sr: int):
58
+
59
+ x, sr = self._validate_input(x, sr)
60
+ num_samples = 512 if sr == 16000 else 256
61
+
62
+ if x.shape[-1] != num_samples:
63
+ raise ValueError(
64
+ f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")
65
+
66
+ batch_size = x.shape[0]
67
+ context_size = 64 if sr == 16000 else 32
68
+
69
+ if not self._last_batch_size:
70
+ self.reset_states(batch_size)
71
+ if (self._last_sr) and (self._last_sr != sr):
72
+ self.reset_states(batch_size)
73
+ if (self._last_batch_size) and (self._last_batch_size != batch_size):
74
+ self.reset_states(batch_size)
75
+
76
+ if not len(self._context):
77
+ self._context = torch.zeros(batch_size, context_size)
78
+
79
+ x = torch.cat([self._context, x], dim=1)
80
+ if sr in [8000, 16000]:
81
+ ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
82
+ ort_outs = self.session.run(None, ort_inputs)
83
+ out, state = ort_outs
84
+ self._state = torch.from_numpy(state)
85
+ else:
86
+ raise ValueError()
87
+
88
+ self._context = x[..., -context_size:]
89
+ self._last_sr = sr
90
+ self._last_batch_size = batch_size
91
+
92
+ out = torch.from_numpy(out)
93
+ return out
94
+
95
+ def audio_forward(self, x, sr: int):
96
+ outs = []
97
+ x, sr = self._validate_input(x, sr)
98
+ self.reset_states()
99
+ num_samples = 512 if sr == 16000 else 256
100
+
101
+ if x.shape[1] % num_samples:
102
+ pad_num = num_samples - (x.shape[1] % num_samples)
103
+ x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)
104
+
105
+ for i in range(0, x.shape[1], num_samples):
106
+ wavs_batch = x[:, i:i + num_samples]
107
+ out_chunk = self.__call__(wavs_batch, sr)
108
+ outs.append(out_chunk)
109
+
110
+ stacked = torch.cat(outs, dim=1)
111
+ return stacked.cpu()
112
+
113
+ @staticmethod
114
+ def download(model_url="https://github.com/snakers4/silero-vad/raw/v5.0/files/silero_vad.onnx"):
115
+ target_dir = os.path.expanduser("~/.cache/silero-vad/")
116
+
117
+ # Ensure the target directory exists
118
+ os.makedirs(target_dir, exist_ok=True)
119
+
120
+ # Define the target file path
121
+ model_filename = os.path.join(target_dir, "silero_vad.onnx")
122
+
123
+ # Check if the model file already exists
124
+ if not os.path.exists(model_filename):
125
+ # If it doesn't exist, download the model using wget
126
+ try:
127
+ # subprocess.run(["wget", "-O", model_filename, model_url], check=True)
128
+ subprocess.run(["curl", "-sL", "-o", model_filename, model_url], check=True)
129
+ except subprocess.CalledProcessError:
130
+ print("Failed to download the model using wget.")
131
+ return model_filename
132
+
133
+
134
+ class VoiceActivityDetector:
135
+ def __init__(self, threshold=0.5, frame_rate=16000):
136
+ """
137
+ Initializes the VoiceActivityDetector with a voice activity detection model and a threshold.
138
+
139
+ Args:
140
+ threshold (float, optional): The probability threshold for detecting voice activity. Defaults to 0.5.
141
+ """
142
+ self.model = VoiceActivityDetection()
143
+ self.threshold = threshold
144
+ self.frame_rate = frame_rate
145
+
146
+ def __call__(self, audio_frame):
147
+ """
148
+ Determines if the given audio frame contains speech by comparing the detected speech probability against
149
+ the threshold.
150
+
151
+ Args:
152
+ audio_frame (np.ndarray): The audio frame to be analyzed for voice activity. It is expected to be a
153
+ NumPy array of audio samples.
154
+
155
+ Returns:
156
+ bool: True if the speech probability exceeds the threshold, indicating the presence of voice activity;
157
+ False otherwise.
158
+ """
159
+ speech_probs = self.model.audio_forward(torch.from_numpy(audio_frame.copy()), self.frame_rate)[0]
160
+ return torch.any(speech_probs > self.threshold).item()