File size: 4,392 Bytes
a1d8504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c13fca7
 
 
a1d8504
 
 
 
 
 
 
 
 
 
 
 
 
 
c13fca7
 
 
 
 
 
 
 
 
a1d8504
 
c13fca7
 
 
 
 
 
 
a1d8504
 
 
 
c13fca7
 
 
 
 
 
a1d8504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c13fca7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""LLM stream pre-tool feedback injection."""

from __future__ import annotations

from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable
from typing import Any

from livekit.agents import llm

from src.agent.prompts.runtime import TOOL_PRE_SPEECH_FALLBACK
from src.agent.tools.feedback import ToolFeedbackController
from src.core.logger import logger


async def inject_pre_tool_feedback(
    source: AsyncIterable[Any],
    *,
    tool_feedback: ToolFeedbackController | None,
    on_tool_step_started: Callable[[], Awaitable[bool | None]] | None = None,
    should_announce_tool_step: Callable[[], Awaitable[bool]] | None = None,
    allowed_tool_names: set[str] | None = None,
) -> AsyncGenerator[Any, None]:
    tool_step_started = False

    async for chunk in source:
        if not isinstance(chunk, llm.ChatChunk):
            yield chunk
            continue

        delta = chunk.delta
        has_tool_calls = bool(delta and delta.tool_calls)
        if not has_tool_calls:
            yield chunk
            continue

        tool_call_names = _extract_tool_call_names(delta)
        if not _tool_calls_supported(tool_call_names, allowed_tool_names):
            logger.info(
                "tool_pre_speech_skipped reason=unknown_tool_names names=%s",
                ",".join(sorted(tool_call_names)) if tool_call_names else "<none>",
            )
            yield chunk
            continue

        if not tool_step_started:
            tool_step_started = True
            should_announce = True
            if should_announce_tool_step is not None:
                try:
                    should_announce = bool(await should_announce_tool_step())
                except Exception as exc:
                    logger.debug("should_announce_tool_step callback failed: %s", exc)
            elif on_tool_step_started is not None:
                try:
                    await on_tool_step_started()
                except Exception as exc:
                    logger.debug("tool_step_started callback failed: %s", exc)

            if not should_announce:
                logger.debug("tool_pre_speech_skipped reason=announcement_suppressed")
                yield chunk
                continue

            leadin_text = (delta.content or "").strip() if delta is not None else ""
            if leadin_text:
                logger.info(
                    "tool_pre_speech_source=model tool_pre_speech_text=%s",
                    leadin_text,
                )
                yield leadin_text

                rewritten = chunk.model_copy(deep=True)
                if rewritten.delta is not None:
                    rewritten.delta.content = None
                if tool_feedback is not None:
                    await tool_feedback.start_typing_sound()
                yield rewritten
                continue

            if tool_feedback is not None:
                fallback = tool_feedback.next_fallback_phrase()
            else:
                fallback = TOOL_PRE_SPEECH_FALLBACK
            logger.info(
                "tool_pre_speech_source=fallback tool_pre_speech_text=%s",
                fallback,
            )
            yield fallback
            if tool_feedback is not None:
                await tool_feedback.start_typing_sound()
            yield chunk
            continue

        yield chunk


def _extract_tool_call_names(delta: Any) -> set[str]:
    tool_call_names: set[str] = set()
    tool_calls = getattr(delta, "tool_calls", None) or []
    for tool_call in tool_calls:
        name = _normalize_tool_name(getattr(tool_call, "name", None))
        if name:
            tool_call_names.add(name)
            continue
        function = getattr(tool_call, "function", None)
        function_name = _normalize_tool_name(getattr(function, "name", None))
        if function_name:
            tool_call_names.add(function_name)
    return tool_call_names


def _normalize_tool_name(value: Any) -> str | None:
    if not isinstance(value, str):
        return None
    stripped = value.strip()
    return stripped or None


def _tool_calls_supported(
    tool_call_names: set[str],
    allowed_tool_names: set[str] | None,
) -> bool:
    if allowed_tool_names is None:
        return True
    if not tool_call_names:
        return True
    return any(name in allowed_tool_names for name in tool_call_names)