File size: 3,558 Bytes
9fbd33c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from __future__ import annotations

import asyncio
import math
import threading
import time
from typing import Any

from livekit.plugins import nvidia

from src.agent.models import tts_factory


class _FakeResponse:
    def __init__(self, audio: bytes) -> None:
        self.audio = audio


class _FakeSynthesisService:
    def __init__(
        self,
        *,
        chunks: list[bytes] | None = None,
        delay_sec: float = 0.0,
        started_event: threading.Event | None = None,
        finished_event: threading.Event | None = None,
    ) -> None:
        self._chunks = chunks or []
        self._delay_sec = delay_sec
        self._started_event = started_event
        self._finished_event = finished_event

    def synthesize_online(self, *args: Any, **kwargs: Any) -> Any:
        _ = args
        _ = kwargs
        if self._started_event is not None:
            self._started_event.set()
        try:
            if self._delay_sec > 0:
                time.sleep(self._delay_sec)
            for chunk in self._chunks:
                yield _FakeResponse(chunk)
        finally:
            if self._finished_event is not None:
                self._finished_event.set()


def test_patched_nvidia_stream_emits_tts_metrics() -> None:
    tts_factory._patch_nvidia_tts_stream_once()
    collected_metrics: list[Any] = []

    async def _run() -> int:
        tts_engine = nvidia.TTS(use_ssl=False)
        tts_engine._ensure_session = lambda: _FakeSynthesisService(
            chunks=[b"\0\0" * 1600, b"\0\0" * 1600]
        )
        tts_engine.on("metrics_collected", lambda metric: collected_metrics.append(metric))

        stream = tts_engine.stream()
        stream.push_text("hello world")
        stream.end_input()

        frames = 0
        async for _ in stream:
            frames += 1
        return frames

    frame_count = asyncio.run(_run())

    assert frame_count > 0
    assert len(collected_metrics) == 1
    metric = collected_metrics[0]
    assert math.isfinite(metric.ttfb)
    assert metric.ttfb >= 0
    assert metric.characters_count == len("hello world")


def test_patched_nvidia_stream_avoids_invalid_state_on_shutdown() -> None:
    tts_factory._patch_nvidia_tts_stream_once()
    loop_errors: list[dict[str, Any]] = []
    started_event = threading.Event()
    finished_event = threading.Event()

    async def _run() -> None:
        loop = asyncio.get_running_loop()
        previous_handler = loop.get_exception_handler()

        def _capture_exception(_loop: asyncio.AbstractEventLoop, context: dict[str, Any]) -> None:
            loop_errors.append(context)

        loop.set_exception_handler(_capture_exception)
        try:
            tts_engine = nvidia.TTS(use_ssl=False)
            tts_engine._ensure_session = lambda: _FakeSynthesisService(
                delay_sec=0.05,
                started_event=started_event,
                finished_event=finished_event,
            )

            stream = tts_engine.stream()
            stream.push_text("cancel me")
            stream.end_input()

            await asyncio.to_thread(started_event.wait, 1.0)
            await stream.aclose()
            await asyncio.to_thread(finished_event.wait, 1.0)
            await asyncio.sleep(0.05)
        finally:
            loop.set_exception_handler(previous_handler)

    asyncio.run(_run())

    invalid_state_errors = [
        context
        for context in loop_errors
        if isinstance(context.get("exception"), asyncio.InvalidStateError)
    ]
    assert not invalid_state_errors