ishaq101's picture
fixing error stt and tts, empty chunk audio
c2e783d
import asyncio
import logging
from typing import Callable, Awaitable
from src.config import WAKE_WORDS, WAKE_WORD_ENABLED
from src.stt.deepgram_client import DeepgramStreamer
from src.tts.cartesia_client import synthesize_stream as cartesia_synthesize_stream
from src.tts.gemini_client import synthesize_stream as gemini_synthesize_stream, GEMINI_SAMPLE_RATE
logger = logging.getLogger(__name__)
SendAudioCallback = Callable[[bytes], Awaitable[None]]
SendEventCallback = Callable[[dict], Awaitable[None]]
class VoicePipeline:
"""
Voice pipeline: Audio in β†’ STT β†’ wake word check β†’ chatbot agent β†’ TTS β†’ audio out.
After wake word detection, the text following the wake word is sent to the
chatbot service (POST /api/agents/call). The response field answer.audio_text
is used for TTS; falls back to answer.content if audio_text is null.
"""
def __init__(
self,
send_audio: SendAudioCallback,
send_event: SendEventCallback,
stt_provider: str = "deepgram",
tts_provider: str = "cartesia",
wake_word_enabled: bool = WAKE_WORD_ENABLED,
**kwargs,
):
self._send_audio = send_audio
self._send_event = send_event
self._loop = asyncio.get_event_loop()
self._tts_provider = tts_provider
self._wake_word_enabled = wake_word_enabled
if stt_provider == "gemini":
from src.stt.gemini_stt import GeminiSTTStreamer
self._stt = GeminiSTTStreamer(
on_final_transcript=self._on_final_transcript,
loop=self._loop,
)
elif stt_provider == "chirp3":
from src.stt.chirp3_client import Chirp3STTStreamer
self._stt = Chirp3STTStreamer(
on_final_transcript=self._on_final_transcript,
loop=self._loop,
)
else:
self._stt = DeepgramStreamer(
on_final_transcript=self._on_final_transcript,
loop=self._loop,
)
self._tts_lock = asyncio.Lock()
self._tts_task: asyncio.Task | None = None
def start(self) -> None:
self._stt.start()
def feed_audio(self, chunk: bytes) -> None:
self._stt.send_audio(chunk)
async def _on_final_transcript(self, text: str) -> None:
if self._wake_word_enabled:
question = self._extract_question(text)
if question is None:
logger.debug("No wake word detected, ignoring: %s", text)
return
logger.info("Wake word detected, question: %s", question)
await self._send_event({"event": "transcript", "text": question})
else:
if not text.strip():
return
logger.info("Transcript: %s", text)
await self._send_event({"event": "transcript", "text": text})
def _extract_question(self, transcript: str) -> str | None:
lower = transcript.lower()
match = min(
((lower.index(w), w) for w in WAKE_WORDS if w in lower),
key=lambda x: x[0],
default=None,
)
if match is None:
return None
idx = match[0] + len(match[1])
question = transcript[idx:].strip(" ,.")
return question if question else None
async def speak(self, text: str) -> None:
"""Called by the WebSocket handler when the FE sends a speak action."""
self._tts_task = asyncio.create_task(self._speak(text))
async def _speak(self, text: str) -> None:
logger.info("TTS speak: provider=%s, text_len=%d, text=%r", self._tts_provider, len(text), text)
async with self._tts_lock:
try:
stream = (
gemini_synthesize_stream(text)
if self._tts_provider == "gemini"
else cartesia_synthesize_stream(text)
)
async for audio_chunk in stream:
await self._send_audio(audio_chunk)
await self._send_event({"event": "tts_end"})
except asyncio.CancelledError:
raise
except Exception:
logger.exception("TTS error for text: %s", text)
await self._send_event({
"event": "error",
"code": "TTS_ERROR",
"message": "TTS generation failed",
})
async def interrupt(self) -> None:
if self._tts_task and not self._tts_task.done():
self._tts_task.cancel()
try:
await self._tts_task
except asyncio.CancelledError:
pass
await self._send_event({"event": "interrupted"})
logger.info("TTS interrupted by client")
def stop(self) -> None:
if self._tts_task and not self._tts_task.done():
self._tts_task.cancel()
self._stt.stop()
async def stop_async(self) -> None:
if self._tts_task and not self._tts_task.done():
self._tts_task.cancel()
if hasattr(self._stt, "stop_and_transcribe"):
await self._stt.stop_and_transcribe()
else:
self._stt.stop()
# Backward-compatible alias
EchoPipeline = VoicePipeline