Spaces:
Sleeping
Sleeping
| """Agent implementation with session event hooks.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import contextlib | |
| from collections.abc import AsyncGenerator, AsyncIterable | |
| from typing import Any | |
| from livekit.agents import Agent, llm | |
| from livekit.agents.llm.tool_context import ( | |
| get_function_info, | |
| get_raw_function_info, | |
| is_function_tool, | |
| is_raw_function_tool, | |
| ) | |
| from livekit.agents.voice.events import ( | |
| AgentStateChangedEvent, | |
| CloseEvent, | |
| ConversationItemAddedEvent, | |
| ErrorEvent, | |
| FunctionToolsExecutedEvent, | |
| MetricsCollectedEvent, | |
| SpeechCreatedEvent, | |
| UserInputTranscribedEvent, | |
| ) | |
| from src.agent.prompts.assistant import build_assistant_instructions | |
| from src.agent.tools.feedback import ToolFeedbackController | |
| from src.agent.tools.pre_tool_feedback import inject_pre_tool_feedback | |
| from src.agent.traces.errors import error_detail, error_recoverable, error_type_name | |
| from src.agent.traces.metrics_collector import MetricsCollector | |
| from src.core.logger import logger | |
| class Assistant(Agent): | |
| def __init__( | |
| self, | |
| metrics_collector: MetricsCollector, | |
| *, | |
| room_name: str, | |
| job_id: str, | |
| tool_feedback: ToolFeedbackController | None = None, | |
| ) -> None: | |
| super().__init__( | |
| instructions=build_assistant_instructions(), | |
| ) | |
| self._metrics_collector = metrics_collector | |
| self._room_name = room_name | |
| self._job_id = job_id | |
| self._tool_feedback = tool_feedback | |
| async def llm_node( | |
| self, | |
| chat_ctx: llm.ChatContext, | |
| tools: list[llm.FunctionTool | llm.RawFunctionTool], | |
| model_settings: Any, | |
| ) -> AsyncGenerator[llm.ChatChunk | str | Any, None]: | |
| llm_node = Agent.default.llm_node(self, chat_ctx, tools, model_settings) | |
| if asyncio.iscoroutine(llm_node): | |
| llm_node = await llm_node | |
| if isinstance(llm_node, str): | |
| yield llm_node | |
| return | |
| if not isinstance(llm_node, AsyncIterable): | |
| return | |
| allowed_tool_names = _resolve_allowed_tool_names(tools) | |
| aclose = getattr(llm_node, "aclose", None) | |
| try: | |
| async for chunk in inject_pre_tool_feedback( | |
| llm_node, | |
| tool_feedback=self._tool_feedback, | |
| should_announce_tool_step=self._metrics_collector.on_tool_step_started, | |
| allowed_tool_names=allowed_tool_names, | |
| ): | |
| yield chunk | |
| finally: | |
| if callable(aclose): | |
| with contextlib.suppress(Exception): | |
| await aclose() | |
| async def on_enter(self) -> None: | |
| """Called when the agent enters the session. Set up metrics listeners.""" | |
| def metrics_wrapper(event: MetricsCollectedEvent) -> None: | |
| self._metrics_collector.submit_metrics_collected(event.metrics) | |
| def transcript_wrapper(event: UserInputTranscribedEvent) -> None: | |
| self._metrics_collector.submit_user_input_transcribed( | |
| event.transcript, | |
| is_final=event.is_final, | |
| ) | |
| def conversation_item_wrapper(event: ConversationItemAddedEvent) -> None: | |
| item = event.item | |
| role = getattr(item, "role", None) | |
| content = getattr(item, "content", None) | |
| item_created_at = getattr(item, "created_at", None) | |
| self._metrics_collector.submit_conversation_item_added( | |
| role=role, | |
| item=item, | |
| content=content, | |
| event_created_at=event.created_at, | |
| item_created_at=item_created_at, | |
| ) | |
| def speech_created_wrapper(event: SpeechCreatedEvent) -> None: | |
| self._metrics_collector.submit_speech_created(event.speech_handle) | |
| def function_tools_executed_wrapper(event: FunctionToolsExecutedEvent) -> None: | |
| self._metrics_collector.submit_function_tools_executed( | |
| function_calls=event.function_calls, | |
| function_call_outputs=event.function_call_outputs, | |
| created_at=event.created_at, | |
| ) | |
| if self._tool_feedback is not None: | |
| asyncio.create_task( | |
| self._tool_feedback.stop_typing_sound(reason="function_tools_executed") | |
| ) | |
| def agent_state_changed_wrapper(event: AgentStateChangedEvent) -> None: | |
| self._metrics_collector.submit_agent_state_changed( | |
| old_state=event.old_state, | |
| new_state=event.new_state, | |
| ) | |
| def error_wrapper(event: ErrorEvent) -> None: | |
| if self._tool_feedback is not None: | |
| asyncio.create_task(self._tool_feedback.stop_typing_sound(reason="error")) | |
| source = type(event.source).__name__ | |
| error_type = error_type_name(event.error) | |
| recoverable = error_recoverable(event.error) | |
| detail = error_detail(event.error) | |
| logger.error( | |
| "Agent session pipeline error: room=%s job_id=%s source=%s error_type=%s recoverable=%s detail=%s", | |
| self._room_name, | |
| self._job_id, | |
| source, | |
| error_type, | |
| recoverable, | |
| detail, | |
| ) | |
| def close_wrapper(event: CloseEvent) -> None: | |
| if self._tool_feedback is not None: | |
| asyncio.create_task( | |
| self._tool_feedback.stop_typing_sound(reason=f"close:{event.reason.value}") | |
| ) | |
| reason = event.reason.value | |
| if event.error is None: | |
| logger.info( | |
| "Agent session closed: room=%s job_id=%s reason=%s", | |
| self._room_name, | |
| self._job_id, | |
| reason, | |
| ) | |
| return | |
| error_type = error_type_name(event.error) | |
| recoverable = error_recoverable(event.error) | |
| detail = error_detail(event.error) | |
| logger.warning( | |
| "Agent session closed with error: room=%s job_id=%s reason=%s error_type=%s recoverable=%s detail=%s", | |
| self._room_name, | |
| self._job_id, | |
| reason, | |
| error_type, | |
| recoverable, | |
| detail, | |
| ) | |
| self.session.on("metrics_collected", metrics_wrapper) | |
| self.session.on("user_input_transcribed", transcript_wrapper) | |
| self.session.on("conversation_item_added", conversation_item_wrapper) | |
| self.session.on("speech_created", speech_created_wrapper) | |
| self.session.on("function_tools_executed", function_tools_executed_wrapper) | |
| self.session.on("agent_state_changed", agent_state_changed_wrapper) | |
| self.session.on("error", error_wrapper) | |
| self.session.on("close", close_wrapper) | |
| def _resolve_allowed_tool_names( | |
| tools: list[llm.FunctionTool | llm.RawFunctionTool], | |
| ) -> set[str]: | |
| names: set[str] = set() | |
| for tool in tools: | |
| try: | |
| if is_function_tool(tool): | |
| names.add(get_function_info(tool).name) | |
| continue | |
| if is_raw_function_tool(tool): | |
| names.add(get_raw_function_info(tool).name) | |
| continue | |
| except Exception: | |
| pass | |
| fallback_name = _normalize_tool_name(getattr(tool, "name", None)) | |
| if fallback_name: | |
| names.add(fallback_name) | |
| continue | |
| raw_schema = getattr(tool, "raw_schema", None) | |
| if isinstance(raw_schema, dict): | |
| raw_name = _normalize_tool_name(raw_schema.get("name")) | |
| if raw_name: | |
| names.add(raw_name) | |
| return names | |
| def _normalize_tool_name(value: Any) -> str | None: | |
| if not isinstance(value, str): | |
| return None | |
| stripped = value.strip() | |
| return stripped or None | |