Spaces:
Sleeping
Sleeping
| import asyncio | |
| import logging | |
| from typing import Callable, Awaitable | |
| from src.config import WAKE_WORDS, WAKE_WORD_ENABLED | |
| from src.stt.deepgram_client import DeepgramStreamer | |
| from src.tts.cartesia_client import synthesize_stream as cartesia_synthesize_stream | |
| from src.tts.gemini_client import synthesize_stream as gemini_synthesize_stream, GEMINI_SAMPLE_RATE | |
| logger = logging.getLogger(__name__) | |
| SendAudioCallback = Callable[[bytes], Awaitable[None]] | |
| SendEventCallback = Callable[[dict], Awaitable[None]] | |
| class VoicePipeline: | |
| """ | |
| Voice pipeline: Audio in β STT β wake word check β chatbot agent β TTS β audio out. | |
| After wake word detection, the text following the wake word is sent to the | |
| chatbot service (POST /api/agents/call). The response field answer.audio_text | |
| is used for TTS; falls back to answer.content if audio_text is null. | |
| """ | |
| def __init__( | |
| self, | |
| send_audio: SendAudioCallback, | |
| send_event: SendEventCallback, | |
| stt_provider: str = "deepgram", | |
| tts_provider: str = "cartesia", | |
| wake_word_enabled: bool = WAKE_WORD_ENABLED, | |
| **kwargs, | |
| ): | |
| self._send_audio = send_audio | |
| self._send_event = send_event | |
| self._loop = asyncio.get_event_loop() | |
| self._tts_provider = tts_provider | |
| self._wake_word_enabled = wake_word_enabled | |
| if stt_provider == "gemini": | |
| from src.stt.gemini_stt import GeminiSTTStreamer | |
| self._stt = GeminiSTTStreamer( | |
| on_final_transcript=self._on_final_transcript, | |
| loop=self._loop, | |
| ) | |
| elif stt_provider == "chirp3": | |
| from src.stt.chirp3_client import Chirp3STTStreamer | |
| self._stt = Chirp3STTStreamer( | |
| on_final_transcript=self._on_final_transcript, | |
| loop=self._loop, | |
| ) | |
| else: | |
| self._stt = DeepgramStreamer( | |
| on_final_transcript=self._on_final_transcript, | |
| loop=self._loop, | |
| ) | |
| self._tts_lock = asyncio.Lock() | |
| self._tts_task: asyncio.Task | None = None | |
| def start(self) -> None: | |
| self._stt.start() | |
| def feed_audio(self, chunk: bytes) -> None: | |
| self._stt.send_audio(chunk) | |
| async def _on_final_transcript(self, text: str) -> None: | |
| if self._wake_word_enabled: | |
| question = self._extract_question(text) | |
| if question is None: | |
| logger.debug("No wake word detected, ignoring: %s", text) | |
| return | |
| logger.info("Wake word detected, question: %s", question) | |
| await self._send_event({"event": "transcript", "text": question}) | |
| else: | |
| if not text.strip(): | |
| return | |
| logger.info("Transcript: %s", text) | |
| await self._send_event({"event": "transcript", "text": text}) | |
| def _extract_question(self, transcript: str) -> str | None: | |
| lower = transcript.lower() | |
| match = min( | |
| ((lower.index(w), w) for w in WAKE_WORDS if w in lower), | |
| key=lambda x: x[0], | |
| default=None, | |
| ) | |
| if match is None: | |
| return None | |
| idx = match[0] + len(match[1]) | |
| question = transcript[idx:].strip(" ,.") | |
| return question if question else None | |
| async def speak(self, text: str) -> None: | |
| """Called by the WebSocket handler when the FE sends a speak action.""" | |
| self._tts_task = asyncio.create_task(self._speak(text)) | |
| async def _speak(self, text: str) -> None: | |
| logger.info("TTS speak: provider=%s, text_len=%d, text=%r", self._tts_provider, len(text), text) | |
| async with self._tts_lock: | |
| try: | |
| stream = ( | |
| gemini_synthesize_stream(text) | |
| if self._tts_provider == "gemini" | |
| else cartesia_synthesize_stream(text) | |
| ) | |
| async for audio_chunk in stream: | |
| await self._send_audio(audio_chunk) | |
| await self._send_event({"event": "tts_end"}) | |
| except asyncio.CancelledError: | |
| raise | |
| except Exception: | |
| logger.exception("TTS error for text: %s", text) | |
| await self._send_event({ | |
| "event": "error", | |
| "code": "TTS_ERROR", | |
| "message": "TTS generation failed", | |
| }) | |
| async def interrupt(self) -> None: | |
| if self._tts_task and not self._tts_task.done(): | |
| self._tts_task.cancel() | |
| try: | |
| await self._tts_task | |
| except asyncio.CancelledError: | |
| pass | |
| await self._send_event({"event": "interrupted"}) | |
| logger.info("TTS interrupted by client") | |
| def stop(self) -> None: | |
| if self._tts_task and not self._tts_task.done(): | |
| self._tts_task.cancel() | |
| self._stt.stop() | |
| async def stop_async(self) -> None: | |
| if self._tts_task and not self._tts_task.done(): | |
| self._tts_task.cancel() | |
| if hasattr(self._stt, "stop_and_transcribe"): | |
| await self._stt.stop_and_transcribe() | |
| else: | |
| self._stt.stop() | |
| # Backward-compatible alias | |
| EchoPipeline = VoicePipeline | |