|
|
| import json |
| import logging |
| import threading |
| import time |
| import config |
| import librosa |
| import numpy as np |
| import soundfile |
| from pywhispercpp.model import Model |
|
|
| logging.basicConfig(level=logging.INFO) |
|
|
| class ServeClientBase(object): |
| RATE = 16000 |
| SERVER_READY = "SERVER_READY" |
| DISCONNECT = "DISCONNECT" |
|
|
| def __init__(self, client_uid, websocket): |
| self.client_uid = client_uid |
| self.websocket = websocket |
| self.frames = b"" |
| self.timestamp_offset = 0.0 |
| self.frames_np = None |
| self.frames_offset = 0.0 |
| self.text = [] |
| self.current_out = '' |
| self.prev_out = '' |
| self.t_start = None |
| self.exit = False |
| self.same_output_count = 0 |
| self.show_prev_out_thresh = 5 |
| self.add_pause_thresh = 3 |
| self.transcript = [] |
| self.send_last_n_segments = 10 |
|
|
| |
| self.pick_previous_segments = 2 |
|
|
| |
| self.lock = threading.Lock() |
|
|
| def speech_to_text(self): |
| raise NotImplementedError |
|
|
| def transcribe_audio(self): |
| raise NotImplementedError |
|
|
| def handle_transcription_output(self): |
| raise NotImplementedError |
|
|
| def add_frames(self, frame_np): |
| """ |
| Add audio frames to the ongoing audio stream buffer. |
| |
| This method is responsible for maintaining the audio stream buffer, allowing the continuous addition |
| of audio frames as they are received. It also ensures that the buffer does not exceed a specified size |
| to prevent excessive memory usage. |
| |
| If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds |
| of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided |
| audio frame. The audio stream buffer is used for real-time processing of audio data for transcription. |
| |
| Args: |
| frame_np (numpy.ndarray): The audio frame data as a NumPy array. |
| |
| """ |
| self.lock.acquire() |
| if self.frames_np is not None and self.frames_np.shape[0] > 45 * self.RATE: |
| self.frames_offset += 30.0 |
| self.frames_np = self.frames_np[int(30 * self.RATE):] |
| |
| |
| |
| if self.timestamp_offset < self.frames_offset: |
| self.timestamp_offset = self.frames_offset |
| if self.frames_np is None: |
| self.frames_np = frame_np.copy() |
| else: |
| self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0) |
| self.lock.release() |
|
|
| def clip_audio_if_no_valid_segment(self): |
| """ |
| Update the timestamp offset based on audio buffer status. |
| Clip audio if the current chunk exceeds 30 seconds, this basically implies that |
| no valid segment for the last 30 seconds from whisper |
| """ |
| with self.lock: |
| if self.frames_np[int((self.timestamp_offset - self.frames_offset) * self.RATE):].shape[0] > 25 * self.RATE: |
| duration = self.frames_np.shape[0] / self.RATE |
| self.timestamp_offset = self.frames_offset + duration - 5 |
|
|
| def get_audio_chunk_for_processing(self): |
| """ |
| Retrieves the next chunk of audio data for processing based on the current offsets. |
| |
| Calculates which part of the audio data should be processed next, based on |
| the difference between the current timestamp offset and the frame's offset, scaled by |
| the audio sample rate (RATE). It then returns this chunk of audio data along with its |
| duration in seconds. |
| |
| Returns: |
| tuple: A tuple containing: |
| - input_bytes (np.ndarray): The next chunk of audio data to be processed. |
| - duration (float): The duration of the audio chunk in seconds. |
| """ |
| with self.lock: |
| samples_take = max(0, (self.timestamp_offset - self.frames_offset) * self.RATE) |
| input_bytes = self.frames_np[int(samples_take):].copy() |
| duration = input_bytes.shape[0] / self.RATE |
| return input_bytes, duration |
|
|
| def prepare_segments(self, last_segment=None): |
| """ |
| Prepares the segments of transcribed text to be sent to the client. |
| |
| This method compiles the recent segments of transcribed text, ensuring that only the |
| specified number of the most recent segments are included. It also appends the most |
| recent segment of text if provided (which is considered incomplete because of the possibility |
| of the last word being truncated in the audio chunk). |
| |
| Args: |
| last_segment (str, optional): The most recent segment of transcribed text to be added |
| to the list of segments. Defaults to None. |
| |
| Returns: |
| list: A list of transcribed text segments to be sent to the client. |
| """ |
| segments = [] |
| if len(self.transcript) >= self.send_last_n_segments: |
| segments = self.transcript[-self.send_last_n_segments:].copy() |
| else: |
| segments = self.transcript.copy() |
| if last_segment is not None: |
| segments = segments + [last_segment] |
| logging.info(f"{segments}") |
| return segments |
|
|
| def get_audio_chunk_duration(self, input_bytes): |
| """ |
| Calculates the duration of the provided audio chunk. |
| |
| Args: |
| input_bytes (numpy.ndarray): The audio chunk for which to calculate the duration. |
| |
| Returns: |
| float: The duration of the audio chunk in seconds. |
| """ |
| return input_bytes.shape[0] / self.RATE |
|
|
| def send_transcription_to_client(self, segments): |
| """ |
| Sends the specified transcription segments to the client over the websocket connection. |
| |
| This method formats the transcription segments into a JSON object and attempts to send |
| this object to the client. If an error occurs during the send operation, it logs the error. |
| |
| Returns: |
| segments (list): A list of transcription segments to be sent to the client. |
| """ |
| try: |
| self.websocket.send( |
| json.dumps({ |
| "uid": self.client_uid, |
| "segments": segments, |
| }) |
| ) |
| except Exception as e: |
| logging.error(f"[ERROR]: Sending data to client: {e}") |
|
|
| def disconnect(self): |
| """ |
| Notify the client of disconnection and send a disconnect message. |
| |
| This method sends a disconnect message to the client via the WebSocket connection to notify them |
| that the transcription service is disconnecting gracefully. |
| |
| """ |
| self.websocket.send(json.dumps({ |
| "uid": self.client_uid, |
| "message": self.DISCONNECT |
| })) |
|
|
| def cleanup(self): |
| """ |
| Perform cleanup tasks before exiting the transcription service. |
| |
| This method performs necessary cleanup tasks, including stopping the transcription thread, marking |
| the exit flag to indicate the transcription thread should exit gracefully, and destroying resources |
| associated with the transcription process. |
| |
| """ |
| logging.info("Cleaning up.") |
| self.exit = True |
|
|
|
|
| class ServeClientWhisperCPP(ServeClientBase): |
| SINGLE_MODEL = None |
| SINGLE_MODEL_LOCK = threading.Lock() |
|
|
| def __init__(self, websocket, language=None, client_uid=None, |
| single_model=False): |
| """ |
| Initialize a ServeClient instance. |
| The Whisper model is initialized based on the client's language and device availability. |
| The transcription thread is started upon initialization. A "SERVER_READY" message is sent |
| to the client to indicate that the server is ready. |
| |
| Args: |
| websocket (WebSocket): The WebSocket connection for the client. |
| language (str, optional): The language for transcription. Defaults to None. |
| client_uid (str, optional): A unique identifier for the client. Defaults to None. |
| single_model (bool, optional): Whether to instantiate a new model for each client connection. Defaults to False. |
| |
| """ |
| super().__init__(client_uid, websocket) |
| self.language = language |
| self.eos = False |
|
|
| if single_model: |
| if ServeClientWhisperCPP.SINGLE_MODEL is None: |
| self.create_model() |
| ServeClientWhisperCPP.SINGLE_MODEL = self.transcriber |
| else: |
| self.transcriber = ServeClientWhisperCPP.SINGLE_MODEL |
| else: |
| self.create_model() |
|
|
| |
| logging.info('Create a thread to process audio.') |
| self.trans_thread = threading.Thread(target=self.speech_to_text) |
| self.trans_thread.start() |
|
|
| self.websocket.send(json.dumps({ |
| "uid": self.client_uid, |
| "message": self.SERVER_READY, |
| "backend": "pywhispercpp" |
| })) |
|
|
| def create_model(self, warmup=True): |
| """ |
| Instantiates a new model, sets it as the transcriber and does warmup if desired. |
| """ |
| |
| self.transcriber = Model(model=config.WHISPER_MODEL, models_dir=config.MODEL_DIR) |
| if warmup: |
| self.warmup() |
|
|
| def warmup(self, warmup_steps=1): |
| """ |
| Warmup TensorRT since first few inferences are slow. |
| |
| Args: |
| warmup_steps (int): Number of steps to warm up the model for. |
| """ |
| logging.info("[INFO:] Warming up whisper.cpp engine..") |
| mel, _, = soundfile.read("assets/jfk.flac") |
| for i in range(warmup_steps): |
| self.transcriber.transcribe(mel, print_progress=False) |
|
|
| def set_eos(self, eos): |
| """ |
| Sets the End of Speech (EOS) flag. |
| |
| Args: |
| eos (bool): The value to set for the EOS flag. |
| """ |
| self.lock.acquire() |
| self.eos = eos |
| self.lock.release() |
|
|
| def handle_transcription_output(self, last_segment, duration): |
| """ |
| Handle the transcription output, updating the transcript and sending data to the client. |
| |
| Args: |
| last_segment (str): The last segment from the whisper output which is considered to be incomplete because |
| of the possibility of word being truncated. |
| duration (float): Duration of the transcribed audio chunk. |
| """ |
| segments = self.prepare_segments({"text": last_segment}) |
| self.send_transcription_to_client(segments) |
| if self.eos: |
| self.update_timestamp_offset(last_segment, duration) |
|
|
| def transcribe_audio(self, input_bytes): |
| """ |
| Transcribe the audio chunk and send the results to the client. |
| |
| Args: |
| input_bytes (np.array): The audio chunk to transcribe. |
| """ |
| if ServeClientWhisperCPP.SINGLE_MODEL: |
| ServeClientWhisperCPP.SINGLE_MODEL_LOCK.acquire() |
| logging.info(f"[pywhispercpp:] Processing audio with duration: {input_bytes.shape[0] / self.RATE}") |
| mel = input_bytes |
| duration = librosa.get_duration(y=input_bytes, sr=self.RATE) |
|
|
| if self.language == "zh": |
| prompt = '以下是简体中文普通话的句子。' |
| else: |
| prompt = 'The following is an English sentence.' |
| |
| segments = self.transcriber.transcribe( |
| mel, |
| language=self.language, |
| initial_prompt=prompt, |
| token_timestamps=True, |
| |
| print_progress=False |
| ) |
| text = [] |
| for segment in segments: |
| content = segment.text |
| text.append(content) |
| last_segment = ' '.join(text) |
|
|
| logging.info(f"[pywhispercpp:] Last segment: {last_segment}") |
|
|
| if ServeClientWhisperCPP.SINGLE_MODEL: |
| ServeClientWhisperCPP.SINGLE_MODEL_LOCK.release() |
| if last_segment: |
| self.handle_transcription_output(last_segment, duration) |
|
|
| def update_timestamp_offset(self, last_segment, duration): |
| """ |
| Update timestamp offset and transcript. |
| |
| Args: |
| last_segment (str): Last transcribed audio from the whisper model. |
| duration (float): Duration of the last audio chunk. |
| """ |
| if not len(self.transcript): |
| self.transcript.append({"text": last_segment + " "}) |
| elif self.transcript[-1]["text"].strip() != last_segment: |
| self.transcript.append({"text": last_segment + " "}) |
|
|
| logging.info(f'Transcript list context: {self.transcript}') |
|
|
| with self.lock: |
| self.timestamp_offset += duration |
|
|
| def speech_to_text(self): |
| """ |
| Process an audio stream in an infinite loop, continuously transcribing the speech. |
| |
| This method continuously receives audio frames, performs real-time transcription, and sends |
| transcribed segments to the client via a WebSocket connection. |
| |
| If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction. |
| It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments |
| are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech |
| (no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if |
| there is no speech for a specified duration to indicate a pause. |
| |
| Raises: |
| Exception: If there is an issue with audio processing or WebSocket communication. |
| |
| """ |
| while True: |
| if self.exit: |
| logging.info("Exiting speech to text thread") |
| break |
|
|
| if self.frames_np is None: |
| time.sleep(0.02) |
| continue |
|
|
| self.clip_audio_if_no_valid_segment() |
|
|
| input_bytes, duration = self.get_audio_chunk_for_processing() |
| if duration < 1: |
| continue |
|
|
| try: |
| input_sample = input_bytes.copy() |
| logging.info(f"[pywhispercpp:] Processing audio with duration: {duration}") |
| self.transcribe_audio(input_sample) |
|
|
| except Exception as e: |
| logging.error(f"[ERROR]: {e}") |
|
|