Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import logging | |
| import time | |
| from abc import ABC, abstractmethod | |
| from asyncio import CancelledError | |
| from functools import wraps | |
| from typing import ( | |
| TYPE_CHECKING, | |
| Any, | |
| Callable, | |
| Dict, | |
| List, | |
| NoReturn, | |
| Optional, | |
| Tuple, | |
| Type, | |
| Union, | |
| ) | |
| from langchain_core.agents import AgentAction, AgentFinish | |
| from langchain_core.load.dump import dumpd | |
| from langchain_core.outputs import RunInfo | |
| from langchain_core.utils.input import get_color_mapping | |
| from langchain.callbacks.manager import ( | |
| AsyncCallbackManager, | |
| AsyncCallbackManagerForChainRun, | |
| CallbackManager, | |
| CallbackManagerForChainRun, | |
| Callbacks, | |
| ) | |
| from langchain.schema import RUN_KEY | |
| from langchain.tools import BaseTool | |
| from langchain.utilities.asyncio import asyncio_timeout | |
| if TYPE_CHECKING: | |
| from langchain.agents.agent import AgentExecutor | |
| logger = logging.getLogger(__name__) | |
| class BaseAgentExecutorIterator(ABC): | |
| """Base class for AgentExecutorIterator.""" | |
| def build_callback_manager(self) -> None: | |
| pass | |
| def rebuild_callback_manager_on_set( | |
| setter_method: Callable[..., None] | |
| ) -> Callable[..., None]: | |
| """Decorator to force setters to rebuild callback mgr""" | |
| def wrapper(self: BaseAgentExecutorIterator, *args: Any, **kwargs: Any) -> None: | |
| setter_method(self, *args, **kwargs) | |
| self.build_callback_manager() | |
| return wrapper | |
| class AgentExecutorIterator(BaseAgentExecutorIterator): | |
| """Iterator for AgentExecutor.""" | |
| def __init__( | |
| self, | |
| agent_executor: AgentExecutor, | |
| inputs: Any, | |
| callbacks: Callbacks = None, | |
| *, | |
| tags: Optional[list[str]] = None, | |
| include_run_info: bool = False, | |
| async_: bool = False, | |
| ): | |
| """ | |
| Initialize the AgentExecutorIterator with the given AgentExecutor, | |
| inputs, and optional callbacks. | |
| """ | |
| self._agent_executor = agent_executor | |
| self.inputs = inputs | |
| self.async_ = async_ | |
| # build callback manager on tags setter | |
| self._callbacks = callbacks | |
| self.tags = tags | |
| self.include_run_info = include_run_info | |
| self.run_manager = None | |
| self.reset() | |
| _callback_manager: Union[AsyncCallbackManager, CallbackManager] | |
| _inputs: dict[str, str] | |
| _final_outputs: Optional[dict[str, str]] | |
| run_manager: Optional[ | |
| Union[AsyncCallbackManagerForChainRun, CallbackManagerForChainRun] | |
| ] | |
| timeout_manager: Any # TODO: Fix a type here; the shim makes it tricky. | |
| def inputs(self) -> dict[str, str]: | |
| return self._inputs | |
| def inputs(self, inputs: Any) -> None: | |
| self._inputs = self.agent_executor.prep_inputs(inputs) | |
| def callbacks(self) -> Callbacks: | |
| return self._callbacks | |
| def callbacks(self, callbacks: Callbacks) -> None: | |
| """When callbacks are changed after __init__, rebuild callback mgr""" | |
| self._callbacks = callbacks | |
| def tags(self) -> Optional[List[str]]: | |
| return self._tags | |
| def tags(self, tags: Optional[List[str]]) -> None: | |
| """When tags are changed after __init__, rebuild callback mgr""" | |
| self._tags = tags | |
| def agent_executor(self) -> AgentExecutor: | |
| return self._agent_executor | |
| def agent_executor(self, agent_executor: AgentExecutor) -> None: | |
| self._agent_executor = agent_executor | |
| # force re-prep inputs in case agent_executor's prep_inputs fn changed | |
| self.inputs = self.inputs | |
| def callback_manager(self) -> Union[AsyncCallbackManager, CallbackManager]: | |
| return self._callback_manager | |
| def build_callback_manager(self) -> None: | |
| """ | |
| Create and configure the callback manager based on the current | |
| callbacks and tags. | |
| """ | |
| CallbackMgr: Union[Type[AsyncCallbackManager], Type[CallbackManager]] = ( | |
| AsyncCallbackManager if self.async_ else CallbackManager | |
| ) | |
| self._callback_manager = CallbackMgr.configure( | |
| self.callbacks, | |
| self.agent_executor.callbacks, | |
| self.agent_executor.verbose, | |
| self.tags, | |
| self.agent_executor.tags, | |
| ) | |
| def name_to_tool_map(self) -> dict[str, BaseTool]: | |
| return {tool.name: tool for tool in self.agent_executor.tools} | |
| def color_mapping(self) -> dict[str, str]: | |
| return get_color_mapping( | |
| [tool.name for tool in self.agent_executor.tools], | |
| excluded_colors=["green", "red"], | |
| ) | |
| def reset(self) -> None: | |
| """ | |
| Reset the iterator to its initial state, clearing intermediate steps, | |
| iterations, and time elapsed. | |
| """ | |
| logger.debug("(Re)setting AgentExecutorIterator to fresh state") | |
| self.intermediate_steps: list[tuple[AgentAction, str]] = [] | |
| self.iterations = 0 | |
| # maybe better to start these on the first __anext__ call? | |
| self.time_elapsed = 0.0 | |
| self.start_time = time.time() | |
| self._final_outputs = None | |
| def update_iterations(self) -> None: | |
| """ | |
| Increment the number of iterations and update the time elapsed. | |
| """ | |
| self.iterations += 1 | |
| self.time_elapsed = time.time() - self.start_time | |
| logger.debug( | |
| f"Agent Iterations: {self.iterations} ({self.time_elapsed:.2f}s elapsed)" | |
| ) | |
| def raise_stopiteration(self, output: Any) -> NoReturn: | |
| """ | |
| Raise a StopIteration exception with the given output. | |
| """ | |
| logger.debug("Chain end: stop iteration") | |
| raise StopIteration(output) | |
| async def raise_stopasynciteration(self, output: Any) -> NoReturn: | |
| """ | |
| Raise a StopAsyncIteration exception with the given output. | |
| Close the timeout context manager. | |
| """ | |
| logger.debug("Chain end: stop async iteration") | |
| if self.timeout_manager is not None: | |
| await self.timeout_manager.__aexit__(None, None, None) | |
| raise StopAsyncIteration(output) | |
| def final_outputs(self) -> Optional[dict[str, Any]]: | |
| return self._final_outputs | |
| def final_outputs(self, outputs: Optional[Dict[str, Any]]) -> None: | |
| # have access to intermediate steps by design in iterator, | |
| # so return only outputs may as well always be true. | |
| self._final_outputs = None | |
| if outputs: | |
| prepared_outputs: dict[str, Any] = self.agent_executor.prep_outputs( | |
| self.inputs, outputs, return_only_outputs=True | |
| ) | |
| if self.include_run_info and self.run_manager is not None: | |
| logger.debug("Assign run key") | |
| prepared_outputs[RUN_KEY] = RunInfo(run_id=self.run_manager.run_id) | |
| self._final_outputs = prepared_outputs | |
| def __iter__(self: "AgentExecutorIterator") -> "AgentExecutorIterator": | |
| logger.debug("Initialising AgentExecutorIterator") | |
| self.reset() | |
| assert isinstance(self.callback_manager, CallbackManager) | |
| self.run_manager = self.callback_manager.on_chain_start( | |
| dumpd(self.agent_executor), | |
| self.inputs, | |
| ) | |
| return self | |
| def __aiter__(self) -> "AgentExecutorIterator": | |
| """ | |
| N.B. __aiter__ must be a normal method, so need to initialise async run manager | |
| on first __anext__ call where we can await it | |
| """ | |
| logger.debug("Initialising AgentExecutorIterator (async)") | |
| self.reset() | |
| if self.agent_executor.max_execution_time: | |
| self.timeout_manager = asyncio_timeout( | |
| self.agent_executor.max_execution_time | |
| ) | |
| else: | |
| self.timeout_manager = None | |
| return self | |
| def _on_first_step(self) -> None: | |
| """ | |
| Perform any necessary setup for the first step of the synchronous iterator. | |
| """ | |
| pass | |
| async def _on_first_async_step(self) -> None: | |
| """ | |
| Perform any necessary setup for the first step of the asynchronous iterator. | |
| """ | |
| # on first step, need to await callback manager and start async timeout ctxmgr | |
| if self.iterations == 0: | |
| assert isinstance(self.callback_manager, AsyncCallbackManager) | |
| self.run_manager = await self.callback_manager.on_chain_start( | |
| dumpd(self.agent_executor), | |
| self.inputs, | |
| ) | |
| if self.timeout_manager: | |
| await self.timeout_manager.__aenter__() | |
| def __next__(self) -> dict[str, Any]: | |
| """ | |
| AgentExecutor AgentExecutorIterator | |
| __call__ (__iter__ ->) __next__ | |
| _call <=> _call_next | |
| _take_next_step _take_next_step | |
| """ | |
| # first step | |
| if self.iterations == 0: | |
| self._on_first_step() | |
| # N.B. timeout taken care of by "_should_continue" in sync case | |
| try: | |
| return self._call_next() | |
| except StopIteration: | |
| raise | |
| except BaseException as e: | |
| if self.run_manager: | |
| self.run_manager.on_chain_error(e) | |
| raise | |
| async def __anext__(self) -> dict[str, Any]: | |
| """ | |
| AgentExecutor AgentExecutorIterator | |
| acall (__aiter__ ->) __anext__ | |
| _acall <=> _acall_next | |
| _atake_next_step _atake_next_step | |
| """ | |
| if self.iterations == 0: | |
| await self._on_first_async_step() | |
| try: | |
| return await self._acall_next() | |
| except StopAsyncIteration: | |
| raise | |
| except (TimeoutError, CancelledError): | |
| await self.timeout_manager.__aexit__(None, None, None) | |
| self.timeout_manager = None | |
| return await self._astop() | |
| except BaseException as e: | |
| if self.run_manager: | |
| assert isinstance(self.run_manager, AsyncCallbackManagerForChainRun) | |
| await self.run_manager.on_chain_error(e) | |
| raise | |
| def _execute_next_step( | |
| self, run_manager: Optional[CallbackManagerForChainRun] | |
| ) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: | |
| """ | |
| Execute the next step in the chain using the | |
| AgentExecutor's _take_next_step method. | |
| """ | |
| return self.agent_executor._take_next_step( | |
| self.name_to_tool_map, | |
| self.color_mapping, | |
| self.inputs, | |
| self.intermediate_steps, | |
| run_manager=run_manager, | |
| ) | |
| async def _execute_next_async_step( | |
| self, run_manager: Optional[AsyncCallbackManagerForChainRun] | |
| ) -> Union[AgentFinish, List[Tuple[AgentAction, str]]]: | |
| """ | |
| Execute the next step in the chain using the | |
| AgentExecutor's _atake_next_step method. | |
| """ | |
| return await self.agent_executor._atake_next_step( | |
| self.name_to_tool_map, | |
| self.color_mapping, | |
| self.inputs, | |
| self.intermediate_steps, | |
| run_manager=run_manager, | |
| ) | |
| def _process_next_step_output( | |
| self, | |
| next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]], | |
| run_manager: Optional[CallbackManagerForChainRun], | |
| ) -> Dict[str, Union[str, List[Tuple[AgentAction, str]]]]: | |
| """ | |
| Process the output of the next step, | |
| handling AgentFinish and tool return cases. | |
| """ | |
| logger.debug("Processing output of Agent loop step") | |
| if isinstance(next_step_output, AgentFinish): | |
| logger.debug( | |
| "Hit AgentFinish: _return -> on_chain_end -> run final output logic" | |
| ) | |
| output = self.agent_executor._return( | |
| next_step_output, self.intermediate_steps, run_manager=run_manager | |
| ) | |
| if self.run_manager: | |
| self.run_manager.on_chain_end(output) | |
| self.final_outputs = output | |
| return output | |
| self.intermediate_steps.extend(next_step_output) | |
| logger.debug("Updated intermediate_steps with step output") | |
| # Check for tool return | |
| if len(next_step_output) == 1: | |
| next_step_action = next_step_output[0] | |
| tool_return = self.agent_executor._get_tool_return(next_step_action) | |
| if tool_return is not None: | |
| output = self.agent_executor._return( | |
| tool_return, self.intermediate_steps, run_manager=run_manager | |
| ) | |
| if self.run_manager: | |
| self.run_manager.on_chain_end(output) | |
| self.final_outputs = output | |
| return output | |
| output = {"intermediate_step": next_step_output} | |
| return output | |
| async def _aprocess_next_step_output( | |
| self, | |
| next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]], | |
| run_manager: Optional[AsyncCallbackManagerForChainRun], | |
| ) -> Dict[str, Union[str, List[Tuple[AgentAction, str]]]]: | |
| """ | |
| Process the output of the next async step, | |
| handling AgentFinish and tool return cases. | |
| """ | |
| logger.debug("Processing output of async Agent loop step") | |
| if isinstance(next_step_output, AgentFinish): | |
| logger.debug( | |
| "Hit AgentFinish: _areturn -> on_chain_end -> run final output logic" | |
| ) | |
| output = await self.agent_executor._areturn( | |
| next_step_output, self.intermediate_steps, run_manager=run_manager | |
| ) | |
| if run_manager: | |
| await run_manager.on_chain_end(output) | |
| self.final_outputs = output | |
| return output | |
| self.intermediate_steps.extend(next_step_output) | |
| logger.debug("Updated intermediate_steps with step output") | |
| # Check for tool return | |
| if len(next_step_output) == 1: | |
| next_step_action = next_step_output[0] | |
| tool_return = self.agent_executor._get_tool_return(next_step_action) | |
| if tool_return is not None: | |
| output = await self.agent_executor._areturn( | |
| tool_return, self.intermediate_steps, run_manager=run_manager | |
| ) | |
| if run_manager: | |
| await run_manager.on_chain_end(output) | |
| self.final_outputs = output | |
| return output | |
| output = {"intermediate_step": next_step_output} | |
| return output | |
| def _stop(self) -> dict[str, Any]: | |
| """ | |
| Stop the iterator and raise a StopIteration exception with the stopped response. | |
| """ | |
| logger.warning("Stopping agent prematurely due to triggering stop condition") | |
| # this manually constructs agent finish with output key | |
| output = self.agent_executor.agent.return_stopped_response( | |
| self.agent_executor.early_stopping_method, | |
| self.intermediate_steps, | |
| **self.inputs, | |
| ) | |
| assert ( | |
| isinstance(self.run_manager, CallbackManagerForChainRun) | |
| or self.run_manager is None | |
| ) | |
| returned_output = self.agent_executor._return( | |
| output, self.intermediate_steps, run_manager=self.run_manager | |
| ) | |
| self.final_outputs = returned_output | |
| return returned_output | |
| async def _astop(self) -> dict[str, Any]: | |
| """ | |
| Stop the async iterator and raise a StopAsyncIteration exception with | |
| the stopped response. | |
| """ | |
| logger.warning("Stopping agent prematurely due to triggering stop condition") | |
| output = self.agent_executor.agent.return_stopped_response( | |
| self.agent_executor.early_stopping_method, | |
| self.intermediate_steps, | |
| **self.inputs, | |
| ) | |
| assert ( | |
| isinstance(self.run_manager, AsyncCallbackManagerForChainRun) | |
| or self.run_manager is None | |
| ) | |
| returned_output = await self.agent_executor._areturn( | |
| output, self.intermediate_steps, run_manager=self.run_manager | |
| ) | |
| self.final_outputs = returned_output | |
| return returned_output | |
| def _call_next(self) -> dict[str, Any]: | |
| """ | |
| Perform a single iteration of the synchronous AgentExecutorIterator. | |
| """ | |
| # final output already reached: stopiteration (final output) | |
| if self.final_outputs is not None: | |
| self.raise_stopiteration(self.final_outputs) | |
| # timeout/max iterations: stopiteration (stopped response) | |
| if not self.agent_executor._should_continue(self.iterations, self.time_elapsed): | |
| return self._stop() | |
| assert ( | |
| isinstance(self.run_manager, CallbackManagerForChainRun) | |
| or self.run_manager is None | |
| ) | |
| next_step_output = self._execute_next_step(self.run_manager) | |
| output = self._process_next_step_output(next_step_output, self.run_manager) | |
| self.update_iterations() | |
| return output | |
| async def _acall_next(self) -> dict[str, Any]: | |
| """ | |
| Perform a single iteration of the asynchronous AgentExecutorIterator. | |
| """ | |
| # final output already reached: stopiteration (final output) | |
| if self.final_outputs is not None: | |
| await self.raise_stopasynciteration(self.final_outputs) | |
| # timeout/max iterations: stopiteration (stopped response) | |
| if not self.agent_executor._should_continue(self.iterations, self.time_elapsed): | |
| return await self._astop() | |
| assert ( | |
| isinstance(self.run_manager, AsyncCallbackManagerForChainRun) | |
| or self.run_manager is None | |
| ) | |
| next_step_output = await self._execute_next_async_step(self.run_manager) | |
| output = await self._aprocess_next_step_output( | |
| next_step_output, self.run_manager | |
| ) | |
| self.update_iterations() | |
| return output | |