File size: 10,175 Bytes
a1d8504
 
 
 
 
 
 
 
 
 
 
 
 
 
4a03ace
a1d8504
 
 
 
 
 
9b5c112
3344a84
a1d8504
 
 
 
d0f8e17
a1d8504
 
 
 
 
a524a82
a1d8504
 
 
 
60ef510
 
 
 
 
 
 
 
 
 
 
a1d8504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c394e23
 
 
 
 
 
 
 
 
4a03ace
 
 
 
 
 
 
 
 
3344a84
 
 
 
 
 
 
 
 
62404fa
3344a84
 
 
 
 
 
 
 
a1d8504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c394e23
a1d8504
 
 
 
 
 
 
 
7fdafe4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1d8504
 
 
 
 
 
 
 
 
 
 
7fdafe4
 
 
a1d8504
 
9b5c112
3344a84
1641cad
a1d8504
 
d0f8e17
 
 
 
 
 
a1d8504
 
 
 
 
 
4a03ace
 
 
 
 
 
 
 
 
 
 
 
 
a1d8504
4a03ace
a1d8504
 
 
 
 
 
 
4a03ace
a1d8504
 
cd02c6a
a1d8504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a524a82
 
 
 
 
 
 
 
 
 
 
 
 
 
a1d8504
 
 
 
 
62404fa
a1d8504
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
"""LiveKit server and session handler wiring."""

from __future__ import annotations

import asyncio
import json
import sys
from typing import Any

from livekit import agents, rtc
from livekit.agents import AgentServer, AgentSession, room_io
from livekit.agents.types import APIConnectOptions
from livekit.agents.voice.agent_session import SessionConnectOptions
from livekit.plugins import noise_cancellation, silero
from livekit.plugins.turn_detector.english import EnglishModel

from src.agent.models.llm_runtime import (
    build_llm_runtime,
    install_mcp_generate_reply_guard,
    run_startup_greeting,
)
from src.agent.models.tts_factory import create_tts
from src.agent.runtime.connect_options import build_api_connect_options
from src.agent.models.stt_factory import create_stt
from src.agent.runtime.assistant import Assistant
from src.agent.runtime.tasks import (
    cancel_task_for_shutdown,
    run_llm_warmup,
    schedule_startup_greeting_task,
)
from src.agent.tools.feedback import ToolFeedbackController
from src.agent.traces.langfuse import setup_langfuse_tracer
from src.agent.traces.metrics_collector import MetricsCollector
from src.agent.traces.text_output_tracing import install_tracing_text_output
from src.core.logger import logger
from src.core.settings import settings


def _build_server() -> AgentServer:
    return AgentServer(
        num_idle_processes=settings.livekit.LIVEKIT_NUM_IDLE_PROCESSES,
        job_memory_warn_mb=settings.livekit.LIVEKIT_JOB_MEMORY_WARN_MB,
        initialize_process_timeout=(
            settings.livekit.LIVEKIT_INITIALIZE_PROCESS_TIMEOUT_SEC
        ),
    )


server = _build_server()


def fallback_session_prefix() -> str | None:
    """Use console-prefixed fallback session id when running `... console`."""
    if any(arg == "console" for arg in sys.argv[1:]):
        return "console"
    return None


def fallback_participant_prefix() -> str | None:
    """Use console-prefixed fallback participant id when running `... console`."""
    if any(arg == "console" for arg in sys.argv[1:]):
        return "console"
    return None


def _resolve_stt_metrics_model_name() -> str:
    provider = settings.stt.STT_PROVIDER.lower()
    if provider == "moonshine":
        return settings.stt.MOONSHINE_MODEL_ID
    if provider == "deepgram":
        return settings.stt.DEEPGRAM_STT_MODEL
    return settings.stt.NVIDIA_STT_MODEL


def _resolve_stt_language() -> str:
    provider = settings.stt.STT_PROVIDER.lower()
    if provider == "moonshine":
        return settings.stt.MOONSHINE_LANGUAGE
    if provider == "deepgram":
        return settings.stt.DEEPGRAM_STT_LANGUAGE
    return settings.stt.NVIDIA_STT_LANGUAGE_CODE


def _build_session_connect_options() -> tuple[APIConnectOptions, SessionConnectOptions]:
    llm_conn_options = build_api_connect_options(
        max_retry=settings.llm.LLM_CONN_MAX_RETRY,
        retry_interval_sec=settings.llm.LLM_CONN_RETRY_INTERVAL_SEC,
        timeout_sec=settings.llm.LLM_CONN_TIMEOUT_SEC,
    )
    tts_conn_options = build_api_connect_options(
        max_retry=settings.llm.LLM_CONN_MAX_RETRY,
        retry_interval_sec=settings.llm.LLM_CONN_RETRY_INTERVAL_SEC,
        timeout_sec=settings.voice.POCKET_TTS_CONN_TIMEOUT_SEC,
    )
    session_conn_options = SessionConnectOptions(
        llm_conn_options=llm_conn_options,
        tts_conn_options=tts_conn_options,
    )
    return llm_conn_options, session_conn_options


@server.rtc_session(agent_name=settings.livekit.LIVEKIT_AGENT_NAME)
async def session_handler(ctx: agents.JobContext) -> None:
    logger.info(
        "Agent session started: room=%s job_id=%s",
        ctx.room.name,
        ctx.job.id,
    )
    trace_provider = setup_langfuse_tracer()
    startup_greeting_task: asyncio.Task[Any] | None = None
    tool_feedback = ToolFeedbackController(enabled=False)

    async def cancel_startup_greeting(_: str) -> None:
        await cancel_task_for_shutdown(
            startup_greeting_task,
            task_name="startup greeting",
        )

    ctx.add_shutdown_callback(cancel_startup_greeting)

    async def close_tool_feedback(_: str) -> None:
        await tool_feedback.aclose()

    ctx.add_shutdown_callback(close_tool_feedback)

    participant = getattr(ctx.job, "participant", None)
    initial_participant_id = getattr(participant, "identity", None)
    room_info = getattr(ctx.job, "room", None)
    initial_room_id = getattr(room_info, "sid", None) or ctx.room.name
    metrics_collector = MetricsCollector(
        room=ctx.room,
        model_name=_resolve_stt_metrics_model_name(),
        room_name=ctx.room.name,
        room_id=initial_room_id,
        participant_id=initial_participant_id,
        fallback_session_prefix=fallback_session_prefix(),
        fallback_participant_prefix=fallback_participant_prefix(),
        langfuse_enabled=trace_provider is not None,
    )

    async def drain_pending_traces(_: str) -> None:
        try:
            await metrics_collector.drain_pending_traces()
        except TimeoutError:
            logger.warning("Timed out while draining pending Langfuse traces during shutdown")
        except Exception as exc:
            logger.warning(f"Failed to drain pending Langfuse traces: {exc}")
        if trace_provider is None:
            return
        try:
            trace_provider.force_flush()
        except Exception as exc:
            logger.warning(f"Failed to flush Langfuse traces: {exc}")

    ctx.add_shutdown_callback(drain_pending_traces)

    if isinstance(ctx.job.metadata, str) and ctx.job.metadata.strip():
        try:
            metadata = json.loads(ctx.job.metadata)
        except Exception:
            metadata = {}
        logger.info(
            "Session metadata received from dispatch: session_id=%s participant_id=%s room=%s",
            metadata.get("session_id"),
            metadata.get("participant_id"),
            ctx.room.name,
        )
        await metrics_collector.on_session_metadata(
            session_id=metadata.get("session_id"),
            participant_id=metadata.get("participant_id"),
        )

    tts_engine = create_tts()
    llm_conn_options, session_conn_options = _build_session_connect_options()
    llm_runtime = build_llm_runtime(settings.llm)
    mcp_runtime_active = llm_runtime.mcp_runtime_active
    tool_feedback = ToolFeedbackController(enabled=mcp_runtime_active)
    logger.info(
        "Running LLM warm-up before session start: provider=%s model=%s",
        llm_runtime.provider,
        llm_runtime.model,
    )
    await run_llm_warmup(
        llm_client=llm_runtime.llm,
        conn_options=llm_conn_options,
        provider=llm_runtime.provider,
        model=llm_runtime.model,
    )

    stt_engine = create_stt()
    logger.info(
        "Turn profile: detector=%s stt_provider=%s stt_model=%s stt_language=%s vad_min_silence=%.2fs min_endpointing=%.2fs max_endpointing=%.2fs preemptive_generation=%s",
        "EnglishModel",
        settings.stt.STT_PROVIDER,
        _resolve_stt_metrics_model_name(),
        _resolve_stt_language(),
        settings.voice.VAD_MIN_SILENCE_DURATION,
        settings.voice.MIN_ENDPOINTING_DELAY,
        settings.voice.MAX_ENDPOINTING_DELAY,
        settings.voice.PREEMPTIVE_GENERATION,
    )

    session_kwargs: dict[str, Any] = dict(
        stt=stt_engine,
        llm=llm_runtime.llm,
        tts=tts_engine,
        vad=silero.VAD.load(
            min_speech_duration=settings.voice.VAD_MIN_SPEECH_DURATION,
            min_silence_duration=settings.voice.VAD_MIN_SILENCE_DURATION,
            activation_threshold=settings.voice.VAD_THRESHOLD,
        ),
        turn_detection=EnglishModel(),
        min_endpointing_delay=settings.voice.MIN_ENDPOINTING_DELAY,
        max_endpointing_delay=settings.voice.MAX_ENDPOINTING_DELAY,
        max_tool_steps=8, 
        preemptive_generation=settings.voice.PREEMPTIVE_GENERATION,
        conn_options=session_conn_options,
    )
    if llm_runtime.mcp_servers is not None:
        session_kwargs["mcp_servers"] = llm_runtime.mcp_servers

    session = AgentSession(**session_kwargs)
    install_mcp_generate_reply_guard(session, mcp_runtime_active=mcp_runtime_active)

    await session.start(
        room=ctx.room,
        record=False,
        agent=Assistant(
            metrics_collector=metrics_collector,
            room_name=ctx.room.name,
            job_id=ctx.job.id,
            tool_feedback=tool_feedback,
        ),
        room_options=room_io.RoomOptions(
            audio_input=room_io.AudioInputOptions(
                sample_rate=settings.voice.LIVEKIT_SAMPLE_RATE,
                num_channels=settings.voice.LIVEKIT_NUM_CHANNELS,
                frame_size_ms=settings.voice.LIVEKIT_FRAME_SIZE_MS,
                pre_connect_audio=settings.voice.LIVEKIT_PRE_CONNECT_AUDIO,
                pre_connect_audio_timeout=settings.voice.LIVEKIT_PRE_CONNECT_TIMEOUT,
                noise_cancellation=lambda params: noise_cancellation.BVCTelephony()
                if params.participant.kind == rtc.ParticipantKind.PARTICIPANT_KIND_SIP
                else noise_cancellation.BVC(),
            ),
        ),
    )
    if all(
        hasattr(metrics_collector, attr_name)
        for attr_name in (
            "submit_streamed_assistant_text_delta",
            "submit_streamed_assistant_text_flush",
            "submit_streamed_assistant_text_context_missing",
        )
    ):
        install_tracing_text_output(
            session=session,
            on_delta=metrics_collector.submit_streamed_assistant_text_delta,
            on_flush=metrics_collector.submit_streamed_assistant_text_flush,
            on_context_missing=metrics_collector.submit_streamed_assistant_text_context_missing,
        )
    await tool_feedback.start(room=ctx.room, session=session)
    if mcp_runtime_active:
        startup_greeting_task = schedule_startup_greeting_task(
            session,
            mcp_runtime_active=mcp_runtime_active,
            timeout_sec=settings.llm.MCP_STARTUP_GREETING_TIMEOUT_SEC,
        )
    else:
        run_startup_greeting(session, mcp_runtime_active=mcp_runtime_active)