open-voice-agent / tests /test_nvidia_tts_patch.py
dvalle08's picture
feat: Implement NVIDIA TTS stream patch for improved synthesis handling
9fbd33c
from __future__ import annotations
import asyncio
import math
import threading
import time
from typing import Any
from livekit.plugins import nvidia
from src.agent.models import tts_factory
class _FakeResponse:
def __init__(self, audio: bytes) -> None:
self.audio = audio
class _FakeSynthesisService:
def __init__(
self,
*,
chunks: list[bytes] | None = None,
delay_sec: float = 0.0,
started_event: threading.Event | None = None,
finished_event: threading.Event | None = None,
) -> None:
self._chunks = chunks or []
self._delay_sec = delay_sec
self._started_event = started_event
self._finished_event = finished_event
def synthesize_online(self, *args: Any, **kwargs: Any) -> Any:
_ = args
_ = kwargs
if self._started_event is not None:
self._started_event.set()
try:
if self._delay_sec > 0:
time.sleep(self._delay_sec)
for chunk in self._chunks:
yield _FakeResponse(chunk)
finally:
if self._finished_event is not None:
self._finished_event.set()
def test_patched_nvidia_stream_emits_tts_metrics() -> None:
tts_factory._patch_nvidia_tts_stream_once()
collected_metrics: list[Any] = []
async def _run() -> int:
tts_engine = nvidia.TTS(use_ssl=False)
tts_engine._ensure_session = lambda: _FakeSynthesisService(
chunks=[b"\0\0" * 1600, b"\0\0" * 1600]
)
tts_engine.on("metrics_collected", lambda metric: collected_metrics.append(metric))
stream = tts_engine.stream()
stream.push_text("hello world")
stream.end_input()
frames = 0
async for _ in stream:
frames += 1
return frames
frame_count = asyncio.run(_run())
assert frame_count > 0
assert len(collected_metrics) == 1
metric = collected_metrics[0]
assert math.isfinite(metric.ttfb)
assert metric.ttfb >= 0
assert metric.characters_count == len("hello world")
def test_patched_nvidia_stream_avoids_invalid_state_on_shutdown() -> None:
tts_factory._patch_nvidia_tts_stream_once()
loop_errors: list[dict[str, Any]] = []
started_event = threading.Event()
finished_event = threading.Event()
async def _run() -> None:
loop = asyncio.get_running_loop()
previous_handler = loop.get_exception_handler()
def _capture_exception(_loop: asyncio.AbstractEventLoop, context: dict[str, Any]) -> None:
loop_errors.append(context)
loop.set_exception_handler(_capture_exception)
try:
tts_engine = nvidia.TTS(use_ssl=False)
tts_engine._ensure_session = lambda: _FakeSynthesisService(
delay_sec=0.05,
started_event=started_event,
finished_event=finished_event,
)
stream = tts_engine.stream()
stream.push_text("cancel me")
stream.end_input()
await asyncio.to_thread(started_event.wait, 1.0)
await stream.aclose()
await asyncio.to_thread(finished_event.wait, 1.0)
await asyncio.sleep(0.05)
finally:
loop.set_exception_handler(previous_handler)
asyncio.run(_run())
invalid_state_errors = [
context
for context in loop_errors
if isinstance(context.get("exception"), asyncio.InvalidStateError)
]
assert not invalid_state_errors