Text-to-Speech
ONNX
GGUF
speech-translation
streaming-speech-translation
speech
audio
speech-recognition
automatic-speech-recognition
streaming-asr
ASR
NeMo
ONNX
cache-aware ASR
FastConformer
RNNT
Parakeet
neural-machine-translation
NMT
gemma3
llama-cpp
GGUF
conversational
TTS
xtts
xttsv2
voice-clone
gpt2
hifigan
multilingual
vq
perceiver-encoder
websocket
| #!/usr/bin/env python3 | |
| # License: CC-BY-NC-ND-4.0 | |
| # Created by: Patrick Lumbantobing, Vertox-AI | |
| # Copyright (c) 2026 Vertox-AI. All rights reserved. | |
| # | |
| # This work is licensed under the Creative Commons | |
| # Attribution-NonCommercial-NoDerivatives 4.0 International License. | |
| # To view a copy of this license, visit | |
| # http://creativecommons.org/licenses/by-nc-nd/4.0/ | |
| """ | |
| WebSocket server for streaming speech translation. | |
| Protocol | |
| -------- | |
| Client β Server: | |
| - Binary frames: raw PCM16 audio at the declared source sample rate. | |
| - Text frames (JSON): | |
| {"action": "start", "sample_rate": 48000} | |
| {"action": "stop"} | |
| Server β Client: | |
| - Binary frames: PCM16 audio at 24 kHz (synthesized translation). | |
| - Text frames (JSON): | |
| {"type": "status", "status": "started" | "stopped", "sample_rate": sr} | |
| {"type": "transcript", "text": "..."} | |
| {"type": "translation", "text": "..."} | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import websockets | |
| from websockets.server import WebSocketServerProtocol | |
| from src.asr.modules import ASRModelPackage | |
| from src.asr.streaming_asr import StreamingASR | |
| from src.nmt.streaming_nmt import StreamingNMT | |
| from src.nmt.translator_module import StreamingTranslator | |
| from src.pipeline.config import PipelineConfig | |
| from src.pipeline.orchestrator import PipelineOrchestrator | |
| from src.tts.streaming_tts import StreamingTTS | |
| from src.tts.xtts_streaming_pipeline import StreamingTTSPipeline | |
| log = logging.getLogger(__name__) | |
| class TranslationServer: | |
| """ | |
| WebSocket server for streaming speech translation. | |
| Models are loaded once at startup and shared across all sessions. | |
| Each WebSocket connection gets its own :class:`PipelineOrchestrator` | |
| with per-session state (buffers, KV caches, etc.). | |
| Parameters | |
| ---------- | |
| config : | |
| Server and pipeline configuration. | |
| """ | |
| def __init__(self, config: PipelineConfig) -> None: | |
| self.config = config | |
| self._asr_model: Optional[ASRModelPackage] = None | |
| self._nmt_model: Optional[StreamingTranslator] = None | |
| self._tts_pipeline: Optional[StreamingTTSPipeline] = None | |
| # Shared executor across all pipeline sessions for NMT/TTS. | |
| self._executor = ThreadPoolExecutor( | |
| max_workers=os.cpu_count(), | |
| thread_name_prefix="pipeline", | |
| ) | |
| # βββ Model loading ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_models(self) -> None: | |
| """ | |
| Load all models once at server startup. | |
| This includes: | |
| - ASR ONNX model package. | |
| - NMT GGUF model via llama-cpp (with KV cache warmup). | |
| - TTS XTTSv2 ONNX pipeline. | |
| """ | |
| log.info("Loading ASR ONNX models...") | |
| self._asr_model = ASRModelPackage.load_model(path=self.config.asr_onnx_path) | |
| log.info("ASR models loaded") | |
| log.info("Loading NMT GGUF model...") | |
| self._nmt_model = StreamingTranslator( | |
| model_path=self.config.nmt_gguf_path, | |
| n_threads=self.config.nmt_n_threads, | |
| ) | |
| log.info("NMT model loaded") | |
| log.info("Warming up NMT KV cache...") | |
| self._nmt_model.warmup_cache() | |
| log.info("NMT cache warm-up done") | |
| log.info("Loading TTS ONNX models...") | |
| self._tts_pipeline = StreamingTTSPipeline( | |
| model_dir=self.config.tts_model_dir, | |
| vocab_path=self.config.tts_vocab_path, | |
| mel_norms_path=self.config.tts_mel_norms_path, | |
| use_int8_gpt=self.config.tts_use_int8_gpt, | |
| num_threads_gpt=self.config.tts_num_threads_gpt, | |
| ) | |
| log.info("TTS models loaded") | |
| def _create_session_orchestrator(self) -> PipelineOrchestrator: | |
| """ | |
| Create a new orchestrator for a WebSocket session. | |
| Models (ASR, NMT, TTS) are shared across sessions; only | |
| per-connection state lives in the orchestrator and its | |
| StreamingASR/StreamingNMT/StreamingTTS wrappers. | |
| """ | |
| if self._asr_model is None or self._nmt_model is None or self._tts_pipeline is None: | |
| raise RuntimeError("Models must be loaded before creating a session orchestrator") | |
| asr = StreamingASR( | |
| asr_model_package=self._asr_model, | |
| chunk_duration_ms=self.config.asr_chunk_duration_ms, | |
| sample_rate=self.config.asr_sample_rate, | |
| audio_queue_maxsize=self.config.audio_queue_maxsize, | |
| resampler_rates=self.config.resampler_rates, | |
| ) | |
| nmt = StreamingNMT(translator=self._nmt_model) | |
| tts = StreamingTTS( | |
| tts_pipeline=self._tts_pipeline, | |
| ref_audio_path=self.config.tts_ref_audio_path, | |
| language=self.config.tts_language, | |
| stream_chunk_size=self.config.tts_stream_chunk_size, | |
| ) | |
| return PipelineOrchestrator( | |
| config=self.config, | |
| asr=asr, | |
| nmt=nmt, | |
| tts=tts, | |
| executor=self._executor, | |
| ) | |
| # βββ Per-client handling ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def handle_client(self, websocket: WebSocketServerProtocol) -> None: | |
| """ | |
| Handle a single WebSocket client connection. | |
| The client must send a ``{"action": "start"}`` control message | |
| before sending audio, and can send ``{"action": "stop"}`` to end | |
| the session while keeping the WebSocket open. | |
| """ | |
| client_addr: Optional[Tuple[str, int]] = websocket.remote_address | |
| log.info(f"Client connected: {client_addr}") | |
| orchestrator: Optional[PipelineOrchestrator] = None | |
| source_sr: int = self.config.asr_sample_rate | |
| send_tasks: List[asyncio.Task] = [] | |
| try: | |
| async for message in websocket: | |
| if isinstance(message, str): | |
| # JSON control message. | |
| try: | |
| data: Dict[str, Any] = json.loads(message) | |
| except json.JSONDecodeError: | |
| log.warning(f"Invalid JSON from {client_addr}") | |
| continue | |
| action = data.get("action") | |
| if action == "start": | |
| source_sr = int(data.get("sample_rate", self.config.asr_sample_rate)) | |
| log.info( | |
| f"Starting session for {client_addr}, sample_rate={source_sr}", | |
| ) | |
| # Restart session if it already exists. | |
| if orchestrator is not None: | |
| await orchestrator.stop_session() | |
| orchestrator = self._create_session_orchestrator() | |
| log.info("server.handle_client orchestrator {orchestrator}") | |
| await orchestrator.start_session() | |
| # (Re)start sender tasks. | |
| for task in send_tasks: | |
| task.cancel() | |
| send_tasks = [ | |
| asyncio.create_task( | |
| self._send_audio_loop(websocket, orchestrator), | |
| name="send-audio", | |
| ), | |
| asyncio.create_task( | |
| self._send_text_loop(websocket, orchestrator), | |
| name="send-text", | |
| ), | |
| ] | |
| await websocket.send( | |
| json.dumps( | |
| { | |
| "type": "status", | |
| "status": "started", | |
| "sample_rate": source_sr, | |
| } | |
| ) | |
| ) | |
| elif action == "stop": | |
| log.info(f"Stopping session for {client_addr}") | |
| if orchestrator is not None: | |
| await orchestrator.stop_session() | |
| orchestrator = None | |
| await websocket.send(json.dumps({"type": "status", "status": "stopped"})) | |
| else: | |
| log.warning(f"Unknown action '{action}' from {client_addr}") | |
| elif isinstance(message, bytes): | |
| # Binary audio frame. | |
| if orchestrator is not None: | |
| # log.info(f"Received audio from {client_addr} message {message}") | |
| if source_sr != self.config.asr_sample_rate: | |
| await orchestrator.push_audio_resampled(message, source_sr) | |
| else: | |
| await orchestrator.push_audio(message) | |
| else: | |
| # Audio without a started session; ignore to avoid errors. | |
| log.debug(f"Received audio from {client_addr} before 'start'; ignoring") | |
| except websockets.exceptions.ConnectionClosed: | |
| log.info(f"Client disconnected: {client_addr}") | |
| except Exception as e: | |
| log.error(f"Error handling client {client_addr}: {e}", exc_info=True) | |
| finally: | |
| for task in send_tasks: | |
| task.cancel() | |
| if orchestrator is not None: | |
| await orchestrator.stop_session() | |
| log.info(f"Session cleanup done for {client_addr}") | |
| # βββ Outbound loops βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def _send_audio_loop( | |
| self, | |
| websocket: WebSocketServerProtocol, | |
| orchestrator: PipelineOrchestrator, | |
| ) -> None: | |
| """ | |
| Continuously send synthesized audio chunks to the client. | |
| This loop terminates when cancelled or when the WebSocket closes. | |
| """ | |
| try: | |
| while True: | |
| audio_bytes = await orchestrator.get_audio_output(timeout=0.2) | |
| if audio_bytes is not None: | |
| await websocket.send(audio_bytes) | |
| except asyncio.CancelledError: | |
| pass | |
| except websockets.exceptions.ConnectionClosed: | |
| pass | |
| async def _send_text_loop( | |
| self, | |
| websocket: WebSocketServerProtocol, | |
| orchestrator: PipelineOrchestrator, | |
| ) -> None: | |
| """ | |
| Continuously send transcript and translation updates to the client. | |
| Transcript and translation messages are polled independently with | |
| short timeouts to keep latency low while avoiding busy-waiting. | |
| """ | |
| try: | |
| while True: | |
| # Check for transcript. | |
| transcript = await orchestrator.get_transcript(timeout=0.1) | |
| if transcript is not None: | |
| try: | |
| await websocket.send( | |
| json.dumps( | |
| { | |
| "type": "transcript", | |
| "text": transcript, | |
| } | |
| ) | |
| ) | |
| except websockets.exceptions.ConnectionClosed: | |
| break | |
| # Check for translation. | |
| translation = await orchestrator.get_translation(timeout=0.1) | |
| if translation is not None: | |
| try: | |
| await websocket.send( | |
| json.dumps( | |
| { | |
| "type": "translation", | |
| "text": translation, | |
| } | |
| ) | |
| ) | |
| except websockets.exceptions.ConnectionClosed: | |
| break | |
| except asyncio.CancelledError: | |
| pass | |
| except websockets.exceptions.ConnectionClosed: | |
| pass | |
| # βββ Server entrypoint ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def start(self) -> None: | |
| """ | |
| Load models and start the WebSocket server. | |
| This call blocks (by awaiting a never-resolving future) until the | |
| process is terminated. | |
| """ | |
| self._load_models() | |
| log.info(f"Starting WebSocket server on {self.config.host}:{self.config.port}") | |
| async with websockets.serve( | |
| self.handle_client, | |
| self.config.host, | |
| self.config.port, | |
| max_size=2**20, # 1 MB max message size. | |
| ping_interval=30, | |
| ping_timeout=10, | |
| ): | |
| log.info(f"Server listening on ws://{self.config.host}:{self.config.port}") | |
| await asyncio.Future() # Run forever. | |