Spaces:
Sleeping
Sleeping
| """ | |
| PlotWeaver Voice Agent — HuggingFace Space | |
| ============================================ | |
| Gradio app demonstrating a Hausa-first conversational AI for | |
| African banks, telecoms, and delivery services. | |
| Pipeline: ASR (Whisper-small) → NLU (rule-based) → Dialogue FSM → | |
| TTS (facebook/mms-tts-hau). | |
| Runs on CPU. First turn triggers model download (~500MB), subsequent turns | |
| are ~2-4s end-to-end. | |
| """ | |
| from __future__ import annotations | |
| # --------------------------------------------------------------------------- | |
| # Monkey-patch for a known gradio_client bug on Python 3.13 + gradio 4.44.1: | |
| # gradio_client/utils.py:get_type() does `"const" in schema` where schema is | |
| # sometimes a bool (False), triggering: | |
| # TypeError: argument of type 'bool' is not iterable | |
| # See: https://github.com/gradio-app/gradio/issues/11722 | |
| # We patch the two affected functions to handle bool schemas defensively. | |
| # This MUST run before `import gradio`. | |
| # --------------------------------------------------------------------------- | |
| def _patch_gradio_client_schema_bug(): | |
| try: | |
| from gradio_client import utils as _gcu | |
| _orig_get_type = _gcu.get_type | |
| _orig_json_to_py = _gcu._json_schema_to_python_type | |
| def _safe_get_type(schema): | |
| if isinstance(schema, bool): | |
| return "Any" | |
| return _orig_get_type(schema) | |
| def _safe_json_to_py(schema, defs=None): | |
| if isinstance(schema, bool): | |
| return "Any" | |
| return _orig_json_to_py(schema, defs) | |
| _gcu.get_type = _safe_get_type | |
| _gcu._json_schema_to_python_type = _safe_json_to_py | |
| except Exception as _e: | |
| # If the patch fails, we fall back to show_api=False in launch() | |
| print(f"[plotweaver] gradio_client patch failed: {_e}") | |
| _patch_gradio_client_schema_bug() | |
| import time | |
| import uuid | |
| import html as html_lib | |
| from typing import Optional | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from transformers import ( | |
| VitsModel, AutoTokenizer, | |
| WhisperProcessor, WhisperForConditionalGeneration, | |
| ) | |
| from dialogue import ( | |
| DialogueState, SCENARIOS, | |
| get_prompt, get_expected_slot, transition, | |
| ) | |
| from nlu import parse as nlu_parse | |
| # --------------------------------------------------------------------------- | |
| # Model loading (lazy, cached) | |
| # --------------------------------------------------------------------------- | |
| _asr_model = None | |
| _asr_processor = None | |
| _tts_model = None | |
| _tts_tokenizer = None | |
| def load_asr(): | |
| global _asr_model, _asr_processor | |
| if _asr_model is None: | |
| print("Loading Whisper-small…") | |
| _asr_processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
| _asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") | |
| _asr_model.eval() | |
| print("Whisper-small ready.") | |
| return _asr_model, _asr_processor | |
| def load_tts(): | |
| global _tts_model, _tts_tokenizer | |
| if _tts_model is None: | |
| print("Loading MMS-TTS Hausa…") | |
| _tts_model = VitsModel.from_pretrained("facebook/mms-tts-hau") | |
| _tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-hau") | |
| _tts_model.eval() | |
| print("MMS-TTS Hausa ready.") | |
| return _tts_model, _tts_tokenizer | |
| def transcribe_hausa(audio_tuple) -> str: | |
| """audio_tuple is (sample_rate, np.ndarray) from Gradio.""" | |
| if audio_tuple is None: | |
| return "" | |
| sample_rate, audio_array = audio_tuple | |
| if audio_array is None or len(audio_array) == 0: | |
| return "" | |
| # Convert to float32 mono | |
| if audio_array.dtype != np.float32: | |
| audio_array = audio_array.astype(np.float32) / np.iinfo(audio_array.dtype).max | |
| if audio_array.ndim > 1: | |
| audio_array = audio_array.mean(axis=1) | |
| # Cap at 30s — Whisper-small is trained on 30s chunks; longer audio | |
| # would need windowing which slows the demo | |
| max_samples = sample_rate * 30 | |
| if len(audio_array) > max_samples: | |
| audio_array = audio_array[:max_samples] | |
| # Resample to 16 kHz | |
| if sample_rate != 16000: | |
| import scipy.signal | |
| num_samples = int(len(audio_array) * 16000 / sample_rate) | |
| audio_array = scipy.signal.resample(audio_array, num_samples).astype(np.float32) | |
| model, processor = load_asr() | |
| inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt") | |
| forced_ids = processor.get_decoder_prompt_ids(language="hausa", task="transcribe") | |
| with torch.no_grad(): | |
| ids = model.generate(inputs.input_features, forced_decoder_ids=forced_ids, max_new_tokens=128) | |
| text = processor.batch_decode(ids, skip_special_tokens=True)[0].strip() | |
| return text | |
| def synthesize_hausa(text: str) -> Optional[tuple]: | |
| """Return (sample_rate, np.ndarray) or None.""" | |
| if not text.strip(): | |
| return None | |
| model, tokenizer = load_tts() | |
| inputs = tokenizer(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| out = model(**inputs).waveform | |
| audio = out.squeeze().cpu().numpy().astype(np.float32) | |
| return (model.config.sampling_rate, audio) | |
| # --------------------------------------------------------------------------- | |
| # Core turn handler | |
| # --------------------------------------------------------------------------- | |
| def run_turn(user_text: str, session: dict, trace: list, asr_ms: int = 0) -> tuple: | |
| """ | |
| Executes one turn. Returns (bot_prompt_dict, updated_session, trace, tts_audio). | |
| `session` is a serialized dict stored in gr.State. | |
| """ | |
| state = DialogueState.from_dict(session) if session else None | |
| if state is None: | |
| state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical="bank") | |
| turn_trace = [] | |
| if asr_ms: | |
| turn_trace.append({"stage": "asr (whisper-small)", "ms": asr_ms, | |
| "detail": f'→ "{user_text}"'}) | |
| t0 = time.time() | |
| expected = get_expected_slot(state.vertical, state.current_state) | |
| intent, entities, nlu_source = nlu_parse(user_text, expected) | |
| nlu_stage_label = { | |
| "rule": "nlu (rule-based)", | |
| "llm": "nlu (qwen2.5-1.5b)", | |
| "rule_fallback": "nlu (rule + llm fallback)", | |
| }.get(nlu_source, "nlu") | |
| turn_trace.append({ | |
| "stage": nlu_stage_label, | |
| "ms": max(1, int((time.time() - t0) * 1000)), | |
| "detail": f"intent={intent} entities={entities}", | |
| }) | |
| t1 = time.time() | |
| prev_state = state.current_state | |
| state = transition(state, intent, entities) | |
| turn_trace.append({ | |
| "stage": "dialogue_manager", | |
| "ms": max(1, int((time.time() - t1) * 1000)), | |
| "detail": f"{prev_state} → {state.current_state}", | |
| }) | |
| t2 = time.time() | |
| prompt = get_prompt(state.vertical, state.current_state) | |
| turn_trace.append({"stage": "response_gen", "ms": max(1, int((time.time() - t2) * 1000))}) | |
| t3 = time.time() | |
| audio = synthesize_hausa(prompt["ha"]) | |
| turn_trace.append({"stage": "tts (mms-tts-hau)", "ms": int((time.time() - t3) * 1000)}) | |
| state.history.append({"role": "user", "text": user_text}) | |
| state.history.append({"role": "bot", "text_ha": prompt["ha"], "text_en": prompt["en"]}) | |
| return prompt, state.to_dict(), turn_trace, audio | |
| # --------------------------------------------------------------------------- | |
| # WhatsApp-style HTML renderer | |
| # --------------------------------------------------------------------------- | |
| def render_whatsapp(session: dict, pending_user: Optional[str] = None, | |
| pending_is_voice: bool = False) -> str: | |
| vertical = session.get("vertical", "bank") if session else "bank" | |
| name = SCENARIOS[vertical]["name"] | |
| avatar = {"bank": "PB", "telecom": "PT", "ecommerce": "PD"}[vertical] | |
| escalated = session.get("escalate_to_human", False) if session else False | |
| bubbles = [] | |
| history = session.get("history", []) if session else [] | |
| for msg in history: | |
| if msg["role"] == "user": | |
| is_voice = msg.get("is_voice", False) | |
| bubbles.append(_user_bubble(msg["text"], is_voice)) | |
| else: | |
| bubbles.append(_bot_bubble(msg.get("text_ha", ""), msg.get("text_en", ""))) | |
| if pending_user: | |
| bubbles.append(_user_bubble(pending_user, pending_is_voice)) | |
| banner = ('<div class="pw-esc-banner">Session escalated to human agent</div>' | |
| if escalated else "") | |
| return f""" | |
| <div class="pw-phone"> | |
| <div class="pw-ph-header"> | |
| <div class="pw-ph-avatar">{avatar}</div> | |
| <div> | |
| <div class="pw-ph-name">{html_lib.escape(name)}</div> | |
| <div class="pw-ph-status">online • voice agent</div> | |
| </div> | |
| </div> | |
| <div class="pw-ph-messages"> | |
| {banner} | |
| {"".join(bubbles) if bubbles else '<div style="text-align:center; color:#667781; font-size:12px; padding:40px 0;">Waiting for first message…</div>'} | |
| </div> | |
| </div> | |
| <style> | |
| .pw-phone {{ max-width: 440px; margin: 0 auto; background: #ECE5DD; border-radius: 14px; overflow: hidden; border: 1px solid #ccc; display: flex; flex-direction: column; min-height: 520px; font-family: -apple-system, "Segoe UI", Roboto, sans-serif; }} | |
| .pw-ph-header {{ background: #075E54; color: #fff; padding: 10px 14px; display: flex; align-items: center; gap: 10px; }} | |
| .pw-ph-avatar {{ width: 36px; height: 36px; border-radius: 50%; background: #128C7E; display: flex; align-items: center; justify-content: center; font-weight: 500; font-size: 13px; color: #fff; }} | |
| .pw-ph-name {{ font-size: 14px; font-weight: 500; line-height: 1.2; }} | |
| .pw-ph-status {{ font-size: 11px; color: #D4EDE8; }} | |
| .pw-ph-messages {{ flex: 1; padding: 14px 10px; background: #ECE5DD; background-image: radial-gradient(#D8CFC2 1px, transparent 1px); background-size: 18px 18px; max-height: 460px; overflow-y: auto; min-height: 400px; }} | |
| .pw-b {{ max-width: 80%; padding: 7px 10px 5px; border-radius: 8px; margin-bottom: 6px; font-size: 13.5px; line-height: 1.4; color: #1f2d1f; word-wrap: break-word; }} | |
| .pw-b.user {{ background: #DCF8C6; margin-left: auto; border-bottom-right-radius: 2px; }} | |
| .pw-b.bot {{ background: #fff; margin-right: auto; border-bottom-left-radius: 2px; }} | |
| .pw-b-meta {{ font-size: 10px; color: #667781; margin-top: 3px; text-align: right; }} | |
| .pw-b-trans {{ font-size: 11px; color: #667781; font-style: italic; margin-top: 3px; border-top: 1px solid #E5E5E5; padding-top: 3px; }} | |
| .pw-voice-row {{ display: flex; align-items: center; gap: 8px; }} | |
| .pw-voice-icon {{ width: 22px; height: 22px; border-radius: 50%; background: #128C7E; color: #fff; font-size: 10px; display: flex; align-items: center; justify-content: center; }} | |
| .pw-voice-bars {{ flex: 1; height: 14px; display: flex; align-items: center; gap: 2px; }} | |
| .pw-voice-bars span {{ flex: 1; background: #8D9A9F; border-radius: 1px; }} | |
| .pw-esc-banner {{ background: #FAEEDA; color: #854F0B; font-size: 12px; padding: 8px 12px; border-radius: 8px; margin-bottom: 10px; border: 1px solid #EF9F27; text-align: center; }} | |
| </style> | |
| """ | |
| def _now() -> str: | |
| return time.strftime("%H:%M") | |
| def _user_bubble(text: str, is_voice: bool) -> str: | |
| text_safe = html_lib.escape(text) | |
| if is_voice: | |
| bars = "".join( | |
| f'<span style="height:{4 + int(8 * abs(np.sin(i * 0.7)))}px;"></span>' | |
| for i in range(20) | |
| ) | |
| return f'''<div class="pw-b user"> | |
| <div class="pw-voice-row"> | |
| <div class="pw-voice-icon">▶</div> | |
| <div class="pw-voice-bars">{bars}</div> | |
| </div> | |
| <div style="font-size:12px; color:#667781; margin-top:3px;">"{text_safe}"</div> | |
| <div class="pw-b-meta">{_now()} ✓✓</div> | |
| </div>''' | |
| return f'<div class="pw-b user">{text_safe}<div class="pw-b-meta">{_now()} ✓✓</div></div>' | |
| def _bot_bubble(text_ha: str, text_en: str) -> str: | |
| ha_safe = html_lib.escape(text_ha) | |
| en_safe = html_lib.escape(text_en) | |
| return f'''<div class="pw-b bot"> | |
| <div>{ha_safe}</div> | |
| <div class="pw-b-trans">{en_safe}</div> | |
| <div class="pw-b-meta">{_now()} ✓✓</div> | |
| </div>''' | |
| def render_trace(trace: list) -> str: | |
| if not trace: | |
| return '<div style="color:#888; font-size:13px;">Send a message to see the pipeline trace.</div>' | |
| rows = [] | |
| for r in trace: | |
| row = f'<div style="display:flex; justify-content:space-between; padding:5px 0; border-bottom:1px solid #eee;"><span style="color:#5f5e5a;">{html_lib.escape(r["stage"])}</span><span style="color:#0C447C; font-weight:500;">{r["ms"]}ms</span></div>' | |
| rows.append(row) | |
| if r.get("detail"): | |
| rows.append(f'<div style="font-size:11px; color:#888; padding:0 0 5px; font-family:monospace;">{html_lib.escape(str(r["detail"]))}</div>') | |
| return f'<div style="font-family:monospace; font-size:12px;">{"".join(rows)}</div>' | |
| def render_metrics(session: dict) -> str: | |
| if not session: | |
| return "" | |
| sid = session.get("session_id", "—") | |
| turn = session.get("turn_count", 0) | |
| state = session.get("current_state", "greeting") | |
| slots = session.get("slots", {}) | |
| slots_html = ", ".join(f"<code>{k}={v}</code>" for k, v in slots.items()) or "—" | |
| return f''' | |
| <div style="display:grid; grid-template-columns:1fr 1fr; gap:8px; font-size:13px;"> | |
| <div><div style="color:#888; font-size:11px; text-transform:uppercase;">Session</div><div style="font-family:monospace;">{sid}</div></div> | |
| <div><div style="color:#888; font-size:11px; text-transform:uppercase;">Turn</div><div style="font-weight:500;">{turn}</div></div> | |
| <div><div style="color:#888; font-size:11px; text-transform:uppercase;">State</div><div style="font-family:monospace;">{state}</div></div> | |
| <div><div style="color:#888; font-size:11px; text-transform:uppercase;">Slots</div><div>{slots_html}</div></div> | |
| </div>''' | |
| # --------------------------------------------------------------------------- | |
| # Gradio event handlers | |
| # --------------------------------------------------------------------------- | |
| def on_vertical_change(vertical: str, synth_greeting: bool = False): | |
| """Reset session when vertical changes. TTS the greeting only on first real | |
| user interaction — keeps initial page load fast (avoids MMS-TTS cold-start).""" | |
| state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical=vertical) | |
| greet = get_prompt(vertical, "greeting") | |
| state.history.append({"role": "bot", "text_ha": greet["ha"], "text_en": greet["en"]}) | |
| session = state.to_dict() | |
| audio = None | |
| if synth_greeting: | |
| try: | |
| audio = synthesize_hausa(greet["ha"]) | |
| except Exception as e: | |
| print(f"TTS failed on greeting: {e}") | |
| return ( | |
| session, | |
| render_whatsapp(session), | |
| render_trace([]), | |
| render_metrics(session), | |
| audio, | |
| ) | |
| def on_text_submit(text: str, session: dict): | |
| if not text or not text.strip(): | |
| return session, render_whatsapp(session), render_trace([]), render_metrics(session), None, "" | |
| prompt, new_session, trace, audio = run_turn(text, session, [], asr_ms=0) | |
| return ( | |
| new_session, | |
| render_whatsapp(new_session), | |
| render_trace(trace), | |
| render_metrics(new_session), | |
| audio, | |
| "", # clear input | |
| ) | |
| def on_audio_submit(audio_data, session: dict): | |
| if audio_data is None: | |
| return session, render_whatsapp(session), render_trace([]), render_metrics(session), None | |
| t0 = time.time() | |
| try: | |
| text = transcribe_hausa(audio_data) | |
| except Exception as e: | |
| print(f"ASR failed: {e}") | |
| return session, render_whatsapp(session), render_trace([{"stage": "asr error", "ms": 0, "detail": str(e)}]), render_metrics(session), None | |
| asr_ms = int((time.time() - t0) * 1000) | |
| if not text: | |
| return session, render_whatsapp(session), render_trace([{"stage": "asr", "ms": asr_ms, "detail": "(no speech detected)"}]), render_metrics(session), None | |
| # Mark last user message as voice after appending | |
| prompt, new_session, trace, audio = run_turn(text, session, [], asr_ms=asr_ms) | |
| # Tag the last user entry as voice | |
| if new_session.get("history"): | |
| for i in range(len(new_session["history"]) - 1, -1, -1): | |
| if new_session["history"][i]["role"] == "user": | |
| new_session["history"][i]["is_voice"] = True | |
| break | |
| return ( | |
| new_session, | |
| render_whatsapp(new_session), | |
| render_trace(trace), | |
| render_metrics(new_session), | |
| audio, | |
| ) | |
| def on_reset(session: dict): | |
| vertical = session.get("vertical", "bank") if session else "bank" | |
| return on_vertical_change(vertical) | |
| def on_escalate(session: dict): | |
| return on_text_submit("Ina son wakili mutum", session) | |
| # --------------------------------------------------------------------------- | |
| # Preset phrases for quick-click demo | |
| # --------------------------------------------------------------------------- | |
| PRESETS = { | |
| "bank": ["duba ma'auni", "toshe kati", "canjin kuɗi", "1234", "Aisha", "dubu biyar", "i"], | |
| "telecom": ["saya airtime", "saya bundle", "korafi", "1000", "rana", "Intanet bai aiki"], | |
| "ecommerce": ["bincika oda", "sake tsara", "mayar da kaya", "10234", "jumma'a", "Ya lalace"], | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| CUSTOM_CSS = """ | |
| .gradio-container { max-width: 1200px !important; } | |
| #vertical-selector { background: #fff; border-radius: 10px; padding: 12px; } | |
| #whatsapp-html { background: #f5f4ef; border-radius: 12px; padding: 20px; } | |
| #trace-box, #metrics-box { background: #fff; border-radius: 10px; padding: 12px; border: 1px solid #e5e5e5; } | |
| h1 { font-size: 22px !important; font-weight: 500 !important; } | |
| .header-sub { color: #5f5e5a; font-size: 14px; margin-top: -8px; margin-bottom: 16px; } | |
| """ | |
| with gr.Blocks(css=CUSTOM_CSS, title="PlotWeaver Voice Agent") as demo: | |
| gr.HTML(""" | |
| <h1 style="margin-bottom:4px;">PlotWeaver Voice Agent</h1> | |
| <p class="header-sub">Hausa-first conversational AI for African banks, telecoms, and delivery services. Real Whisper-small ASR and MMS-TTS Hausa running on CPU.</p> | |
| """) | |
| session_state = gr.State({}) | |
| with gr.Row(): | |
| # Left column: controls + trace | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Select vertical") | |
| vertical_radio = gr.Radio( | |
| choices=[("PlotWeaver Bank", "bank"), | |
| ("PlotWeaver Telecom", "telecom"), | |
| ("PlotWeaver Delivery", "ecommerce")], | |
| value="bank", | |
| label="", | |
| elem_id="vertical-selector", | |
| ) | |
| with gr.Row(): | |
| reset_btn = gr.Button("Reset session", size="sm") | |
| escalate_btn = gr.Button("Force escalate", size="sm") | |
| gr.Markdown("### Session metrics") | |
| metrics_html = gr.HTML(elem_id="metrics-box") | |
| gr.Markdown("### Pipeline trace (last turn)") | |
| trace_html = gr.HTML(elem_id="trace-box") | |
| # Middle column: WhatsApp mockup | |
| with gr.Column(scale=2): | |
| whatsapp_html = gr.HTML(elem_id="whatsapp-html") | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| placeholder="Type in Hausa… e.g. 'duba ma'auni'", | |
| label="", | |
| scale=4, | |
| container=False, | |
| ) | |
| send_btn = gr.Button("Send", scale=1, variant="primary") | |
| gr.Markdown("**Or speak / upload audio in Hausa:**") | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="numpy", | |
| label="Record or upload a Hausa audio file (.wav, .mp3, .ogg)", | |
| show_download_button=False, | |
| ) | |
| with gr.Row(): | |
| transcribe_btn = gr.Button("Transcribe & send", variant="secondary", size="sm") | |
| clear_audio_btn = gr.Button("Clear", size="sm") | |
| bot_audio = gr.Audio( | |
| label="Bot response (Hausa TTS)", | |
| autoplay=True, | |
| interactive=False, | |
| ) | |
| # Preset quick-clicks | |
| gr.Markdown("### Quick phrases (Hausa)") | |
| preset_btns = [] | |
| with gr.Row(): | |
| for p in PRESETS["bank"]: | |
| preset_btns.append(gr.Button(p, size="sm")) | |
| # ----------------------------------------------------------------------- | |
| # Event wiring | |
| # ----------------------------------------------------------------------- | |
| outputs = [session_state, whatsapp_html, trace_html, metrics_html, bot_audio] | |
| demo.load( | |
| fn=lambda: on_vertical_change("bank"), | |
| outputs=outputs, | |
| ) | |
| vertical_radio.change( | |
| fn=on_vertical_change, | |
| inputs=[vertical_radio], | |
| outputs=outputs, | |
| ) | |
| send_btn.click( | |
| fn=on_text_submit, | |
| inputs=[text_input, session_state], | |
| outputs=outputs + [text_input], | |
| ) | |
| text_input.submit( | |
| fn=on_text_submit, | |
| inputs=[text_input, session_state], | |
| outputs=outputs + [text_input], | |
| ) | |
| audio_input.stop_recording( | |
| fn=on_audio_submit, | |
| inputs=[audio_input, session_state], | |
| outputs=outputs, | |
| ) | |
| transcribe_btn.click( | |
| fn=on_audio_submit, | |
| inputs=[audio_input, session_state], | |
| outputs=outputs, | |
| ) | |
| clear_audio_btn.click( | |
| fn=lambda: None, | |
| outputs=[audio_input], | |
| ) | |
| reset_btn.click(fn=on_reset, inputs=[session_state], outputs=outputs) | |
| escalate_btn.click( | |
| fn=on_escalate, | |
| inputs=[session_state], | |
| outputs=outputs + [text_input], | |
| ) | |
| # Preset buttons submit their own text | |
| for btn, phrase in zip(preset_btns, PRESETS["bank"]): | |
| btn.click( | |
| fn=lambda s, _phrase=phrase: on_text_submit(_phrase, s), | |
| inputs=[session_state], | |
| outputs=outputs + [text_input], | |
| ) | |
| if __name__ == "__main__": | |
| # show_api=False avoids a known gradio_client JSON-schema bug on | |
| # certain Gradio/Python 3.13 combinations where get_api_info() crashes | |
| # with "TypeError: argument of type 'bool' is not iterable". | |
| # We don't need the /?view=api endpoint for this demo anyway. | |
| demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860) |