ground-zero / app_minimal.py
jefffffff9
Stage 4: split translate/reply UI + CPU-safe TTS + reply-not-translate prompt
9e99c2c
Raw
History Blame Contribute Delete
24.2 kB
"""Minimal baseline Gradio entry point for the Month 1-3 rebuild.
Wires the simplest possible slice: Whisper (zero-shot) -> Aya-Expanse -> MMS-TTS.
No LoRA adapters, no memory loop, no speaker ID, no voice cloning, no IoT,
no phrase matcher. Used for field testing and building a real-user eval set.
See docs/baseline_rebuild.md for the plan this fits into.
Run locally:
HF_TOKEN=hf_xxx python app_minimal.py
Environment variables (all optional except HF_TOKEN, which is needed for the
HF Serverless LLM call):
HF_TOKEN — HuggingFace token with read access
LLM_MODEL_ID — default "CohereLabs/aya-expanse-32b"
(23-language multilingual, strong African-language coverage)
DEVICE — "cuda" or "cpu" (auto if unset)
LOG_LEVEL — default "INFO"
"""
from __future__ import annotations
import logging
import os
from typing import Optional, Tuple
import numpy as np
# Load .env (HF_TOKEN etc.) before reading os.environ below. Silent no-op if
# python-dotenv is not installed or no .env is present.
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
# Local imports — the four modules the baseline-rebuild plan authorizes.
# Everything else in src/ is intentionally unused here.
from src.data.bam_normalize import normalize as bam_normalize
from src.engine.turn_logger import TurnLogger
from src.engine.whisper_base import WhisperBackbone
from src.llm.minimal_client import MinimalClient
from src.llm.phrasebook import lookup as phrasebook_lookup, top_k as phrasebook_top_k
from src.tts.mms_tts import MMSTTSEngine
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO"),
format="%(asctime)s %(name)-30s %(levelname)-7s %(message)s",
)
logger = logging.getLogger(__name__)
# ── Environment ──────────────────────────────────────────────────────────────
HF_TOKEN = os.environ.get("HF_TOKEN")
LLM_MODEL_ID = os.environ.get("LLM_MODEL_ID", "CohereLabs/aya-expanse-32b")
_REQUESTED_DEVICE = os.environ.get("DEVICE") # optional override
LANG_CHOICES = [("Bambara", "bam"), ("Fula", "ful"), ("French", "fr"), ("English", "en")]
LANG_NAMES = {"bam": "Bambara", "ful": "Fula", "fr": "French", "en": "English"}
LANG_TO_WHISPER_HINT = {
# Whisper large-v3-turbo does not know Bambara/Fula as first-class
# languages. We leave `language` unset for those so Whisper auto-detects;
# fr/en are explicit hints for clean decoding.
"bam": None,
"ful": None,
"fr": "french",
"en": "english",
}
# Reply-language steering is handled inside MinimalClient via a dialect-anchored
# system prompt (see src/llm/minimal_client.py). No per-turn directive needed.
# ── Service singletons (lazy-loaded) ────────────────────────────────────────
_backbone: Optional[WhisperBackbone] = None
_llm: Optional[MinimalClient] = None
_tts: Optional[MMSTTSEngine] = None
_turn_logger: TurnLogger = TurnLogger()
def _resolve_device() -> str:
"""Pick 'cuda' if torch sees a GPU, else 'cpu'. DEVICE env overrides.
Some torch builds (CPU-only wheels) report `cuda.is_available() == True`
in error states; we additionally probe device_count and fall back to cpu
on any exception to keep the app usable on CPU-only laptops.
"""
import torch # lazy
if _REQUESTED_DEVICE:
return _REQUESTED_DEVICE
try:
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
return "cuda"
except Exception:
pass
return "cpu"
def get_backbone() -> WhisperBackbone:
"""Load the Whisper backbone once and cache. Zero-shot — no adapters."""
global _backbone
if _backbone is None:
_backbone = WhisperBackbone(config_path="configs/base_config.yaml")
_backbone.load(device=_resolve_device(), hf_token=HF_TOKEN)
logger.info("Whisper backbone ready: %s on %s",
_backbone.model_id, _backbone.device)
return _backbone
def get_llm() -> MinimalClient:
global _llm
if _llm is None:
_llm = MinimalClient(model_id=LLM_MODEL_ID, hf_token=HF_TOKEN)
logger.info("Minimal LLM client configured: %s", LLM_MODEL_ID)
return _llm
def get_tts() -> MMSTTSEngine:
global _tts
if _tts is None:
_tts = MMSTTSEngine()
logger.info("MMS-TTS engine ready (lazy per-language load)")
return _tts
# ── Core pipeline ────────────────────────────────────────────────────────────
def transcribe(audio_np: np.ndarray, sample_rate: int, input_lang: str) -> str:
"""Run zero-shot Whisper on a numpy audio array. Returns the raw transcript.
`input_lang` drives two things only: the Whisper language hint (for fr/en)
and whether bam_normalize is applied. It has no effect on the TTS voice or
on the LLM reply language — those are driven by the separate output-language
dropdown in the UI.
"""
import torch # lazy
import librosa # lazy — resample if the mic gave us something non-16k
backbone = get_backbone()
target_sr = 16_000
# Ensure mono float32
if audio_np.ndim == 2:
audio_np = audio_np.mean(axis=1)
audio_np = audio_np.astype(np.float32)
# Gradio's gr.Audio often returns int16-scaled floats or ints — normalize.
peak = np.max(np.abs(audio_np)) if audio_np.size else 0.0
if peak > 1.5: # looks like raw int16 cast to float
audio_np = audio_np / 32768.0
if sample_rate != target_sr:
audio_np = librosa.resample(audio_np, orig_sr=sample_rate, target_sr=target_sr)
inputs = backbone.processor(
audio_np, sampling_rate=target_sr, return_tensors="pt"
)
input_features = inputs.input_features.to(backbone.device)
if backbone.device == "cuda":
input_features = input_features.half()
gen_kwargs: dict = {"max_new_tokens": 128}
hint = LANG_TO_WHISPER_HINT.get(input_lang)
if hint:
gen_kwargs["language"] = hint
gen_kwargs["task"] = "transcribe"
with torch.no_grad():
output_ids = backbone.model.generate(input_features, **gen_kwargs)
transcript = backbone.processor.batch_decode(
output_ids, skip_special_tokens=True
)[0].strip()
if input_lang == "bam" and transcript:
transcript = bam_normalize(transcript)
return transcript
NO_TRANSLATION = "(no curated translation — try Generate reply)"
def _synthesize(text: str, output_lang: str
) -> Tuple[Optional[Tuple[int, np.ndarray]], Optional[int], Optional[str]]:
"""Run TTS on `text` in `output_lang`. Returns (audio_or_None, tts_ms, error)."""
import time
if not text:
return None, None, None
t = time.perf_counter()
device = _resolve_device()
try:
wav, sr = get_tts().synthesize(text, language=output_lang, device=device)
return (sr, wav), int((time.perf_counter() - t) * 1000), None
except AssertionError as exc:
# Most common: "Torch not compiled with CUDA enabled" on CPU-only boxes
# where is_available() lied. Retry once on CPU.
if device != "cpu":
logger.warning("TTS failed on %s (%s) — retrying on cpu", device, exc)
try:
wav, sr = get_tts().synthesize(text, language=output_lang, device="cpu")
return (sr, wav), int((time.perf_counter() - t) * 1000), None
except Exception as exc2: # pragma: no cover
logger.exception("TTS failed on cpu fallback")
return None, None, f"tts: {exc2}"
logger.exception("TTS failed")
return None, None, f"tts: {exc}"
except Exception as exc: # pragma: no cover
logger.exception("TTS failed")
return None, None, f"tts: {exc}"
def _translate_only(user_text: str, output_lang: str
) -> Tuple[str, Optional[Tuple[int, np.ndarray]], Optional[dict], Optional[int]]:
"""Phrasebook-only translation — never calls the LLM.
Returns (translation_text, translation_audio, hit_or_None, tts_ms).
On miss for bam/ful, returns NO_TRANSLATION and no audio.
For en/fr targets (no curated phrasebook), echoes the input as the
translation since the user likely wants to hear it spoken — TTS in that
language is still the right thing to play.
"""
text = (user_text or "").strip()
if not text:
return "", None, None, None
hit = phrasebook_lookup(text, output_lang)
if hit:
logger.info(
"Phrasebook hit (%s, score=%.2f): %r → %r [cat=%s]",
hit["match"], hit["score"], text, hit["target"], hit["category"],
)
target = hit["target"] or ""
audio, tts_ms, _ = _synthesize(target, output_lang)
return target, audio, hit, tts_ms
# No curated translation. For en/fr we still synthesize the input itself
# (the user can use the app as a TTS box). For bam/ful we surface the
# honest "no curated translation" sentinel — the user can then click
# "Generate reply" if they want the LLM to handle it.
if output_lang in ("en", "fr"):
audio, tts_ms, _ = _synthesize(text, output_lang)
return text, audio, None, tts_ms
return NO_TRANSLATION, None, None, None
def _generate_reply(user_text: str, output_lang: str
) -> Tuple[str, Optional[Tuple[int, np.ndarray]], Optional[int], Optional[int], Optional[str]]:
"""Dialect-anchored LLM reply (with RAG top-3 few-shot) + TTS.
Returns (reply_text, reply_audio, llm_ms, tts_ms, error).
Always returns a usable text string — even on LLM failure it returns a
short parenthetical so the UI never goes blank.
"""
import time
text = (user_text or "").strip()
if not text:
return "(nothing to reply to)", None, None, None, None
extras = phrasebook_top_k(text, output_lang, k=3) or None
if extras:
logger.info(
"RAG-injecting top-%d nearest phrasebook entries (top score=%.2f)",
len(extras), extras[0]["score"],
)
t_llm = time.perf_counter()
try:
reply = get_llm().chat(
text, target_lang=output_lang, extra_examples=extras,
)
except Exception as exc: # pragma: no cover
logger.exception("LLM call failed")
llm_ms = int((time.perf_counter() - t_llm) * 1000)
return f"(LLM error: {exc})", None, llm_ms, None, f"llm: {exc}"
llm_ms = int((time.perf_counter() - t_llm) * 1000)
reply = (reply or "").strip() or "(empty reply)"
audio, tts_ms, tts_error = _synthesize(reply, output_lang)
return reply, audio, llm_ms, tts_ms, tts_error
# ── Tab handlers ─────────────────────────────────────────────────────────────
def run_text_translate(
text: str,
output_lang: str,
) -> Tuple[str, Optional[Tuple[int, np.ndarray]], str]:
"""Text tab → Send: phrasebook-only translation. Always-on, no LLM.
Returns (translation_text, translation_audio, transcript_state).
`transcript_state` is the canonicalised input passed to the Generate-reply
button so it doesn't need to re-read the textbox.
"""
import time
t0 = time.perf_counter()
text = (text or "").strip()
if not text:
return "(no text entered)", None, ""
translation, audio, hit, tts_ms = _translate_only(text, output_lang)
_turn_logger.log(
phase="translate", tab="text",
input_lang=None, output_lang=output_lang,
user_text=text, transcript=None, transcribe_ms=None,
phrasebook=hit, llm_model=None, llm_ms=None,
reply_text=translation, tts_ms=tts_ms,
total_ms=int((time.perf_counter() - t0) * 1000),
error=None,
)
return translation, audio, text
def run_text_reply(
transcript_state: str,
output_lang: str,
) -> Tuple[str, Optional[Tuple[int, np.ndarray]]]:
"""Text tab → Generate reply: dialect-anchored LLM + TTS."""
import time
t0 = time.perf_counter()
if not (transcript_state or "").strip():
return "(send a message first)", None
reply, audio, llm_ms, tts_ms, error = _generate_reply(
transcript_state, output_lang
)
_turn_logger.log(
phase="reply", tab="text",
input_lang=None, output_lang=output_lang,
user_text=transcript_state, transcript=None, transcribe_ms=None,
phrasebook=None, llm_model=LLM_MODEL_ID, llm_ms=llm_ms,
reply_text=reply, tts_ms=tts_ms,
total_ms=int((time.perf_counter() - t0) * 1000),
error=error,
)
return reply, audio
def run_voice_translate(
audio: Optional[Tuple[int, np.ndarray]],
input_lang: str,
output_lang: str,
) -> Tuple[str, str, Optional[Tuple[int, np.ndarray]], str]:
"""Voice tab → Submit: Whisper transcribe + phrasebook-only translation.
Returns (transcript, translation_text, translation_audio, transcript_state).
"""
import time
t0 = time.perf_counter()
if audio is None:
return "", "(no audio received)", None, ""
sample_rate, audio_np = audio
if audio_np.size == 0:
return "", "(empty audio)", None, ""
t_stt = time.perf_counter()
try:
transcript = transcribe(audio_np, sample_rate, input_lang)
except Exception as exc: # pragma: no cover
logger.exception("Transcription failed")
_turn_logger.log(
phase="translate", tab="voice",
input_lang=input_lang, output_lang=output_lang,
user_text=None, transcript=None, transcribe_ms=None,
phrasebook=None, llm_model=None, llm_ms=None,
reply_text=None, tts_ms=None,
total_ms=int((time.perf_counter() - t0) * 1000),
error=f"stt: {exc}",
)
return "", f"(STT error: {exc})", None, ""
transcribe_ms = int((time.perf_counter() - t_stt) * 1000)
if not transcript:
_turn_logger.log(
phase="translate", tab="voice",
input_lang=input_lang, output_lang=output_lang,
user_text=None, transcript="", transcribe_ms=transcribe_ms,
phrasebook=None, llm_model=None, llm_ms=None,
reply_text=None, tts_ms=None,
total_ms=int((time.perf_counter() - t0) * 1000),
error="no_speech",
)
return "", "(no speech detected)", None, ""
translation, t_audio, hit, tts_ms = _translate_only(transcript, output_lang)
_turn_logger.log(
phase="translate", tab="voice",
input_lang=input_lang, output_lang=output_lang,
user_text=transcript, transcript=transcript,
transcribe_ms=transcribe_ms,
phrasebook=hit, llm_model=None, llm_ms=None,
reply_text=translation, tts_ms=tts_ms,
total_ms=int((time.perf_counter() - t0) * 1000),
error=None,
)
return transcript, translation, t_audio, transcript
def run_voice_reply(
transcript_state: str,
output_lang: str,
) -> Tuple[str, Optional[Tuple[int, np.ndarray]]]:
"""Voice tab → Generate reply: uses the stored transcript, no re-Whisper."""
import time
t0 = time.perf_counter()
if not (transcript_state or "").strip():
return "(record audio and submit first)", None
reply, audio, llm_ms, tts_ms, error = _generate_reply(
transcript_state, output_lang
)
_turn_logger.log(
phase="reply", tab="voice",
input_lang=None, output_lang=output_lang,
user_text=transcript_state, transcript=transcript_state,
transcribe_ms=None,
phrasebook=None, llm_model=LLM_MODEL_ID, llm_ms=llm_ms,
reply_text=reply, tts_ms=tts_ms,
total_ms=int((time.perf_counter() - t0) * 1000),
error=error,
)
return reply, audio
# ── Gradio UI ────────────────────────────────────────────────────────────────
def build_ui():
"""Construct and return the Gradio Blocks app."""
import gradio as gr # lazy — keeps module importable without gradio installed
with gr.Blocks(title="Sahel-Voice — Minimal Baseline") as demo:
gr.Markdown(
"# 🌾 Sahel-Voice — Minimal Baseline\n"
f"Zero-shot Whisper → {LLM_MODEL_ID} → MMS-TTS, with a curated "
"Bambara/Pular phrasebook short-circuit in front of the LLM. "
"No adapters, no memory, no polish. This is the field-test "
"baseline — see `docs/baseline_rebuild.md`."
)
# Shared across tabs. Split into two so input and output language
# are never conflated — the Voice tab cares about both; the Text tab
# only uses output_lang (it doesn't feed Whisper).
with gr.Row():
input_lang = gr.Dropdown(
choices=LANG_CHOICES, value="bam", label="Input language",
info="Language you're speaking/typing. Drives Whisper hint "
"(fr/en only) and bam_normalize (bam only).",
)
output_lang = gr.Dropdown(
choices=LANG_CHOICES, value="bam", label="Output language",
info="Language the LLM should reply in. Also picks the TTS voice.",
)
# Carries the canonical input (typed text, or Whisper transcript) from
# Submit/Send into the Generate-reply button so we don't re-transcribe
# or re-read the textbox.
transcript_state = gr.State("")
with gr.Tabs():
# ── Voice tab — the actual baseline the field test measures ─────
with gr.Tab("🎤 Voice (full STT → translation + optional reply)"):
with gr.Row():
with gr.Column():
audio_in = gr.Audio(
sources=["microphone", "upload"],
type="numpy",
label="Speak (or upload a .wav)",
)
voice_submit = gr.Button(
"Transcribe + translate", variant="primary"
)
voice_transcript_out = gr.Textbox(
label="Transcript (zero-shot Whisper)",
lines=2, interactive=False,
)
with gr.Column():
voice_translation_out = gr.Textbox(
label="Phrasebook translation",
lines=3, interactive=False,
)
voice_translation_audio = gr.Audio(
label="Translation audio",
type="numpy", autoplay=False,
)
voice_reply_btn = gr.Button(
"Generate reply (LLM)", variant="secondary"
)
voice_reply_out = gr.Textbox(
label="LLM reply", lines=4, interactive=False,
)
voice_reply_audio = gr.Audio(
label="Reply audio", type="numpy", autoplay=False,
)
voice_submit.click(
fn=run_voice_translate,
inputs=[audio_in, input_lang, output_lang],
outputs=[
voice_transcript_out,
voice_translation_out,
voice_translation_audio,
transcript_state,
],
)
voice_reply_btn.click(
fn=run_voice_reply,
inputs=[transcript_state, output_lang],
outputs=[voice_reply_out, voice_reply_audio],
)
# ── Text tab — dev loop, skips Whisper ──────────────────────────
with gr.Tab("⌨️ Text (translation + optional reply, dev loop)"):
with gr.Row():
with gr.Column():
text_in = gr.Textbox(
label="Type your message",
lines=3,
placeholder="e.g. Good morning, how are you?",
)
text_submit = gr.Button("Send", variant="primary")
with gr.Column():
text_translation_out = gr.Textbox(
label="Phrasebook translation",
lines=3, interactive=False,
)
text_translation_audio = gr.Audio(
label="Translation audio",
type="numpy", autoplay=False,
)
text_reply_btn = gr.Button(
"Generate reply (LLM)", variant="secondary"
)
text_reply_out = gr.Textbox(
label="LLM reply", lines=4, interactive=False,
)
text_reply_audio = gr.Audio(
label="Reply audio", type="numpy", autoplay=False,
)
# Text tab only uses output_lang — input_lang is a no-op here.
text_submit.click(
fn=run_text_translate,
inputs=[text_in, output_lang],
outputs=[
text_translation_out,
text_translation_audio,
transcript_state,
],
)
# Pressing Enter in the textbox also submits.
text_in.submit(
fn=run_text_translate,
inputs=[text_in, output_lang],
outputs=[
text_translation_out,
text_translation_audio,
transcript_state,
],
)
text_reply_btn.click(
fn=run_text_reply,
inputs=[transcript_state, output_lang],
outputs=[text_reply_out, text_reply_audio],
)
gr.Markdown(
"---\n"
"**What's intentionally missing:** LoRA adapters, memory/vocabulary "
"persistence, speaker ID, Waxal/F5 TTS, IoT sensor integration, "
"phrase-matcher shortcuts. All of those live in `app.py` — this is the "
"stripped-down baseline used to measure what Whisper zero-shot does on "
"real Bambara/Fula recordings and to collect a real-user eval set.\n\n"
"The **Text** tab skips Whisper — it's for fast iteration on the "
"LLM + TTS path, not for field-test measurement.\n\n"
"**How the two boxes differ:** the top pair is a phrasebook lookup "
"(no LLM, instant, gold-curated translation). If your input isn't "
"in the curated list you'll see *(no curated translation)* — click "
"**Generate reply** to get a dialect-anchored LLM response in the "
"bottom pair."
)
return demo
def main() -> None:
if not HF_TOKEN:
logger.warning(
"HF_TOKEN is not set — the LLM call will fail. "
"Export HF_TOKEN before launching for the pipeline to work end-to-end."
)
demo = build_ui()
demo.queue().launch()
if __name__ == "__main__":
main()