open-voice-agent / src /agent /runtime /assistant.py
dvalle08's picture
feat: Enhance metrics collection and assistant item correlation
6549fd6
"""Agent implementation with session event hooks."""
from __future__ import annotations
import asyncio
import contextlib
from collections.abc import AsyncGenerator, AsyncIterable
from typing import Any
from livekit.agents import Agent, llm
from livekit.agents.llm.tool_context import (
get_function_info,
get_raw_function_info,
is_function_tool,
is_raw_function_tool,
)
from livekit.agents.voice.events import (
AgentStateChangedEvent,
CloseEvent,
ConversationItemAddedEvent,
ErrorEvent,
FunctionToolsExecutedEvent,
MetricsCollectedEvent,
SpeechCreatedEvent,
UserInputTranscribedEvent,
)
from src.agent.prompts.assistant import build_assistant_instructions
from src.agent.tools.feedback import ToolFeedbackController
from src.agent.tools.pre_tool_feedback import inject_pre_tool_feedback
from src.agent.traces.errors import error_detail, error_recoverable, error_type_name
from src.agent.traces.metrics_collector import MetricsCollector
from src.core.logger import logger
class Assistant(Agent):
def __init__(
self,
metrics_collector: MetricsCollector,
*,
room_name: str,
job_id: str,
tool_feedback: ToolFeedbackController | None = None,
) -> None:
super().__init__(
instructions=build_assistant_instructions(),
)
self._metrics_collector = metrics_collector
self._room_name = room_name
self._job_id = job_id
self._tool_feedback = tool_feedback
async def llm_node(
self,
chat_ctx: llm.ChatContext,
tools: list[llm.FunctionTool | llm.RawFunctionTool],
model_settings: Any,
) -> AsyncGenerator[llm.ChatChunk | str | Any, None]:
llm_node = Agent.default.llm_node(self, chat_ctx, tools, model_settings)
if asyncio.iscoroutine(llm_node):
llm_node = await llm_node
if isinstance(llm_node, str):
yield llm_node
return
if not isinstance(llm_node, AsyncIterable):
return
allowed_tool_names = _resolve_allowed_tool_names(tools)
aclose = getattr(llm_node, "aclose", None)
try:
async for chunk in inject_pre_tool_feedback(
llm_node,
tool_feedback=self._tool_feedback,
should_announce_tool_step=self._metrics_collector.on_tool_step_started,
allowed_tool_names=allowed_tool_names,
):
yield chunk
finally:
if callable(aclose):
with contextlib.suppress(Exception):
await aclose()
async def on_enter(self) -> None:
"""Called when the agent enters the session. Set up metrics listeners."""
def metrics_wrapper(event: MetricsCollectedEvent) -> None:
self._metrics_collector.submit_metrics_collected(event.metrics)
def transcript_wrapper(event: UserInputTranscribedEvent) -> None:
self._metrics_collector.submit_user_input_transcribed(
event.transcript,
is_final=event.is_final,
)
def conversation_item_wrapper(event: ConversationItemAddedEvent) -> None:
item = event.item
role = getattr(item, "role", None)
content = getattr(item, "content", None)
item_created_at = getattr(item, "created_at", None)
self._metrics_collector.submit_conversation_item_added(
role=role,
item=item,
content=content,
event_created_at=event.created_at,
item_created_at=item_created_at,
)
def speech_created_wrapper(event: SpeechCreatedEvent) -> None:
self._metrics_collector.submit_speech_created(event.speech_handle)
def function_tools_executed_wrapper(event: FunctionToolsExecutedEvent) -> None:
self._metrics_collector.submit_function_tools_executed(
function_calls=event.function_calls,
function_call_outputs=event.function_call_outputs,
created_at=event.created_at,
)
if self._tool_feedback is not None:
asyncio.create_task(
self._tool_feedback.stop_typing_sound(reason="function_tools_executed")
)
def agent_state_changed_wrapper(event: AgentStateChangedEvent) -> None:
self._metrics_collector.submit_agent_state_changed(
old_state=event.old_state,
new_state=event.new_state,
)
def error_wrapper(event: ErrorEvent) -> None:
if self._tool_feedback is not None:
asyncio.create_task(self._tool_feedback.stop_typing_sound(reason="error"))
source = type(event.source).__name__
error_type = error_type_name(event.error)
recoverable = error_recoverable(event.error)
detail = error_detail(event.error)
logger.error(
"Agent session pipeline error: room=%s job_id=%s source=%s error_type=%s recoverable=%s detail=%s",
self._room_name,
self._job_id,
source,
error_type,
recoverable,
detail,
)
def close_wrapper(event: CloseEvent) -> None:
if self._tool_feedback is not None:
asyncio.create_task(
self._tool_feedback.stop_typing_sound(reason=f"close:{event.reason.value}")
)
reason = event.reason.value
if event.error is None:
logger.info(
"Agent session closed: room=%s job_id=%s reason=%s",
self._room_name,
self._job_id,
reason,
)
return
error_type = error_type_name(event.error)
recoverable = error_recoverable(event.error)
detail = error_detail(event.error)
logger.warning(
"Agent session closed with error: room=%s job_id=%s reason=%s error_type=%s recoverable=%s detail=%s",
self._room_name,
self._job_id,
reason,
error_type,
recoverable,
detail,
)
self.session.on("metrics_collected", metrics_wrapper)
self.session.on("user_input_transcribed", transcript_wrapper)
self.session.on("conversation_item_added", conversation_item_wrapper)
self.session.on("speech_created", speech_created_wrapper)
self.session.on("function_tools_executed", function_tools_executed_wrapper)
self.session.on("agent_state_changed", agent_state_changed_wrapper)
self.session.on("error", error_wrapper)
self.session.on("close", close_wrapper)
def _resolve_allowed_tool_names(
tools: list[llm.FunctionTool | llm.RawFunctionTool],
) -> set[str]:
names: set[str] = set()
for tool in tools:
try:
if is_function_tool(tool):
names.add(get_function_info(tool).name)
continue
if is_raw_function_tool(tool):
names.add(get_raw_function_info(tool).name)
continue
except Exception:
pass
fallback_name = _normalize_tool_name(getattr(tool, "name", None))
if fallback_name:
names.add(fallback_name)
continue
raw_schema = getattr(tool, "raw_schema", None)
if isinstance(raw_schema, dict):
raw_name = _normalize_tool_name(raw_schema.get("name"))
if raw_name:
names.add(raw_name)
return names
def _normalize_tool_name(value: Any) -> str | None:
if not isinstance(value, str):
return None
stripped = value.strip()
return stripped or None