| import logging |
| import time |
| import functools |
| import json |
| import logging |
| import time |
| from enum import Enum |
| from typing import List, Optional |
| import numpy as np |
| from .server import ServeClientBase |
| from .whisper_llm_serve import PyWhiperCppServe |
| from .vad import VoiceActivityDetector |
| from urllib.parse import urlparse, parse_qsl |
| from websockets.exceptions import ConnectionClosed |
| from websockets.sync.server import serve |
| from uuid import uuid1 |
|
|
|
|
| logging.basicConfig(level=logging.INFO) |
|
|
|
|
| class ClientManager: |
| def __init__(self, max_clients=4, max_connection_time=600): |
| """ |
| Initializes the ClientManager with specified limits on client connections and connection durations. |
| |
| Args: |
| max_clients (int, optional): The maximum number of simultaneous client connections allowed. Defaults to 4. |
| max_connection_time (int, optional): The maximum duration (in seconds) a client can stay connected. Defaults |
| to 600 seconds (10 minutes). |
| """ |
| self.clients = {} |
| self.start_times = {} |
| self.max_clients = max_clients |
| self.max_connection_time = max_connection_time |
|
|
| def add_client(self, websocket, client): |
| """ |
| Adds a client and their connection start time to the tracking dictionaries. |
| |
| Args: |
| websocket: The websocket associated with the client to add. |
| client: The client object to be added and tracked. |
| """ |
| self.clients[websocket] = client |
| self.start_times[websocket] = time.time() |
|
|
| def get_client(self, websocket): |
| """ |
| Retrieves a client associated with the given websocket. |
| |
| Args: |
| websocket: The websocket associated with the client to retrieve. |
| |
| Returns: |
| The client object if found, False otherwise. |
| """ |
| if websocket in self.clients: |
| return self.clients[websocket] |
| return False |
|
|
| def remove_client(self, websocket): |
| """ |
| Removes a client and their connection start time from the tracking dictionaries. Performs cleanup on the |
| client if necessary. |
| |
| Args: |
| websocket: The websocket associated with the client to be removed. |
| """ |
| client = self.clients.pop(websocket, None) |
| if client: |
| client.cleanup() |
| self.start_times.pop(websocket, None) |
|
|
| def get_wait_time(self): |
| """ |
| Calculates the estimated wait time for new clients based on the remaining connection times of current clients. |
| |
| Returns: |
| The estimated wait time in minutes for new clients to connect. Returns 0 if there are available slots. |
| """ |
| wait_time = None |
| for start_time in self.start_times.values(): |
| current_client_time_remaining = self.max_connection_time - (time.time() - start_time) |
| if wait_time is None or current_client_time_remaining < wait_time: |
| wait_time = current_client_time_remaining |
| return wait_time / 60 if wait_time is not None else 0 |
|
|
| def is_server_full(self, websocket, options): |
| """ |
| Checks if the server is at its maximum client capacity and sends a wait message to the client if necessary. |
| |
| Args: |
| websocket: The websocket of the client attempting to connect. |
| options: A dictionary of options that may include the client's unique identifier. |
| |
| Returns: |
| True if the server is full, False otherwise. |
| """ |
| if len(self.clients) >= self.max_clients: |
| wait_time = self.get_wait_time() |
| response = {"uid": options["uid"], "status": "WAIT", "message": wait_time} |
| websocket.send(json.dumps(response)) |
| return True |
| return False |
|
|
| def is_client_timeout(self, websocket): |
| """ |
| Checks if a client has exceeded the maximum allowed connection time and disconnects them if so, issuing a warning. |
| |
| Args: |
| websocket: The websocket associated with the client to check. |
| |
| Returns: |
| True if the client's connection time has exceeded the maximum limit, False otherwise. |
| """ |
| elapsed_time = time.time() - self.start_times[websocket] |
| if elapsed_time >= self.max_connection_time: |
| self.clients[websocket].disconnect() |
| logging.warning(f"Client with uid '{self.clients[websocket].client_uid}' disconnected due to overtime.") |
| return True |
| return False |
|
|
|
|
| class BackendType(Enum): |
| PYWHISPERCPP = "pywhispercpp" |
|
|
| @staticmethod |
| def valid_types() -> List[str]: |
| return [backend_type.value for backend_type in BackendType] |
|
|
| @staticmethod |
| def is_valid(backend: str) -> bool: |
| return backend in BackendType.valid_types() |
|
|
| def is_pywhispercpp(self) -> bool: |
| return self == BackendType.PYWHISPERCPP |
|
|
|
|
| class TranscriptionServer: |
| RATE = 16000 |
|
|
| def __init__(self): |
| self.client_manager = None |
| self.no_voice_activity_chunks = 0 |
| self.single_model = False |
|
|
| def initialize_client( |
| self, websocket, options |
| ): |
| client: Optional[ServeClientBase] = None |
|
|
| if self.backend.is_pywhispercpp(): |
| client = PyWhiperCppServe( |
| websocket, |
| language=options["language"], |
| client_uid=options["uid"], |
| ) |
| logging.info("Running pywhispercpp backend.") |
|
|
| if client is None: |
| raise ValueError(f"Backend type {self.backend.value} not recognised or not handled.") |
|
|
| self.client_manager.add_client(websocket, client) |
|
|
| def get_audio_from_websocket(self, websocket): |
| """ |
| Receives audio buffer from websocket and creates a numpy array out of it. |
| |
| Args: |
| websocket: The websocket to receive audio from. |
| |
| Returns: |
| A numpy array containing the audio. |
| """ |
| frame_data = websocket.recv() |
| if frame_data == b"END_OF_AUDIO": |
| return False |
| return np.frombuffer(frame_data, dtype=np.int16).astype(np.float32) / 32768.0 |
| |
|
|
|
|
| def handle_new_connection(self, websocket): |
| query_parameters_dict = dict(parse_qsl(urlparse(websocket.request.path).query)) |
| from_lang, to_lang = query_parameters_dict.get('from'), query_parameters_dict.get('to') |
|
|
| try: |
| logging.info("New client connected") |
| options = websocket.recv() |
| try: |
| options = json.loads(options) |
| except Exception as e: |
| options = {"language": from_lang, "uid": str(uuid1())} |
| if self.client_manager is None: |
| max_clients = options.get('max_clients', 4) |
| max_connection_time = options.get('max_connection_time', 600) |
| self.client_manager = ClientManager(max_clients, max_connection_time) |
|
|
| if self.client_manager.is_server_full(websocket, options): |
| websocket.close() |
| return False |
|
|
| if self.backend.is_pywhispercpp(): |
| self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE) |
|
|
| self.initialize_client(websocket, options) |
| if from_lang and to_lang: |
| self.set_lang(websocket, from_lang, to_lang) |
| logging.info(f"Source lange: {from_lang} -> Dst lange: {to_lang}") |
| return True |
| except json.JSONDecodeError: |
| logging.error("Failed to decode JSON from client") |
| return False |
| except ConnectionClosed: |
| logging.info("Connection closed by client") |
| return False |
| except Exception as e: |
| logging.error(f"Error during new connection initialization: {str(e)}") |
| return False |
|
|
| def process_audio_frames(self, websocket): |
| frame_np = self.get_audio_from_websocket(websocket) |
| client = self.client_manager.get_client(websocket) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| client.add_frames(frame_np) |
| return True |
| |
| def set_lang(self, websocket, src_lang, dst_lang): |
| client = self.client_manager.get_client(websocket) |
| if isinstance(client, PyWhiperCppServe): |
| client.set_lang(src_lang, dst_lang) |
|
|
| def recv_audio(self, |
| websocket, |
| backend: BackendType = BackendType.PYWHISPERCPP): |
|
|
| self.backend = backend |
| if not self.handle_new_connection(websocket): |
| return |
| |
|
|
| try: |
| while not self.client_manager.is_client_timeout(websocket): |
| if not self.process_audio_frames(websocket): |
| break |
| except ConnectionClosed: |
| logging.info("Connection closed by client") |
| except Exception as e: |
| logging.error(f"Unexpected error: {str(e)}") |
| finally: |
| if self.client_manager.get_client(websocket): |
| self.cleanup(websocket) |
| websocket.close() |
| del websocket |
|
|
| def run(self, |
| host, |
| port=9090, |
| backend="pywhispercpp"): |
| """ |
| Run the transcription server. |
| |
| Args: |
| host (str): The host address to bind the server. |
| port (int): The port number to bind the server. |
| """ |
|
|
| if not BackendType.is_valid(backend): |
| raise ValueError(f"{backend} is not a valid backend type. Choose backend from {BackendType.valid_types()}") |
|
|
| with serve( |
| functools.partial( |
| self.recv_audio, |
| backend=BackendType(backend), |
| ), |
| host, |
| port |
| ) as server: |
| server.serve_forever() |
|
|
| def voice_activity(self, websocket, frame_np): |
| """ |
| Evaluates the voice activity in a given audio frame and manages the state of voice activity detection. |
| |
| This method uses the configured voice activity detection (VAD) model to assess whether the given audio frame |
| contains speech. If the VAD model detects no voice activity for more than three consecutive frames, |
| it sets an end-of-speech (EOS) flag for the associated client. This method aims to efficiently manage |
| speech detection to improve subsequent processing steps. |
| |
| Args: |
| websocket: The websocket associated with the current client. Used to retrieve the client object |
| from the client manager for state management. |
| frame_np (numpy.ndarray): The audio frame to be analyzed. This should be a NumPy array containing |
| the audio data for the current frame. |
| |
| Returns: |
| bool: True if voice activity is detected in the current frame, False otherwise. When returning False |
| after detecting no voice activity for more than three consecutive frames, it also triggers the |
| end-of-speech (EOS) flag for the client. |
| """ |
| if not self.vad_detector(frame_np): |
| self.no_voice_activity_chunks += 1 |
| if self.no_voice_activity_chunks > 3: |
| client = self.client_manager.get_client(websocket) |
| if not client.eos: |
| client.set_eos(True) |
| time.sleep(0.1) |
| return False |
| return True |
|
|
| def cleanup(self, websocket): |
| """ |
| Cleans up resources associated with a given client's websocket. |
| |
| Args: |
| websocket: The websocket associated with the client to be cleaned up. |
| """ |
| if self.client_manager.get_client(websocket): |
| self.client_manager.remove_client(websocket) |
|
|
|
|