Toadoum commited on
Commit
e8796cb
·
verified ·
1 Parent(s): 70de5c0

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -509
app.py DELETED
@@ -1,509 +0,0 @@
1
- """
2
- PlotWeaver Voice Agent — HuggingFace Space
3
- ============================================
4
- Gradio app demonstrating a Hausa-first conversational AI for
5
- African banks, telecoms, and delivery services.
6
-
7
- Pipeline: ASR (Whisper-small) → NLU (rule-based) → Dialogue FSM →
8
- TTS (facebook/mms-tts-hau).
9
-
10
- Runs on CPU. First turn triggers model download (~500MB), subsequent turns
11
- are ~2-4s end-to-end.
12
- """
13
- from __future__ import annotations
14
- import time
15
- import uuid
16
- import html as html_lib
17
- from typing import Optional
18
-
19
- import gradio as gr
20
- import numpy as np
21
- import torch
22
- from transformers import (
23
- VitsModel, AutoTokenizer,
24
- WhisperProcessor, WhisperForConditionalGeneration,
25
- )
26
-
27
- from dialogue import (
28
- DialogueState, SCENARIOS,
29
- get_prompt, get_expected_slot, transition,
30
- )
31
- from nlu import parse as nlu_parse
32
-
33
-
34
- # ---------------------------------------------------------------------------
35
- # Model loading (lazy, cached)
36
- # ---------------------------------------------------------------------------
37
- _asr_model = None
38
- _asr_processor = None
39
- _tts_model = None
40
- _tts_tokenizer = None
41
-
42
-
43
- def load_asr():
44
- global _asr_model, _asr_processor
45
- if _asr_model is None:
46
- print("Loading Whisper-small…")
47
- _asr_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
48
- _asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
49
- _asr_model.eval()
50
- print("Whisper-small ready.")
51
- return _asr_model, _asr_processor
52
-
53
-
54
- def load_tts():
55
- global _tts_model, _tts_tokenizer
56
- if _tts_model is None:
57
- print("Loading MMS-TTS Hausa…")
58
- _tts_model = VitsModel.from_pretrained("facebook/mms-tts-hau")
59
- _tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-hau")
60
- _tts_model.eval()
61
- print("MMS-TTS Hausa ready.")
62
- return _tts_model, _tts_tokenizer
63
-
64
-
65
- def transcribe_hausa(audio_tuple) -> str:
66
- """audio_tuple is (sample_rate, np.ndarray) from Gradio."""
67
- if audio_tuple is None:
68
- return ""
69
- sample_rate, audio_array = audio_tuple
70
- if audio_array is None or len(audio_array) == 0:
71
- return ""
72
- # Convert to float32 mono
73
- if audio_array.dtype != np.float32:
74
- audio_array = audio_array.astype(np.float32) / np.iinfo(audio_array.dtype).max
75
- if audio_array.ndim > 1:
76
- audio_array = audio_array.mean(axis=1)
77
- # Cap at 30s — Whisper-small is trained on 30s chunks; longer audio
78
- # would need windowing which slows the demo
79
- max_samples = sample_rate * 30
80
- if len(audio_array) > max_samples:
81
- audio_array = audio_array[:max_samples]
82
- # Resample to 16 kHz
83
- if sample_rate != 16000:
84
- import scipy.signal
85
- num_samples = int(len(audio_array) * 16000 / sample_rate)
86
- audio_array = scipy.signal.resample(audio_array, num_samples).astype(np.float32)
87
-
88
- model, processor = load_asr()
89
- inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
90
- forced_ids = processor.get_decoder_prompt_ids(language="hausa", task="transcribe")
91
- with torch.no_grad():
92
- ids = model.generate(inputs.input_features, forced_decoder_ids=forced_ids, max_new_tokens=128)
93
- text = processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
94
- return text
95
-
96
-
97
- def synthesize_hausa(text: str) -> Optional[tuple]:
98
- """Return (sample_rate, np.ndarray) or None."""
99
- if not text.strip():
100
- return None
101
- model, tokenizer = load_tts()
102
- inputs = tokenizer(text, return_tensors="pt")
103
- with torch.no_grad():
104
- out = model(**inputs).waveform
105
- audio = out.squeeze().cpu().numpy().astype(np.float32)
106
- return (model.config.sampling_rate, audio)
107
-
108
-
109
- # ---------------------------------------------------------------------------
110
- # Core turn handler
111
- # ---------------------------------------------------------------------------
112
- def run_turn(user_text: str, session: dict, trace: list, asr_ms: int = 0) -> tuple:
113
- """
114
- Executes one turn. Returns (bot_prompt_dict, updated_session, trace, tts_audio).
115
- `session` is a serialized dict stored in gr.State.
116
- """
117
- state = DialogueState.from_dict(session) if session else None
118
- if state is None:
119
- state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical="bank")
120
-
121
- turn_trace = []
122
- if asr_ms:
123
- turn_trace.append({"stage": "asr (whisper-small)", "ms": asr_ms,
124
- "detail": f'→ "{user_text}"'})
125
-
126
- t0 = time.time()
127
- expected = get_expected_slot(state.vertical, state.current_state)
128
- intent, entities, nlu_source = nlu_parse(user_text, expected)
129
- nlu_stage_label = {
130
- "rule": "nlu (rule-based)",
131
- "llm": "nlu (qwen2.5-1.5b)",
132
- "rule_fallback": "nlu (rule + llm fallback)",
133
- }.get(nlu_source, "nlu")
134
- turn_trace.append({
135
- "stage": nlu_stage_label,
136
- "ms": max(1, int((time.time() - t0) * 1000)),
137
- "detail": f"intent={intent} entities={entities}",
138
- })
139
-
140
- t1 = time.time()
141
- prev_state = state.current_state
142
- state = transition(state, intent, entities)
143
- turn_trace.append({
144
- "stage": "dialogue_manager",
145
- "ms": max(1, int((time.time() - t1) * 1000)),
146
- "detail": f"{prev_state} → {state.current_state}",
147
- })
148
-
149
- t2 = time.time()
150
- prompt = get_prompt(state.vertical, state.current_state)
151
- turn_trace.append({"stage": "response_gen", "ms": max(1, int((time.time() - t2) * 1000))})
152
-
153
- t3 = time.time()
154
- audio = synthesize_hausa(prompt["ha"])
155
- turn_trace.append({"stage": "tts (mms-tts-hau)", "ms": int((time.time() - t3) * 1000)})
156
-
157
- state.history.append({"role": "user", "text": user_text})
158
- state.history.append({"role": "bot", "text_ha": prompt["ha"], "text_en": prompt["en"]})
159
-
160
- return prompt, state.to_dict(), turn_trace, audio
161
-
162
-
163
- # ---------------------------------------------------------------------------
164
- # WhatsApp-style HTML renderer
165
- # ---------------------------------------------------------------------------
166
- def render_whatsapp(session: dict, pending_user: Optional[str] = None,
167
- pending_is_voice: bool = False) -> str:
168
- vertical = session.get("vertical", "bank") if session else "bank"
169
- name = SCENARIOS[vertical]["name"]
170
- avatar = {"bank": "PB", "telecom": "PT", "ecommerce": "PD"}[vertical]
171
- escalated = session.get("escalate_to_human", False) if session else False
172
-
173
- bubbles = []
174
- history = session.get("history", []) if session else []
175
- for msg in history:
176
- if msg["role"] == "user":
177
- is_voice = msg.get("is_voice", False)
178
- bubbles.append(_user_bubble(msg["text"], is_voice))
179
- else:
180
- bubbles.append(_bot_bubble(msg.get("text_ha", ""), msg.get("text_en", "")))
181
- if pending_user:
182
- bubbles.append(_user_bubble(pending_user, pending_is_voice))
183
-
184
- banner = ('<div class="pw-esc-banner">Session escalated to human agent</div>'
185
- if escalated else "")
186
-
187
- return f"""
188
- <div class="pw-phone">
189
- <div class="pw-ph-header">
190
- <div class="pw-ph-avatar">{avatar}</div>
191
- <div>
192
- <div class="pw-ph-name">{html_lib.escape(name)}</div>
193
- <div class="pw-ph-status">online • voice agent</div>
194
- </div>
195
- </div>
196
- <div class="pw-ph-messages">
197
- {banner}
198
- {"".join(bubbles) if bubbles else '<div style="text-align:center; color:#667781; font-size:12px; padding:40px 0;">Waiting for first message…</div>'}
199
- </div>
200
- </div>
201
- <style>
202
- .pw-phone {{ max-width: 440px; margin: 0 auto; background: #ECE5DD; border-radius: 14px; overflow: hidden; border: 1px solid #ccc; display: flex; flex-direction: column; min-height: 520px; font-family: -apple-system, "Segoe UI", Roboto, sans-serif; }}
203
- .pw-ph-header {{ background: #075E54; color: #fff; padding: 10px 14px; display: flex; align-items: center; gap: 10px; }}
204
- .pw-ph-avatar {{ width: 36px; height: 36px; border-radius: 50%; background: #128C7E; display: flex; align-items: center; justify-content: center; font-weight: 500; font-size: 13px; color: #fff; }}
205
- .pw-ph-name {{ font-size: 14px; font-weight: 500; line-height: 1.2; }}
206
- .pw-ph-status {{ font-size: 11px; color: #D4EDE8; }}
207
- .pw-ph-messages {{ flex: 1; padding: 14px 10px; background: #ECE5DD; background-image: radial-gradient(#D8CFC2 1px, transparent 1px); background-size: 18px 18px; max-height: 460px; overflow-y: auto; min-height: 400px; }}
208
- .pw-b {{ max-width: 80%; padding: 7px 10px 5px; border-radius: 8px; margin-bottom: 6px; font-size: 13.5px; line-height: 1.4; color: #1f2d1f; word-wrap: break-word; }}
209
- .pw-b.user {{ background: #DCF8C6; margin-left: auto; border-bottom-right-radius: 2px; }}
210
- .pw-b.bot {{ background: #fff; margin-right: auto; border-bottom-left-radius: 2px; }}
211
- .pw-b-meta {{ font-size: 10px; color: #667781; margin-top: 3px; text-align: right; }}
212
- .pw-b-trans {{ font-size: 11px; color: #667781; font-style: italic; margin-top: 3px; border-top: 1px solid #E5E5E5; padding-top: 3px; }}
213
- .pw-voice-row {{ display: flex; align-items: center; gap: 8px; }}
214
- .pw-voice-icon {{ width: 22px; height: 22px; border-radius: 50%; background: #128C7E; color: #fff; font-size: 10px; display: flex; align-items: center; justify-content: center; }}
215
- .pw-voice-bars {{ flex: 1; height: 14px; display: flex; align-items: center; gap: 2px; }}
216
- .pw-voice-bars span {{ flex: 1; background: #8D9A9F; border-radius: 1px; }}
217
- .pw-esc-banner {{ background: #FAEEDA; color: #854F0B; font-size: 12px; padding: 8px 12px; border-radius: 8px; margin-bottom: 10px; border: 1px solid #EF9F27; text-align: center; }}
218
- </style>
219
- """
220
-
221
-
222
- def _now() -> str:
223
- return time.strftime("%H:%M")
224
-
225
-
226
- def _user_bubble(text: str, is_voice: bool) -> str:
227
- text_safe = html_lib.escape(text)
228
- if is_voice:
229
- bars = "".join(
230
- f'<span style="height:{4 + int(8 * abs(np.sin(i * 0.7)))}px;"></span>'
231
- for i in range(20)
232
- )
233
- return f'''<div class="pw-b user">
234
- <div class="pw-voice-row">
235
- <div class="pw-voice-icon">▶</div>
236
- <div class="pw-voice-bars">{bars}</div>
237
- </div>
238
- <div style="font-size:12px; color:#667781; margin-top:3px;">"{text_safe}"</div>
239
- <div class="pw-b-meta">{_now()} ✓✓</div>
240
- </div>'''
241
- return f'<div class="pw-b user">{text_safe}<div class="pw-b-meta">{_now()} ✓✓</div></div>'
242
-
243
-
244
- def _bot_bubble(text_ha: str, text_en: str) -> str:
245
- ha_safe = html_lib.escape(text_ha)
246
- en_safe = html_lib.escape(text_en)
247
- return f'''<div class="pw-b bot">
248
- <div>{ha_safe}</div>
249
- <div class="pw-b-trans">{en_safe}</div>
250
- <div class="pw-b-meta">{_now()} ✓✓</div>
251
- </div>'''
252
-
253
-
254
- def render_trace(trace: list) -> str:
255
- if not trace:
256
- return '<div style="color:#888; font-size:13px;">Send a message to see the pipeline trace.</div>'
257
- rows = []
258
- for r in trace:
259
- row = f'<div style="display:flex; justify-content:space-between; padding:5px 0; border-bottom:1px solid #eee;"><span style="color:#5f5e5a;">{html_lib.escape(r["stage"])}</span><span style="color:#0C447C; font-weight:500;">{r["ms"]}ms</span></div>'
260
- rows.append(row)
261
- if r.get("detail"):
262
- rows.append(f'<div style="font-size:11px; color:#888; padding:0 0 5px; font-family:monospace;">{html_lib.escape(str(r["detail"]))}</div>')
263
- return f'<div style="font-family:monospace; font-size:12px;">{"".join(rows)}</div>'
264
-
265
-
266
- def render_metrics(session: dict) -> str:
267
- if not session:
268
- return ""
269
- sid = session.get("session_id", "—")
270
- turn = session.get("turn_count", 0)
271
- state = session.get("current_state", "greeting")
272
- slots = session.get("slots", {})
273
- slots_html = ", ".join(f"<code>{k}={v}</code>" for k, v in slots.items()) or "—"
274
- return f'''
275
- <div style="display:grid; grid-template-columns:1fr 1fr; gap:8px; font-size:13px;">
276
- <div><div style="color:#888; font-size:11px; text-transform:uppercase;">Session</div><div style="font-family:monospace;">{sid}</div></div>
277
- <div><div style="color:#888; font-size:11px; text-transform:uppercase;">Turn</div><div style="font-weight:500;">{turn}</div></div>
278
- <div><div style="color:#888; font-size:11px; text-transform:uppercase;">State</div><div style="font-family:monospace;">{state}</div></div>
279
- <div><div style="color:#888; font-size:11px; text-transform:uppercase;">Slots</div><div>{slots_html}</div></div>
280
- </div>'''
281
-
282
-
283
- # ---------------------------------------------------------------------------
284
- # Gradio event handlers
285
- # ---------------------------------------------------------------------------
286
- def on_vertical_change(vertical: str, synth_greeting: bool = False):
287
- """Reset session when vertical changes. TTS the greeting only on first real
288
- user interaction — keeps initial page load fast (avoids MMS-TTS cold-start)."""
289
- state = DialogueState(session_id="sess_" + uuid.uuid4().hex[:8], vertical=vertical)
290
- greet = get_prompt(vertical, "greeting")
291
- state.history.append({"role": "bot", "text_ha": greet["ha"], "text_en": greet["en"]})
292
- session = state.to_dict()
293
- audio = None
294
- if synth_greeting:
295
- try:
296
- audio = synthesize_hausa(greet["ha"])
297
- except Exception as e:
298
- print(f"TTS failed on greeting: {e}")
299
- return (
300
- session,
301
- render_whatsapp(session),
302
- render_trace([]),
303
- render_metrics(session),
304
- audio,
305
- )
306
-
307
-
308
- def on_text_submit(text: str, session: dict):
309
- if not text or not text.strip():
310
- return session, render_whatsapp(session), render_trace([]), render_metrics(session), None, ""
311
- prompt, new_session, trace, audio = run_turn(text, session, [], asr_ms=0)
312
- return (
313
- new_session,
314
- render_whatsapp(new_session),
315
- render_trace(trace),
316
- render_metrics(new_session),
317
- audio,
318
- "", # clear input
319
- )
320
-
321
-
322
- def on_audio_submit(audio_data, session: dict):
323
- if audio_data is None:
324
- return session, render_whatsapp(session), render_trace([]), render_metrics(session), None
325
- t0 = time.time()
326
- try:
327
- text = transcribe_hausa(audio_data)
328
- except Exception as e:
329
- print(f"ASR failed: {e}")
330
- return session, render_whatsapp(session), render_trace([{"stage": "asr error", "ms": 0, "detail": str(e)}]), render_metrics(session), None
331
- asr_ms = int((time.time() - t0) * 1000)
332
- if not text:
333
- return session, render_whatsapp(session), render_trace([{"stage": "asr", "ms": asr_ms, "detail": "(no speech detected)"}]), render_metrics(session), None
334
- # Mark last user message as voice after appending
335
- prompt, new_session, trace, audio = run_turn(text, session, [], asr_ms=asr_ms)
336
- # Tag the last user entry as voice
337
- if new_session.get("history"):
338
- for i in range(len(new_session["history"]) - 1, -1, -1):
339
- if new_session["history"][i]["role"] == "user":
340
- new_session["history"][i]["is_voice"] = True
341
- break
342
- return (
343
- new_session,
344
- render_whatsapp(new_session),
345
- render_trace(trace),
346
- render_metrics(new_session),
347
- audio,
348
- )
349
-
350
-
351
- def on_reset(session: dict):
352
- vertical = session.get("vertical", "bank") if session else "bank"
353
- return on_vertical_change(vertical)
354
-
355
-
356
- def on_escalate(session: dict):
357
- return on_text_submit("Ina son wakili mutum", session)
358
-
359
-
360
- # ---------------------------------------------------------------------------
361
- # Preset phrases for quick-click demo
362
- # ---------------------------------------------------------------------------
363
- PRESETS = {
364
- "bank": ["duba ma'auni", "toshe kati", "canjin kuɗi", "1234", "Aisha", "dubu biyar", "i"],
365
- "telecom": ["saya airtime", "saya bundle", "korafi", "1000", "rana", "Intanet bai aiki"],
366
- "ecommerce": ["bincika oda", "sake tsara", "mayar da kaya", "10234", "jumma'a", "Ya lalace"],
367
- }
368
-
369
-
370
- # ---------------------------------------------------------------------------
371
- # Gradio UI
372
- # ---------------------------------------------------------------------------
373
- CUSTOM_CSS = """
374
- .gradio-container { max-width: 1200px !important; }
375
- #vertical-selector { background: #fff; border-radius: 10px; padding: 12px; }
376
- #whatsapp-html { background: #f5f4ef; border-radius: 12px; padding: 20px; }
377
- #trace-box, #metrics-box { background: #fff; border-radius: 10px; padding: 12px; border: 1px solid #e5e5e5; }
378
- h1 { font-size: 22px !important; font-weight: 500 !important; }
379
- .header-sub { color: #5f5e5a; font-size: 14px; margin-top: -8px; margin-bottom: 16px; }
380
- """
381
-
382
- with gr.Blocks(css=CUSTOM_CSS, title="PlotWeaver Voice Agent") as demo:
383
- gr.HTML("""
384
- <h1 style="margin-bottom:4px;">PlotWeaver Voice Agent</h1>
385
- <p class="header-sub">Hausa-first conversational AI for African banks, telecoms, and delivery services. Real Whisper-small ASR and MMS-TTS Hausa running on CPU.</p>
386
- """)
387
-
388
- session_state = gr.State({})
389
-
390
- with gr.Row():
391
- # Left column: controls + trace
392
- with gr.Column(scale=1):
393
- gr.Markdown("### Select vertical")
394
- vertical_radio = gr.Radio(
395
- choices=[("PlotWeaver Bank", "bank"),
396
- ("PlotWeaver Telecom", "telecom"),
397
- ("PlotWeaver Delivery", "ecommerce")],
398
- value="bank",
399
- label="",
400
- elem_id="vertical-selector",
401
- )
402
-
403
- with gr.Row():
404
- reset_btn = gr.Button("Reset session", size="sm")
405
- escalate_btn = gr.Button("Force escalate", size="sm")
406
-
407
- gr.Markdown("### Session metrics")
408
- metrics_html = gr.HTML(elem_id="metrics-box")
409
-
410
- gr.Markdown("### Pipeline trace (last turn)")
411
- trace_html = gr.HTML(elem_id="trace-box")
412
-
413
- # Middle column: WhatsApp mockup
414
- with gr.Column(scale=2):
415
- whatsapp_html = gr.HTML(elem_id="whatsapp-html")
416
-
417
- with gr.Row():
418
- text_input = gr.Textbox(
419
- placeholder="Type in Hausa… e.g. 'duba ma'auni'",
420
- label="",
421
- scale=4,
422
- container=False,
423
- )
424
- send_btn = gr.Button("Send", scale=1, variant="primary")
425
-
426
- gr.Markdown("**Or speak / upload audio in Hausa:**")
427
- audio_input = gr.Audio(
428
- sources=["microphone", "upload"],
429
- type="numpy",
430
- label="Record or upload a Hausa audio file (.wav, .mp3, .ogg)",
431
- show_download_button=False,
432
- )
433
- with gr.Row():
434
- transcribe_btn = gr.Button("Transcribe & send", variant="secondary", size="sm")
435
- clear_audio_btn = gr.Button("Clear", size="sm")
436
-
437
- bot_audio = gr.Audio(
438
- label="Bot response (Hausa TTS)",
439
- autoplay=True,
440
- interactive=False,
441
- )
442
-
443
- # Preset quick-clicks
444
- gr.Markdown("### Quick phrases (Hausa)")
445
- preset_btns = []
446
- with gr.Row():
447
- for p in PRESETS["bank"]:
448
- preset_btns.append(gr.Button(p, size="sm"))
449
-
450
- # -----------------------------------------------------------------------
451
- # Event wiring
452
- # -----------------------------------------------------------------------
453
- outputs = [session_state, whatsapp_html, trace_html, metrics_html, bot_audio]
454
-
455
- demo.load(
456
- fn=lambda: on_vertical_change("bank"),
457
- outputs=outputs,
458
- )
459
-
460
- vertical_radio.change(
461
- fn=on_vertical_change,
462
- inputs=[vertical_radio],
463
- outputs=outputs,
464
- )
465
-
466
- send_btn.click(
467
- fn=on_text_submit,
468
- inputs=[text_input, session_state],
469
- outputs=outputs + [text_input],
470
- )
471
- text_input.submit(
472
- fn=on_text_submit,
473
- inputs=[text_input, session_state],
474
- outputs=outputs + [text_input],
475
- )
476
-
477
- audio_input.stop_recording(
478
- fn=on_audio_submit,
479
- inputs=[audio_input, session_state],
480
- outputs=outputs,
481
- )
482
- transcribe_btn.click(
483
- fn=on_audio_submit,
484
- inputs=[audio_input, session_state],
485
- outputs=outputs,
486
- )
487
- clear_audio_btn.click(
488
- fn=lambda: None,
489
- outputs=[audio_input],
490
- )
491
-
492
- reset_btn.click(fn=on_reset, inputs=[session_state], outputs=outputs)
493
- escalate_btn.click(
494
- fn=on_escalate,
495
- inputs=[session_state],
496
- outputs=outputs + [text_input],
497
- )
498
-
499
- # Preset buttons submit their own text
500
- for btn, phrase in zip(preset_btns, PRESETS["bank"]):
501
- btn.click(
502
- fn=lambda s, _phrase=phrase: on_text_submit(_phrase, s),
503
- inputs=[session_state],
504
- outputs=outputs + [text_input],
505
- )
506
-
507
-
508
- if __name__ == "__main__":
509
- demo.launch()