Spaces:
Sleeping
Sleeping
| """ | |
| Twilio Media Streams (bidirectional) + Vosk + OpenAI Answer + Piper -> Twilio playback | |
| + Live UI (web_demo) showing STT/LLM in realtime | |
| + Multi-call UI support (separate calls by streamSid) | |
| """ | |
| import asyncio | |
| import base64 | |
| import json | |
| import logging | |
| import os | |
| import tempfile | |
| import time | |
| import audioop | |
| import subprocess | |
| from dataclasses import dataclass, field | |
| from typing import Optional, List, Dict | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request | |
| from fastapi.responses import PlainTextResponse, Response, FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from vosk import Model, KaldiRecognizer | |
| from openai import OpenAI | |
| # ---------------------------- | |
| # Logging | |
| # ---------------------------- | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| log = logging.getLogger("app") | |
| def P(tag: str, msg: str): | |
| print(f"{tag} {msg}", flush=True) | |
| # ---------------------------- | |
| # Env | |
| # ---------------------------- | |
| VOSK_MODEL_PATH = os.getenv("VOSK_MODEL_PATH", "/app/models/vosk-model-en-us-0.22-lgraph").strip() | |
| TWILIO_STREAM_URL = os.getenv("TWILIO_STREAM_URL", "").strip() | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "").strip() | |
| OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini").strip() | |
| PIPER_BIN = os.getenv("PIPER_BIN", "piper").strip() | |
| PIPER_MODEL_PATH = os.getenv("PIPER_MODEL_PATH", "").strip() | |
| HOST = "0.0.0.0" | |
| PORT = int(os.getenv("PORT", "7860")) # HF uses 7860 | |
| # ---------------------------- | |
| # FastAPI | |
| # ---------------------------- | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ---------------------------- | |
| # Frontend serving (web_demo/dist) | |
| # ---------------------------- | |
| FRONTEND_DIR = os.path.join(os.getcwd(), "web_demo", "dist") | |
| ASSETS_DIR = os.path.join(FRONTEND_DIR, "assets") | |
| if os.path.isdir(ASSETS_DIR): | |
| app.mount("/assets", StaticFiles(directory=ASSETS_DIR), name="assets") | |
| async def serve_index(): | |
| index_path = os.path.join(FRONTEND_DIR, "index.html") | |
| if os.path.isfile(index_path): | |
| return FileResponse(index_path) | |
| return PlainTextResponse("UI not built. Ensure web_demo/dist exists.", status_code=404) | |
| # SPA fallback (optional): if you later add routes like /dashboard | |
| async def serve_spa_fallback(path: str): | |
| candidate = os.path.join(FRONTEND_DIR, path) | |
| if os.path.isfile(candidate): | |
| return FileResponse(candidate) | |
| index_path = os.path.join(FRONTEND_DIR, "index.html") | |
| if os.path.isfile(index_path): | |
| return FileResponse(index_path) | |
| return PlainTextResponse("Not Found", status_code=404) | |
| # ---------------------------- | |
| # Audio / Twilio | |
| # ---------------------------- | |
| FRAME_MS = 20 | |
| INPUT_RATE = 8000 | |
| STT_RATE = 16000 | |
| BYTES_PER_20MS_MULAW = int(INPUT_RATE * (FRAME_MS / 1000.0)) # 160 bytes @ 8kHz, 20ms | |
| # ---------------------------- | |
| # VAD settings | |
| # ---------------------------- | |
| RMS_SPEECH_THRESHOLD = 550 | |
| SPEECH_START_FRAMES = 3 | |
| SPEECH_END_SILENCE_FRAMES = 40 # 800ms | |
| MAX_UTTERANCE_MS = 12000 | |
| PARTIAL_EMIT_EVERY_MS = 250 | |
| # ---------------------------- | |
| # LLM prompt | |
| # ---------------------------- | |
| SYSTEM_PROMPT = ( | |
| "You are a phone-call assistant. " | |
| "Reply in 1 short sentence (max 15 words). " | |
| "No filler. No greetings unless user greets first." | |
| ) | |
| # ---------------------------- | |
| # Cached Vosk model | |
| # ---------------------------- | |
| _VOSK_MODEL = None | |
| def now_ms() -> int: | |
| return int(time.time() * 1000) | |
| def build_twiml(stream_url: str) -> str: | |
| return f"""<?xml version="1.0" encoding="UTF-8"?> | |
| <Response> | |
| <Connect> | |
| <Stream url="{stream_url}" /> | |
| </Connect> | |
| <Pause length="600"/> | |
| </Response> | |
| """ | |
| def split_mulaw_frames(mulaw_bytes: bytes) -> List[bytes]: | |
| frames = [] | |
| for i in range(0, len(mulaw_bytes), BYTES_PER_20MS_MULAW): | |
| chunk = mulaw_bytes[i:i + BYTES_PER_20MS_MULAW] | |
| if len(chunk) < BYTES_PER_20MS_MULAW: | |
| chunk += b"\xFF" * (BYTES_PER_20MS_MULAW - len(chunk)) | |
| frames.append(chunk) | |
| return frames | |
| async def drain_queue(q: asyncio.Queue): | |
| try: | |
| while True: | |
| q.get_nowait() | |
| q.task_done() | |
| except asyncio.QueueEmpty: | |
| return | |
| # ---------------------------- | |
| # UI live dashboard (multi-call) | |
| # ---------------------------- | |
| _UI_CLIENTS = set() | |
| _UI_LOCK = asyncio.Lock() | |
| ACTIVE_CALLS: Dict[str, Dict] = {} # key: streamSid | |
| ACTIVE_LOCK = asyncio.Lock() | |
| async def ui_broadcast(event: str, data: dict): | |
| msg = {"event": event, "data": data, "ts_ms": now_ms()} | |
| dead = [] | |
| async with _UI_LOCK: | |
| for c in list(_UI_CLIENTS): | |
| try: | |
| await c.send_text(json.dumps(msg)) | |
| except Exception: | |
| dead.append(c) | |
| for c in dead: | |
| _UI_CLIENTS.discard(c) | |
| async def upsert_call(stream_sid: str, **fields): | |
| if not stream_sid: | |
| return | |
| async with ACTIVE_LOCK: | |
| row = ACTIVE_CALLS.get(stream_sid, {}) | |
| row.update(fields) | |
| ACTIVE_CALLS[stream_sid] = row | |
| async def remove_call(stream_sid: str): | |
| if not stream_sid: | |
| return | |
| async with ACTIVE_LOCK: | |
| ACTIVE_CALLS.pop(stream_sid, None) | |
| async def ui_ws(ws: WebSocket): | |
| await ws.accept() | |
| async with _UI_LOCK: | |
| _UI_CLIENTS.add(ws) | |
| try: | |
| while True: | |
| await asyncio.sleep(60) | |
| except WebSocketDisconnect: | |
| pass | |
| finally: | |
| async with _UI_LOCK: | |
| _UI_CLIENTS.discard(ws) | |
| async def ui_calls(): | |
| async with ACTIVE_LOCK: | |
| return {k: dict(v) for k, v in ACTIVE_CALLS.items()} | |
| # ---------------------------- | |
| # OpenAI | |
| # ---------------------------- | |
| def openai_client() -> OpenAI: | |
| if not OPENAI_API_KEY: | |
| raise RuntimeError("OPENAI_API_KEY not set") | |
| return OpenAI(api_key=OPENAI_API_KEY) | |
| def openai_answer_blocking(history: List[Dict], user_text: str) -> str: | |
| client = openai_client() | |
| msgs = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| tail = history[-6:] if len(history) > 1 else [] | |
| msgs.extend(tail) | |
| msgs.append({"role": "user", "content": user_text}) | |
| resp = client.chat.completions.create( | |
| model=OPENAI_MODEL, | |
| messages=msgs, | |
| temperature=0.3, | |
| max_tokens=80, | |
| ) | |
| ans = (resp.choices[0].message.content or "").strip() | |
| return ans | |
| # ---------------------------- | |
| # Piper TTS -> 8k mulaw | |
| # ---------------------------- | |
| def piper_tts_to_mulaw(text: str) -> bytes: | |
| if not PIPER_MODEL_PATH: | |
| raise RuntimeError("PIPER_MODEL_PATH not set") | |
| text = (text or "").strip() | |
| if not text: | |
| return b"" | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as wavf: | |
| wav_path = wavf.name | |
| with tempfile.NamedTemporaryFile(suffix=".mulaw", delete=False) as mlf: | |
| mulaw_path = mlf.name | |
| try: | |
| r1 = subprocess.run( | |
| [PIPER_BIN, "--model", PIPER_MODEL_PATH, "--output_file", wav_path], | |
| input=text.encode("utf-8"), | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| ) | |
| if r1.returncode != 0: | |
| raise RuntimeError(f"piper rc={r1.returncode} stderr={r1.stderr.decode('utf-8','ignore')[:500]}") | |
| af = "highpass=f=200,lowpass=f=3400,compand=attacks=0:decays=0.3:points=-80/-80|-20/-10|0/-3" | |
| r2 = subprocess.run( | |
| ["ffmpeg", "-y", "-i", wav_path, | |
| "-ac", "1", "-ar", "8000", | |
| "-af", af, | |
| "-f", "mulaw", mulaw_path], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| ) | |
| if r2.returncode != 0: | |
| raise RuntimeError(f"ffmpeg rc={r2.returncode} stderr={r2.stderr.decode('utf-8','ignore')[:500]}") | |
| with open(mulaw_path, "rb") as f: | |
| data = f.read() | |
| return data | |
| finally: | |
| for p in (wav_path, mulaw_path): | |
| try: | |
| os.unlink(p) | |
| except Exception: | |
| pass | |
| # ---------------------------- | |
| # Call state | |
| # ---------------------------- | |
| class CancelFlag: | |
| is_set: bool = False | |
| def set(self): | |
| self.is_set = True | |
| class CallState: | |
| call_id: str | |
| stream_sid: str = "" | |
| # vad | |
| in_speech: bool = False | |
| speech_start_count: int = 0 | |
| silence_count: int = 0 | |
| utter_start_ms: int = 0 | |
| rec: Optional[KaldiRecognizer] = None | |
| # partials | |
| last_partial: str = "" | |
| last_partial_emit_ms: int = 0 | |
| # outbound | |
| outbound_q: asyncio.Queue = field(default_factory=lambda: asyncio.Queue(maxsize=50000)) | |
| outbound_task: Optional[asyncio.Task] = None | |
| keepalive_task: Optional[asyncio.Task] = None | |
| mark_i: int = 0 | |
| # speaking / generation | |
| bot_speaking: bool = False | |
| cancel_llm: CancelFlag = field(default_factory=CancelFlag) | |
| tts_generation_id: int = 0 | |
| # conversation history | |
| history: List[Dict] = field(default_factory=list) | |
| bot_lock: asyncio.Lock = field(default_factory=asyncio.Lock) | |
| def bump_tts_generation(self) -> int: | |
| self.tts_generation_id += 1 | |
| return self.tts_generation_id | |
| # ---------------------------- | |
| # Keepalive marks | |
| # ---------------------------- | |
| async def twilio_keepalive(ws: WebSocket, st: CallState): | |
| try: | |
| while True: | |
| await asyncio.sleep(10) | |
| if st.stream_sid: | |
| st.mark_i += 1 | |
| name = f"ka_{st.mark_i}" | |
| await ws.send_text(json.dumps({ | |
| "event": "mark", | |
| "streamSid": st.stream_sid, | |
| "mark": {"name": name}, | |
| })) | |
| except asyncio.CancelledError: | |
| return | |
| except Exception as e: | |
| P("SYS>", f"keepalive_error={e}") | |
| # ---------------------------- | |
| # HTTP | |
| # ---------------------------- | |
| async def health(): | |
| return {"ok": True} | |
| async def voice(request: Request): | |
| stream_url = TWILIO_STREAM_URL | |
| if not stream_url: | |
| host = request.headers.get("host") | |
| if host: | |
| stream_url = f"wss://{host}/stream" | |
| P("SYS>", f"auto_stream_url={stream_url}") | |
| if not stream_url: | |
| return PlainTextResponse("TWILIO_STREAM_URL not set and host not found", status_code=500) | |
| return Response(content=build_twiml(stream_url), media_type="application/xml") | |
| async def voice_get(request: Request): | |
| return await voice(request) | |
| # ---------------------------- | |
| # WebSocket /stream (Twilio) | |
| # ---------------------------- | |
| async def stream(ws: WebSocket): | |
| await ws.accept() | |
| st = CallState(call_id=str(id(ws))) | |
| st.history = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| P("SYS>", f"ws_open call_id={st.call_id}") | |
| global _VOSK_MODEL | |
| if _VOSK_MODEL is None: | |
| P("SYS>", f"loading_vosk={VOSK_MODEL_PATH}") | |
| _VOSK_MODEL = Model(VOSK_MODEL_PATH) | |
| P("SYS>", "vosk_loaded") | |
| st.rec = KaldiRecognizer(_VOSK_MODEL, STT_RATE) | |
| st.rec.SetWords(False) | |
| st.outbound_task = asyncio.create_task(outbound_sender(ws, st)) | |
| try: | |
| while True: | |
| raw = await ws.receive_text() | |
| msg = json.loads(raw) | |
| event = msg.get("event") | |
| if event == "start": | |
| st.stream_sid = msg["start"]["streamSid"] | |
| P("TWILIO>", f"start streamSid={st.stream_sid}") | |
| await upsert_call( | |
| st.stream_sid, | |
| call_id=st.call_id, | |
| started_ms=now_ms(), | |
| last_seen_ms=now_ms(), | |
| last_event="start", | |
| ) | |
| await ui_broadcast("call_start", {"streamSid": st.stream_sid, "call_id": st.call_id}) | |
| if st.keepalive_task is None: | |
| st.keepalive_task = asyncio.create_task(twilio_keepalive(ws, st)) | |
| # greeting (optional) | |
| asyncio.create_task(speak_text(ws, st, "Hi! How can I help?")) | |
| elif event == "media": | |
| mulaw = base64.b64decode(msg["media"]["payload"]) | |
| pcm16_8k = audioop.ulaw2lin(mulaw, 2) | |
| pcm16_16k, _ = audioop.ratecv(pcm16_8k, 2, 1, INPUT_RATE, STT_RATE, None) | |
| rms = audioop.rms(pcm16_16k, 2) | |
| is_speech = rms >= RMS_SPEECH_THRESHOLD | |
| # barge-in: cancel current bot audio if caller speaks | |
| if st.bot_speaking and is_speech: | |
| await barge_in(ws, st) | |
| await vad_and_stt(ws, st, pcm16_16k, is_speech) | |
| elif event == "mark": | |
| # ignore; keepalive | |
| pass | |
| elif event == "stop": | |
| P("TWILIO>", "stop") | |
| break | |
| except WebSocketDisconnect: | |
| P("SYS>", "ws_disconnect") | |
| except Exception as e: | |
| P("SYS>", f"ws_error={e}") | |
| log.exception("ws_error") | |
| finally: | |
| if st.stream_sid: | |
| await remove_call(st.stream_sid) | |
| await ui_broadcast("call_end", {"streamSid": st.stream_sid}) | |
| if st.keepalive_task: | |
| st.keepalive_task.cancel() | |
| if st.outbound_task: | |
| st.outbound_task.cancel() | |
| P("SYS>", "ws_closed") | |
| # ---------------------------- | |
| # VAD + STT | |
| # ---------------------------- | |
| async def vad_and_stt(ws: WebSocket, st: CallState, pcm16_16k: bytes, is_speech: bool): | |
| t = now_ms() | |
| if not st.in_speech: | |
| if is_speech: | |
| st.speech_start_count += 1 | |
| if st.speech_start_count >= SPEECH_START_FRAMES: | |
| st.in_speech = True | |
| st.silence_count = 0 | |
| st.utter_start_ms = t | |
| st.speech_start_count = 0 | |
| st.last_partial = "" | |
| st.last_partial_emit_ms = 0 | |
| st.rec = KaldiRecognizer(_VOSK_MODEL, STT_RATE) | |
| st.rec.SetWords(False) | |
| else: | |
| st.speech_start_count = 0 | |
| return | |
| st.rec.AcceptWaveform(pcm16_16k) | |
| if t - st.last_partial_emit_ms >= PARTIAL_EMIT_EVERY_MS: | |
| st.last_partial_emit_ms = t | |
| try: | |
| pj = json.loads(st.rec.PartialResult() or "{}") | |
| partial = (pj.get("partial") or "").strip() | |
| except Exception: | |
| partial = "" | |
| if partial and partial != st.last_partial: | |
| st.last_partial = partial | |
| P("STT_PART>", partial) | |
| await upsert_call(st.stream_sid, last_seen_ms=t, last_event="stt_partial") | |
| await ui_broadcast("stt_partial", {"streamSid": st.stream_sid, "text": partial}) | |
| if (t - st.utter_start_ms) > MAX_UTTERANCE_MS: | |
| await finalize_utterance(ws, st, "max_utterance") | |
| return | |
| if is_speech: | |
| st.silence_count = 0 | |
| return | |
| st.silence_count += 1 | |
| if st.silence_count >= SPEECH_END_SILENCE_FRAMES: | |
| await finalize_utterance(ws, st, f"vad_silence_{SPEECH_END_SILENCE_FRAMES*FRAME_MS}ms") | |
| async def finalize_utterance(ws: WebSocket, st: CallState, reason: str): | |
| if not st.in_speech: | |
| return | |
| st.in_speech = False | |
| st.silence_count = 0 | |
| st.speech_start_count = 0 | |
| st.last_partial = "" | |
| try: | |
| j = json.loads(st.rec.FinalResult() or "{}") | |
| except Exception: | |
| j = {} | |
| user_text = (j.get("text") or "").strip() | |
| if not user_text: | |
| return | |
| P("STT_FINAL>", f"{user_text} ({reason})") | |
| await upsert_call(st.stream_sid, last_seen_ms=now_ms(), last_event="stt_final", last_user_text=user_text) | |
| await ui_broadcast("stt_final", {"streamSid": st.stream_sid, "text": user_text, "reason": reason}) | |
| async def bot_job(): | |
| async with st.bot_lock: | |
| await answer_and_speak(ws, st, user_text) | |
| asyncio.create_task(bot_job()) | |
| # ---------------------------- | |
| # LLM Answer -> Speak | |
| # ---------------------------- | |
| async def answer_and_speak(ws: WebSocket, st: CallState, user_text: str): | |
| st.cancel_llm = CancelFlag(False) | |
| st.history.append({"role": "user", "content": user_text}) | |
| st.history = st.history[:1] + st.history[-8:] | |
| loop = asyncio.get_running_loop() | |
| def worker(): | |
| return openai_answer_blocking(st.history, user_text) | |
| ans = await loop.run_in_executor(None, worker) | |
| ans = (ans or "").strip() | |
| if not ans: | |
| ans = "Sorry, I didn’t catch that." | |
| P("LLM_ANS>", ans) | |
| await upsert_call(st.stream_sid, last_seen_ms=now_ms(), last_event="llm_ans", last_bot_text=ans) | |
| await ui_broadcast("llm_ans", {"streamSid": st.stream_sid, "text": ans}) | |
| st.history.append({"role": "assistant", "content": ans}) | |
| st.history = st.history[:1] + st.history[-8:] | |
| await speak_text(ws, st, ans) | |
| # ---------------------------- | |
| # Barge-in (clear + drain) | |
| # ---------------------------- | |
| async def barge_in(ws: WebSocket, st: CallState): | |
| st.cancel_llm.set() | |
| st.bump_tts_generation() # invalidate older audio | |
| if st.stream_sid: | |
| try: | |
| await ws.send_text(json.dumps({"event": "clear", "streamSid": st.stream_sid})) | |
| P("TWILIO>", "sent_clear") | |
| except Exception: | |
| pass | |
| await drain_queue(st.outbound_q) | |
| st.bot_speaking = False | |
| # ---------------------------- | |
| # Speak / TTS (no clear here; clear only on barge-in) | |
| # ---------------------------- | |
| async def speak_text(ws: WebSocket, st: CallState, text: str): | |
| gen = st.bump_tts_generation() | |
| await tts_enqueue(st, text, gen) | |
| async def tts_enqueue(st: CallState, text: str, gen: int): | |
| my_gen = gen | |
| st.bot_speaking = True | |
| P("TTS>", f"text={text} gen={my_gen}") | |
| await ui_broadcast("tts", {"streamSid": st.stream_sid, "text": text, "gen": my_gen}) | |
| loop = asyncio.get_running_loop() | |
| try: | |
| mulaw_bytes = await loop.run_in_executor(None, piper_tts_to_mulaw, text) | |
| except Exception as e: | |
| P("TTS_ERR>", str(e)) | |
| st.bot_speaking = False | |
| return | |
| if my_gen != st.tts_generation_id: | |
| return | |
| # enqueue audio frames | |
| for fr in split_mulaw_frames(mulaw_bytes): | |
| if my_gen != st.tts_generation_id: | |
| return | |
| await st.outbound_q.put(base64.b64encode(fr).decode("ascii")) | |
| # add a short silence tail to prevent cutoff | |
| silence = base64.b64encode(b"\xFF" * BYTES_PER_20MS_MULAW).decode("ascii") | |
| for _ in range(6): # ~120ms | |
| await st.outbound_q.put(silence) | |
| await st.outbound_q.put("__END_CHUNK__") | |
| async def outbound_sender(ws: WebSocket, st: CallState): | |
| try: | |
| while True: | |
| item = await st.outbound_q.get() | |
| if item == "__END_CHUNK__": | |
| await asyncio.sleep(0.02) | |
| if st.outbound_q.empty(): | |
| st.bot_speaking = False | |
| st.outbound_q.task_done() | |
| continue | |
| if not st.stream_sid: | |
| st.outbound_q.task_done() | |
| continue | |
| await ws.send_text(json.dumps({ | |
| "event": "media", | |
| "streamSid": st.stream_sid, | |
| "media": {"payload": item}, | |
| })) | |
| st.outbound_q.task_done() | |
| await asyncio.sleep(FRAME_MS / 1000.0) | |
| except asyncio.CancelledError: | |
| return | |
| except Exception as e: | |
| P("SYS>", f"outbound_sender_error={e}") | |
| log.exception("outbound_sender_error") | |
| # ---------------------------- | |
| # main | |
| # ---------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| P("SYS>", f"starting {HOST}:{PORT}") | |
| uvicorn.run(app, host=HOST, port=PORT) |