Spaces:
Runtime error
Runtime error
File size: 15,524 Bytes
129cd69 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 | """Callback Handler that prints to streamlit."""
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.outputs import LLMResult
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.streamlit.mutable_expander import MutableExpander
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
def _convert_newlines(text: str) -> str:
"""Convert newline characters to markdown newline sequences
(space, space, newline).
"""
return text.replace("\n", " \n")
CHECKMARK_EMOJI = "✅"
THINKING_EMOJI = ":thinking_face:"
HISTORY_EMOJI = ":books:"
EXCEPTION_EMOJI = "⚠️"
class LLMThoughtState(Enum):
"""Enumerator of the LLMThought state."""
# The LLM is thinking about what to do next. We don't know which tool we'll run.
THINKING = "THINKING"
# The LLM has decided to run a tool. We don't have results from the tool yet.
RUNNING_TOOL = "RUNNING_TOOL"
# We have results from the tool.
COMPLETE = "COMPLETE"
class ToolRecord(NamedTuple):
"""The tool record as a NamedTuple."""
name: str
input_str: str
class LLMThoughtLabeler:
"""
Generates markdown labels for LLMThought containers. Pass a custom
subclass of this to StreamlitCallbackHandler to override its default
labeling logic.
"""
def get_initial_label(self) -> str:
"""Return the markdown label for a new LLMThought that doesn't have
an associated tool yet.
"""
return f"{THINKING_EMOJI} **Thinking...**"
def get_tool_label(self, tool: ToolRecord, is_complete: bool) -> str:
"""Return the label for an LLMThought that has an associated
tool.
Parameters
----------
tool
The tool's ToolRecord
is_complete
True if the thought is complete; False if the thought
is still receiving input.
Returns
-------
The markdown label for the thought's container.
"""
input = tool.input_str
name = tool.name
emoji = CHECKMARK_EMOJI if is_complete else THINKING_EMOJI
if name == "_Exception":
emoji = EXCEPTION_EMOJI
name = "Parsing error"
idx = min([60, len(input)])
input = input[0:idx]
if len(tool.input_str) > idx:
input = input + "..."
input = input.replace("\n", " ")
label = f"{emoji} **{name}:** {input}"
return label
def get_history_label(self) -> str:
"""Return a markdown label for the special 'history' container
that contains overflow thoughts.
"""
return f"{HISTORY_EMOJI} **History**"
def get_final_agent_thought_label(self) -> str:
"""Return the markdown label for the agent's final thought -
the "Now I have the answer" thought, that doesn't involve
a tool.
"""
return f"{CHECKMARK_EMOJI} **Complete!**"
class LLMThought:
"""A thought in the LLM's thought stream."""
def __init__(
self,
parent_container: DeltaGenerator,
labeler: LLMThoughtLabeler,
expanded: bool,
collapse_on_complete: bool,
):
"""Initialize the LLMThought.
Args:
parent_container: The container we're writing into.
labeler: The labeler to use for this thought.
expanded: Whether the thought should be expanded by default.
collapse_on_complete: Whether the thought should be collapsed.
"""
self._container = MutableExpander(
parent_container=parent_container,
label=labeler.get_initial_label(),
expanded=expanded,
)
self._state = LLMThoughtState.THINKING
self._llm_token_stream = ""
self._llm_token_writer_idx: Optional[int] = None
self._last_tool: Optional[ToolRecord] = None
self._collapse_on_complete = collapse_on_complete
self._labeler = labeler
@property
def container(self) -> MutableExpander:
"""The container we're writing into."""
return self._container
@property
def last_tool(self) -> Optional[ToolRecord]:
"""The last tool executed by this thought"""
return self._last_tool
def _reset_llm_token_stream(self) -> None:
self._llm_token_stream = ""
self._llm_token_writer_idx = None
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]) -> None:
self._reset_llm_token_stream()
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
# This is only called when the LLM is initialized with `streaming=True`
self._llm_token_stream += _convert_newlines(token)
self._llm_token_writer_idx = self._container.markdown(
self._llm_token_stream, index=self._llm_token_writer_idx
)
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
# `response` is the concatenation of all the tokens received by the LLM.
# If we're receiving streaming tokens from `on_llm_new_token`, this response
# data is redundant
self._reset_llm_token_stream()
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self._container.markdown("**LLM encountered an error...**")
self._container.exception(error)
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
# Called with the name of the tool we're about to run (in `serialized[name]`),
# and its input. We change our container's label to be the tool name.
self._state = LLMThoughtState.RUNNING_TOOL
tool_name = serialized["name"]
self._last_tool = ToolRecord(name=tool_name, input_str=input_str)
self._container.update(
new_label=self._labeler.get_tool_label(self._last_tool, is_complete=False)
)
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
self._container.markdown(f"**{output}**")
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self._container.markdown("**Tool encountered an error...**")
self._container.exception(error)
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
# Called when we're about to kick off a new tool. The `action` data
# tells us the tool we're about to use, and the input we'll give it.
# We don't output anything here, because we'll receive this same data
# when `on_tool_start` is called immediately after.
pass
def complete(self, final_label: Optional[str] = None) -> None:
"""Finish the thought."""
if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL:
assert (
self._last_tool is not None
), "_last_tool should never be null when _state == RUNNING_TOOL"
final_label = self._labeler.get_tool_label(
self._last_tool, is_complete=True
)
self._state = LLMThoughtState.COMPLETE
if self._collapse_on_complete:
self._container.update(new_label=final_label, new_expanded=False)
else:
self._container.update(new_label=final_label)
def clear(self) -> None:
"""Remove the thought from the screen. A cleared thought can't be reused."""
self._container.clear()
class StreamlitCallbackHandler(BaseCallbackHandler):
"""A callback handler that writes to a Streamlit app."""
def __init__(
self,
parent_container: DeltaGenerator,
*,
max_thought_containers: int = 4,
expand_new_thoughts: bool = True,
collapse_completed_thoughts: bool = True,
thought_labeler: Optional[LLMThoughtLabeler] = None,
):
"""Create a StreamlitCallbackHandler instance.
Parameters
----------
parent_container
The `st.container` that will contain all the Streamlit elements that the
Handler creates.
max_thought_containers
The max number of completed LLM thought containers to show at once. When
this threshold is reached, a new thought will cause the oldest thoughts to
be collapsed into a "History" expander. Defaults to 4.
expand_new_thoughts
Each LLM "thought" gets its own `st.expander`. This param controls whether
that expander is expanded by default. Defaults to True.
collapse_completed_thoughts
If True, LLM thought expanders will be collapsed when completed.
Defaults to True.
thought_labeler
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
will use the default thought labeling logic. Defaults to None.
"""
self._parent_container = parent_container
self._history_parent = parent_container.container()
self._history_container: Optional[MutableExpander] = None
self._current_thought: Optional[LLMThought] = None
self._completed_thoughts: List[LLMThought] = []
self._max_thought_containers = max(max_thought_containers, 1)
self._expand_new_thoughts = expand_new_thoughts
self._collapse_completed_thoughts = collapse_completed_thoughts
self._thought_labeler = thought_labeler or LLMThoughtLabeler()
def _require_current_thought(self) -> LLMThought:
"""Return our current LLMThought. Raise an error if we have no current
thought.
"""
if self._current_thought is None:
raise RuntimeError("Current LLMThought is unexpectedly None!")
return self._current_thought
def _get_last_completed_thought(self) -> Optional[LLMThought]:
"""Return our most recent completed LLMThought, or None if we don't have one."""
if len(self._completed_thoughts) > 0:
return self._completed_thoughts[len(self._completed_thoughts) - 1]
return None
@property
def _num_thought_containers(self) -> int:
"""The number of 'thought containers' we're currently showing: the
number of completed thought containers, the history container (if it exists),
and the current thought container (if it exists).
"""
count = len(self._completed_thoughts)
if self._history_container is not None:
count += 1
if self._current_thought is not None:
count += 1
return count
def _complete_current_thought(self, final_label: Optional[str] = None) -> None:
"""Complete the current thought, optionally assigning it a new label.
Add it to our _completed_thoughts list.
"""
thought = self._require_current_thought()
thought.complete(final_label)
self._completed_thoughts.append(thought)
self._current_thought = None
def _prune_old_thought_containers(self) -> None:
"""If we have too many thoughts onscreen, move older thoughts to the
'history container.'
"""
while (
self._num_thought_containers > self._max_thought_containers
and len(self._completed_thoughts) > 0
):
# Create our history container if it doesn't exist, and if
# max_thought_containers is > 1. (if max_thought_containers is 1, we don't
# have room to show history.)
if self._history_container is None and self._max_thought_containers > 1:
self._history_container = MutableExpander(
self._history_parent,
label=self._thought_labeler.get_history_label(),
expanded=False,
)
oldest_thought = self._completed_thoughts.pop(0)
if self._history_container is not None:
self._history_container.markdown(oldest_thought.container.label)
self._history_container.append_copy(oldest_thought.container)
oldest_thought.clear()
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
if self._current_thought is None:
self._current_thought = LLMThought(
parent_container=self._parent_container,
expanded=self._expand_new_thoughts,
collapse_on_complete=self._collapse_completed_thoughts,
labeler=self._thought_labeler,
)
self._current_thought.on_llm_start(serialized, prompts)
# We don't prune_old_thought_containers here, because our container won't
# be visible until it has a child.
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self._require_current_thought().on_llm_new_token(token, **kwargs)
self._prune_old_thought_containers()
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self._require_current_thought().on_llm_end(response, **kwargs)
self._prune_old_thought_containers()
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self._require_current_thought().on_llm_error(error, **kwargs)
self._prune_old_thought_containers()
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
self._require_current_thought().on_tool_start(serialized, input_str, **kwargs)
self._prune_old_thought_containers()
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
self._require_current_thought().on_tool_end(
output, color, observation_prefix, llm_prefix, **kwargs
)
self._complete_current_thought()
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self._require_current_thought().on_tool_error(error, **kwargs)
self._prune_old_thought_containers()
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Any,
) -> None:
pass
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
self._require_current_thought().on_agent_action(action, color, **kwargs)
self._prune_old_thought_containers()
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
if self._current_thought is not None:
self._current_thought.complete(
self._thought_labeler.get_final_agent_thought_label()
)
self._current_thought = None
|