""" S2S model interface + a runnable cascaded streaming translator for SN59. The server feeds the source utterance as 80 ms PCM frames (float32, 24 kHz, mono, 1920 samples/frame) and expects target-language (English) PCM frames back, ASAP. Score = latency once you clear the accuracy (>=0.65) and speech-rate gates, so EMIT EARLY and keep output tight. Implement `S2SModel`: start_session(language, sample_rate_hz, channels) -> state push(state, pcm: np.ndarray, is_final: bool) -> (out_pcm: np.ndarray, done: bool) `push` is called once per input frame, then (during drain) possibly several more times with an empty array and is_final=True until you return done=True. Models: echo -> EchoDelayModel (debug; proves protocol, scores ~0) cascade -> CascadedStreamingS2S (faster-whisper -> NLLB -> TTS, simultaneous) """ from __future__ import annotations import os import subprocess import tempfile import wave from dataclasses import dataclass, field from typing import Any, List, Tuple import numpy as np TARGET_SAMPLE_RATE_HZ = 24_000 ASR_SAMPLE_RATE_HZ = 16_000 # --------------------------------------------------------------------------- # # audio helpers # --------------------------------------------------------------------------- # def resample_mono(x: np.ndarray, src_hz: int, dst_hz: int) -> np.ndarray: x = np.asarray(x, dtype=np.float32).reshape(-1) if x.size == 0 or src_hz == dst_hz: return x n_out = max(1, int(round(x.size * dst_hz / src_hz))) xp = np.linspace(0.0, x.size - 1, num=x.size) xq = np.linspace(0.0, x.size - 1, num=n_out) return np.interp(xq, xp, x).astype(np.float32) def read_wav_float32_mono(path: str) -> Tuple[np.ndarray, int]: with wave.open(path, "rb") as w: sr = w.getframerate() ch = w.getnchannels() sw = w.getsampwidth() raw = w.readframes(w.getnframes()) if sw == 2: a = np.frombuffer(raw, dtype=" 1: a = a.reshape(-1, ch).mean(axis=1) return a, sr def _common_prefix_words(a: List[str], b: List[str]) -> List[str]: out: List[str] = [] for wa, wb in zip(a, b): if wa == wb: out.append(wa) else: break return out # --------------------------------------------------------------------------- # # S2S interface # --------------------------------------------------------------------------- # class S2SModel: def start_session(self, *, language: str | None, sample_rate_hz: int, channels: int) -> Any: raise NotImplementedError def push(self, state: Any, pcm: np.ndarray, is_final: bool) -> Tuple[np.ndarray, bool]: raise NotImplementedError # --------------------------------------------------------------------------- # # debug echo model # --------------------------------------------------------------------------- # @dataclass class _EchoState: chunks: list = field(default_factory=list) class EchoDelayModel(S2SModel): def start_session(self, *, language, sample_rate_hz, channels): return _EchoState() def push(self, state: _EchoState, pcm: np.ndarray, is_final: bool): if pcm.size: state.chunks.append(pcm.astype(np.float32, copy=False)) if is_final: out = np.concatenate(state.chunks) if state.chunks else np.zeros(1, np.float32) return out, True return np.zeros(0, np.float32), False # --------------------------------------------------------------------------- # # TTS backends (English text -> 24 kHz float32 mono) # --------------------------------------------------------------------------- # class TTSBackend: def synth(self, text: str) -> np.ndarray: raise NotImplementedError class EspeakTTS(TTSBackend): """Ultra-low-latency, no downloads. Robotic but intelligible -> usually clears the 0.65 accuracy gate. Great for getting latency to overshoot<=0 first; swap to PiperTTS once the policy is tuned.""" def __init__(self, wpm: int = 175, voice: str = "en"): self.wpm = int(wpm) self.voice = voice def synth(self, text: str) -> np.ndarray: if not text.strip(): return np.zeros(0, np.float32) with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as f: subprocess.run( ["espeak-ng", "-v", self.voice, "-s", str(self.wpm), "-w", f.name, text], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) pcm, sr = read_wav_float32_mono(f.name) return resample_mono(pcm, sr, TARGET_SAMPLE_RATE_HZ) class PiperTTS(TTSBackend): """Higher quality, still fast. Needs a voice model: pip install piper-tts download e.g. en_US-lessac-medium.onnx (+ .json) from rhasspy/piper-voices export BB_PIPER_VOICE=/path/en_US-lessac-medium.onnx """ def __init__(self, voice_path: str | None = None): self.voice_path = voice_path or os.getenv("BB_PIPER_VOICE", "") self._voice = None def _ensure(self): if self._voice is None: from piper import PiperVoice # type: ignore self._voice = PiperVoice.load(self.voice_path) return self._voice def synth(self, text: str) -> np.ndarray: if not text.strip(): return np.zeros(0, np.float32) voice = self._ensure() sr = int(getattr(voice.config, "sample_rate", 22050)) pcm = bytearray() for chunk in voice.synthesize_stream_raw(text): # int16 LE bytes pcm.extend(chunk) a = np.frombuffer(bytes(pcm), dtype=" TTSBackend: name = os.getenv("BB_TTS", "espeak").lower() if name == "piper": return PiperTTS() return EspeakTTS(wpm=int(os.getenv("BB_TTS_WPM", "175"))) # --------------------------------------------------------------------------- # # ASR (faster-whisper) and MT (NLLB) wrappers # --------------------------------------------------------------------------- # # Whisper ISO-639-1 -> NLLB FLORES-200 code map (covers the common languages). NLLB_CODE = { "af": "afr_Latn", "am": "amh_Ethi", "ar": "arb_Arab", "az": "azj_Latn", "be": "bel_Cyrl", "bg": "bul_Cyrl", "bn": "ben_Beng", "bs": "bos_Latn", "ca": "cat_Latn", "cs": "ces_Latn", "cy": "cym_Latn", "da": "dan_Latn", "de": "deu_Latn", "el": "ell_Grek", "en": "eng_Latn", "es": "spa_Latn", "et": "est_Latn", "eu": "eus_Latn", "fa": "pes_Arab", "fi": "fin_Latn", "fr": "fra_Latn", "gl": "glg_Latn", "gu": "guj_Gujr", "he": "heb_Hebr", "hi": "hin_Deva", "hr": "hrv_Latn", "hu": "hun_Latn", "hy": "hye_Armn", "id": "ind_Latn", "is": "isl_Latn", "it": "ita_Latn", "ja": "jpn_Jpan", "ka": "kat_Geor", "kk": "kaz_Cyrl", "km": "khm_Khmr", "kn": "kan_Knda", "ko": "kor_Hang", "lt": "lit_Latn", "lv": "lvs_Latn", "mk": "mkd_Cyrl", "ml": "mal_Mlym", "mn": "khk_Cyrl", "mr": "mar_Deva", "ms": "zsm_Latn", "my": "mya_Mymr", "ne": "npi_Deva", "nl": "nld_Latn", "no": "nob_Latn", "pa": "pan_Guru", "pl": "pol_Latn", "ps": "pbt_Arab", "pt": "por_Latn", "ro": "ron_Latn", "ru": "rus_Cyrl", "si": "sin_Sinh", "sk": "slk_Latn", "sl": "slv_Latn", "sq": "als_Latn", "sr": "srp_Cyrl", "sv": "swe_Latn", "sw": "swh_Latn", "ta": "tam_Taml", "te": "tel_Telu", "th": "tha_Thai", "tl": "tgl_Latn", "tr": "tur_Latn", "uk": "ukr_Cyrl", "ur": "urd_Arab", "uz": "uzn_Latn", "vi": "vie_Latn", "zh": "zho_Hans", } class WhisperASR: def __init__(self): import torch from faster_whisper import WhisperModel device = "cuda" if torch.cuda.is_available() else "cpu" compute = "float16" if device == "cuda" else "int8" size = os.getenv("BB_ASR_MODEL", "small") self.model = WhisperModel(size, device=device, compute_type=compute) def transcribe(self, pcm16k: np.ndarray, language: str | None) -> str: if pcm16k.size < ASR_SAMPLE_RATE_HZ // 4: # <0.25s: not enough yet return "" segments, _ = self.model.transcribe( pcm16k, language=language or None, beam_size=int(os.getenv("BB_ASR_BEAMS", "1")), vad_filter=False, without_timestamps=True, condition_on_previous_text=False, ) return "".join(seg.text for seg in segments).strip() class NllbMT: """Multilingual (one model, ~2.4 GB) — best on the GPU. BB_MT_BACKEND=nllb.""" def __init__(self): import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer name = os.getenv("BB_MT_MODEL", "facebook/nllb-200-distilled-600M") self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tok = AutoTokenizer.from_pretrained(name) self.model = AutoModelForSeq2SeqLM.from_pretrained(name).to(self.device) self._eng_id = self.tok.convert_tokens_to_ids("eng_Latn") def translate(self, text: str, src_lang: str | None) -> str: import torch if not text.strip(): return "" self.tok.src_lang = NLLB_CODE.get((src_lang or "").lower(), "fra_Latn") inputs = self.tok(text, return_tensors="pt", truncation=True, max_length=256).to(self.device) with torch.inference_mode(): gen = self.model.generate( **inputs, forced_bos_token_id=self._eng_id, num_beams=int(os.getenv("BB_MT_BEAMS", "1")), max_new_tokens=200, ) return self.tok.batch_decode(gen, skip_special_tokens=True)[0].strip() class _IdentityMT: """Passthrough — returns source text unchanged (for English round-trip tests).""" def translate(self, text: str, src_lang: str | None = None) -> str: return text.strip() class MarianMT: """Per-pair, light (~300 MB), CPU-friendly. BB_MT_BACKEND=marian. Loads Helsinki-NLP/opus-mt--en for the session's source language.""" def __init__(self, src_lang: str | None): import torch from transformers import MarianMTModel, MarianTokenizer self.device = "cuda" if torch.cuda.is_available() else "cpu" src = (src_lang or "fr").lower() name = os.getenv("BB_MARIAN_MODEL", f"Helsinki-NLP/opus-mt-{src}-en") self.tok = MarianTokenizer.from_pretrained(name) self.model = MarianMTModel.from_pretrained(name).to(self.device) def translate(self, text: str, src_lang: str | None = None) -> str: import torch if not text.strip(): return "" inputs = self.tok([text], return_tensors="pt", padding=True, truncation=True, max_length=256).to(self.device) with torch.inference_mode(): gen = self.model.generate(**inputs, num_beams=1, max_new_tokens=128) return self.tok.batch_decode(gen, skip_special_tokens=True)[0].strip() # --------------------------------------------------------------------------- # # cascaded simultaneous translator # --------------------------------------------------------------------------- # @dataclass class _CascadeState: language: str | None buf: List[np.ndarray] = field(default_factory=list) n_samples: int = 0 last_asr_samples: int = 0 prev_src_hyp: str = "" # for ASR local agreement committed_src: str = "" # stabilized source text spoken_en: str = "" # English already synthesized/emitted finished: bool = False class CascadedStreamingS2S(S2SModel): """ Simultaneous policy = wait-k + double local-agreement: 1. read WAIT_K_SEC of audio before the first emit; 2. every STEP_SEC of new audio, re-transcribe the buffer and commit the source prefix that is STABLE across two consecutive ASR runs; 3. translate the committed source, hold back MARGIN_WORDS of English (which may still change), and synthesize only the newly-stable tail; 4. on is_final, commit + speak everything remaining. Tune WAIT_K_SEC / STEP_SEC / MARGIN_WORDS against ../selfscore to push overshoot<=0 while keeping accuracy>=0.65. """ def __init__(self): self._asr: WhisperASR | None = None self._tts: TTSBackend | None = None self._mt_cache: dict = {} self._mt_backend = os.getenv("BB_MT_BACKEND", "nllb").lower() self.wait_k_sec = float(os.getenv("BB_WAIT_K_SEC", "1.0")) self.step_sec = float(os.getenv("BB_STEP_SEC", "0.48")) self.margin_words = int(os.getenv("BB_MARGIN_WORDS", "2")) # Only synthesize once >= this many new words are ready (fewer, more fluent # chunks -> better Whisper intelligibility). 0 = emit every step. self.min_emit_words = int(os.getenv("BB_EMIT_MIN_WORDS", "0")) def _ensure_loaded(self): if self._asr is None: self._asr = WhisperASR() self._tts = _load_tts() def _get_mt(self, src_lang: str | None): # identity: passthrough (English round-trip test); marian: one model per # source language; nllb: shared multilingual model. if self._mt_backend == "identity": key = "identity" elif self._mt_backend == "marian": key = f"marian:{(src_lang or 'fr').lower()}" else: key = "nllb" if key not in self._mt_cache: if self._mt_backend == "identity": self._mt_cache[key] = _IdentityMT() elif self._mt_backend == "marian": self._mt_cache[key] = MarianMT(src_lang) else: self._mt_cache[key] = NllbMT() return self._mt_cache[key] def start_session(self, *, language, sample_rate_hz, channels): self._ensure_loaded() return _CascadeState(language=language) def push(self, state: _CascadeState, pcm: np.ndarray, is_final: bool): if state.finished: return np.zeros(0, np.float32), True if pcm.size: state.buf.append(pcm.astype(np.float32, copy=False)) state.n_samples += pcm.size wait_k = int(self.wait_k_sec * TARGET_SAMPLE_RATE_HZ) step = int(self.step_sec * TARGET_SAMPLE_RATE_HZ) ready = state.n_samples >= wait_k due = (state.n_samples - state.last_asr_samples) >= step if not is_final and (not ready or not due): return np.zeros(0, np.float32), False state.last_asr_samples = state.n_samples audio24 = np.concatenate(state.buf) if state.buf else np.zeros(0, np.float32) audio16 = resample_mono(audio24, TARGET_SAMPLE_RATE_HZ, ASR_SAMPLE_RATE_HZ) hyp = self._asr.transcribe(audio16, state.language) # ASR local agreement -> stable source prefix if is_final: stable_src = hyp or state.committed_src else: stable_words = _common_prefix_words(state.prev_src_hyp.split(), hyp.split()) stable_src = " ".join(stable_words) state.prev_src_hyp = hyp if not stable_src or (stable_src == state.committed_src and not is_final): return np.zeros(0, np.float32), (True if is_final else False) state.committed_src = stable_src en_full = self._get_mt(state.language).translate(state.committed_src, state.language) en_words = en_full.split() target_words = en_words if is_final else en_words[: max(0, len(en_words) - self.margin_words)] spoken_words = state.spoken_en.split() # Only extend if what we've already spoken is a true prefix of the new # translation (re-translation can revise the tail; we never un-speak). if spoken_words and _common_prefix_words(spoken_words, target_words) != spoken_words: new_words = target_words[len(spoken_words):] # accept minor drift else: new_words = target_words[len(spoken_words):] if not new_words or (len(new_words) < self.min_emit_words and not is_final): if is_final: state.finished = True return np.zeros(1, np.float32), True return np.zeros(0, np.float32), False out = self._tts.synth(" ".join(new_words)) state.spoken_en = " ".join(target_words) if is_final: state.finished = True return out, is_final def load_model(name: str = "echo") -> S2SModel: name = (name or "echo").lower() if name == "echo": return EchoDelayModel() if name == "cascade": return CascadedStreamingS2S() if name == "streamspeech": from streamspeech_model import load_streamspeech # lazy: GPU-box only return load_streamspeech() if name == "gist": from gist_model import load_gist # lazy: GPU-box only (whisper-large+Qwen-7B+Kokoro) return load_gist() raise ValueError(f"unknown model '{name}'")