Spaces:
Sleeping
Sleeping
File size: 5,301 Bytes
226ff5d 38a5904 226ff5d 38a5904 226ff5d e75bac4 226ff5d e75bac4 226ff5d e75bac4 226ff5d e75bac4 38a5904 986403e e75bac4 226ff5d 38a5904 e75bac4 38a5904 986403e 38a5904 226ff5d 38a5904 986403e 38a5904 986403e 38a5904 986403e 226ff5d e75bac4 226ff5d e75bac4 986403e 226ff5d c2e783d 226ff5d 38a5904 226ff5d e75bac4 986403e e75bac4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | import asyncio
import logging
from typing import Callable, Awaitable
from src.config import WAKE_WORDS, WAKE_WORD_ENABLED
from src.stt.deepgram_client import DeepgramStreamer
from src.tts.cartesia_client import synthesize_stream as cartesia_synthesize_stream
from src.tts.gemini_client import synthesize_stream as gemini_synthesize_stream, GEMINI_SAMPLE_RATE
logger = logging.getLogger(__name__)
SendAudioCallback = Callable[[bytes], Awaitable[None]]
SendEventCallback = Callable[[dict], Awaitable[None]]
class VoicePipeline:
"""
Voice pipeline: Audio in β STT β wake word check β chatbot agent β TTS β audio out.
After wake word detection, the text following the wake word is sent to the
chatbot service (POST /api/agents/call). The response field answer.audio_text
is used for TTS; falls back to answer.content if audio_text is null.
"""
def __init__(
self,
send_audio: SendAudioCallback,
send_event: SendEventCallback,
stt_provider: str = "deepgram",
tts_provider: str = "cartesia",
wake_word_enabled: bool = WAKE_WORD_ENABLED,
**kwargs,
):
self._send_audio = send_audio
self._send_event = send_event
self._loop = asyncio.get_event_loop()
self._tts_provider = tts_provider
self._wake_word_enabled = wake_word_enabled
if stt_provider == "gemini":
from src.stt.gemini_stt import GeminiSTTStreamer
self._stt = GeminiSTTStreamer(
on_final_transcript=self._on_final_transcript,
loop=self._loop,
)
elif stt_provider == "chirp3":
from src.stt.chirp3_client import Chirp3STTStreamer
self._stt = Chirp3STTStreamer(
on_final_transcript=self._on_final_transcript,
loop=self._loop,
)
else:
self._stt = DeepgramStreamer(
on_final_transcript=self._on_final_transcript,
loop=self._loop,
)
self._tts_lock = asyncio.Lock()
self._tts_task: asyncio.Task | None = None
def start(self) -> None:
self._stt.start()
def feed_audio(self, chunk: bytes) -> None:
self._stt.send_audio(chunk)
async def _on_final_transcript(self, text: str) -> None:
if self._wake_word_enabled:
question = self._extract_question(text)
if question is None:
logger.debug("No wake word detected, ignoring: %s", text)
return
logger.info("Wake word detected, question: %s", question)
await self._send_event({"event": "transcript", "text": question})
else:
if not text.strip():
return
logger.info("Transcript: %s", text)
await self._send_event({"event": "transcript", "text": text})
def _extract_question(self, transcript: str) -> str | None:
lower = transcript.lower()
match = min(
((lower.index(w), w) for w in WAKE_WORDS if w in lower),
key=lambda x: x[0],
default=None,
)
if match is None:
return None
idx = match[0] + len(match[1])
question = transcript[idx:].strip(" ,.")
return question if question else None
async def speak(self, text: str) -> None:
"""Called by the WebSocket handler when the FE sends a speak action."""
self._tts_task = asyncio.create_task(self._speak(text))
async def _speak(self, text: str) -> None:
logger.info("TTS speak: provider=%s, text_len=%d, text=%r", self._tts_provider, len(text), text)
async with self._tts_lock:
try:
stream = (
gemini_synthesize_stream(text)
if self._tts_provider == "gemini"
else cartesia_synthesize_stream(text)
)
async for audio_chunk in stream:
await self._send_audio(audio_chunk)
await self._send_event({"event": "tts_end"})
except asyncio.CancelledError:
raise
except Exception:
logger.exception("TTS error for text: %s", text)
await self._send_event({
"event": "error",
"code": "TTS_ERROR",
"message": "TTS generation failed",
})
async def interrupt(self) -> None:
if self._tts_task and not self._tts_task.done():
self._tts_task.cancel()
try:
await self._tts_task
except asyncio.CancelledError:
pass
await self._send_event({"event": "interrupted"})
logger.info("TTS interrupted by client")
def stop(self) -> None:
if self._tts_task and not self._tts_task.done():
self._tts_task.cancel()
self._stt.stop()
async def stop_async(self) -> None:
if self._tts_task and not self._tts_task.done():
self._tts_task.cancel()
if hasattr(self._stt, "stop_and_transcribe"):
await self._stt.stop_and_transcribe()
else:
self._stt.stop()
# Backward-compatible alias
EchoPipeline = VoicePipeline
|