""" PlotWeaver Voice Agent — HuggingFace Space ============================================ Gradio app demonstrating a Hausa-first conversational AI for African banks, telecoms, and delivery services. Pipeline: ASR (Whisper-small) → NLU (rule-based) → Dialogue FSM → TTS (facebook/mms-tts-hau). Runs on CPU. First turn triggers model download (~500MB), subsequent turns are ~2-4s end-to-end. """ from __future__ import annotations # --------------------------------------------------------------------------- # Monkey-patch for a known gradio_client bug on Python 3.13 + gradio 4.44.1: # gradio_client/utils.py:get_type() does `"const" in schema` where schema is # sometimes a bool (False), triggering: # TypeError: argument of type 'bool' is not iterable # See: https://github.com/gradio-app/gradio/issues/11722 # We patch the two affected functions to handle bool schemas defensively. # This MUST run before `import gradio`. # --------------------------------------------------------------------------- def _patch_gradio_client_schema_bug(): try: from gradio_client import utils as _gcu _orig_get_type = _gcu.get_type _orig_json_to_py = _gcu._json_schema_to_python_type def _safe_get_type(schema): if isinstance(schema, bool): return "Any" return _orig_get_type(schema) def _safe_json_to_py(schema, defs=None): if isinstance(schema, bool): return "Any" return _orig_json_to_py(schema, defs) _gcu.get_type = _safe_get_type _gcu._json_schema_to_python_type = _safe_json_to_py except Exception as _e: # If the patch fails, we fall back to show_api=False in launch() print(f"[plotweaver] gradio_client patch failed: {_e}") _patch_gradio_client_schema_bug() import time import uuid import html as html_lib from typing import Optional import gradio as gr import numpy as np import torch from transformers import ( VitsModel, AutoTokenizer, WhisperProcessor, WhisperForConditionalGeneration, ) from dialogue import ( DialogueState, SCENARIOS, get_prompt, get_expected_slot, transition, ) from nlu import parse as nlu_parse # --------------------------------------------------------------------------- # Model loading (lazy, cached) # --------------------------------------------------------------------------- _asr_model = None _asr_processor = None _tts_model = None _tts_tokenizer = None def load_asr(): global _asr_model, _asr_processor if _asr_model is None: print("Loading Whisper-small…") _asr_processor = WhisperProcessor.from_pretrained("openai/whisper-small") _asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") _asr_model.eval() print("Whisper-small ready.") return _asr_model, _asr_processor def load_tts(): global _tts_model, _tts_tokenizer if _tts_model is None: print("Loading MMS-TTS Hausa…") _tts_model = VitsModel.from_pretrained("facebook/mms-tts-hau") _tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-hau") _tts_model.eval() print("MMS-TTS Hausa ready.") return _tts_model, _tts_tokenizer def transcribe_hausa(audio_tuple) -> str: """audio_tuple is (sample_rate, np.ndarray) from Gradio.""" if audio_tuple is None: return "" sample_rate, audio_array = audio_tuple if audio_array is None or len(audio_array) == 0: return "" # Convert to float32 mono if audio_array.dtype != np.float32: audio_array = audio_array.astype(np.float32) / np.iinfo(audio_array.dtype).max if audio_array.ndim > 1: audio_array = audio_array.mean(axis=1) # Cap at 30s — Whisper-small is trained on 30s chunks; longer audio # would need windowing which slows the demo max_samples = sample_rate * 30 if len(audio_array) > max_samples: audio_array = audio_array[:max_samples] # Resample to 16 kHz if sample_rate != 16000: import scipy.signal num_samples = int(len(audio_array) * 16000 / sample_rate) audio_array = scipy.signal.resample(audio_array, num_samples).astype(np.float32) model, processor = load_asr() inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt") forced_ids = processor.get_decoder_prompt_ids(language="hausa", task="transcribe") with torch.no_grad(): ids = model.generate(inputs.input_features, forced_decoder_ids=forced_ids, max_new_tokens=128) text = processor.batch_decode(ids, skip_special_tokens=True)[0].strip() return text def synthesize_hausa(text: str) -> Optional[tuple]: """Return (sample_rate, np.ndarray) or None.""" if not text.strip(): return None model, tokenizer = load_tts() inputs = tokenizer(text, return_tensors="pt") with torch.no_grad(): out = model(**inputs).waveform audio = out.squeeze().cpu().numpy().astype(np.float32) return (model.config.sampling_rate, audio) # --------------------------------------------------------------------------- # Core turn handler # --------------------------------------------------------------------------- def run_turn(user_text: str, session: dict, trace: list, asr_ms: int = 0) -> tuple: """ Executes one turn. Returns (bot_prompt_dict, updated_session, trace, tts_audio). `session` is a serialized dict stored in gr.State. """ state = DialogueState.from_dict(session) if session else None if state is None: state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical="bank") turn_trace = [] if asr_ms: turn_trace.append({"stage": "asr (whisper-small)", "ms": asr_ms, "detail": f'→ "{user_text}"'}) t0 = time.time() expected = get_expected_slot(state.vertical, state.current_state) intent, entities, nlu_source = nlu_parse(user_text, expected) nlu_stage_label = { "rule": "nlu (rule-based)", "llm": "nlu (qwen2.5-1.5b)", "rule_fallback": "nlu (rule + llm fallback)", }.get(nlu_source, "nlu") turn_trace.append({ "stage": nlu_stage_label, "ms": max(1, int((time.time() - t0) * 1000)), "detail": f"intent={intent} entities={entities}", }) t1 = time.time() prev_state = state.current_state state = transition(state, intent, entities) turn_trace.append({ "stage": "dialogue_manager", "ms": max(1, int((time.time() - t1) * 1000)), "detail": f"{prev_state} → {state.current_state}", }) t2 = time.time() prompt = get_prompt(state.vertical, state.current_state) turn_trace.append({"stage": "response_gen", "ms": max(1, int((time.time() - t2) * 1000))}) t3 = time.time() audio = synthesize_hausa(prompt["ha"]) turn_trace.append({"stage": "tts (mms-tts-hau)", "ms": int((time.time() - t3) * 1000)}) state.history.append({"role": "user", "text": user_text}) state.history.append({"role": "bot", "text_ha": prompt["ha"], "text_en": prompt["en"]}) return prompt, state.to_dict(), turn_trace, audio # --------------------------------------------------------------------------- # WhatsApp-style HTML renderer # --------------------------------------------------------------------------- def render_whatsapp(session: dict, pending_user: Optional[str] = None, pending_is_voice: bool = False) -> str: vertical = session.get("vertical", "bank") if session else "bank" name = SCENARIOS[vertical]["name"] avatar = {"bank": "PB", "telecom": "PT", "ecommerce": "PD"}[vertical] escalated = session.get("escalate_to_human", False) if session else False bubbles = [] history = session.get("history", []) if session else [] for msg in history: if msg["role"] == "user": is_voice = msg.get("is_voice", False) bubbles.append(_user_bubble(msg["text"], is_voice)) else: bubbles.append(_bot_bubble(msg.get("text_ha", ""), msg.get("text_en", ""))) if pending_user: bubbles.append(_user_bubble(pending_user, pending_is_voice)) banner = ('
' if escalated else "") return f"""{k}={v}" for k, v in slots.items()) or "—"
return f'''
Hausa-first conversational AI for African banks, telecoms, and delivery services. Real Whisper-small ASR and MMS-TTS Hausa running on CPU.
""") session_state = gr.State({}) with gr.Row(): # Left column: controls + trace with gr.Column(scale=1): gr.Markdown("### Select vertical") vertical_radio = gr.Radio( choices=[("PlotWeaver Bank", "bank"), ("PlotWeaver Telecom", "telecom"), ("PlotWeaver Delivery", "ecommerce")], value="bank", label="", elem_id="vertical-selector", ) with gr.Row(): reset_btn = gr.Button("Reset session", size="sm") escalate_btn = gr.Button("Force escalate", size="sm") gr.Markdown("### Session metrics") metrics_html = gr.HTML(elem_id="metrics-box") gr.Markdown("### Pipeline trace (last turn)") trace_html = gr.HTML(elem_id="trace-box") # Middle column: WhatsApp mockup with gr.Column(scale=2): whatsapp_html = gr.HTML(elem_id="whatsapp-html") with gr.Row(): text_input = gr.Textbox( placeholder="Type in Hausa… e.g. 'duba ma'auni'", label="", scale=4, container=False, ) send_btn = gr.Button("Send", scale=1, variant="primary") gr.Markdown("**Or speak / upload audio in Hausa:**") audio_input = gr.Audio( sources=["microphone", "upload"], type="numpy", label="Record or upload a Hausa audio file (.wav, .mp3, .ogg)", show_download_button=False, ) with gr.Row(): transcribe_btn = gr.Button("Transcribe & send", variant="secondary", size="sm") clear_audio_btn = gr.Button("Clear", size="sm") bot_audio = gr.Audio( label="Bot response (Hausa TTS)", autoplay=True, interactive=False, ) # Preset quick-clicks gr.Markdown("### Quick phrases (Hausa)") preset_btns = [] with gr.Row(): for p in PRESETS["bank"]: preset_btns.append(gr.Button(p, size="sm")) # ----------------------------------------------------------------------- # Event wiring # ----------------------------------------------------------------------- outputs = [session_state, whatsapp_html, trace_html, metrics_html, bot_audio] demo.load( fn=lambda: on_vertical_change("bank"), outputs=outputs, ) vertical_radio.change( fn=on_vertical_change, inputs=[vertical_radio], outputs=outputs, ) send_btn.click( fn=on_text_submit, inputs=[text_input, session_state], outputs=outputs + [text_input], ) text_input.submit( fn=on_text_submit, inputs=[text_input, session_state], outputs=outputs + [text_input], ) audio_input.stop_recording( fn=on_audio_submit, inputs=[audio_input, session_state], outputs=outputs, ) transcribe_btn.click( fn=on_audio_submit, inputs=[audio_input, session_state], outputs=outputs, ) clear_audio_btn.click( fn=lambda: None, outputs=[audio_input], ) reset_btn.click(fn=on_reset, inputs=[session_state], outputs=outputs) escalate_btn.click( fn=on_escalate, inputs=[session_state], outputs=outputs + [text_input], ) # Preset buttons submit their own text for btn, phrase in zip(preset_btns, PRESETS["bank"]): btn.click( fn=lambda s, _phrase=phrase: on_text_submit(_phrase, s), inputs=[session_state], outputs=outputs + [text_input], ) if __name__ == "__main__": # show_api=False avoids a known gradio_client JSON-schema bug on # certain Gradio/Python 3.13 combinations where get_api_info() crashes # with "TypeError: argument of type 'bool' is not iterable". # We don't need the /?view=api endpoint for this demo anyway. demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)