streaming-speech-translation / src /server /websocket_server.py
pltobing's picture
Pin TTS pause threshold with 1.2s for now, and revert max_workers to cpu_count.
1f1fc5c
#!/usr/bin/env python3
# License: CC-BY-NC-ND-4.0
# Created by: Patrick Lumbantobing, Vertox-AI
# Copyright (c) 2026 Vertox-AI. All rights reserved.
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-NoDerivatives 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-nd/4.0/
"""
WebSocket server for streaming speech translation.
Protocol
--------
Client β†’ Server:
- Binary frames: raw PCM16 audio at the declared source sample rate.
- Text frames (JSON):
{"action": "start", "sample_rate": 48000}
{"action": "stop"}
Server β†’ Client:
- Binary frames: PCM16 audio at 24 kHz (synthesized translation).
- Text frames (JSON):
{"type": "status", "status": "started" | "stopped", "sample_rate": sr}
{"type": "transcript", "text": "..."}
{"type": "translation", "text": "..."}
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple
import websockets
from websockets.server import WebSocketServerProtocol
from src.asr.modules import ASRModelPackage
from src.asr.streaming_asr import StreamingASR
from src.nmt.streaming_nmt import StreamingNMT
from src.nmt.translator_module import StreamingTranslator
from src.pipeline.config import PipelineConfig
from src.pipeline.orchestrator import PipelineOrchestrator
from src.tts.streaming_tts import StreamingTTS
from src.tts.xtts_streaming_pipeline import StreamingTTSPipeline
log = logging.getLogger(__name__)
class TranslationServer:
"""
WebSocket server for streaming speech translation.
Models are loaded once at startup and shared across all sessions.
Each WebSocket connection gets its own :class:`PipelineOrchestrator`
with per-session state (buffers, KV caches, etc.).
Parameters
----------
config :
Server and pipeline configuration.
"""
def __init__(self, config: PipelineConfig) -> None:
self.config = config
self._asr_model: Optional[ASRModelPackage] = None
self._nmt_model: Optional[StreamingTranslator] = None
self._tts_pipeline: Optional[StreamingTTSPipeline] = None
# Shared executor across all pipeline sessions for NMT/TTS.
self._executor = ThreadPoolExecutor(
max_workers=os.cpu_count(),
thread_name_prefix="pipeline",
)
# ─── Model loading ──────────────────────────────────────────────────────
def _load_models(self) -> None:
"""
Load all models once at server startup.
This includes:
- ASR ONNX model package.
- NMT GGUF model via llama-cpp (with KV cache warmup).
- TTS XTTSv2 ONNX pipeline.
"""
log.info("Loading ASR ONNX models...")
self._asr_model = ASRModelPackage.load_model(path=self.config.asr_onnx_path)
log.info("ASR models loaded")
log.info("Loading NMT GGUF model...")
self._nmt_model = StreamingTranslator(
model_path=self.config.nmt_gguf_path,
n_threads=self.config.nmt_n_threads,
)
log.info("NMT model loaded")
log.info("Warming up NMT KV cache...")
self._nmt_model.warmup_cache()
log.info("NMT cache warm-up done")
log.info("Loading TTS ONNX models...")
self._tts_pipeline = StreamingTTSPipeline(
model_dir=self.config.tts_model_dir,
vocab_path=self.config.tts_vocab_path,
mel_norms_path=self.config.tts_mel_norms_path,
use_int8_gpt=self.config.tts_use_int8_gpt,
num_threads_gpt=self.config.tts_num_threads_gpt,
)
log.info("TTS models loaded")
def _create_session_orchestrator(self) -> PipelineOrchestrator:
"""
Create a new orchestrator for a WebSocket session.
Models (ASR, NMT, TTS) are shared across sessions; only
per-connection state lives in the orchestrator and its
StreamingASR/StreamingNMT/StreamingTTS wrappers.
"""
if self._asr_model is None or self._nmt_model is None or self._tts_pipeline is None:
raise RuntimeError("Models must be loaded before creating a session orchestrator")
asr = StreamingASR(
asr_model_package=self._asr_model,
chunk_duration_ms=self.config.asr_chunk_duration_ms,
sample_rate=self.config.asr_sample_rate,
audio_queue_maxsize=self.config.audio_queue_maxsize,
resampler_rates=self.config.resampler_rates,
)
nmt = StreamingNMT(translator=self._nmt_model)
tts = StreamingTTS(
tts_pipeline=self._tts_pipeline,
ref_audio_path=self.config.tts_ref_audio_path,
language=self.config.tts_language,
stream_chunk_size=self.config.tts_stream_chunk_size,
)
return PipelineOrchestrator(
config=self.config,
asr=asr,
nmt=nmt,
tts=tts,
executor=self._executor,
)
# ─── Per-client handling ────────────────────────────────────────────────
async def handle_client(self, websocket: WebSocketServerProtocol) -> None:
"""
Handle a single WebSocket client connection.
The client must send a ``{"action": "start"}`` control message
before sending audio, and can send ``{"action": "stop"}`` to end
the session while keeping the WebSocket open.
"""
client_addr: Optional[Tuple[str, int]] = websocket.remote_address
log.info(f"Client connected: {client_addr}")
orchestrator: Optional[PipelineOrchestrator] = None
source_sr: int = self.config.asr_sample_rate
send_tasks: List[asyncio.Task] = []
try:
async for message in websocket:
if isinstance(message, str):
# JSON control message.
try:
data: Dict[str, Any] = json.loads(message)
except json.JSONDecodeError:
log.warning(f"Invalid JSON from {client_addr}")
continue
action = data.get("action")
if action == "start":
source_sr = int(data.get("sample_rate", self.config.asr_sample_rate))
log.info(
f"Starting session for {client_addr}, sample_rate={source_sr}",
)
# Restart session if it already exists.
if orchestrator is not None:
await orchestrator.stop_session()
orchestrator = self._create_session_orchestrator()
log.info("server.handle_client orchestrator {orchestrator}")
await orchestrator.start_session()
# (Re)start sender tasks.
for task in send_tasks:
task.cancel()
send_tasks = [
asyncio.create_task(
self._send_audio_loop(websocket, orchestrator),
name="send-audio",
),
asyncio.create_task(
self._send_text_loop(websocket, orchestrator),
name="send-text",
),
]
await websocket.send(
json.dumps(
{
"type": "status",
"status": "started",
"sample_rate": source_sr,
}
)
)
elif action == "stop":
log.info(f"Stopping session for {client_addr}")
if orchestrator is not None:
await orchestrator.stop_session()
orchestrator = None
await websocket.send(json.dumps({"type": "status", "status": "stopped"}))
else:
log.warning(f"Unknown action '{action}' from {client_addr}")
elif isinstance(message, bytes):
# Binary audio frame.
if orchestrator is not None:
# log.info(f"Received audio from {client_addr} message {message}")
if source_sr != self.config.asr_sample_rate:
await orchestrator.push_audio_resampled(message, source_sr)
else:
await orchestrator.push_audio(message)
else:
# Audio without a started session; ignore to avoid errors.
log.debug(f"Received audio from {client_addr} before 'start'; ignoring")
except websockets.exceptions.ConnectionClosed:
log.info(f"Client disconnected: {client_addr}")
except Exception as e:
log.error(f"Error handling client {client_addr}: {e}", exc_info=True)
finally:
for task in send_tasks:
task.cancel()
if orchestrator is not None:
await orchestrator.stop_session()
log.info(f"Session cleanup done for {client_addr}")
# ─── Outbound loops ─────────────────────────────────────────────────────
async def _send_audio_loop(
self,
websocket: WebSocketServerProtocol,
orchestrator: PipelineOrchestrator,
) -> None:
"""
Continuously send synthesized audio chunks to the client.
This loop terminates when cancelled or when the WebSocket closes.
"""
try:
while True:
audio_bytes = await orchestrator.get_audio_output(timeout=0.2)
if audio_bytes is not None:
await websocket.send(audio_bytes)
except asyncio.CancelledError:
pass
except websockets.exceptions.ConnectionClosed:
pass
async def _send_text_loop(
self,
websocket: WebSocketServerProtocol,
orchestrator: PipelineOrchestrator,
) -> None:
"""
Continuously send transcript and translation updates to the client.
Transcript and translation messages are polled independently with
short timeouts to keep latency low while avoiding busy-waiting.
"""
try:
while True:
# Check for transcript.
transcript = await orchestrator.get_transcript(timeout=0.1)
if transcript is not None:
try:
await websocket.send(
json.dumps(
{
"type": "transcript",
"text": transcript,
}
)
)
except websockets.exceptions.ConnectionClosed:
break
# Check for translation.
translation = await orchestrator.get_translation(timeout=0.1)
if translation is not None:
try:
await websocket.send(
json.dumps(
{
"type": "translation",
"text": translation,
}
)
)
except websockets.exceptions.ConnectionClosed:
break
except asyncio.CancelledError:
pass
except websockets.exceptions.ConnectionClosed:
pass
# ─── Server entrypoint ──────────────────────────────────────────────────
async def start(self) -> None:
"""
Load models and start the WebSocket server.
This call blocks (by awaiting a never-resolving future) until the
process is terminated.
"""
self._load_models()
log.info(f"Starting WebSocket server on {self.config.host}:{self.config.port}")
async with websockets.serve(
self.handle_client,
self.config.host,
self.config.port,
max_size=2**20, # 1 MB max message size.
ping_interval=30,
ping_timeout=10,
):
log.info(f"Server listening on ws://{self.config.host}:{self.config.port}")
await asyncio.Future() # Run forever.