TeLLAgent / streamlit_callback_handler.py
jinysun's picture
Upload 9 files
dbaa85f verified
from typing import Any, Dict, List, Optional
from langchain.callbacks.streamlit.streamlit_callback_handler import (
LLMThought,
LLMThoughtLabeler,
LLMThoughtState,
StreamlitCallbackHandler,
ToolRecord,
)
from langchain_core.agents import AgentAction, AgentFinish
from streamlit.delta_generator import DeltaGenerator
from utils import is_smiles
import requests
from langchain import LLMChain, PromptTemplate
from langchain.chat_models import ChatOpenAI
from rdkit import Chem
def cdk(smiles):
"""
Get a depiction of some smiles.
"""
url = "https://www.simolecule.com/cdkdepict/depict/wob/svg"
headers = {"Content-Type": "application/json"}
response = requests.get(
url,
headers=headers,
params={
"smi": smiles,
"annotate": "colmap",
"zoom": 2,
"w": 150,
"h": 80,
"abbr": "off",
},
)
return response.text
class LLMThoughtChem(LLMThought):
def __init__(
self,
parent_container: DeltaGenerator,
labeler: LLMThoughtLabeler,
expanded: bool,
collapse_on_complete: bool,
):
super().__init__(
parent_container,
labeler,
expanded,
collapse_on_complete,
)
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
output_ph: dict = {},
input_tool: str = "",
serialized: dict = {},
**kwargs: Any,
) -> None:
# Depending on the tool name, decide what to display.
if serialized["name"] == "Name2SMILES":
safe_smiles = output.replace("[", "\[").replace("]", "\]")
if is_smiles(output):
self._container.markdown(
f"**{safe_smiles}**{cdk(output)}", unsafe_allow_html=True
)
if serialized["name"] == "ReactionPredict":
rxn = f"{input_tool}>>{output}"
safe_smiles = rxn.replace("[", "\[").replace("]", "\]")
self._container.markdown(
f"**{safe_smiles}**{cdk(rxn)}", unsafe_allow_html=True
)
if serialized["name"] == "ReactionRetrosynthesis":
output = output.replace("[", "\[").replace("]", "\]")
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
# Called with the name of the tool we're about to run (in `serialized[name]`),
# and its input. We change our container's label to be the tool name.
self._state = LLMThoughtState.RUNNING_TOOL
tool_name = serialized["name"]
self._last_tool = ToolRecord(name=tool_name, input_str=input_str)
self._container.update(
new_label=(
self._labeler.get_tool_label(self._last_tool, is_complete=False)
.replace("[", "\[")
.replace("]", "\]")
)
)
# Display note of potential long time
if serialized["name"] == "ReactionRetrosynthesis" or serialized["name"] == "LiteratureSearch":
self._container.markdown(
f"‼️ Note: This tool can take some time to complete execution ‼️",
unsafe_allow_html=True,
)
def complete(self, final_label: Optional[str] = None) -> None:
"""Finish the thought."""
if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL:
assert (
self._last_tool is not None
), "_last_tool should never be null when _state == RUNNING_TOOL"
final_label = self._labeler.get_tool_label(
self._last_tool, is_complete=True
)
self._state = LLMThoughtState.COMPLETE
final_label = final_label.replace("[", "\[").replace("]", "\]")
if self._collapse_on_complete:
self._container.update(new_label=final_label, new_expanded=False)
else:
self._container.update(new_label=final_label)
class StreamlitCallbackHandlerChem(StreamlitCallbackHandler):
def __init__(
self,
parent_container: DeltaGenerator,
*,
max_thought_containers: int = 4,
expand_new_thoughts: bool = True,
collapse_completed_thoughts: bool = True,
thought_labeler: Optional[LLMThoughtLabeler] = None,
output_placeholder: dict = {},
):
super(StreamlitCallbackHandlerChem, self).__init__(
parent_container,
max_thought_containers=max_thought_containers,
expand_new_thoughts=expand_new_thoughts,
collapse_completed_thoughts=collapse_completed_thoughts,
thought_labeler=thought_labeler,
)
self._output_placeholder = output_placeholder
self.last_input = ""
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
if self._current_thought is None:
self._current_thought = LLMThoughtChem(
parent_container=self._parent_container,
expanded=self._expand_new_thoughts,
collapse_on_complete=self._collapse_completed_thoughts,
labeler=self._thought_labeler,
)
self._current_thought.on_llm_start(serialized, prompts)
# We don't prune_old_thought_containers here, because our container won't
# be visible until it has a child.
def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
) -> None:
self._require_current_thought().on_tool_start(serialized, input_str, **kwargs)
self._prune_old_thought_containers()
self._last_input = input_str
self._serialized = serialized
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
self._require_current_thought().on_tool_end(
output,
color,
observation_prefix,
llm_prefix,
output_ph=self._output_placeholder,
input_tool=self._last_input,
serialized=self._serialized,
**kwargs,
)
self._complete_current_thought()
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
if self._current_thought is not None:
self._current_thought.complete(
self._thought_labeler.get_final_agent_thought_label()
.replace("[", "\[")
.replace("]", "\]")
)
self._current_thought = None