Spaces:
Sleeping
Sleeping
| """LLM stream pre-tool feedback injection.""" | |
| from __future__ import annotations | |
| from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable | |
| from typing import Any | |
| from livekit.agents import llm | |
| from src.agent.prompts.runtime import TOOL_PRE_SPEECH_FALLBACK | |
| from src.agent.tools.feedback import ToolFeedbackController | |
| from src.core.logger import logger | |
| async def inject_pre_tool_feedback( | |
| source: AsyncIterable[Any], | |
| *, | |
| tool_feedback: ToolFeedbackController | None, | |
| on_tool_step_started: Callable[[], Awaitable[bool | None]] | None = None, | |
| should_announce_tool_step: Callable[[], Awaitable[bool]] | None = None, | |
| allowed_tool_names: set[str] | None = None, | |
| ) -> AsyncGenerator[Any, None]: | |
| tool_step_started = False | |
| async for chunk in source: | |
| if not isinstance(chunk, llm.ChatChunk): | |
| yield chunk | |
| continue | |
| delta = chunk.delta | |
| has_tool_calls = bool(delta and delta.tool_calls) | |
| if not has_tool_calls: | |
| yield chunk | |
| continue | |
| tool_call_names = _extract_tool_call_names(delta) | |
| if not _tool_calls_supported(tool_call_names, allowed_tool_names): | |
| logger.info( | |
| "tool_pre_speech_skipped reason=unknown_tool_names names=%s", | |
| ",".join(sorted(tool_call_names)) if tool_call_names else "<none>", | |
| ) | |
| yield chunk | |
| continue | |
| if not tool_step_started: | |
| tool_step_started = True | |
| should_announce = True | |
| if should_announce_tool_step is not None: | |
| try: | |
| should_announce = bool(await should_announce_tool_step()) | |
| except Exception as exc: | |
| logger.debug("should_announce_tool_step callback failed: %s", exc) | |
| elif on_tool_step_started is not None: | |
| try: | |
| await on_tool_step_started() | |
| except Exception as exc: | |
| logger.debug("tool_step_started callback failed: %s", exc) | |
| if not should_announce: | |
| logger.debug("tool_pre_speech_skipped reason=announcement_suppressed") | |
| yield chunk | |
| continue | |
| leadin_text = (delta.content or "").strip() if delta is not None else "" | |
| if leadin_text: | |
| logger.info( | |
| "tool_pre_speech_source=model tool_pre_speech_text=%s", | |
| leadin_text, | |
| ) | |
| yield leadin_text | |
| rewritten = chunk.model_copy(deep=True) | |
| if rewritten.delta is not None: | |
| rewritten.delta.content = None | |
| if tool_feedback is not None: | |
| await tool_feedback.start_typing_sound() | |
| yield rewritten | |
| continue | |
| if tool_feedback is not None: | |
| fallback = tool_feedback.next_fallback_phrase() | |
| else: | |
| fallback = TOOL_PRE_SPEECH_FALLBACK | |
| logger.info( | |
| "tool_pre_speech_source=fallback tool_pre_speech_text=%s", | |
| fallback, | |
| ) | |
| yield fallback | |
| if tool_feedback is not None: | |
| await tool_feedback.start_typing_sound() | |
| yield chunk | |
| continue | |
| yield chunk | |
| def _extract_tool_call_names(delta: Any) -> set[str]: | |
| tool_call_names: set[str] = set() | |
| tool_calls = getattr(delta, "tool_calls", None) or [] | |
| for tool_call in tool_calls: | |
| name = _normalize_tool_name(getattr(tool_call, "name", None)) | |
| if name: | |
| tool_call_names.add(name) | |
| continue | |
| function = getattr(tool_call, "function", None) | |
| function_name = _normalize_tool_name(getattr(function, "name", None)) | |
| if function_name: | |
| tool_call_names.add(function_name) | |
| return tool_call_names | |
| def _normalize_tool_name(value: Any) -> str | None: | |
| if not isinstance(value, str): | |
| return None | |
| stripped = value.strip() | |
| return stripped or None | |
| def _tool_calls_supported( | |
| tool_call_names: set[str], | |
| allowed_tool_names: set[str] | None, | |
| ) -> bool: | |
| if allowed_tool_names is None: | |
| return True | |
| if not tool_call_names: | |
| return True | |
| return any(name in allowed_tool_names for name in tool_call_names) | |