Spaces:
Runtime error
Runtime error
| """A fake callback handler for testing purposes.""" | |
| from typing import Any, Dict, List, Union | |
| from pydantic import BaseModel | |
| from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler | |
| from langchain.schema import AgentAction, AgentFinish, LLMResult | |
| class BaseFakeCallbackHandler(BaseModel): | |
| """Base fake callback handler for testing.""" | |
| starts: int = 0 | |
| ends: int = 0 | |
| errors: int = 0 | |
| text: int = 0 | |
| ignore_llm_: bool = False | |
| ignore_chain_: bool = False | |
| ignore_agent_: bool = False | |
| always_verbose_: bool = False | |
| def always_verbose(self) -> bool: | |
| """Whether to call verbose callbacks even if verbose is False.""" | |
| return self.always_verbose_ | |
| def ignore_llm(self) -> bool: | |
| """Whether to ignore LLM callbacks.""" | |
| return self.ignore_llm_ | |
| def ignore_chain(self) -> bool: | |
| """Whether to ignore chain callbacks.""" | |
| return self.ignore_chain_ | |
| def ignore_agent(self) -> bool: | |
| """Whether to ignore agent callbacks.""" | |
| return self.ignore_agent_ | |
| # add finer-grained counters for easier debugging of failing tests | |
| chain_starts: int = 0 | |
| chain_ends: int = 0 | |
| llm_starts: int = 0 | |
| llm_ends: int = 0 | |
| llm_streams: int = 0 | |
| tool_starts: int = 0 | |
| tool_ends: int = 0 | |
| agent_ends: int = 0 | |
| class FakeCallbackHandler(BaseFakeCallbackHandler, BaseCallbackHandler): | |
| """Fake callback handler for testing.""" | |
| def on_llm_start( | |
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
| ) -> None: | |
| """Run when LLM starts running.""" | |
| self.llm_starts += 1 | |
| self.starts += 1 | |
| def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
| """Run when LLM generates a new token.""" | |
| self.llm_streams += 1 | |
| def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |
| """Run when LLM ends running.""" | |
| self.llm_ends += 1 | |
| self.ends += 1 | |
| def on_llm_error( | |
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
| ) -> None: | |
| """Run when LLM errors.""" | |
| self.errors += 1 | |
| def on_chain_start( | |
| self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |
| ) -> None: | |
| """Run when chain starts running.""" | |
| self.chain_starts += 1 | |
| self.starts += 1 | |
| def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |
| """Run when chain ends running.""" | |
| self.chain_ends += 1 | |
| self.ends += 1 | |
| def on_chain_error( | |
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
| ) -> None: | |
| """Run when chain errors.""" | |
| self.errors += 1 | |
| def on_tool_start( | |
| self, serialized: Dict[str, Any], input_str: str, **kwargs: Any | |
| ) -> None: | |
| """Run when tool starts running.""" | |
| self.tool_starts += 1 | |
| self.starts += 1 | |
| def on_tool_end(self, output: str, **kwargs: Any) -> None: | |
| """Run when tool ends running.""" | |
| self.tool_ends += 1 | |
| self.ends += 1 | |
| def on_tool_error( | |
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
| ) -> None: | |
| """Run when tool errors.""" | |
| self.errors += 1 | |
| def on_text(self, text: str, **kwargs: Any) -> None: | |
| """Run when agent is ending.""" | |
| self.text += 1 | |
| def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: | |
| """Run when agent ends running.""" | |
| self.agent_ends += 1 | |
| self.ends += 1 | |
| def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: | |
| """Run on agent action.""" | |
| self.tool_starts += 1 | |
| self.starts += 1 | |
| class FakeAsyncCallbackHandler(BaseFakeCallbackHandler, AsyncCallbackHandler): | |
| """Fake async callback handler for testing.""" | |
| async def on_llm_start( | |
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
| ) -> None: | |
| """Run when LLM starts running.""" | |
| self.llm_starts += 1 | |
| self.starts += 1 | |
| async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
| """Run when LLM generates a new token.""" | |
| self.llm_streams += 1 | |
| async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | |
| """Run when LLM ends running.""" | |
| self.llm_ends += 1 | |
| self.ends += 1 | |
| async def on_llm_error( | |
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
| ) -> None: | |
| """Run when LLM errors.""" | |
| self.errors += 1 | |
| async def on_chain_start( | |
| self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any | |
| ) -> None: | |
| """Run when chain starts running.""" | |
| self.chain_starts += 1 | |
| self.starts += 1 | |
| async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: | |
| """Run when chain ends running.""" | |
| self.chain_ends += 1 | |
| self.ends += 1 | |
| async def on_chain_error( | |
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
| ) -> None: | |
| """Run when chain errors.""" | |
| self.errors += 1 | |
| async def on_tool_start( | |
| self, serialized: Dict[str, Any], input_str: str, **kwargs: Any | |
| ) -> None: | |
| """Run when tool starts running.""" | |
| self.tool_starts += 1 | |
| self.starts += 1 | |
| async def on_tool_end(self, output: str, **kwargs: Any) -> None: | |
| """Run when tool ends running.""" | |
| self.tool_ends += 1 | |
| self.ends += 1 | |
| async def on_tool_error( | |
| self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any | |
| ) -> None: | |
| """Run when tool errors.""" | |
| self.errors += 1 | |
| async def on_text(self, text: str, **kwargs: Any) -> None: | |
| """Run when agent is ending.""" | |
| self.text += 1 | |
| async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> None: | |
| """Run when agent ends running.""" | |
| self.agent_ends += 1 | |
| self.ends += 1 | |
| async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> None: | |
| """Run on agent action.""" | |
| self.tool_starts += 1 | |
| self.starts += 1 | |