webapp / model.py
jenkins1122's picture
Upload model.py with huggingface_hub
347efb4 verified
Raw
History Blame Contribute Delete
17.3 kB
"""
S2S model interface + a runnable cascaded streaming translator for SN59.
The server feeds the source utterance as 80 ms PCM frames (float32, 24 kHz,
mono, 1920 samples/frame) and expects target-language (English) PCM frames
back, ASAP. Score = latency once you clear the accuracy (>=0.65) and
speech-rate gates, so EMIT EARLY and keep output tight.
Implement `S2SModel`:
start_session(language, sample_rate_hz, channels) -> state
push(state, pcm: np.ndarray, is_final: bool) -> (out_pcm: np.ndarray, done: bool)
`push` is called once per input frame, then (during drain) possibly several
more times with an empty array and is_final=True until you return done=True.
Models:
echo -> EchoDelayModel (debug; proves protocol, scores ~0)
cascade -> CascadedStreamingS2S (faster-whisper -> NLLB -> TTS, simultaneous)
"""
from __future__ import annotations
import os
import subprocess
import tempfile
import wave
from dataclasses import dataclass, field
from typing import Any, List, Tuple
import numpy as np
TARGET_SAMPLE_RATE_HZ = 24_000
ASR_SAMPLE_RATE_HZ = 16_000
# --------------------------------------------------------------------------- #
# audio helpers
# --------------------------------------------------------------------------- #
def resample_mono(x: np.ndarray, src_hz: int, dst_hz: int) -> np.ndarray:
x = np.asarray(x, dtype=np.float32).reshape(-1)
if x.size == 0 or src_hz == dst_hz:
return x
n_out = max(1, int(round(x.size * dst_hz / src_hz)))
xp = np.linspace(0.0, x.size - 1, num=x.size)
xq = np.linspace(0.0, x.size - 1, num=n_out)
return np.interp(xq, xp, x).astype(np.float32)
def read_wav_float32_mono(path: str) -> Tuple[np.ndarray, int]:
with wave.open(path, "rb") as w:
sr = w.getframerate()
ch = w.getnchannels()
sw = w.getsampwidth()
raw = w.readframes(w.getnframes())
if sw == 2:
a = np.frombuffer(raw, dtype="<i2").astype(np.float32) / 32768.0
elif sw == 4:
a = np.frombuffer(raw, dtype="<i4").astype(np.float32) / 2147483648.0
elif sw == 1:
a = (np.frombuffer(raw, dtype="u1").astype(np.float32) - 128.0) / 128.0
else:
raise ValueError(f"unsupported sample width {sw}")
if ch > 1:
a = a.reshape(-1, ch).mean(axis=1)
return a, sr
def _common_prefix_words(a: List[str], b: List[str]) -> List[str]:
out: List[str] = []
for wa, wb in zip(a, b):
if wa == wb:
out.append(wa)
else:
break
return out
# --------------------------------------------------------------------------- #
# S2S interface
# --------------------------------------------------------------------------- #
class S2SModel:
def start_session(self, *, language: str | None, sample_rate_hz: int, channels: int) -> Any:
raise NotImplementedError
def push(self, state: Any, pcm: np.ndarray, is_final: bool) -> Tuple[np.ndarray, bool]:
raise NotImplementedError
# --------------------------------------------------------------------------- #
# debug echo model
# --------------------------------------------------------------------------- #
@dataclass
class _EchoState:
chunks: list = field(default_factory=list)
class EchoDelayModel(S2SModel):
def start_session(self, *, language, sample_rate_hz, channels):
return _EchoState()
def push(self, state: _EchoState, pcm: np.ndarray, is_final: bool):
if pcm.size:
state.chunks.append(pcm.astype(np.float32, copy=False))
if is_final:
out = np.concatenate(state.chunks) if state.chunks else np.zeros(1, np.float32)
return out, True
return np.zeros(0, np.float32), False
# --------------------------------------------------------------------------- #
# TTS backends (English text -> 24 kHz float32 mono)
# --------------------------------------------------------------------------- #
class TTSBackend:
def synth(self, text: str) -> np.ndarray:
raise NotImplementedError
class EspeakTTS(TTSBackend):
"""Ultra-low-latency, no downloads. Robotic but intelligible -> usually
clears the 0.65 accuracy gate. Great for getting latency to overshoot<=0
first; swap to PiperTTS once the policy is tuned."""
def __init__(self, wpm: int = 175, voice: str = "en"):
self.wpm = int(wpm)
self.voice = voice
def synth(self, text: str) -> np.ndarray:
if not text.strip():
return np.zeros(0, np.float32)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as f:
subprocess.run(
["espeak-ng", "-v", self.voice, "-s", str(self.wpm), "-w", f.name, text],
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
pcm, sr = read_wav_float32_mono(f.name)
return resample_mono(pcm, sr, TARGET_SAMPLE_RATE_HZ)
class PiperTTS(TTSBackend):
"""Higher quality, still fast. Needs a voice model:
pip install piper-tts
download e.g. en_US-lessac-medium.onnx (+ .json) from rhasspy/piper-voices
export BB_PIPER_VOICE=/path/en_US-lessac-medium.onnx
"""
def __init__(self, voice_path: str | None = None):
self.voice_path = voice_path or os.getenv("BB_PIPER_VOICE", "")
self._voice = None
def _ensure(self):
if self._voice is None:
from piper import PiperVoice # type: ignore
self._voice = PiperVoice.load(self.voice_path)
return self._voice
def synth(self, text: str) -> np.ndarray:
if not text.strip():
return np.zeros(0, np.float32)
voice = self._ensure()
sr = int(getattr(voice.config, "sample_rate", 22050))
pcm = bytearray()
for chunk in voice.synthesize_stream_raw(text): # int16 LE bytes
pcm.extend(chunk)
a = np.frombuffer(bytes(pcm), dtype="<i2").astype(np.float32) / 32768.0
return resample_mono(a, sr, TARGET_SAMPLE_RATE_HZ)
def _load_tts() -> TTSBackend:
name = os.getenv("BB_TTS", "espeak").lower()
if name == "piper":
return PiperTTS()
return EspeakTTS(wpm=int(os.getenv("BB_TTS_WPM", "175")))
# --------------------------------------------------------------------------- #
# ASR (faster-whisper) and MT (NLLB) wrappers
# --------------------------------------------------------------------------- #
# Whisper ISO-639-1 -> NLLB FLORES-200 code map (covers the common languages).
NLLB_CODE = {
"af": "afr_Latn", "am": "amh_Ethi", "ar": "arb_Arab", "az": "azj_Latn",
"be": "bel_Cyrl", "bg": "bul_Cyrl", "bn": "ben_Beng", "bs": "bos_Latn",
"ca": "cat_Latn", "cs": "ces_Latn", "cy": "cym_Latn", "da": "dan_Latn",
"de": "deu_Latn", "el": "ell_Grek", "en": "eng_Latn", "es": "spa_Latn",
"et": "est_Latn", "eu": "eus_Latn", "fa": "pes_Arab", "fi": "fin_Latn",
"fr": "fra_Latn", "gl": "glg_Latn", "gu": "guj_Gujr", "he": "heb_Hebr",
"hi": "hin_Deva", "hr": "hrv_Latn", "hu": "hun_Latn", "hy": "hye_Armn",
"id": "ind_Latn", "is": "isl_Latn", "it": "ita_Latn", "ja": "jpn_Jpan",
"ka": "kat_Geor", "kk": "kaz_Cyrl", "km": "khm_Khmr", "kn": "kan_Knda",
"ko": "kor_Hang", "lt": "lit_Latn", "lv": "lvs_Latn", "mk": "mkd_Cyrl",
"ml": "mal_Mlym", "mn": "khk_Cyrl", "mr": "mar_Deva", "ms": "zsm_Latn",
"my": "mya_Mymr", "ne": "npi_Deva", "nl": "nld_Latn", "no": "nob_Latn",
"pa": "pan_Guru", "pl": "pol_Latn", "ps": "pbt_Arab", "pt": "por_Latn",
"ro": "ron_Latn", "ru": "rus_Cyrl", "si": "sin_Sinh", "sk": "slk_Latn",
"sl": "slv_Latn", "sq": "als_Latn", "sr": "srp_Cyrl", "sv": "swe_Latn",
"sw": "swh_Latn", "ta": "tam_Taml", "te": "tel_Telu", "th": "tha_Thai",
"tl": "tgl_Latn", "tr": "tur_Latn", "uk": "ukr_Cyrl", "ur": "urd_Arab",
"uz": "uzn_Latn", "vi": "vie_Latn", "zh": "zho_Hans",
}
class WhisperASR:
def __init__(self):
import torch
from faster_whisper import WhisperModel
device = "cuda" if torch.cuda.is_available() else "cpu"
compute = "float16" if device == "cuda" else "int8"
size = os.getenv("BB_ASR_MODEL", "small")
self.model = WhisperModel(size, device=device, compute_type=compute)
def transcribe(self, pcm16k: np.ndarray, language: str | None) -> str:
if pcm16k.size < ASR_SAMPLE_RATE_HZ // 4: # <0.25s: not enough yet
return ""
segments, _ = self.model.transcribe(
pcm16k,
language=language or None,
beam_size=int(os.getenv("BB_ASR_BEAMS", "1")),
vad_filter=False,
without_timestamps=True,
condition_on_previous_text=False,
)
return "".join(seg.text for seg in segments).strip()
class NllbMT:
"""Multilingual (one model, ~2.4 GB) — best on the GPU. BB_MT_BACKEND=nllb."""
def __init__(self):
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
name = os.getenv("BB_MT_MODEL", "facebook/nllb-200-distilled-600M")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tok = AutoTokenizer.from_pretrained(name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(name).to(self.device)
self._eng_id = self.tok.convert_tokens_to_ids("eng_Latn")
def translate(self, text: str, src_lang: str | None) -> str:
import torch
if not text.strip():
return ""
self.tok.src_lang = NLLB_CODE.get((src_lang or "").lower(), "fra_Latn")
inputs = self.tok(text, return_tensors="pt", truncation=True, max_length=256).to(self.device)
with torch.inference_mode():
gen = self.model.generate(
**inputs,
forced_bos_token_id=self._eng_id,
num_beams=int(os.getenv("BB_MT_BEAMS", "1")),
max_new_tokens=200,
)
return self.tok.batch_decode(gen, skip_special_tokens=True)[0].strip()
class _IdentityMT:
"""Passthrough — returns source text unchanged (for English round-trip tests)."""
def translate(self, text: str, src_lang: str | None = None) -> str:
return text.strip()
class MarianMT:
"""Per-pair, light (~300 MB), CPU-friendly. BB_MT_BACKEND=marian.
Loads Helsinki-NLP/opus-mt-<src>-en for the session's source language."""
def __init__(self, src_lang: str | None):
import torch
from transformers import MarianMTModel, MarianTokenizer
self.device = "cuda" if torch.cuda.is_available() else "cpu"
src = (src_lang or "fr").lower()
name = os.getenv("BB_MARIAN_MODEL", f"Helsinki-NLP/opus-mt-{src}-en")
self.tok = MarianTokenizer.from_pretrained(name)
self.model = MarianMTModel.from_pretrained(name).to(self.device)
def translate(self, text: str, src_lang: str | None = None) -> str:
import torch
if not text.strip():
return ""
inputs = self.tok([text], return_tensors="pt", padding=True, truncation=True, max_length=256).to(self.device)
with torch.inference_mode():
gen = self.model.generate(**inputs, num_beams=1, max_new_tokens=128)
return self.tok.batch_decode(gen, skip_special_tokens=True)[0].strip()
# --------------------------------------------------------------------------- #
# cascaded simultaneous translator
# --------------------------------------------------------------------------- #
@dataclass
class _CascadeState:
language: str | None
buf: List[np.ndarray] = field(default_factory=list)
n_samples: int = 0
last_asr_samples: int = 0
prev_src_hyp: str = "" # for ASR local agreement
committed_src: str = "" # stabilized source text
spoken_en: str = "" # English already synthesized/emitted
finished: bool = False
class CascadedStreamingS2S(S2SModel):
"""
Simultaneous policy = wait-k + double local-agreement:
1. read WAIT_K_SEC of audio before the first emit;
2. every STEP_SEC of new audio, re-transcribe the buffer and commit the
source prefix that is STABLE across two consecutive ASR runs;
3. translate the committed source, hold back MARGIN_WORDS of English
(which may still change), and synthesize only the newly-stable tail;
4. on is_final, commit + speak everything remaining.
Tune WAIT_K_SEC / STEP_SEC / MARGIN_WORDS against ../selfscore to push
overshoot<=0 while keeping accuracy>=0.65.
"""
def __init__(self):
self._asr: WhisperASR | None = None
self._tts: TTSBackend | None = None
self._mt_cache: dict = {}
self._mt_backend = os.getenv("BB_MT_BACKEND", "nllb").lower()
self.wait_k_sec = float(os.getenv("BB_WAIT_K_SEC", "1.0"))
self.step_sec = float(os.getenv("BB_STEP_SEC", "0.48"))
self.margin_words = int(os.getenv("BB_MARGIN_WORDS", "2"))
# Only synthesize once >= this many new words are ready (fewer, more fluent
# chunks -> better Whisper intelligibility). 0 = emit every step.
self.min_emit_words = int(os.getenv("BB_EMIT_MIN_WORDS", "0"))
def _ensure_loaded(self):
if self._asr is None:
self._asr = WhisperASR()
self._tts = _load_tts()
def _get_mt(self, src_lang: str | None):
# identity: passthrough (English round-trip test); marian: one model per
# source language; nllb: shared multilingual model.
if self._mt_backend == "identity":
key = "identity"
elif self._mt_backend == "marian":
key = f"marian:{(src_lang or 'fr').lower()}"
else:
key = "nllb"
if key not in self._mt_cache:
if self._mt_backend == "identity":
self._mt_cache[key] = _IdentityMT()
elif self._mt_backend == "marian":
self._mt_cache[key] = MarianMT(src_lang)
else:
self._mt_cache[key] = NllbMT()
return self._mt_cache[key]
def start_session(self, *, language, sample_rate_hz, channels):
self._ensure_loaded()
return _CascadeState(language=language)
def push(self, state: _CascadeState, pcm: np.ndarray, is_final: bool):
if state.finished:
return np.zeros(0, np.float32), True
if pcm.size:
state.buf.append(pcm.astype(np.float32, copy=False))
state.n_samples += pcm.size
wait_k = int(self.wait_k_sec * TARGET_SAMPLE_RATE_HZ)
step = int(self.step_sec * TARGET_SAMPLE_RATE_HZ)
ready = state.n_samples >= wait_k
due = (state.n_samples - state.last_asr_samples) >= step
if not is_final and (not ready or not due):
return np.zeros(0, np.float32), False
state.last_asr_samples = state.n_samples
audio24 = np.concatenate(state.buf) if state.buf else np.zeros(0, np.float32)
audio16 = resample_mono(audio24, TARGET_SAMPLE_RATE_HZ, ASR_SAMPLE_RATE_HZ)
hyp = self._asr.transcribe(audio16, state.language)
# ASR local agreement -> stable source prefix
if is_final:
stable_src = hyp or state.committed_src
else:
stable_words = _common_prefix_words(state.prev_src_hyp.split(), hyp.split())
stable_src = " ".join(stable_words)
state.prev_src_hyp = hyp
if not stable_src or (stable_src == state.committed_src and not is_final):
return np.zeros(0, np.float32), (True if is_final else False)
state.committed_src = stable_src
en_full = self._get_mt(state.language).translate(state.committed_src, state.language)
en_words = en_full.split()
target_words = en_words if is_final else en_words[: max(0, len(en_words) - self.margin_words)]
spoken_words = state.spoken_en.split()
# Only extend if what we've already spoken is a true prefix of the new
# translation (re-translation can revise the tail; we never un-speak).
if spoken_words and _common_prefix_words(spoken_words, target_words) != spoken_words:
new_words = target_words[len(spoken_words):] # accept minor drift
else:
new_words = target_words[len(spoken_words):]
if not new_words or (len(new_words) < self.min_emit_words and not is_final):
if is_final:
state.finished = True
return np.zeros(1, np.float32), True
return np.zeros(0, np.float32), False
out = self._tts.synth(" ".join(new_words))
state.spoken_en = " ".join(target_words)
if is_final:
state.finished = True
return out, is_final
def load_model(name: str = "echo") -> S2SModel:
name = (name or "echo").lower()
if name == "echo":
return EchoDelayModel()
if name == "cascade":
return CascadedStreamingS2S()
if name == "streamspeech":
from streamspeech_model import load_streamspeech # lazy: GPU-box only
return load_streamspeech()
if name == "gist":
from gist_model import load_gist # lazy: GPU-box only (whisper-large+Qwen-7B+Kokoro)
return load_gist()
raise ValueError(f"unknown model '{name}'")