Voice-AI-Agent / app.py
Toadoum's picture
Update app.py
52880b9 verified
"""
PlotWeaver Voice Agent — HuggingFace Space
============================================
Gradio app demonstrating a Hausa-first conversational AI for
African banks, telecoms, and delivery services.
Pipeline: ASR (Whisper-small) → NLU (rule-based) → Dialogue FSM →
TTS (facebook/mms-tts-hau).
Runs on CPU. First turn triggers model download (~500MB), subsequent turns
are ~2-4s end-to-end.
"""
from __future__ import annotations
# ---------------------------------------------------------------------------
# Monkey-patch for a known gradio_client bug on Python 3.13 + gradio 4.44.1:
# gradio_client/utils.py:get_type() does `"const" in schema` where schema is
# sometimes a bool (False), triggering:
# TypeError: argument of type 'bool' is not iterable
# See: https://github.com/gradio-app/gradio/issues/11722
# We patch the two affected functions to handle bool schemas defensively.
# This MUST run before `import gradio`.
# ---------------------------------------------------------------------------
def _patch_gradio_client_schema_bug():
try:
from gradio_client import utils as _gcu
_orig_get_type = _gcu.get_type
_orig_json_to_py = _gcu._json_schema_to_python_type
def _safe_get_type(schema):
if isinstance(schema, bool):
return "Any"
return _orig_get_type(schema)
def _safe_json_to_py(schema, defs=None):
if isinstance(schema, bool):
return "Any"
return _orig_json_to_py(schema, defs)
_gcu.get_type = _safe_get_type
_gcu._json_schema_to_python_type = _safe_json_to_py
except Exception as _e:
# If the patch fails, we fall back to show_api=False in launch()
print(f"[plotweaver] gradio_client patch failed: {_e}")
_patch_gradio_client_schema_bug()
import time
import uuid
import html as html_lib
from typing import Optional
import gradio as gr
import numpy as np
import torch
from transformers import (
VitsModel, AutoTokenizer,
WhisperProcessor, WhisperForConditionalGeneration,
)
from dialogue import (
DialogueState, SCENARIOS,
get_prompt, get_expected_slot, transition,
)
from nlu import parse as nlu_parse
# ---------------------------------------------------------------------------
# Model loading (lazy, cached)
# ---------------------------------------------------------------------------
_asr_model = None
_asr_processor = None
_tts_model = None
_tts_tokenizer = None
def load_asr():
global _asr_model, _asr_processor
if _asr_model is None:
print("Loading Whisper-small…")
_asr_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
_asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
_asr_model.eval()
print("Whisper-small ready.")
return _asr_model, _asr_processor
def load_tts():
global _tts_model, _tts_tokenizer
if _tts_model is None:
print("Loading MMS-TTS Hausa…")
_tts_model = VitsModel.from_pretrained("facebook/mms-tts-hau")
_tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-hau")
_tts_model.eval()
print("MMS-TTS Hausa ready.")
return _tts_model, _tts_tokenizer
def transcribe_hausa(audio_tuple) -> str:
"""audio_tuple is (sample_rate, np.ndarray) from Gradio."""
if audio_tuple is None:
return ""
sample_rate, audio_array = audio_tuple
if audio_array is None or len(audio_array) == 0:
return ""
# Convert to float32 mono
if audio_array.dtype != np.float32:
audio_array = audio_array.astype(np.float32) / np.iinfo(audio_array.dtype).max
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
# Cap at 30s — Whisper-small is trained on 30s chunks; longer audio
# would need windowing which slows the demo
max_samples = sample_rate * 30
if len(audio_array) > max_samples:
audio_array = audio_array[:max_samples]
# Resample to 16 kHz
if sample_rate != 16000:
import scipy.signal
num_samples = int(len(audio_array) * 16000 / sample_rate)
audio_array = scipy.signal.resample(audio_array, num_samples).astype(np.float32)
model, processor = load_asr()
inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
forced_ids = processor.get_decoder_prompt_ids(language="hausa", task="transcribe")
with torch.no_grad():
ids = model.generate(inputs.input_features, forced_decoder_ids=forced_ids, max_new_tokens=128)
text = processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
return text
def synthesize_hausa(text: str) -> Optional[tuple]:
"""Return (sample_rate, np.ndarray) or None."""
if not text.strip():
return None
model, tokenizer = load_tts()
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
out = model(**inputs).waveform
audio = out.squeeze().cpu().numpy().astype(np.float32)
return (model.config.sampling_rate, audio)
# ---------------------------------------------------------------------------
# Core turn handler
# ---------------------------------------------------------------------------
def run_turn(user_text: str, session: dict, trace: list, asr_ms: int = 0) -> tuple:
"""
Executes one turn. Returns (bot_prompt_dict, updated_session, trace, tts_audio).
`session` is a serialized dict stored in gr.State.
"""
state = DialogueState.from_dict(session) if session else None
if state is None:
state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical="bank")
turn_trace = []
if asr_ms:
turn_trace.append({"stage": "asr (whisper-small)", "ms": asr_ms,
"detail": f'→ "{user_text}"'})
t0 = time.time()
expected = get_expected_slot(state.vertical, state.current_state)
intent, entities, nlu_source = nlu_parse(user_text, expected)
nlu_stage_label = {
"rule": "nlu (rule-based)",
"llm": "nlu (qwen2.5-1.5b)",
"rule_fallback": "nlu (rule + llm fallback)",
}.get(nlu_source, "nlu")
turn_trace.append({
"stage": nlu_stage_label,
"ms": max(1, int((time.time() - t0) * 1000)),
"detail": f"intent={intent} entities={entities}",
})
t1 = time.time()
prev_state = state.current_state
state = transition(state, intent, entities)
turn_trace.append({
"stage": "dialogue_manager",
"ms": max(1, int((time.time() - t1) * 1000)),
"detail": f"{prev_state}{state.current_state}",
})
t2 = time.time()
prompt = get_prompt(state.vertical, state.current_state)
turn_trace.append({"stage": "response_gen", "ms": max(1, int((time.time() - t2) * 1000))})
t3 = time.time()
audio = synthesize_hausa(prompt["ha"])
turn_trace.append({"stage": "tts (mms-tts-hau)", "ms": int((time.time() - t3) * 1000)})
state.history.append({"role": "user", "text": user_text})
state.history.append({"role": "bot", "text_ha": prompt["ha"], "text_en": prompt["en"]})
return prompt, state.to_dict(), turn_trace, audio
# ---------------------------------------------------------------------------
# WhatsApp-style HTML renderer
# ---------------------------------------------------------------------------
def render_whatsapp(session: dict, pending_user: Optional[str] = None,
pending_is_voice: bool = False) -> str:
vertical = session.get("vertical", "bank") if session else "bank"
name = SCENARIOS[vertical]["name"]
avatar = {"bank": "PB", "telecom": "PT", "ecommerce": "PD"}[vertical]
escalated = session.get("escalate_to_human", False) if session else False
bubbles = []
history = session.get("history", []) if session else []
for msg in history:
if msg["role"] == "user":
is_voice = msg.get("is_voice", False)
bubbles.append(_user_bubble(msg["text"], is_voice))
else:
bubbles.append(_bot_bubble(msg.get("text_ha", ""), msg.get("text_en", "")))
if pending_user:
bubbles.append(_user_bubble(pending_user, pending_is_voice))
banner = ('<div class="pw-esc-banner">Session escalated to human agent</div>'
if escalated else "")
return f"""
<div class="pw-phone">
<div class="pw-ph-header">
<div class="pw-ph-avatar">{avatar}</div>
<div>
<div class="pw-ph-name">{html_lib.escape(name)}</div>
<div class="pw-ph-status">online • voice agent</div>
</div>
</div>
<div class="pw-ph-messages">
{banner}
{"".join(bubbles) if bubbles else '<div style="text-align:center; color:#667781; font-size:12px; padding:40px 0;">Waiting for first message…</div>'}
</div>
</div>
<style>
.pw-phone {{ max-width: 440px; margin: 0 auto; background: #ECE5DD; border-radius: 14px; overflow: hidden; border: 1px solid #ccc; display: flex; flex-direction: column; min-height: 520px; font-family: -apple-system, "Segoe UI", Roboto, sans-serif; }}
.pw-ph-header {{ background: #075E54; color: #fff; padding: 10px 14px; display: flex; align-items: center; gap: 10px; }}
.pw-ph-avatar {{ width: 36px; height: 36px; border-radius: 50%; background: #128C7E; display: flex; align-items: center; justify-content: center; font-weight: 500; font-size: 13px; color: #fff; }}
.pw-ph-name {{ font-size: 14px; font-weight: 500; line-height: 1.2; }}
.pw-ph-status {{ font-size: 11px; color: #D4EDE8; }}
.pw-ph-messages {{ flex: 1; padding: 14px 10px; background: #ECE5DD; background-image: radial-gradient(#D8CFC2 1px, transparent 1px); background-size: 18px 18px; max-height: 460px; overflow-y: auto; min-height: 400px; }}
.pw-b {{ max-width: 80%; padding: 7px 10px 5px; border-radius: 8px; margin-bottom: 6px; font-size: 13.5px; line-height: 1.4; color: #1f2d1f; word-wrap: break-word; }}
.pw-b.user {{ background: #DCF8C6; margin-left: auto; border-bottom-right-radius: 2px; }}
.pw-b.bot {{ background: #fff; margin-right: auto; border-bottom-left-radius: 2px; }}
.pw-b-meta {{ font-size: 10px; color: #667781; margin-top: 3px; text-align: right; }}
.pw-b-trans {{ font-size: 11px; color: #667781; font-style: italic; margin-top: 3px; border-top: 1px solid #E5E5E5; padding-top: 3px; }}
.pw-voice-row {{ display: flex; align-items: center; gap: 8px; }}
.pw-voice-icon {{ width: 22px; height: 22px; border-radius: 50%; background: #128C7E; color: #fff; font-size: 10px; display: flex; align-items: center; justify-content: center; }}
.pw-voice-bars {{ flex: 1; height: 14px; display: flex; align-items: center; gap: 2px; }}
.pw-voice-bars span {{ flex: 1; background: #8D9A9F; border-radius: 1px; }}
.pw-esc-banner {{ background: #FAEEDA; color: #854F0B; font-size: 12px; padding: 8px 12px; border-radius: 8px; margin-bottom: 10px; border: 1px solid #EF9F27; text-align: center; }}
</style>
"""
def _now() -> str:
return time.strftime("%H:%M")
def _user_bubble(text: str, is_voice: bool) -> str:
text_safe = html_lib.escape(text)
if is_voice:
bars = "".join(
f'<span style="height:{4 + int(8 * abs(np.sin(i * 0.7)))}px;"></span>'
for i in range(20)
)
return f'''<div class="pw-b user">
<div class="pw-voice-row">
<div class="pw-voice-icon">▶</div>
<div class="pw-voice-bars">{bars}</div>
</div>
<div style="font-size:12px; color:#667781; margin-top:3px;">"{text_safe}"</div>
<div class="pw-b-meta">{_now()} ✓✓</div>
</div>'''
return f'<div class="pw-b user">{text_safe}<div class="pw-b-meta">{_now()} ✓✓</div></div>'
def _bot_bubble(text_ha: str, text_en: str) -> str:
ha_safe = html_lib.escape(text_ha)
en_safe = html_lib.escape(text_en)
return f'''<div class="pw-b bot">
<div>{ha_safe}</div>
<div class="pw-b-trans">{en_safe}</div>
<div class="pw-b-meta">{_now()} ✓✓</div>
</div>'''
def render_trace(trace: list) -> str:
if not trace:
return '<div style="color:#888; font-size:13px;">Send a message to see the pipeline trace.</div>'
rows = []
for r in trace:
row = f'<div style="display:flex; justify-content:space-between; padding:5px 0; border-bottom:1px solid #eee;"><span style="color:#5f5e5a;">{html_lib.escape(r["stage"])}</span><span style="color:#0C447C; font-weight:500;">{r["ms"]}ms</span></div>'
rows.append(row)
if r.get("detail"):
rows.append(f'<div style="font-size:11px; color:#888; padding:0 0 5px; font-family:monospace;">{html_lib.escape(str(r["detail"]))}</div>')
return f'<div style="font-family:monospace; font-size:12px;">{"".join(rows)}</div>'
def render_metrics(session: dict) -> str:
if not session:
return ""
sid = session.get("session_id", "—")
turn = session.get("turn_count", 0)
state = session.get("current_state", "greeting")
slots = session.get("slots", {})
slots_html = ", ".join(f"<code>{k}={v}</code>" for k, v in slots.items()) or "—"
return f'''
<div style="display:grid; grid-template-columns:1fr 1fr; gap:8px; font-size:13px;">
<div><div style="color:#888; font-size:11px; text-transform:uppercase;">Session</div><div style="font-family:monospace;">{sid}</div></div>
<div><div style="color:#888; font-size:11px; text-transform:uppercase;">Turn</div><div style="font-weight:500;">{turn}</div></div>
<div><div style="color:#888; font-size:11px; text-transform:uppercase;">State</div><div style="font-family:monospace;">{state}</div></div>
<div><div style="color:#888; font-size:11px; text-transform:uppercase;">Slots</div><div>{slots_html}</div></div>
</div>'''
# ---------------------------------------------------------------------------
# Gradio event handlers
# ---------------------------------------------------------------------------
def on_vertical_change(vertical: str, synth_greeting: bool = False):
"""Reset session when vertical changes. TTS the greeting only on first real
user interaction — keeps initial page load fast (avoids MMS-TTS cold-start)."""
state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical=vertical)
greet = get_prompt(vertical, "greeting")
state.history.append({"role": "bot", "text_ha": greet["ha"], "text_en": greet["en"]})
session = state.to_dict()
audio = None
if synth_greeting:
try:
audio = synthesize_hausa(greet["ha"])
except Exception as e:
print(f"TTS failed on greeting: {e}")
return (
session,
render_whatsapp(session),
render_trace([]),
render_metrics(session),
audio,
)
def on_text_submit(text: str, session: dict):
if not text or not text.strip():
return session, render_whatsapp(session), render_trace([]), render_metrics(session), None, ""
prompt, new_session, trace, audio = run_turn(text, session, [], asr_ms=0)
return (
new_session,
render_whatsapp(new_session),
render_trace(trace),
render_metrics(new_session),
audio,
"", # clear input
)
def on_audio_submit(audio_data, session: dict):
if audio_data is None:
return session, render_whatsapp(session), render_trace([]), render_metrics(session), None
t0 = time.time()
try:
text = transcribe_hausa(audio_data)
except Exception as e:
print(f"ASR failed: {e}")
return session, render_whatsapp(session), render_trace([{"stage": "asr error", "ms": 0, "detail": str(e)}]), render_metrics(session), None
asr_ms = int((time.time() - t0) * 1000)
if not text:
return session, render_whatsapp(session), render_trace([{"stage": "asr", "ms": asr_ms, "detail": "(no speech detected)"}]), render_metrics(session), None
# Mark last user message as voice after appending
prompt, new_session, trace, audio = run_turn(text, session, [], asr_ms=asr_ms)
# Tag the last user entry as voice
if new_session.get("history"):
for i in range(len(new_session["history"]) - 1, -1, -1):
if new_session["history"][i]["role"] == "user":
new_session["history"][i]["is_voice"] = True
break
return (
new_session,
render_whatsapp(new_session),
render_trace(trace),
render_metrics(new_session),
audio,
)
def on_reset(session: dict):
vertical = session.get("vertical", "bank") if session else "bank"
return on_vertical_change(vertical)
def on_escalate(session: dict):
return on_text_submit("Ina son wakili mutum", session)
# ---------------------------------------------------------------------------
# Preset phrases for quick-click demo
# ---------------------------------------------------------------------------
PRESETS = {
"bank": ["duba ma'auni", "toshe kati", "canjin kuɗi", "1234", "Aisha", "dubu biyar", "i"],
"telecom": ["saya airtime", "saya bundle", "korafi", "1000", "rana", "Intanet bai aiki"],
"ecommerce": ["bincika oda", "sake tsara", "mayar da kaya", "10234", "jumma'a", "Ya lalace"],
}
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
CUSTOM_CSS = """
.gradio-container { max-width: 1200px !important; }
#vertical-selector { background: #fff; border-radius: 10px; padding: 12px; }
#whatsapp-html { background: #f5f4ef; border-radius: 12px; padding: 20px; }
#trace-box, #metrics-box { background: #fff; border-radius: 10px; padding: 12px; border: 1px solid #e5e5e5; }
h1 { font-size: 22px !important; font-weight: 500 !important; }
.header-sub { color: #5f5e5a; font-size: 14px; margin-top: -8px; margin-bottom: 16px; }
"""
with gr.Blocks(css=CUSTOM_CSS, title="PlotWeaver Voice Agent") as demo:
gr.HTML("""
<h1 style="margin-bottom:4px;">PlotWeaver Voice Agent</h1>
<p class="header-sub">Hausa-first conversational AI for African banks, telecoms, and delivery services. Real Whisper-small ASR and MMS-TTS Hausa running on CPU.</p>
""")
session_state = gr.State({})
with gr.Row():
# Left column: controls + trace
with gr.Column(scale=1):
gr.Markdown("### Select vertical")
vertical_radio = gr.Radio(
choices=[("PlotWeaver Bank", "bank"),
("PlotWeaver Telecom", "telecom"),
("PlotWeaver Delivery", "ecommerce")],
value="bank",
label="",
elem_id="vertical-selector",
)
with gr.Row():
reset_btn = gr.Button("Reset session", size="sm")
escalate_btn = gr.Button("Force escalate", size="sm")
gr.Markdown("### Session metrics")
metrics_html = gr.HTML(elem_id="metrics-box")
gr.Markdown("### Pipeline trace (last turn)")
trace_html = gr.HTML(elem_id="trace-box")
# Middle column: WhatsApp mockup
with gr.Column(scale=2):
whatsapp_html = gr.HTML(elem_id="whatsapp-html")
with gr.Row():
text_input = gr.Textbox(
placeholder="Type in Hausa… e.g. 'duba ma'auni'",
label="",
scale=4,
container=False,
)
send_btn = gr.Button("Send", scale=1, variant="primary")
gr.Markdown("**Or speak / upload audio in Hausa:**")
audio_input = gr.Audio(
sources=["microphone", "upload"],
type="numpy",
label="Record or upload a Hausa audio file (.wav, .mp3, .ogg)",
show_download_button=False,
)
with gr.Row():
transcribe_btn = gr.Button("Transcribe & send", variant="secondary", size="sm")
clear_audio_btn = gr.Button("Clear", size="sm")
bot_audio = gr.Audio(
label="Bot response (Hausa TTS)",
autoplay=True,
interactive=False,
)
# Preset quick-clicks
gr.Markdown("### Quick phrases (Hausa)")
preset_btns = []
with gr.Row():
for p in PRESETS["bank"]:
preset_btns.append(gr.Button(p, size="sm"))
# -----------------------------------------------------------------------
# Event wiring
# -----------------------------------------------------------------------
outputs = [session_state, whatsapp_html, trace_html, metrics_html, bot_audio]
demo.load(
fn=lambda: on_vertical_change("bank"),
outputs=outputs,
)
vertical_radio.change(
fn=on_vertical_change,
inputs=[vertical_radio],
outputs=outputs,
)
send_btn.click(
fn=on_text_submit,
inputs=[text_input, session_state],
outputs=outputs + [text_input],
)
text_input.submit(
fn=on_text_submit,
inputs=[text_input, session_state],
outputs=outputs + [text_input],
)
audio_input.stop_recording(
fn=on_audio_submit,
inputs=[audio_input, session_state],
outputs=outputs,
)
transcribe_btn.click(
fn=on_audio_submit,
inputs=[audio_input, session_state],
outputs=outputs,
)
clear_audio_btn.click(
fn=lambda: None,
outputs=[audio_input],
)
reset_btn.click(fn=on_reset, inputs=[session_state], outputs=outputs)
escalate_btn.click(
fn=on_escalate,
inputs=[session_state],
outputs=outputs + [text_input],
)
# Preset buttons submit their own text
for btn, phrase in zip(preset_btns, PRESETS["bank"]):
btn.click(
fn=lambda s, _phrase=phrase: on_text_submit(_phrase, s),
inputs=[session_state],
outputs=outputs + [text_input],
)
if __name__ == "__main__":
# show_api=False avoids a known gradio_client JSON-schema bug on
# certain Gradio/Python 3.13 combinations where get_api_info() crashes
# with "TypeError: argument of type 'bool' is not iterable".
# We don't need the /?view=api endpoint for this demo anyway.
demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)