| """ |
| S2S model interface + a runnable cascaded streaming translator for SN59. |
| |
| The server feeds the source utterance as 80 ms PCM frames (float32, 24 kHz, |
| mono, 1920 samples/frame) and expects target-language (English) PCM frames |
| back, ASAP. Score = latency once you clear the accuracy (>=0.65) and |
| speech-rate gates, so EMIT EARLY and keep output tight. |
| |
| Implement `S2SModel`: |
| start_session(language, sample_rate_hz, channels) -> state |
| push(state, pcm: np.ndarray, is_final: bool) -> (out_pcm: np.ndarray, done: bool) |
| |
| `push` is called once per input frame, then (during drain) possibly several |
| more times with an empty array and is_final=True until you return done=True. |
| |
| Models: |
| echo -> EchoDelayModel (debug; proves protocol, scores ~0) |
| cascade -> CascadedStreamingS2S (faster-whisper -> NLLB -> TTS, simultaneous) |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import subprocess |
| import tempfile |
| import wave |
| from dataclasses import dataclass, field |
| from typing import Any, List, Tuple |
|
|
| import numpy as np |
|
|
| TARGET_SAMPLE_RATE_HZ = 24_000 |
| ASR_SAMPLE_RATE_HZ = 16_000 |
|
|
|
|
| |
| |
| |
| def resample_mono(x: np.ndarray, src_hz: int, dst_hz: int) -> np.ndarray: |
| x = np.asarray(x, dtype=np.float32).reshape(-1) |
| if x.size == 0 or src_hz == dst_hz: |
| return x |
| n_out = max(1, int(round(x.size * dst_hz / src_hz))) |
| xp = np.linspace(0.0, x.size - 1, num=x.size) |
| xq = np.linspace(0.0, x.size - 1, num=n_out) |
| return np.interp(xq, xp, x).astype(np.float32) |
|
|
|
|
| def read_wav_float32_mono(path: str) -> Tuple[np.ndarray, int]: |
| with wave.open(path, "rb") as w: |
| sr = w.getframerate() |
| ch = w.getnchannels() |
| sw = w.getsampwidth() |
| raw = w.readframes(w.getnframes()) |
| if sw == 2: |
| a = np.frombuffer(raw, dtype="<i2").astype(np.float32) / 32768.0 |
| elif sw == 4: |
| a = np.frombuffer(raw, dtype="<i4").astype(np.float32) / 2147483648.0 |
| elif sw == 1: |
| a = (np.frombuffer(raw, dtype="u1").astype(np.float32) - 128.0) / 128.0 |
| else: |
| raise ValueError(f"unsupported sample width {sw}") |
| if ch > 1: |
| a = a.reshape(-1, ch).mean(axis=1) |
| return a, sr |
|
|
|
|
| def _common_prefix_words(a: List[str], b: List[str]) -> List[str]: |
| out: List[str] = [] |
| for wa, wb in zip(a, b): |
| if wa == wb: |
| out.append(wa) |
| else: |
| break |
| return out |
|
|
|
|
| |
| |
| |
| class S2SModel: |
| def start_session(self, *, language: str | None, sample_rate_hz: int, channels: int) -> Any: |
| raise NotImplementedError |
|
|
| def push(self, state: Any, pcm: np.ndarray, is_final: bool) -> Tuple[np.ndarray, bool]: |
| raise NotImplementedError |
|
|
|
|
| |
| |
| |
| @dataclass |
| class _EchoState: |
| chunks: list = field(default_factory=list) |
|
|
|
|
| class EchoDelayModel(S2SModel): |
| def start_session(self, *, language, sample_rate_hz, channels): |
| return _EchoState() |
|
|
| def push(self, state: _EchoState, pcm: np.ndarray, is_final: bool): |
| if pcm.size: |
| state.chunks.append(pcm.astype(np.float32, copy=False)) |
| if is_final: |
| out = np.concatenate(state.chunks) if state.chunks else np.zeros(1, np.float32) |
| return out, True |
| return np.zeros(0, np.float32), False |
|
|
|
|
| |
| |
| |
| class TTSBackend: |
| def synth(self, text: str) -> np.ndarray: |
| raise NotImplementedError |
|
|
|
|
| class EspeakTTS(TTSBackend): |
| """Ultra-low-latency, no downloads. Robotic but intelligible -> usually |
| clears the 0.65 accuracy gate. Great for getting latency to overshoot<=0 |
| first; swap to PiperTTS once the policy is tuned.""" |
|
|
| def __init__(self, wpm: int = 175, voice: str = "en"): |
| self.wpm = int(wpm) |
| self.voice = voice |
|
|
| def synth(self, text: str) -> np.ndarray: |
| if not text.strip(): |
| return np.zeros(0, np.float32) |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as f: |
| subprocess.run( |
| ["espeak-ng", "-v", self.voice, "-s", str(self.wpm), "-w", f.name, text], |
| check=True, |
| stdout=subprocess.DEVNULL, |
| stderr=subprocess.DEVNULL, |
| ) |
| pcm, sr = read_wav_float32_mono(f.name) |
| return resample_mono(pcm, sr, TARGET_SAMPLE_RATE_HZ) |
|
|
|
|
| class PiperTTS(TTSBackend): |
| """Higher quality, still fast. Needs a voice model: |
| pip install piper-tts |
| download e.g. en_US-lessac-medium.onnx (+ .json) from rhasspy/piper-voices |
| export BB_PIPER_VOICE=/path/en_US-lessac-medium.onnx |
| """ |
|
|
| def __init__(self, voice_path: str | None = None): |
| self.voice_path = voice_path or os.getenv("BB_PIPER_VOICE", "") |
| self._voice = None |
|
|
| def _ensure(self): |
| if self._voice is None: |
| from piper import PiperVoice |
| self._voice = PiperVoice.load(self.voice_path) |
| return self._voice |
|
|
| def synth(self, text: str) -> np.ndarray: |
| if not text.strip(): |
| return np.zeros(0, np.float32) |
| voice = self._ensure() |
| sr = int(getattr(voice.config, "sample_rate", 22050)) |
| pcm = bytearray() |
| for chunk in voice.synthesize_stream_raw(text): |
| pcm.extend(chunk) |
| a = np.frombuffer(bytes(pcm), dtype="<i2").astype(np.float32) / 32768.0 |
| return resample_mono(a, sr, TARGET_SAMPLE_RATE_HZ) |
|
|
|
|
| def _load_tts() -> TTSBackend: |
| name = os.getenv("BB_TTS", "espeak").lower() |
| if name == "piper": |
| return PiperTTS() |
| return EspeakTTS(wpm=int(os.getenv("BB_TTS_WPM", "175"))) |
|
|
|
|
| |
| |
| |
| |
| NLLB_CODE = { |
| "af": "afr_Latn", "am": "amh_Ethi", "ar": "arb_Arab", "az": "azj_Latn", |
| "be": "bel_Cyrl", "bg": "bul_Cyrl", "bn": "ben_Beng", "bs": "bos_Latn", |
| "ca": "cat_Latn", "cs": "ces_Latn", "cy": "cym_Latn", "da": "dan_Latn", |
| "de": "deu_Latn", "el": "ell_Grek", "en": "eng_Latn", "es": "spa_Latn", |
| "et": "est_Latn", "eu": "eus_Latn", "fa": "pes_Arab", "fi": "fin_Latn", |
| "fr": "fra_Latn", "gl": "glg_Latn", "gu": "guj_Gujr", "he": "heb_Hebr", |
| "hi": "hin_Deva", "hr": "hrv_Latn", "hu": "hun_Latn", "hy": "hye_Armn", |
| "id": "ind_Latn", "is": "isl_Latn", "it": "ita_Latn", "ja": "jpn_Jpan", |
| "ka": "kat_Geor", "kk": "kaz_Cyrl", "km": "khm_Khmr", "kn": "kan_Knda", |
| "ko": "kor_Hang", "lt": "lit_Latn", "lv": "lvs_Latn", "mk": "mkd_Cyrl", |
| "ml": "mal_Mlym", "mn": "khk_Cyrl", "mr": "mar_Deva", "ms": "zsm_Latn", |
| "my": "mya_Mymr", "ne": "npi_Deva", "nl": "nld_Latn", "no": "nob_Latn", |
| "pa": "pan_Guru", "pl": "pol_Latn", "ps": "pbt_Arab", "pt": "por_Latn", |
| "ro": "ron_Latn", "ru": "rus_Cyrl", "si": "sin_Sinh", "sk": "slk_Latn", |
| "sl": "slv_Latn", "sq": "als_Latn", "sr": "srp_Cyrl", "sv": "swe_Latn", |
| "sw": "swh_Latn", "ta": "tam_Taml", "te": "tel_Telu", "th": "tha_Thai", |
| "tl": "tgl_Latn", "tr": "tur_Latn", "uk": "ukr_Cyrl", "ur": "urd_Arab", |
| "uz": "uzn_Latn", "vi": "vie_Latn", "zh": "zho_Hans", |
| } |
|
|
|
|
| class WhisperASR: |
| def __init__(self): |
| import torch |
| from faster_whisper import WhisperModel |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| compute = "float16" if device == "cuda" else "int8" |
| size = os.getenv("BB_ASR_MODEL", "small") |
| self.model = WhisperModel(size, device=device, compute_type=compute) |
|
|
| def transcribe(self, pcm16k: np.ndarray, language: str | None) -> str: |
| if pcm16k.size < ASR_SAMPLE_RATE_HZ // 4: |
| return "" |
| segments, _ = self.model.transcribe( |
| pcm16k, |
| language=language or None, |
| beam_size=int(os.getenv("BB_ASR_BEAMS", "1")), |
| vad_filter=False, |
| without_timestamps=True, |
| condition_on_previous_text=False, |
| ) |
| return "".join(seg.text for seg in segments).strip() |
|
|
|
|
| class NllbMT: |
| """Multilingual (one model, ~2.4 GB) — best on the GPU. BB_MT_BACKEND=nllb.""" |
|
|
| def __init__(self): |
| import torch |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| name = os.getenv("BB_MT_MODEL", "facebook/nllb-200-distilled-600M") |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.tok = AutoTokenizer.from_pretrained(name) |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(name).to(self.device) |
| self._eng_id = self.tok.convert_tokens_to_ids("eng_Latn") |
|
|
| def translate(self, text: str, src_lang: str | None) -> str: |
| import torch |
| if not text.strip(): |
| return "" |
| self.tok.src_lang = NLLB_CODE.get((src_lang or "").lower(), "fra_Latn") |
| inputs = self.tok(text, return_tensors="pt", truncation=True, max_length=256).to(self.device) |
| with torch.inference_mode(): |
| gen = self.model.generate( |
| **inputs, |
| forced_bos_token_id=self._eng_id, |
| num_beams=int(os.getenv("BB_MT_BEAMS", "1")), |
| max_new_tokens=200, |
| ) |
| return self.tok.batch_decode(gen, skip_special_tokens=True)[0].strip() |
|
|
|
|
| class _IdentityMT: |
| """Passthrough — returns source text unchanged (for English round-trip tests).""" |
|
|
| def translate(self, text: str, src_lang: str | None = None) -> str: |
| return text.strip() |
|
|
|
|
| class MarianMT: |
| """Per-pair, light (~300 MB), CPU-friendly. BB_MT_BACKEND=marian. |
| Loads Helsinki-NLP/opus-mt-<src>-en for the session's source language.""" |
|
|
| def __init__(self, src_lang: str | None): |
| import torch |
| from transformers import MarianMTModel, MarianTokenizer |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| src = (src_lang or "fr").lower() |
| name = os.getenv("BB_MARIAN_MODEL", f"Helsinki-NLP/opus-mt-{src}-en") |
| self.tok = MarianTokenizer.from_pretrained(name) |
| self.model = MarianMTModel.from_pretrained(name).to(self.device) |
|
|
| def translate(self, text: str, src_lang: str | None = None) -> str: |
| import torch |
| if not text.strip(): |
| return "" |
| inputs = self.tok([text], return_tensors="pt", padding=True, truncation=True, max_length=256).to(self.device) |
| with torch.inference_mode(): |
| gen = self.model.generate(**inputs, num_beams=1, max_new_tokens=128) |
| return self.tok.batch_decode(gen, skip_special_tokens=True)[0].strip() |
|
|
|
|
| |
| |
| |
| @dataclass |
| class _CascadeState: |
| language: str | None |
| buf: List[np.ndarray] = field(default_factory=list) |
| n_samples: int = 0 |
| last_asr_samples: int = 0 |
| prev_src_hyp: str = "" |
| committed_src: str = "" |
| spoken_en: str = "" |
| finished: bool = False |
|
|
|
|
| class CascadedStreamingS2S(S2SModel): |
| """ |
| Simultaneous policy = wait-k + double local-agreement: |
| 1. read WAIT_K_SEC of audio before the first emit; |
| 2. every STEP_SEC of new audio, re-transcribe the buffer and commit the |
| source prefix that is STABLE across two consecutive ASR runs; |
| 3. translate the committed source, hold back MARGIN_WORDS of English |
| (which may still change), and synthesize only the newly-stable tail; |
| 4. on is_final, commit + speak everything remaining. |
| |
| Tune WAIT_K_SEC / STEP_SEC / MARGIN_WORDS against ../selfscore to push |
| overshoot<=0 while keeping accuracy>=0.65. |
| """ |
|
|
| def __init__(self): |
| self._asr: WhisperASR | None = None |
| self._tts: TTSBackend | None = None |
| self._mt_cache: dict = {} |
| self._mt_backend = os.getenv("BB_MT_BACKEND", "nllb").lower() |
| self.wait_k_sec = float(os.getenv("BB_WAIT_K_SEC", "1.0")) |
| self.step_sec = float(os.getenv("BB_STEP_SEC", "0.48")) |
| self.margin_words = int(os.getenv("BB_MARGIN_WORDS", "2")) |
| |
| |
| self.min_emit_words = int(os.getenv("BB_EMIT_MIN_WORDS", "0")) |
|
|
| def _ensure_loaded(self): |
| if self._asr is None: |
| self._asr = WhisperASR() |
| self._tts = _load_tts() |
|
|
| def _get_mt(self, src_lang: str | None): |
| |
| |
| if self._mt_backend == "identity": |
| key = "identity" |
| elif self._mt_backend == "marian": |
| key = f"marian:{(src_lang or 'fr').lower()}" |
| else: |
| key = "nllb" |
| if key not in self._mt_cache: |
| if self._mt_backend == "identity": |
| self._mt_cache[key] = _IdentityMT() |
| elif self._mt_backend == "marian": |
| self._mt_cache[key] = MarianMT(src_lang) |
| else: |
| self._mt_cache[key] = NllbMT() |
| return self._mt_cache[key] |
|
|
| def start_session(self, *, language, sample_rate_hz, channels): |
| self._ensure_loaded() |
| return _CascadeState(language=language) |
|
|
| def push(self, state: _CascadeState, pcm: np.ndarray, is_final: bool): |
| if state.finished: |
| return np.zeros(0, np.float32), True |
| if pcm.size: |
| state.buf.append(pcm.astype(np.float32, copy=False)) |
| state.n_samples += pcm.size |
|
|
| wait_k = int(self.wait_k_sec * TARGET_SAMPLE_RATE_HZ) |
| step = int(self.step_sec * TARGET_SAMPLE_RATE_HZ) |
| ready = state.n_samples >= wait_k |
| due = (state.n_samples - state.last_asr_samples) >= step |
| if not is_final and (not ready or not due): |
| return np.zeros(0, np.float32), False |
|
|
| state.last_asr_samples = state.n_samples |
| audio24 = np.concatenate(state.buf) if state.buf else np.zeros(0, np.float32) |
| audio16 = resample_mono(audio24, TARGET_SAMPLE_RATE_HZ, ASR_SAMPLE_RATE_HZ) |
|
|
| hyp = self._asr.transcribe(audio16, state.language) |
| |
| if is_final: |
| stable_src = hyp or state.committed_src |
| else: |
| stable_words = _common_prefix_words(state.prev_src_hyp.split(), hyp.split()) |
| stable_src = " ".join(stable_words) |
| state.prev_src_hyp = hyp |
|
|
| if not stable_src or (stable_src == state.committed_src and not is_final): |
| return np.zeros(0, np.float32), (True if is_final else False) |
| state.committed_src = stable_src |
|
|
| en_full = self._get_mt(state.language).translate(state.committed_src, state.language) |
| en_words = en_full.split() |
| target_words = en_words if is_final else en_words[: max(0, len(en_words) - self.margin_words)] |
| spoken_words = state.spoken_en.split() |
|
|
| |
| |
| if spoken_words and _common_prefix_words(spoken_words, target_words) != spoken_words: |
| new_words = target_words[len(spoken_words):] |
| else: |
| new_words = target_words[len(spoken_words):] |
|
|
| if not new_words or (len(new_words) < self.min_emit_words and not is_final): |
| if is_final: |
| state.finished = True |
| return np.zeros(1, np.float32), True |
| return np.zeros(0, np.float32), False |
|
|
| out = self._tts.synth(" ".join(new_words)) |
| state.spoken_en = " ".join(target_words) |
| if is_final: |
| state.finished = True |
| return out, is_final |
|
|
|
|
| def load_model(name: str = "echo") -> S2SModel: |
| name = (name or "echo").lower() |
| if name == "echo": |
| return EchoDelayModel() |
| if name == "cascade": |
| return CascadedStreamingS2S() |
| if name == "streamspeech": |
| from streamspeech_model import load_streamspeech |
| return load_streamspeech() |
| if name == "gist": |
| from gist_model import load_gist |
| return load_gist() |
| raise ValueError(f"unknown model '{name}'") |
|
|