| """Provider-agnostic realtime voice assistant orchestration.""" |
|
|
| from __future__ import annotations |
|
|
| from collections.abc import Iterable, Iterator |
| from dataclasses import dataclass, field |
| from typing import Any, Protocol |
| from uuid import uuid4 |
|
|
|
|
| class StreamingSttProvider(Protocol): |
| def stream_transcript( |
| self, |
| audio_chunks: list[bytes], |
| *, |
| language: str, |
| session_id: str, |
| ) -> Iterable[dict[str, Any]]: ... |
|
|
|
|
| class StreamingLlmProvider(Protocol): |
| def stream_response( |
| self, |
| transcript: str, |
| *, |
| session_id: str, |
| persona_id: str | None, |
| ) -> Iterable[str]: ... |
|
|
|
|
| class StreamingTtsProvider(Protocol): |
| def stream_audio( |
| self, |
| text_chunks: Iterable[str], |
| *, |
| language: str, |
| voice: str, |
| session_id: str, |
| ) -> Iterable[dict[str, Any]]: ... |
|
|
|
|
| @dataclass(slots=True, frozen=True) |
| class RealtimeVoiceAssistantConfig: |
| default_language: str = "lv" |
| default_voice: str = "maris" |
| vad_provider: str = "silero" |
| allow_barge_in: bool = True |
| emit_partial_transcripts: bool = True |
|
|
|
|
| @dataclass(slots=True) |
| class VoiceSessionState: |
| session_id: str |
| active_response_id: str | None = None |
| interrupted_response_ids: set[str] = field(default_factory=set) |
|
|
|
|
| class BargeInController: |
| """Tur aktīvo atbilžu stāvokli un pārtraukumus.""" |
|
|
| def start_response(self, session: VoiceSessionState) -> str: |
| response_id = uuid4().hex |
| session.active_response_id = response_id |
| session.interrupted_response_ids.discard(response_id) |
| return response_id |
|
|
| def interrupt(self, session: VoiceSessionState) -> str | None: |
| response_id = session.active_response_id |
| if response_id is None: |
| return None |
| session.interrupted_response_ids.add(response_id) |
| session.active_response_id = None |
| return response_id |
|
|
| def is_interrupted(self, session: VoiceSessionState, response_id: str) -> bool: |
| return response_id in session.interrupted_response_ids |
|
|
|
|
| class RealtimeVoiceAssistant: |
| """Orchestrē streaming STT → LLM → TTS plūsmu ar VAD/barge-in notikumiem.""" |
|
|
| def __init__( |
| self, |
| *, |
| stt: StreamingSttProvider, |
| llm: StreamingLlmProvider, |
| tts: StreamingTtsProvider, |
| config: RealtimeVoiceAssistantConfig | None = None, |
| barge_in: BargeInController | None = None, |
| ) -> None: |
| self.stt = stt |
| self.llm = llm |
| self.tts = tts |
| self.config = config or RealtimeVoiceAssistantConfig() |
| self.barge_in = barge_in or BargeInController() |
|
|
| def interrupt(self, session: VoiceSessionState) -> dict[str, Any]: |
| response_id = self.barge_in.interrupt(session) |
| return { |
| "type": "barge_in", |
| "session_id": session.session_id, |
| "response_id": response_id, |
| "interrupted": response_id is not None, |
| } |
|
|
| def handle_turn( |
| self, |
| audio_chunks: list[bytes], |
| *, |
| session: VoiceSessionState, |
| language: str | None = None, |
| voice: str | None = None, |
| persona_id: str | None = None, |
| ) -> Iterator[dict[str, Any]]: |
| resolved_language = language or self.config.default_language |
| resolved_voice = voice or self.config.default_voice |
| response_id = self.barge_in.start_response(session) |
|
|
| yield { |
| "type": "vad", |
| "provider": self.config.vad_provider, |
| "event": "speech_end_detected", |
| "session_id": session.session_id, |
| "response_id": response_id, |
| } |
|
|
| final_transcript = "" |
| for chunk in self.stt.stream_transcript( |
| audio_chunks, |
| language=resolved_language, |
| session_id=session.session_id, |
| ): |
| text = str(chunk.get("text", "") or "") |
| if text and (self.config.emit_partial_transcripts or chunk.get("is_final")): |
| yield { |
| "type": "transcript", |
| "session_id": session.session_id, |
| "response_id": response_id, |
| "text": text, |
| "is_final": bool(chunk.get("is_final")), |
| } |
| if chunk.get("is_final") and text: |
| final_transcript = text |
|
|
| if not final_transcript: |
| yield { |
| "type": "error", |
| "session_id": session.session_id, |
| "response_id": response_id, |
| "message": "Streaming STT neatgrieza gala transkripciju.", |
| } |
| return |
|
|
| llm_chunks: list[str] = [] |
| for delta in self.llm.stream_response( |
| final_transcript, |
| session_id=session.session_id, |
| persona_id=persona_id, |
| ): |
| if self.config.allow_barge_in and self.barge_in.is_interrupted(session, response_id): |
| yield { |
| "type": "barge_in", |
| "session_id": session.session_id, |
| "response_id": response_id, |
| "interrupted": True, |
| } |
| return |
| llm_chunks.append(delta) |
| yield { |
| "type": "llm_delta", |
| "session_id": session.session_id, |
| "response_id": response_id, |
| "delta": delta, |
| } |
|
|
| for audio_event in self.tts.stream_audio( |
| llm_chunks, |
| language=resolved_language, |
| voice=resolved_voice, |
| session_id=session.session_id, |
| ): |
| if self.config.allow_barge_in and self.barge_in.is_interrupted(session, response_id): |
| yield { |
| "type": "barge_in", |
| "session_id": session.session_id, |
| "response_id": response_id, |
| "interrupted": True, |
| } |
| return |
| yield { |
| "type": "tts_audio", |
| "session_id": session.session_id, |
| "response_id": response_id, |
| "audio": audio_event.get("audio", b""), |
| "mime_type": audio_event.get("mime_type", "audio/wav"), |
| "is_final": bool(audio_event.get("is_final")), |
| } |
|
|
| session.active_response_id = None |
| yield { |
| "type": "complete", |
| "session_id": session.session_id, |
| "response_id": response_id, |
| "text": "".join(llm_chunks), |
| "voice": resolved_voice, |
| "language": resolved_language, |
| } |
|
|