streaming-speech-translation / src /nmt /translator_module.py
pltobing's picture
Formatting black, isort, flake8
0c397a9
#!/usr/bin/env python3
# License: CC-BY-NC-ND-4.0
# Created by: Patrick Lumbantobing, Vertox-AI
# Copyright (c) 2026 Vertox-AI. All rights reserved.
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-NoDerivatives 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-nd/4.0/
"""
Streaming NMT integration for the ASR–NMT–TTS pipeline.
Implements a low-latency (<700 ms) multi-turn translator based on
TranslateGemma (GGUF) via ``llama-cpp`` with KV-cache warmup.
"""
from __future__ import annotations
import logging
import time
from typing import Dict, List
from llama_cpp import Llama
log = logging.getLogger(__name__)
source_lang = "English"
src_lang_code = "en"
target_lang = "Russian"
tgt_lang_code = "ru"
STOP_TOKENS: List[str] = ["<end_of_turn>"]
DEFAULT_WARMUP_TEXTS: List[str] = [
"Hello,",
" how are you doing?",
" It's a good day,",
" what time is it now?",
" Seems to be early.",
]
DEFAULT_MODEL_PATH = "/home/ubuntu/vertox/streaming-speech-translation/models/nmt/" "translategemma-4b-it-q8_0.gguf"
class StreamingTranslator:
"""
Multi-turn streaming translator using TranslateGemma via ``llama-cpp``.
Warmup primes a global KV cache prefix shared by all sessions.
Per-session state (``prev_translation``, ``prev_query``) is passed
in/out of :meth:`translate`, not stored on the instance, which makes
this class thread-safe with no mutable state after initialization.
Parameters
----------
model_path :
Path to the GGUF model file.
n_ctx :
Context window size (2048 recommended for translations).
n_threads :
CPU threads (e.g., 4 for m8a.xlarge, 8 for m8a.2xlarge).
n_batch :
Batch size for prompt processing (higher = faster prompt).
verbose :
Enable detailed logging from ``llama-cpp``.
"""
def __init__(
self,
model_path: str = DEFAULT_MODEL_PATH,
n_ctx: int = 2048,
n_threads: int = 4,
n_batch: int = 512,
verbose: bool = False,
) -> None:
log.info(f"Loading NMT model: {model_path}")
start_time = time.time()
self.llm: Llama = Llama(
model_path=model_path,
n_ctx=n_ctx,
n_threads=n_threads,
n_batch=n_batch,
verbose=verbose,
# Optimize for low latency.
n_gpu_layers=0, # CPU only.
use_mmap=True, # Memory-map model for faster loading.
use_mlock=False, # Let OS manage memory.
)
load_time = time.time() - start_time
log.info(f"NMT model loaded in {load_time:.2f}s")
# Prompt templates (Perplexity-optimized).
self.init_prompt_template: str = (
"<start_of_turn>user\n"
"You are a professional {source_lang} ({src_lang_code}) to "
"{target_lang} ({tgt_lang_code}) streaming translator.\n\n"
"TASK: Translate a continuous {source_lang} text that arrives in "
"sequential parts.\n"
"After each part, produce ONLY the {target_lang} translation of that part,\n"
"informed by the full context of all previous parts.\n\n"
"RULES:\n"
"1. Output ONLY the {target_lang} translation. No commentary, no source text "
"echo, no duplications, no newlines.\n"
"2. If the current part completes or changes the meaning of a previous part,\n"
" REWRITE only the affected previous part(s) followed by the new translation.\n"
"3. If the current part is an independent continuation, output ONLY the new\n"
" translation without repeating prior translations.\n"
"4. Do NOT add ellipsis (...) or any placeholder markers for incomplete "
"thoughts.\n"
" Translate exactly what is given.\n"
"5. Maintain grammatical agreement and natural {target_lang} word order "
"across parts.\n\n"
f"EXAMPLE ({source_lang} β†’ {target_lang} streaming sequence):\n\n"
'[Part 1] "The latest research in"\n'
"[Translation 1] ПослСдниС исслСдования Π² области\n"
'[Part 2] " artificial intelligence shows"\n'
"[Translation 2] ПослСдниС исслСдования Π² области искусствСнного "
"ΠΈΠ½Ρ‚Π΅Π»Π»Π΅ΠΊΡ‚Π° ΠΏΠΎΠΊΠ°Π·Ρ‹Π²Π°ΡŽΡ‚,\n"
'[Part 3] " promising results for"\n'
"[Translation 3] ΠΌΠ½ΠΎΠ³ΠΎΠΎΠ±Π΅Ρ‰Π°ΡŽΡ‰ΠΈΠ΅ Ρ€Π΅Π·ΡƒΠ»ΡŒΡ‚Π°Ρ‚Ρ‹ для\n"
'[Part 4] " healthcare applications."\n'
"[Translation 4] примСнСния Π² Π·Π΄Ρ€Π°Π²ΠΎΠΎΡ…Ρ€Π°Π½Π΅Π½ΠΈΠΈ.\n\n"
'NOTICE: In Translation 2, the translator revised Translation 1 because "in" '
'was completed by "artificial intelligence shows" β€” changing the sentence '
"structure.\n"
"In Translation 3, no revision was needed β€” it is a pure continuation.\n\n"
"Now begin. The following is the first part of the {source_lang} "
"text sequence:"
"{source_text}<end_of_turn>\n"
"<start_of_turn>model> \n"
)
# Continuation prompt. Uses consistent `> \\n` after `model`.
self.cont_prompt_template: str = (
"{prev_query}{prev_translation}<end_of_turn>\n"
"<start_of_turn>user\n\n\n"
"{source_text}<end_of_turn>\n"
"<start_of_turn>model> \n"
)
# ─── KV-cache warmup ────────────────────────────────────────────────────
def warmup_cache(self, source_texts: List[str] = DEFAULT_WARMUP_TEXTS) -> None:
"""
Pre-compute and cache KV states for the fixed prefix prompt.
This eliminates roughly 60–80 ms of processing time per translation.
Uses local variables only; does not touch any instance state.
Parameters
----------
source_texts :
Warmup text sequence (defaults to ``DEFAULT_WARMUP_TEXTS``).
"""
prev_translation = ""
prev_query = ""
for i, source_text in enumerate(source_texts):
if i > 0:
warmup_prompt = self.cont_prompt_template.format(
source_text=source_text,
prev_translation=prev_translation,
prev_query=prev_query,
)
else:
warmup_prompt = self.init_prompt_template.format(
source_lang=source_lang,
src_lang_code=src_lang_code,
target_lang=target_lang,
tgt_lang_code=tgt_lang_code,
source_text=source_text,
)
log.debug(f"warmup_prompt:\n{warmup_prompt}\n")
start_time = time.time()
output: Dict = self.llm(
warmup_prompt,
max_tokens=100,
temperature=0.3,
top_k=64,
top_p=0.95,
repeat_penalty=1.1,
stop=STOP_TOKENS,
stream=False,
echo=False,
)
latency_s = time.time() - start_time
translation = output["choices"][0]["text"].strip()
log.debug(f"warmup output: {output}\n")
log.debug(f"warmup source_text: {source_text}\n")
log.debug(f"warmup translation: {translation}\n")
log.debug(f"warmup latency: {latency_s * 1000:.2f} ms\n")
prev_translation = translation
prev_query = warmup_prompt
# ─── Main translation API ───────────────────────────────────────────────
def translate(
self,
source_text: str,
prev_translation: str = "",
prev_query: str = "",
max_tokens: int = 100,
temperature: float = 0.3,
) -> Dict[str, object]:
"""
Non-streaming translation with KV cache for a fixed prefix.
Thread-safe: no internal mutable state is modified. Callers supply
and receive ``prev_translation`` and ``prev_query`` to maintain
per-session multi-turn context. Multiple :class:`StreamingNMT`
sessions can share this translator concurrently.
Parameters
----------
source_text :
English phrase to translate.
prev_translation :
Previous translation result for multi-turn context.
prev_query :
Previous prompt for multi-turn context.
max_tokens :
Maximum tokens to generate (default 100).
temperature :
Sampling temperature (default 0.3).
Returns
-------
dict
Keys: ``translation``, ``query``, ``prev_translation``,
``prev_query``, ``latency_ms``, ``tokens_generated``,
``tokens_per_sec``.
"""
if prev_translation and prev_query:
prompt = self.cont_prompt_template.format(
source_text=source_text,
prev_translation=prev_translation,
prev_query=prev_query,
)
else:
prompt = self.init_prompt_template.format(
source_lang=source_lang,
src_lang_code=src_lang_code,
target_lang=target_lang,
tgt_lang_code=tgt_lang_code,
source_text=source_text,
)
log.debug(f"NMT prompt: {prompt[:200]}...")
start_time = time.time()
output: Dict = self.llm(
prompt,
max_tokens=max_tokens,
temperature=temperature,
top_k=64,
top_p=0.95,
repeat_penalty=1.1,
stop=STOP_TOKENS,
stream=False,
echo=False,
)
latency_ms = (time.time() - start_time) * 1000.0
translation = output["choices"][0]["text"].strip()
tokens_generated = int(output["usage"]["completion_tokens"])
log.info(f"NMT: '{source_text}' -> '{translation}' " f"({latency_ms:.1f} ms, {tokens_generated} tokens)")
return {
"translation": translation,
"query": prompt,
"prev_translation": translation,
"prev_query": prompt,
"latency_ms": latency_ms,
"tokens_generated": tokens_generated,
"tokens_per_sec": (tokens_generated / (latency_ms / 1000.0)) if latency_ms > 0 else 0.0,
}