import asyncio import logging from typing import Callable, Awaitable from src.config import WAKE_WORDS, WAKE_WORD_ENABLED from src.stt.deepgram_client import DeepgramStreamer from src.tts.cartesia_client import synthesize_stream as cartesia_synthesize_stream from src.tts.gemini_client import synthesize_stream as gemini_synthesize_stream, GEMINI_SAMPLE_RATE logger = logging.getLogger(__name__) SendAudioCallback = Callable[[bytes], Awaitable[None]] SendEventCallback = Callable[[dict], Awaitable[None]] class VoicePipeline: """ Voice pipeline: Audio in → STT → wake word check → chatbot agent → TTS → audio out. After wake word detection, the text following the wake word is sent to the chatbot service (POST /api/agents/call). The response field answer.audio_text is used for TTS; falls back to answer.content if audio_text is null. """ def __init__( self, send_audio: SendAudioCallback, send_event: SendEventCallback, stt_provider: str = "deepgram", tts_provider: str = "cartesia", wake_word_enabled: bool = WAKE_WORD_ENABLED, **kwargs, ): self._send_audio = send_audio self._send_event = send_event self._loop = asyncio.get_event_loop() self._tts_provider = tts_provider self._wake_word_enabled = wake_word_enabled if stt_provider == "gemini": from src.stt.gemini_stt import GeminiSTTStreamer self._stt = GeminiSTTStreamer( on_final_transcript=self._on_final_transcript, loop=self._loop, ) elif stt_provider == "chirp3": from src.stt.chirp3_client import Chirp3STTStreamer self._stt = Chirp3STTStreamer( on_final_transcript=self._on_final_transcript, loop=self._loop, ) else: self._stt = DeepgramStreamer( on_final_transcript=self._on_final_transcript, loop=self._loop, ) self._tts_lock = asyncio.Lock() self._tts_task: asyncio.Task | None = None def start(self) -> None: self._stt.start() def feed_audio(self, chunk: bytes) -> None: self._stt.send_audio(chunk) async def _on_final_transcript(self, text: str) -> None: if self._wake_word_enabled: question = self._extract_question(text) if question is None: logger.debug("No wake word detected, ignoring: %s", text) return logger.info("Wake word detected, question: %s", question) await self._send_event({"event": "transcript", "text": question}) else: if not text.strip(): return logger.info("Transcript: %s", text) await self._send_event({"event": "transcript", "text": text}) def _extract_question(self, transcript: str) -> str | None: lower = transcript.lower() match = min( ((lower.index(w), w) for w in WAKE_WORDS if w in lower), key=lambda x: x[0], default=None, ) if match is None: return None idx = match[0] + len(match[1]) question = transcript[idx:].strip(" ,.") return question if question else None async def speak(self, text: str) -> None: """Called by the WebSocket handler when the FE sends a speak action.""" self._tts_task = asyncio.create_task(self._speak(text)) async def _speak(self, text: str) -> None: logger.info("TTS speak: provider=%s, text_len=%d, text=%r", self._tts_provider, len(text), text) async with self._tts_lock: try: stream = ( gemini_synthesize_stream(text) if self._tts_provider == "gemini" else cartesia_synthesize_stream(text) ) async for audio_chunk in stream: await self._send_audio(audio_chunk) await self._send_event({"event": "tts_end"}) except asyncio.CancelledError: raise except Exception: logger.exception("TTS error for text: %s", text) await self._send_event({ "event": "error", "code": "TTS_ERROR", "message": "TTS generation failed", }) async def interrupt(self) -> None: if self._tts_task and not self._tts_task.done(): self._tts_task.cancel() try: await self._tts_task except asyncio.CancelledError: pass await self._send_event({"event": "interrupted"}) logger.info("TTS interrupted by client") def stop(self) -> None: if self._tts_task and not self._tts_task.done(): self._tts_task.cancel() self._stt.stop() async def stop_async(self) -> None: if self._tts_task and not self._tts_task.done(): self._tts_task.cancel() if hasattr(self._stt, "stop_and_transcribe"): await self._stt.stop_and_transcribe() else: self._stt.stop() # Backward-compatible alias EchoPipeline = VoicePipeline