Spaces:
Sleeping
Sleeping
| """ | |
| PlotWeaver Voice Agent — HuggingFace Space (Gradio 6 + Python 3.13) | |
| ==================================================================== | |
| Hausa-first conversational AI for African banks, telecoms, and delivery. | |
| Pipeline (all real, running on CPU): | |
| ASR (openai/whisper-small) | |
| → NLU (rule-based + Qwen2.5-1.5B-Instruct fallback, see nlu.py) | |
| → Dialogue FSM (see dialogue.py) | |
| → TTS (facebook/mms-tts-hau) | |
| First turn: ~30-60s model downloads. Subsequent turns: ~5-10s on CPU. | |
| """ | |
| from __future__ import annotations | |
| import time | |
| import uuid | |
| import html as html_lib | |
| from typing import Optional | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from transformers import ( | |
| VitsModel, AutoTokenizer, | |
| WhisperProcessor, WhisperForConditionalGeneration, | |
| ) | |
| from dialogue import ( | |
| DialogueState, SCENARIOS, | |
| get_prompt, get_expected_slot, transition, | |
| ) | |
| from nlu import parse as nlu_parse | |
| # --------------------------------------------------------------------------- | |
| # Model loading (lazy, cached) | |
| # --------------------------------------------------------------------------- | |
| _asr_model = None | |
| _asr_processor = None | |
| _tts_model = None | |
| _tts_tokenizer = None | |
| def load_asr(): | |
| global _asr_model, _asr_processor | |
| if _asr_model is None: | |
| print("Loading Whisper-small…") | |
| _asr_processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
| _asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") | |
| _asr_model.eval() | |
| print("Whisper-small ready.") | |
| return _asr_model, _asr_processor | |
| def load_tts(): | |
| global _tts_model, _tts_tokenizer | |
| if _tts_model is None: | |
| print("Loading MMS-TTS Hausa…") | |
| _tts_model = VitsModel.from_pretrained("facebook/mms-tts-hau") | |
| _tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-hau") | |
| _tts_model.eval() | |
| print("MMS-TTS Hausa ready.") | |
| return _tts_model, _tts_tokenizer | |
| def transcribe_hausa(audio_tuple) -> str: | |
| if audio_tuple is None: | |
| return "" | |
| sample_rate, audio_array = audio_tuple | |
| if audio_array is None or len(audio_array) == 0: | |
| return "" | |
| if audio_array.dtype != np.float32: | |
| audio_array = audio_array.astype(np.float32) / np.iinfo(audio_array.dtype).max | |
| if audio_array.ndim > 1: | |
| audio_array = audio_array.mean(axis=1) | |
| # Cap at 30s (Whisper training chunk size) | |
| max_samples = sample_rate * 30 | |
| if len(audio_array) > max_samples: | |
| audio_array = audio_array[:max_samples] | |
| if sample_rate != 16000: | |
| import scipy.signal | |
| num_samples = int(len(audio_array) * 16000 / sample_rate) | |
| audio_array = scipy.signal.resample(audio_array, num_samples).astype(np.float32) | |
| model, processor = load_asr() | |
| inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt") | |
| forced_ids = processor.get_decoder_prompt_ids(language="hausa", task="transcribe") | |
| with torch.no_grad(): | |
| ids = model.generate(inputs.input_features, forced_decoder_ids=forced_ids, max_new_tokens=128) | |
| text = processor.batch_decode(ids, skip_special_tokens=True)[0].strip() | |
| return text | |
| def synthesize_hausa(text: str) -> Optional[tuple]: | |
| if not text.strip(): | |
| return None | |
| model, tokenizer = load_tts() | |
| inputs = tokenizer(text, return_tensors="pt") | |
| with torch.no_grad(): | |
| out = model(**inputs).waveform | |
| audio = out.squeeze().cpu().numpy().astype(np.float32) | |
| return (model.config.sampling_rate, audio) | |
| # --------------------------------------------------------------------------- | |
| # WhatsApp-style HTML rendering | |
| # --------------------------------------------------------------------------- | |
| def _now() -> str: | |
| return time.strftime("%H:%M") | |
| def _user_bubble(text: str, is_voice: bool) -> str: | |
| t = html_lib.escape(text) | |
| if is_voice: | |
| bars = "".join( | |
| f'<span style="height:{4 + int(8 * abs(np.sin(i * 0.7)))}px;"></span>' | |
| for i in range(20) | |
| ) | |
| return f'''<div class="pw-b user"> | |
| <div class="pw-voice-row"> | |
| <div class="pw-voice-icon">▶</div> | |
| <div class="pw-voice-bars">{bars}</div> | |
| </div> | |
| <div style="font-size:12px;color:#667781;margin-top:3px;">"{t}"</div> | |
| <div class="pw-b-meta">{_now()} ✓✓</div> | |
| </div>''' | |
| return f'<div class="pw-b user">{t}<div class="pw-b-meta">{_now()} ✓✓</div></div>' | |
| def _bot_bubble(text_ha: str, text_en: str) -> str: | |
| ha = html_lib.escape(text_ha) | |
| en = html_lib.escape(text_en) | |
| return f'''<div class="pw-b bot"> | |
| <div>{ha}</div> | |
| <div class="pw-b-trans">{en}</div> | |
| <div class="pw-b-meta">{_now()} ✓✓</div> | |
| </div>''' | |
| def render_whatsapp(session: dict) -> str: | |
| vertical = session.get("vertical", "bank") if session else "bank" | |
| name = SCENARIOS[vertical]["name"] | |
| avatar = {"bank": "PB", "telecom": "PT", "ecommerce": "PD"}[vertical] | |
| escalated = session.get("escalate_to_human", False) if session else False | |
| bubbles = [] | |
| for msg in session.get("history", []) if session else []: | |
| if msg["role"] == "user": | |
| bubbles.append(_user_bubble(msg["text"], msg.get("is_voice", False))) | |
| else: | |
| bubbles.append(_bot_bubble(msg.get("text_ha", ""), msg.get("text_en", ""))) | |
| banner = ('<div class="pw-esc-banner">Session escalated to human agent</div>' | |
| if escalated else "") | |
| if not bubbles: | |
| body = '<div style="text-align:center;color:#667781;font-size:12px;padding:40px 0;">Send a message to begin…</div>' | |
| else: | |
| body = "".join(bubbles) | |
| return f""" | |
| <div class="pw-phone"> | |
| <div class="pw-ph-header"> | |
| <div class="pw-ph-avatar">{avatar}</div> | |
| <div> | |
| <div class="pw-ph-name">{html_lib.escape(name)}</div> | |
| <div class="pw-ph-status">online • voice agent</div> | |
| </div> | |
| </div> | |
| <div class="pw-ph-messages"> | |
| {banner} | |
| {body} | |
| </div> | |
| </div> | |
| <style> | |
| .pw-phone {{ max-width: 480px; margin: 0 auto; background: #ECE5DD; border-radius: 14px; overflow: hidden; border: 1px solid #ccc; display: flex; flex-direction: column; min-height: 540px; font-family: -apple-system, "Segoe UI", Roboto, sans-serif; }} | |
| .pw-ph-header {{ background: #075E54; color: #fff; padding: 10px 14px; display: flex; align-items: center; gap: 10px; }} | |
| .pw-ph-avatar {{ width: 36px; height: 36px; border-radius: 50%; background: #128C7E; display: flex; align-items: center; justify-content: center; font-weight: 500; font-size: 13px; color: #fff; }} | |
| .pw-ph-name {{ font-size: 14px; font-weight: 500; line-height: 1.2; }} | |
| .pw-ph-status {{ font-size: 11px; color: #D4EDE8; }} | |
| .pw-ph-messages {{ flex: 1; padding: 14px 10px; background: #ECE5DD; background-image: radial-gradient(#D8CFC2 1px, transparent 1px); background-size: 18px 18px; max-height: 480px; overflow-y: auto; min-height: 420px; }} | |
| .pw-b {{ max-width: 80%; padding: 7px 10px 5px; border-radius: 8px; margin-bottom: 6px; font-size: 13.5px; line-height: 1.4; color: #1f2d1f; word-wrap: break-word; }} | |
| .pw-b.user {{ background: #DCF8C6; margin-left: auto; border-bottom-right-radius: 2px; }} | |
| .pw-b.bot {{ background: #fff; margin-right: auto; border-bottom-left-radius: 2px; }} | |
| .pw-b-meta {{ font-size: 10px; color: #667781; margin-top: 3px; text-align: right; }} | |
| .pw-b-trans {{ font-size: 11px; color: #667781; font-style: italic; margin-top: 3px; border-top: 1px solid #E5E5E5; padding-top: 3px; }} | |
| .pw-voice-row {{ display: flex; align-items: center; gap: 8px; }} | |
| .pw-voice-icon {{ width: 22px; height: 22px; border-radius: 50%; background: #128C7E; color: #fff; font-size: 10px; display: flex; align-items: center; justify-content: center; }} | |
| .pw-voice-bars {{ flex: 1; height: 14px; display: flex; align-items: center; gap: 2px; }} | |
| .pw-voice-bars span {{ flex: 1; background: #8D9A9F; border-radius: 1px; }} | |
| .pw-esc-banner {{ background: #FAEEDA; color: #854F0B; font-size: 12px; padding: 8px 12px; border-radius: 8px; margin-bottom: 10px; border: 1px solid #EF9F27; text-align: center; }} | |
| </style> | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # Core turn handler | |
| # --------------------------------------------------------------------------- | |
| def run_turn(user_text: str, session: dict, is_voice: bool = False): | |
| """Returns (updated_session_dict, bot_audio).""" | |
| state = DialogueState.from_dict(session) if session else None | |
| if state is None: | |
| state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical="bank") | |
| expected = get_expected_slot(state.vertical, state.current_state) | |
| intent, entities, _ = nlu_parse(user_text, expected) | |
| state = transition(state, intent, entities) | |
| prompt = get_prompt(state.vertical, state.current_state) | |
| state.history.append({"role": "user", "text": user_text, "is_voice": is_voice}) | |
| state.history.append({"role": "bot", "text_ha": prompt["ha"], "text_en": prompt["en"]}) | |
| try: | |
| audio = synthesize_hausa(prompt["ha"]) | |
| except Exception as e: | |
| print(f"TTS failed: {e}") | |
| audio = None | |
| return state.to_dict(), audio | |
| # --------------------------------------------------------------------------- | |
| # Gradio event handlers | |
| # --------------------------------------------------------------------------- | |
| def on_vertical_change(vertical: str): | |
| state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical=vertical) | |
| greet = get_prompt(vertical, "greeting") | |
| state.history.append({"role": "bot", "text_ha": greet["ha"], "text_en": greet["en"]}) | |
| session = state.to_dict() | |
| return session, render_whatsapp(session), None | |
| def on_text_submit(text: str, session: dict): | |
| if not text or not text.strip(): | |
| return session, render_whatsapp(session), None, "" | |
| new_session, audio = run_turn(text, session, is_voice=False) | |
| return new_session, render_whatsapp(new_session), audio, "" | |
| def on_audio_submit(audio_data, session: dict): | |
| if audio_data is None: | |
| return session, render_whatsapp(session), None | |
| try: | |
| text = transcribe_hausa(audio_data) | |
| except Exception as e: | |
| print(f"ASR failed: {e}") | |
| return session, render_whatsapp(session), None | |
| if not text: | |
| return session, render_whatsapp(session), None | |
| new_session, audio = run_turn(text, session, is_voice=True) | |
| return new_session, render_whatsapp(new_session), audio | |
| def on_reset(session: dict): | |
| vertical = session.get("vertical", "bank") if session else "bank" | |
| return on_vertical_change(vertical) | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI (chat-only, minimal components) | |
| # --------------------------------------------------------------------------- | |
| CUSTOM_CSS = """ | |
| .gradio-container { max-width: 720px !important; margin: 0 auto !important; } | |
| #whatsapp-container { padding: 20px 0; } | |
| """ | |
| with gr.Blocks(css=CUSTOM_CSS, title="PlotWeaver Voice Agent") as demo: | |
| gr.HTML(""" | |
| <div style="text-align:center; padding: 0 0 12px;"> | |
| <h1 style="margin:0 0 4px; font-size: 22px; font-weight: 500;">PlotWeaver Voice Agent</h1> | |
| <p style="margin:0; color: #5f5e5a; font-size: 14px;">Hausa-first conversational AI — pick a vertical, type or speak in Hausa</p> | |
| </div> | |
| """) | |
| session_state = gr.State({}) | |
| vertical_radio = gr.Radio( | |
| choices=[("PlotWeaver Bank", "bank"), | |
| ("PlotWeaver Telecom", "telecom"), | |
| ("PlotWeaver Delivery", "ecommerce")], | |
| value="bank", | |
| label="Vertical", | |
| container=False, | |
| ) | |
| whatsapp_html = gr.HTML(elem_id="whatsapp-container") | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| placeholder="Type in Hausa… e.g. 'duba ma'auni'", | |
| label="", | |
| scale=4, | |
| container=False, | |
| ) | |
| send_btn = gr.Button("Send", scale=1, variant="primary") | |
| reset_btn = gr.Button("Reset", scale=1) | |
| audio_input = gr.Audio( | |
| sources=["microphone", "upload"], | |
| type="numpy", | |
| label="Record or upload Hausa audio (click Stop when done recording)", | |
| ) | |
| bot_audio = gr.Audio( | |
| label="Bot response (Hausa TTS)", | |
| autoplay=True, | |
| interactive=False, | |
| ) | |
| # Events | |
| demo.load( | |
| fn=lambda: on_vertical_change("bank"), | |
| outputs=[session_state, whatsapp_html, bot_audio], | |
| ) | |
| vertical_radio.change( | |
| fn=on_vertical_change, | |
| inputs=[vertical_radio], | |
| outputs=[session_state, whatsapp_html, bot_audio], | |
| ) | |
| send_btn.click( | |
| fn=on_text_submit, | |
| inputs=[text_input, session_state], | |
| outputs=[session_state, whatsapp_html, bot_audio, text_input], | |
| ) | |
| text_input.submit( | |
| fn=on_text_submit, | |
| inputs=[text_input, session_state], | |
| outputs=[session_state, whatsapp_html, bot_audio, text_input], | |
| ) | |
| audio_input.stop_recording( | |
| fn=on_audio_submit, | |
| inputs=[audio_input, session_state], | |
| outputs=[session_state, whatsapp_html, bot_audio], | |
| ) | |
| reset_btn.click( | |
| fn=on_reset, | |
| inputs=[session_state], | |
| outputs=[session_state, whatsapp_html, bot_audio], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |