| | """Callback Handler captures all callbacks in a session for future offline playback.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import pickle |
| | import time |
| | from typing import Any, TypedDict |
| |
|
| | from langchain.callbacks.base import BaseCallbackHandler |
| |
|
| |
|
| | |
| | |
| | class CallbackType: |
| | ON_LLM_START = "on_llm_start" |
| | ON_LLM_NEW_TOKEN = "on_llm_new_token" |
| | ON_LLM_END = "on_llm_end" |
| | ON_LLM_ERROR = "on_llm_error" |
| | ON_TOOL_START = "on_tool_start" |
| | ON_TOOL_END = "on_tool_end" |
| | ON_TOOL_ERROR = "on_tool_error" |
| | ON_TEXT = "on_text" |
| | ON_CHAIN_START = "on_chain_start" |
| | ON_CHAIN_END = "on_chain_end" |
| | ON_CHAIN_ERROR = "on_chain_error" |
| | ON_AGENT_ACTION = "on_agent_action" |
| | ON_AGENT_FINISH = "on_agent_finish" |
| |
|
| |
|
| | |
| | |
| | class CallbackRecord(TypedDict): |
| | callback_type: str |
| | args: tuple[Any, ...] |
| | kwargs: dict[str, Any] |
| | time_delta: float |
| |
|
| |
|
| | def load_records_from_file(path: str) -> list[CallbackRecord]: |
| | """Load the list of CallbackRecords from a pickle file at the given path.""" |
| | with open(path, "rb") as file: |
| | records = pickle.load(file) |
| |
|
| | if not isinstance(records, list): |
| | raise RuntimeError(f"Bad CallbackRecord data in {path}") |
| | return records |
| |
|
| |
|
| | def playback_callbacks( |
| | handlers: list[BaseCallbackHandler], |
| | records_or_filename: list[CallbackRecord] | str, |
| | max_pause_time: float, |
| | ) -> str: |
| | if isinstance(records_or_filename, list): |
| | records = records_or_filename |
| | else: |
| | records = load_records_from_file(records_or_filename) |
| |
|
| | for record in records: |
| | pause_time = min(record["time_delta"], max_pause_time) |
| | if pause_time > 0: |
| | time.sleep(pause_time) |
| |
|
| | for handler in handlers: |
| | if record["callback_type"] == CallbackType.ON_LLM_START: |
| | handler.on_llm_start(*record["args"], **record["kwargs"]) |
| | elif record["callback_type"] == CallbackType.ON_LLM_NEW_TOKEN: |
| | handler.on_llm_new_token(*record["args"], **record["kwargs"]) |
| | elif record["callback_type"] == CallbackType.ON_LLM_END: |
| | handler.on_llm_end(*record["args"], **record["kwargs"]) |
| | elif record["callback_type"] == CallbackType.ON_LLM_ERROR: |
| | handler.on_llm_error(*record["args"], **record["kwargs"]) |
| | elif record["callback_type"] == CallbackType.ON_TOOL_START: |
| | handler.on_tool_start(*record["args"], **record["kwargs"]) |
| | elif record["callback_type"] == CallbackType.ON_TOOL_END: |
| | handler.on_tool_end(*record["args"], **record["kwargs"]) |
| | elif record["callback_type"] == CallbackType.ON_TOOL_ERROR: |
| | handler.on_tool_error(*record["args"], **record["kwargs"]) |
| | elif record["callback_type"] == CallbackType.ON_TEXT: |
| | handler.on_text(*record["args"], **record["kwargs"]) |
| | elif record["callback_type"] == CallbackType.ON_CHAIN_START: |
| | handler.on_chain_start(*record["args"], **record["kwargs"]) |
| | elif record["callback_type"] == CallbackType.ON_CHAIN_END: |
| | handler.on_chain_end(*record["args"], **record["kwargs"]) |
| | elif record["callback_type"] == CallbackType.ON_CHAIN_ERROR: |
| | handler.on_chain_error(*record["args"], **record["kwargs"]) |
| | elif record["callback_type"] == CallbackType.ON_AGENT_ACTION: |
| | handler.on_agent_action(*record["args"], **record["kwargs"]) |
| | elif record["callback_type"] == CallbackType.ON_AGENT_FINISH: |
| | handler.on_agent_finish(*record["args"], **record["kwargs"]) |
| |
|
| | |
| | for record in records: |
| | if record["callback_type"] == CallbackType.ON_AGENT_FINISH: |
| | return record["args"][0][0]["output"] |
| |
|
| | return "[Missing Agent Result]" |
| |
|
| |
|
| | class CapturingCallbackHandler(BaseCallbackHandler): |
| | def __init__(self) -> None: |
| | self._records: list[CallbackRecord] = [] |
| | self._last_time: float | None = None |
| |
|
| | def dump_records_to_file(self, path: str) -> None: |
| | """Write the list of CallbackRecords to a pickle file at the given path.""" |
| | with open(path, "wb") as file: |
| | pickle.dump(self._records, file) |
| |
|
| | def _append_record( |
| | self, type: str, args: tuple[Any, ...], kwargs: dict[str, Any] |
| | ) -> None: |
| | time_now = time.time() |
| | time_delta = time_now - self._last_time if self._last_time is not None else 0 |
| | self._last_time = time_now |
| | self._records.append( |
| | CallbackRecord( |
| | callback_type=type, args=args, kwargs=kwargs, time_delta=time_delta |
| | ) |
| | ) |
| |
|
| | def on_llm_start(self, *args: Any, **kwargs: Any) -> None: |
| | self._append_record(CallbackType.ON_LLM_START, args, kwargs) |
| |
|
| | def on_llm_new_token(self, *args: Any, **kwargs: Any) -> None: |
| | self._append_record(CallbackType.ON_LLM_NEW_TOKEN, args, kwargs) |
| |
|
| | def on_llm_end(self, *args: Any, **kwargs: Any) -> None: |
| | self._append_record(CallbackType.ON_LLM_END, args, kwargs) |
| |
|
| | def on_llm_error(self, *args: Any, **kwargs: Any) -> None: |
| | self._append_record(CallbackType.ON_LLM_ERROR, args, kwargs) |
| |
|
| | def on_tool_start(self, *args: Any, **kwargs: Any) -> None: |
| | self._append_record(CallbackType.ON_TOOL_START, args, kwargs) |
| |
|
| | def on_tool_end(self, *args: Any, **kwargs: Any) -> None: |
| | self._append_record(CallbackType.ON_TOOL_END, args, kwargs) |
| |
|
| | def on_tool_error(self, *args: Any, **kwargs: Any) -> None: |
| | self._append_record(CallbackType.ON_TOOL_ERROR, args, kwargs) |
| |
|
| | def on_text(self, *args: Any, **kwargs: Any) -> None: |
| | self._append_record(CallbackType.ON_TEXT, args, kwargs) |
| |
|
| | def on_chain_start(self, *args: Any, **kwargs: Any) -> None: |
| | self._append_record(CallbackType.ON_CHAIN_START, args, kwargs) |
| |
|
| | def on_chain_end(self, *args: Any, **kwargs: Any) -> None: |
| | self._append_record(CallbackType.ON_CHAIN_END, args, kwargs) |
| |
|
| | def on_chain_error(self, *args: Any, **kwargs: Any) -> None: |
| | self._append_record(CallbackType.ON_CHAIN_ERROR, args, kwargs) |
| |
|
| | def on_agent_action(self, *args: Any, **kwargs: Any) -> Any: |
| | self._append_record(CallbackType.ON_AGENT_ACTION, args, kwargs) |
| |
|
| | def on_agent_finish(self, *args: Any, **kwargs: Any) -> None: |
| | self._append_record(CallbackType.ON_AGENT_FINISH, args, kwargs) |
| |
|