File size: 1,863 Bytes
226ff5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import asyncio
import logging
from typing import Callable, Awaitable
import assemblyai as aai
from src.config import ASSEMBLYAI_API_KEY, SAMPLE_RATE

logger = logging.getLogger(__name__)

OnTranscriptCallback = Callable[[str], Awaitable[None]]


class AssemblyAIStreamer:
    """
    Wraps AssemblyAI real-time streaming STT.
    Audio chunks (PCM 16kHz 16-bit mono bytes) are fed via `send_audio()`.
    When a final transcript arrives, `on_final_transcript` callback is awaited.
    """

    def __init__(self, on_final_transcript: OnTranscriptCallback, loop: asyncio.AbstractEventLoop):
        self._on_final_transcript = on_final_transcript
        self._loop = loop
        self._transcriber: aai.RealtimeTranscriber | None = None

    def start(self) -> None:
        aai.settings.api_key = ASSEMBLYAI_API_KEY

        self._transcriber = aai.RealtimeTranscriber(
            sample_rate=SAMPLE_RATE,
            on_data=self._on_data,
            on_error=self._on_error,
        )
        self._transcriber.connect()
        logger.info("AssemblyAI STT connected")

    def _on_data(self, transcript: aai.RealtimeTranscript) -> None:
        if not isinstance(transcript, aai.RealtimeFinalTranscript):
            return
        text = transcript.text.strip()
        if not text:
            return
        logger.info("Final transcript: %s", text)
        asyncio.run_coroutine_threadsafe(self._on_final_transcript(text), self._loop)

    def _on_error(self, error: aai.RealtimeError) -> None:
        logger.error("AssemblyAI error: %s", error)

    def send_audio(self, chunk: bytes) -> None:
        if self._transcriber:
            self._transcriber.stream(chunk)

    def stop(self) -> None:
        if self._transcriber:
            self._transcriber.close()
            self._transcriber = None
            logger.info("AssemblyAI STT closed")