Spaces:
Sleeping
Sleeping
File size: 7,958 Bytes
a1d8504 c13fca7 a1d8504 a24d5f0 a1d8504 a24d5f0 a1d8504 c13fca7 a1d8504 c13fca7 a1d8504 7fdafe4 a1d8504 7fdafe4 a1d8504 758164b 7fdafe4 6549fd6 7fdafe4 a1d8504 7fdafe4 a1d8504 7fdafe4 a1d8504 7fdafe4 a1d8504 c13fca7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | """Agent implementation with session event hooks."""
from __future__ import annotations
import asyncio
import contextlib
from collections.abc import AsyncGenerator, AsyncIterable
from typing import Any
from livekit.agents import Agent, llm
from livekit.agents.llm.tool_context import (
get_function_info,
get_raw_function_info,
is_function_tool,
is_raw_function_tool,
)
from livekit.agents.voice.events import (
AgentStateChangedEvent,
CloseEvent,
ConversationItemAddedEvent,
ErrorEvent,
FunctionToolsExecutedEvent,
MetricsCollectedEvent,
SpeechCreatedEvent,
UserInputTranscribedEvent,
)
from src.agent.prompts.assistant import build_assistant_instructions
from src.agent.tools.feedback import ToolFeedbackController
from src.agent.tools.pre_tool_feedback import inject_pre_tool_feedback
from src.agent.traces.errors import error_detail, error_recoverable, error_type_name
from src.agent.traces.metrics_collector import MetricsCollector
from src.core.logger import logger
class Assistant(Agent):
def __init__(
self,
metrics_collector: MetricsCollector,
*,
room_name: str,
job_id: str,
tool_feedback: ToolFeedbackController | None = None,
) -> None:
super().__init__(
instructions=build_assistant_instructions(),
)
self._metrics_collector = metrics_collector
self._room_name = room_name
self._job_id = job_id
self._tool_feedback = tool_feedback
async def llm_node(
self,
chat_ctx: llm.ChatContext,
tools: list[llm.FunctionTool | llm.RawFunctionTool],
model_settings: Any,
) -> AsyncGenerator[llm.ChatChunk | str | Any, None]:
llm_node = Agent.default.llm_node(self, chat_ctx, tools, model_settings)
if asyncio.iscoroutine(llm_node):
llm_node = await llm_node
if isinstance(llm_node, str):
yield llm_node
return
if not isinstance(llm_node, AsyncIterable):
return
allowed_tool_names = _resolve_allowed_tool_names(tools)
aclose = getattr(llm_node, "aclose", None)
try:
async for chunk in inject_pre_tool_feedback(
llm_node,
tool_feedback=self._tool_feedback,
should_announce_tool_step=self._metrics_collector.on_tool_step_started,
allowed_tool_names=allowed_tool_names,
):
yield chunk
finally:
if callable(aclose):
with contextlib.suppress(Exception):
await aclose()
async def on_enter(self) -> None:
"""Called when the agent enters the session. Set up metrics listeners."""
def metrics_wrapper(event: MetricsCollectedEvent) -> None:
self._metrics_collector.submit_metrics_collected(event.metrics)
def transcript_wrapper(event: UserInputTranscribedEvent) -> None:
self._metrics_collector.submit_user_input_transcribed(
event.transcript,
is_final=event.is_final,
)
def conversation_item_wrapper(event: ConversationItemAddedEvent) -> None:
item = event.item
role = getattr(item, "role", None)
content = getattr(item, "content", None)
item_created_at = getattr(item, "created_at", None)
self._metrics_collector.submit_conversation_item_added(
role=role,
item=item,
content=content,
event_created_at=event.created_at,
item_created_at=item_created_at,
)
def speech_created_wrapper(event: SpeechCreatedEvent) -> None:
self._metrics_collector.submit_speech_created(event.speech_handle)
def function_tools_executed_wrapper(event: FunctionToolsExecutedEvent) -> None:
self._metrics_collector.submit_function_tools_executed(
function_calls=event.function_calls,
function_call_outputs=event.function_call_outputs,
created_at=event.created_at,
)
if self._tool_feedback is not None:
asyncio.create_task(
self._tool_feedback.stop_typing_sound(reason="function_tools_executed")
)
def agent_state_changed_wrapper(event: AgentStateChangedEvent) -> None:
self._metrics_collector.submit_agent_state_changed(
old_state=event.old_state,
new_state=event.new_state,
)
def error_wrapper(event: ErrorEvent) -> None:
if self._tool_feedback is not None:
asyncio.create_task(self._tool_feedback.stop_typing_sound(reason="error"))
source = type(event.source).__name__
error_type = error_type_name(event.error)
recoverable = error_recoverable(event.error)
detail = error_detail(event.error)
logger.error(
"Agent session pipeline error: room=%s job_id=%s source=%s error_type=%s recoverable=%s detail=%s",
self._room_name,
self._job_id,
source,
error_type,
recoverable,
detail,
)
def close_wrapper(event: CloseEvent) -> None:
if self._tool_feedback is not None:
asyncio.create_task(
self._tool_feedback.stop_typing_sound(reason=f"close:{event.reason.value}")
)
reason = event.reason.value
if event.error is None:
logger.info(
"Agent session closed: room=%s job_id=%s reason=%s",
self._room_name,
self._job_id,
reason,
)
return
error_type = error_type_name(event.error)
recoverable = error_recoverable(event.error)
detail = error_detail(event.error)
logger.warning(
"Agent session closed with error: room=%s job_id=%s reason=%s error_type=%s recoverable=%s detail=%s",
self._room_name,
self._job_id,
reason,
error_type,
recoverable,
detail,
)
self.session.on("metrics_collected", metrics_wrapper)
self.session.on("user_input_transcribed", transcript_wrapper)
self.session.on("conversation_item_added", conversation_item_wrapper)
self.session.on("speech_created", speech_created_wrapper)
self.session.on("function_tools_executed", function_tools_executed_wrapper)
self.session.on("agent_state_changed", agent_state_changed_wrapper)
self.session.on("error", error_wrapper)
self.session.on("close", close_wrapper)
def _resolve_allowed_tool_names(
tools: list[llm.FunctionTool | llm.RawFunctionTool],
) -> set[str]:
names: set[str] = set()
for tool in tools:
try:
if is_function_tool(tool):
names.add(get_function_info(tool).name)
continue
if is_raw_function_tool(tool):
names.add(get_raw_function_info(tool).name)
continue
except Exception:
pass
fallback_name = _normalize_tool_name(getattr(tool, "name", None))
if fallback_name:
names.add(fallback_name)
continue
raw_schema = getattr(tool, "raw_schema", None)
if isinstance(raw_schema, dict):
raw_name = _normalize_tool_name(raw_schema.get("name"))
if raw_name:
names.add(raw_name)
return names
def _normalize_tool_name(value: Any) -> str | None:
if not isinstance(value, str):
return None
stripped = value.strip()
return stripped or None
|