pltobing's picture
Formatting black, isort, flake8
0c397a9
#!/usr/bin/env python3
# License: CC-BY-NC-ND-4.0
# Created by: Patrick Lumbantobing, Vertox-AI
# Copyright (c) 2026 Vertox-AI. All rights reserved.
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-NoDerivatives 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-nd/4.0/
"""
Main entry point for the streaming speech translation server.
This script starts a WebSocket server that performs real-time
speech-to-speech translation using:
- ASR: ONNX-exported NVIDIA NeMo Conformer RNN-T
- NMT: TranslateGemma GGUF via llama-cpp
- TTS: XTTSv2 ONNX (GPT-2 AR + HiFi-GAN vocoder)
Typical usage (simplified):
python app.py \\
--asr-onnx-path /path/to/asr_dir \\
--nmt-gguf-path /path/to/translategemma.gguf \\
--tts-model-dir /path/to/xtts_dir \\
--tts-vocab-path /path/to/vocab.json \\
--tts-mel-norms-path /path/to/mel_stats.npy \\
--tts-ref-audio-path /path/to/reference.wav
"""
from __future__ import annotations
import argparse
import asyncio
import logging
from src.pipeline.config import PipelineConfig
from src.server.websocket_server import TranslationServer
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def _build_arg_parser() -> argparse.ArgumentParser:
"""
Create and configure the command-line argument parser.
Returns
-------
argparse.ArgumentParser
Parser configured with ASR, NMT, TTS, queue, and server options.
"""
parser = argparse.ArgumentParser(
description="Streaming Speech Translation Server (ASR → NMT → TTS)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# ASR
parser.add_argument(
"--asr-onnx-path",
required=True,
help="Path to ASR ONNX model directory",
)
parser.add_argument(
"--asr-chunk-ms",
type=int,
default=10,
help="ASR audio chunk duration in milliseconds",
)
parser.add_argument(
"--asr-sample-rate",
type=int,
default=16000,
help="ASR expected input sample rate (Hz)",
)
# NMT
parser.add_argument(
"--nmt-gguf-path",
required=True,
help="Path to NMT GGUF model file",
)
parser.add_argument(
"--nmt-n-threads",
type=int,
default=4,
help="Number of CPU threads for NMT (llama-cpp)",
)
# TTS
parser.add_argument(
"--tts-model-dir",
required=True,
help="Path to TTS ONNX model directory (XTTSv2)",
)
parser.add_argument(
"--tts-vocab-path",
required=True,
help="Path to TTS BPE vocab.json",
)
parser.add_argument(
"--tts-mel-norms-path",
required=True,
help="Path to TTS mel_stats.npy (mel normalization statistics)",
)
parser.add_argument(
"--tts-ref-audio-path",
required=True,
help="Path to TTS reference speaker audio file",
)
parser.add_argument(
"--tts-language",
default="ru",
help="Target language code for TTS output (e.g., 'ru')",
)
parser.add_argument("--tts-int8-gpt", action="store_true", help="Use INT8 quantized GPT")
parser.add_argument(
"--tts-threads-gpt",
type=int,
default=2,
help="Number of threads for TTS GPT ONNX inference",
)
parser.add_argument(
"--tts-chunk-size",
type=int,
default=20,
help="Number of AR tokens per vocoder chunk in streaming TTS",
)
# Pipeline queues
parser.add_argument(
"--audio-queue-max",
type=int,
default=256,
help="Maximum size of the raw audio input queue",
)
parser.add_argument(
"--text-queue-max",
type=int,
default=64,
help="Maximum size of the ASR→NMT text queue",
)
parser.add_argument(
"--tts-queue-max",
type=int,
default=16,
help="Maximum size of the NMT→TTS text queue",
)
parser.add_argument(
"--audio-out-queue-max",
type=int,
default=32,
help="Maximum size of the synthesized audio output queue",
)
# Server
parser.add_argument(
"--host",
default="0.0.0.0",
help="Server bind host",
)
parser.add_argument(
"--port",
type=int,
default=8765,
help="Server port",
)
return parser
def main() -> None:
"""
Parse CLI arguments, construct the pipeline configuration, and start the server.
This function:
1. Parses command-line options for ASR, NMT, TTS, and server settings.
2. Instantiates a `PipelineConfig` dataclass from the parsed arguments.
3. Creates a `TranslationServer` and starts its asynchronous event loop.
"""
parser = _build_arg_parser()
args = parser.parse_args()
config = PipelineConfig(
asr_onnx_path=args.asr_onnx_path,
asr_chunk_duration_ms=args.asr_chunk_ms,
asr_sample_rate=args.asr_sample_rate,
nmt_gguf_path=args.nmt_gguf_path,
nmt_n_threads=args.nmt_n_threads,
tts_model_dir=args.tts_model_dir,
tts_vocab_path=args.tts_vocab_path,
tts_mel_norms_path=args.tts_mel_norms_path,
tts_ref_audio_path=args.tts_ref_audio_path,
tts_language=args.tts_language,
tts_use_int8_gpt=args.tts_int8_gpt,
tts_num_threads_gpt=args.tts_threads_gpt,
tts_stream_chunk_size=args.tts_chunk_size,
audio_queue_maxsize=args.audio_queue_max,
text_queue_maxsize=args.text_queue_max,
tts_queue_maxsize=args.tts_queue_max,
audio_out_queue_maxsize=args.audio_out_queue_max,
host=args.host,
port=args.port,
)
server = TranslationServer(config)
asyncio.run(server.start())
if __name__ == "__main__":
main()