Spaces:
Sleeping
Sleeping
File size: 4,392 Bytes
a1d8504 c13fca7 a1d8504 c13fca7 a1d8504 c13fca7 a1d8504 c13fca7 a1d8504 c13fca7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | """LLM stream pre-tool feedback injection."""
from __future__ import annotations
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable
from typing import Any
from livekit.agents import llm
from src.agent.prompts.runtime import TOOL_PRE_SPEECH_FALLBACK
from src.agent.tools.feedback import ToolFeedbackController
from src.core.logger import logger
async def inject_pre_tool_feedback(
source: AsyncIterable[Any],
*,
tool_feedback: ToolFeedbackController | None,
on_tool_step_started: Callable[[], Awaitable[bool | None]] | None = None,
should_announce_tool_step: Callable[[], Awaitable[bool]] | None = None,
allowed_tool_names: set[str] | None = None,
) -> AsyncGenerator[Any, None]:
tool_step_started = False
async for chunk in source:
if not isinstance(chunk, llm.ChatChunk):
yield chunk
continue
delta = chunk.delta
has_tool_calls = bool(delta and delta.tool_calls)
if not has_tool_calls:
yield chunk
continue
tool_call_names = _extract_tool_call_names(delta)
if not _tool_calls_supported(tool_call_names, allowed_tool_names):
logger.info(
"tool_pre_speech_skipped reason=unknown_tool_names names=%s",
",".join(sorted(tool_call_names)) if tool_call_names else "<none>",
)
yield chunk
continue
if not tool_step_started:
tool_step_started = True
should_announce = True
if should_announce_tool_step is not None:
try:
should_announce = bool(await should_announce_tool_step())
except Exception as exc:
logger.debug("should_announce_tool_step callback failed: %s", exc)
elif on_tool_step_started is not None:
try:
await on_tool_step_started()
except Exception as exc:
logger.debug("tool_step_started callback failed: %s", exc)
if not should_announce:
logger.debug("tool_pre_speech_skipped reason=announcement_suppressed")
yield chunk
continue
leadin_text = (delta.content or "").strip() if delta is not None else ""
if leadin_text:
logger.info(
"tool_pre_speech_source=model tool_pre_speech_text=%s",
leadin_text,
)
yield leadin_text
rewritten = chunk.model_copy(deep=True)
if rewritten.delta is not None:
rewritten.delta.content = None
if tool_feedback is not None:
await tool_feedback.start_typing_sound()
yield rewritten
continue
if tool_feedback is not None:
fallback = tool_feedback.next_fallback_phrase()
else:
fallback = TOOL_PRE_SPEECH_FALLBACK
logger.info(
"tool_pre_speech_source=fallback tool_pre_speech_text=%s",
fallback,
)
yield fallback
if tool_feedback is not None:
await tool_feedback.start_typing_sound()
yield chunk
continue
yield chunk
def _extract_tool_call_names(delta: Any) -> set[str]:
tool_call_names: set[str] = set()
tool_calls = getattr(delta, "tool_calls", None) or []
for tool_call in tool_calls:
name = _normalize_tool_name(getattr(tool_call, "name", None))
if name:
tool_call_names.add(name)
continue
function = getattr(tool_call, "function", None)
function_name = _normalize_tool_name(getattr(function, "name", None))
if function_name:
tool_call_names.add(function_name)
return tool_call_names
def _normalize_tool_name(value: Any) -> str | None:
if not isinstance(value, str):
return None
stripped = value.strip()
return stripped or None
def _tool_calls_supported(
tool_call_names: set[str],
allowed_tool_names: set[str] | None,
) -> bool:
if allowed_tool_names is None:
return True
if not tool_call_names:
return True
return any(name in allowed_tool_names for name in tool_call_names)
|