MarisUK's picture
Maris AI model sync
f440f03 verified
"""Provider-agnostic realtime voice assistant orchestration."""
from __future__ import annotations
from collections.abc import Iterable, Iterator
from dataclasses import dataclass, field
from typing import Any, Protocol
from uuid import uuid4
class StreamingSttProvider(Protocol):
def stream_transcript(
self,
audio_chunks: list[bytes],
*,
language: str,
session_id: str,
) -> Iterable[dict[str, Any]]: ...
class StreamingLlmProvider(Protocol):
def stream_response(
self,
transcript: str,
*,
session_id: str,
persona_id: str | None,
) -> Iterable[str]: ...
class StreamingTtsProvider(Protocol):
def stream_audio(
self,
text_chunks: Iterable[str],
*,
language: str,
voice: str,
session_id: str,
) -> Iterable[dict[str, Any]]: ...
@dataclass(slots=True, frozen=True)
class RealtimeVoiceAssistantConfig:
default_language: str = "lv"
default_voice: str = "maris"
vad_provider: str = "silero"
allow_barge_in: bool = True
emit_partial_transcripts: bool = True
@dataclass(slots=True)
class VoiceSessionState:
session_id: str
active_response_id: str | None = None
interrupted_response_ids: set[str] = field(default_factory=set)
class BargeInController:
"""Tur aktīvo atbilžu stāvokli un pārtraukumus."""
def start_response(self, session: VoiceSessionState) -> str:
response_id = uuid4().hex
session.active_response_id = response_id
session.interrupted_response_ids.discard(response_id)
return response_id
def interrupt(self, session: VoiceSessionState) -> str | None:
response_id = session.active_response_id
if response_id is None:
return None
session.interrupted_response_ids.add(response_id)
session.active_response_id = None
return response_id
def is_interrupted(self, session: VoiceSessionState, response_id: str) -> bool:
return response_id in session.interrupted_response_ids
class RealtimeVoiceAssistant:
"""Orchestrē streaming STT → LLM → TTS plūsmu ar VAD/barge-in notikumiem."""
def __init__(
self,
*,
stt: StreamingSttProvider,
llm: StreamingLlmProvider,
tts: StreamingTtsProvider,
config: RealtimeVoiceAssistantConfig | None = None,
barge_in: BargeInController | None = None,
) -> None:
self.stt = stt
self.llm = llm
self.tts = tts
self.config = config or RealtimeVoiceAssistantConfig()
self.barge_in = barge_in or BargeInController()
def interrupt(self, session: VoiceSessionState) -> dict[str, Any]:
response_id = self.barge_in.interrupt(session)
return {
"type": "barge_in",
"session_id": session.session_id,
"response_id": response_id,
"interrupted": response_id is not None,
}
def handle_turn(
self,
audio_chunks: list[bytes],
*,
session: VoiceSessionState,
language: str | None = None,
voice: str | None = None,
persona_id: str | None = None,
) -> Iterator[dict[str, Any]]:
resolved_language = language or self.config.default_language
resolved_voice = voice or self.config.default_voice
response_id = self.barge_in.start_response(session)
yield {
"type": "vad",
"provider": self.config.vad_provider,
"event": "speech_end_detected",
"session_id": session.session_id,
"response_id": response_id,
}
final_transcript = ""
for chunk in self.stt.stream_transcript(
audio_chunks,
language=resolved_language,
session_id=session.session_id,
):
text = str(chunk.get("text", "") or "")
if text and (self.config.emit_partial_transcripts or chunk.get("is_final")):
yield {
"type": "transcript",
"session_id": session.session_id,
"response_id": response_id,
"text": text,
"is_final": bool(chunk.get("is_final")),
}
if chunk.get("is_final") and text:
final_transcript = text
if not final_transcript:
yield {
"type": "error",
"session_id": session.session_id,
"response_id": response_id,
"message": "Streaming STT neatgrieza gala transkripciju.",
}
return
llm_chunks: list[str] = []
for delta in self.llm.stream_response(
final_transcript,
session_id=session.session_id,
persona_id=persona_id,
):
if self.config.allow_barge_in and self.barge_in.is_interrupted(session, response_id):
yield {
"type": "barge_in",
"session_id": session.session_id,
"response_id": response_id,
"interrupted": True,
}
return
llm_chunks.append(delta)
yield {
"type": "llm_delta",
"session_id": session.session_id,
"response_id": response_id,
"delta": delta,
}
for audio_event in self.tts.stream_audio(
llm_chunks,
language=resolved_language,
voice=resolved_voice,
session_id=session.session_id,
):
if self.config.allow_barge_in and self.barge_in.is_interrupted(session, response_id):
yield {
"type": "barge_in",
"session_id": session.session_id,
"response_id": response_id,
"interrupted": True,
}
return
yield {
"type": "tts_audio",
"session_id": session.session_id,
"response_id": response_id,
"audio": audio_event.get("audio", b""),
"mime_type": audio_event.get("mime_type", "audio/wav"),
"is_final": bool(audio_event.get("is_final")),
}
session.active_response_id = None
yield {
"type": "complete",
"session_id": session.session_id,
"response_id": response_id,
"text": "".join(llm_chunks),
"voice": resolved_voice,
"language": resolved_language,
}