File size: 5,301 Bytes
226ff5d
 
 
38a5904
226ff5d
38a5904
 
226ff5d
 
 
 
 
 
 
e75bac4
226ff5d
e75bac4
226ff5d
e75bac4
 
 
226ff5d
 
e75bac4
 
 
 
38a5904
 
 
986403e
e75bac4
226ff5d
 
 
38a5904
 
e75bac4
38a5904
 
 
 
 
 
986403e
 
 
 
 
 
38a5904
 
 
 
 
226ff5d
 
 
 
 
 
 
 
 
 
38a5904
 
 
 
 
 
986403e
38a5904
986403e
38a5904
986403e
 
226ff5d
e75bac4
226ff5d
 
 
 
 
 
 
 
 
e75bac4
 
 
986403e
 
 
226ff5d
 
c2e783d
226ff5d
 
38a5904
 
 
 
 
 
226ff5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e75bac4
986403e
 
 
 
 
 
 
 
e75bac4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import asyncio
import logging
from typing import Callable, Awaitable
from src.config import WAKE_WORDS, WAKE_WORD_ENABLED
from src.stt.deepgram_client import DeepgramStreamer
from src.tts.cartesia_client import synthesize_stream as cartesia_synthesize_stream
from src.tts.gemini_client import synthesize_stream as gemini_synthesize_stream, GEMINI_SAMPLE_RATE

logger = logging.getLogger(__name__)

SendAudioCallback = Callable[[bytes], Awaitable[None]]
SendEventCallback = Callable[[dict], Awaitable[None]]


class VoicePipeline:
    """
    Voice pipeline: Audio in β†’ STT β†’ wake word check β†’ chatbot agent β†’ TTS β†’ audio out.

    After wake word detection, the text following the wake word is sent to the
    chatbot service (POST /api/agents/call). The response field answer.audio_text
    is used for TTS; falls back to answer.content if audio_text is null.
    """

    def __init__(
        self,
        send_audio: SendAudioCallback,
        send_event: SendEventCallback,
        stt_provider: str = "deepgram",
        tts_provider: str = "cartesia",
        wake_word_enabled: bool = WAKE_WORD_ENABLED,
        **kwargs,
    ):
        self._send_audio = send_audio
        self._send_event = send_event
        self._loop = asyncio.get_event_loop()
        self._tts_provider = tts_provider
        self._wake_word_enabled = wake_word_enabled

        if stt_provider == "gemini":
            from src.stt.gemini_stt import GeminiSTTStreamer
            self._stt = GeminiSTTStreamer(
                on_final_transcript=self._on_final_transcript,
                loop=self._loop,
            )
        elif stt_provider == "chirp3":
            from src.stt.chirp3_client import Chirp3STTStreamer
            self._stt = Chirp3STTStreamer(
                on_final_transcript=self._on_final_transcript,
                loop=self._loop,
            )
        else:
            self._stt = DeepgramStreamer(
                on_final_transcript=self._on_final_transcript,
                loop=self._loop,
            )
        self._tts_lock = asyncio.Lock()
        self._tts_task: asyncio.Task | None = None

    def start(self) -> None:
        self._stt.start()

    def feed_audio(self, chunk: bytes) -> None:
        self._stt.send_audio(chunk)

    async def _on_final_transcript(self, text: str) -> None:
        if self._wake_word_enabled:
            question = self._extract_question(text)
            if question is None:
                logger.debug("No wake word detected, ignoring: %s", text)
                return
            logger.info("Wake word detected, question: %s", question)
            await self._send_event({"event": "transcript", "text": question})
        else:
            if not text.strip():
                return
            logger.info("Transcript: %s", text)
            await self._send_event({"event": "transcript", "text": text})

    def _extract_question(self, transcript: str) -> str | None:
        lower = transcript.lower()
        match = min(
            ((lower.index(w), w) for w in WAKE_WORDS if w in lower),
            key=lambda x: x[0],
            default=None,
        )
        if match is None:
            return None
        idx = match[0] + len(match[1])
        question = transcript[idx:].strip(" ,.")
        return question if question else None

    async def speak(self, text: str) -> None:
        """Called by the WebSocket handler when the FE sends a speak action."""
        self._tts_task = asyncio.create_task(self._speak(text))

    async def _speak(self, text: str) -> None:
        logger.info("TTS speak: provider=%s, text_len=%d, text=%r", self._tts_provider, len(text), text)
        async with self._tts_lock:
            try:
                stream = (
                    gemini_synthesize_stream(text)
                    if self._tts_provider == "gemini"
                    else cartesia_synthesize_stream(text)
                )
                async for audio_chunk in stream:
                    await self._send_audio(audio_chunk)
                await self._send_event({"event": "tts_end"})
            except asyncio.CancelledError:
                raise
            except Exception:
                logger.exception("TTS error for text: %s", text)
                await self._send_event({
                    "event": "error",
                    "code": "TTS_ERROR",
                    "message": "TTS generation failed",
                })

    async def interrupt(self) -> None:
        if self._tts_task and not self._tts_task.done():
            self._tts_task.cancel()
            try:
                await self._tts_task
            except asyncio.CancelledError:
                pass
        await self._send_event({"event": "interrupted"})
        logger.info("TTS interrupted by client")

    def stop(self) -> None:
        if self._tts_task and not self._tts_task.done():
            self._tts_task.cancel()
        self._stt.stop()

    async def stop_async(self) -> None:
        if self._tts_task and not self._tts_task.done():
            self._tts_task.cancel()
        if hasattr(self._stt, "stop_and_transcribe"):
            await self._stt.stop_and_transcribe()
        else:
            self._stt.stop()


# Backward-compatible alias
EchoPipeline = VoicePipeline