|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
|
|
|
|
|
LLMThought,
|
|
|
LLMThoughtLabeler,
|
|
|
LLMThoughtState,
|
|
|
StreamlitCallbackHandler,
|
|
|
ToolRecord,
|
|
|
)
|
|
|
from langchain_core.agents import AgentAction, AgentFinish
|
|
|
from streamlit.delta_generator import DeltaGenerator
|
|
|
|
|
|
from utils import is_smiles
|
|
|
|
|
|
import requests
|
|
|
from langchain import LLMChain, PromptTemplate
|
|
|
from langchain.chat_models import ChatOpenAI
|
|
|
from rdkit import Chem
|
|
|
|
|
|
|
|
|
def cdk(smiles):
|
|
|
"""
|
|
|
Get a depiction of some smiles.
|
|
|
"""
|
|
|
|
|
|
url = "https://www.simolecule.com/cdkdepict/depict/wob/svg"
|
|
|
headers = {"Content-Type": "application/json"}
|
|
|
response = requests.get(
|
|
|
url,
|
|
|
headers=headers,
|
|
|
params={
|
|
|
"smi": smiles,
|
|
|
"annotate": "colmap",
|
|
|
"zoom": 2,
|
|
|
"w": 150,
|
|
|
"h": 80,
|
|
|
"abbr": "off",
|
|
|
},
|
|
|
)
|
|
|
return response.text
|
|
|
|
|
|
|
|
|
class LLMThoughtChem(LLMThought):
|
|
|
def __init__(
|
|
|
self,
|
|
|
parent_container: DeltaGenerator,
|
|
|
labeler: LLMThoughtLabeler,
|
|
|
expanded: bool,
|
|
|
collapse_on_complete: bool,
|
|
|
):
|
|
|
super().__init__(
|
|
|
parent_container,
|
|
|
labeler,
|
|
|
expanded,
|
|
|
collapse_on_complete,
|
|
|
)
|
|
|
|
|
|
def on_tool_end(
|
|
|
self,
|
|
|
output: str,
|
|
|
color: Optional[str] = None,
|
|
|
observation_prefix: Optional[str] = None,
|
|
|
llm_prefix: Optional[str] = None,
|
|
|
output_ph: dict = {},
|
|
|
input_tool: str = "",
|
|
|
serialized: dict = {},
|
|
|
**kwargs: Any,
|
|
|
) -> None:
|
|
|
|
|
|
if serialized["name"] == "Name2SMILES":
|
|
|
safe_smiles = output.replace("[", "\[").replace("]", "\]")
|
|
|
if is_smiles(output):
|
|
|
self._container.markdown(
|
|
|
f"**{safe_smiles}**{cdk(output)}", unsafe_allow_html=True
|
|
|
)
|
|
|
|
|
|
if serialized["name"] == "ReactionPredict":
|
|
|
rxn = f"{input_tool}>>{output}"
|
|
|
safe_smiles = rxn.replace("[", "\[").replace("]", "\]")
|
|
|
self._container.markdown(
|
|
|
f"**{safe_smiles}**{cdk(rxn)}", unsafe_allow_html=True
|
|
|
)
|
|
|
|
|
|
if serialized["name"] == "ReactionRetrosynthesis":
|
|
|
output = output.replace("[", "\[").replace("]", "\]")
|
|
|
|
|
|
def on_tool_start(
|
|
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
|
|
) -> None:
|
|
|
|
|
|
|
|
|
self._state = LLMThoughtState.RUNNING_TOOL
|
|
|
tool_name = serialized["name"]
|
|
|
self._last_tool = ToolRecord(name=tool_name, input_str=input_str)
|
|
|
self._container.update(
|
|
|
new_label=(
|
|
|
self._labeler.get_tool_label(self._last_tool, is_complete=False)
|
|
|
.replace("[", "\[")
|
|
|
.replace("]", "\]")
|
|
|
)
|
|
|
)
|
|
|
|
|
|
|
|
|
if serialized["name"] == "ReactionRetrosynthesis" or serialized["name"] == "LiteratureSearch":
|
|
|
self._container.markdown(
|
|
|
f"‼️ Note: This tool can take some time to complete execution ‼️",
|
|
|
unsafe_allow_html=True,
|
|
|
)
|
|
|
|
|
|
def complete(self, final_label: Optional[str] = None) -> None:
|
|
|
"""Finish the thought."""
|
|
|
if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL:
|
|
|
assert (
|
|
|
self._last_tool is not None
|
|
|
), "_last_tool should never be null when _state == RUNNING_TOOL"
|
|
|
final_label = self._labeler.get_tool_label(
|
|
|
self._last_tool, is_complete=True
|
|
|
)
|
|
|
self._state = LLMThoughtState.COMPLETE
|
|
|
|
|
|
final_label = final_label.replace("[", "\[").replace("]", "\]")
|
|
|
if self._collapse_on_complete:
|
|
|
self._container.update(new_label=final_label, new_expanded=False)
|
|
|
else:
|
|
|
self._container.update(new_label=final_label)
|
|
|
|
|
|
|
|
|
class StreamlitCallbackHandlerChem(StreamlitCallbackHandler):
|
|
|
def __init__(
|
|
|
self,
|
|
|
parent_container: DeltaGenerator,
|
|
|
*,
|
|
|
max_thought_containers: int = 4,
|
|
|
expand_new_thoughts: bool = True,
|
|
|
collapse_completed_thoughts: bool = True,
|
|
|
thought_labeler: Optional[LLMThoughtLabeler] = None,
|
|
|
output_placeholder: dict = {},
|
|
|
):
|
|
|
super(StreamlitCallbackHandlerChem, self).__init__(
|
|
|
parent_container,
|
|
|
max_thought_containers=max_thought_containers,
|
|
|
expand_new_thoughts=expand_new_thoughts,
|
|
|
collapse_completed_thoughts=collapse_completed_thoughts,
|
|
|
thought_labeler=thought_labeler,
|
|
|
)
|
|
|
|
|
|
self._output_placeholder = output_placeholder
|
|
|
self.last_input = ""
|
|
|
|
|
|
def on_llm_start(
|
|
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
|
) -> None:
|
|
|
if self._current_thought is None:
|
|
|
self._current_thought = LLMThoughtChem(
|
|
|
parent_container=self._parent_container,
|
|
|
expanded=self._expand_new_thoughts,
|
|
|
collapse_on_complete=self._collapse_completed_thoughts,
|
|
|
labeler=self._thought_labeler,
|
|
|
)
|
|
|
|
|
|
self._current_thought.on_llm_start(serialized, prompts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_tool_start(
|
|
|
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
|
|
) -> None:
|
|
|
self._require_current_thought().on_tool_start(serialized, input_str, **kwargs)
|
|
|
self._prune_old_thought_containers()
|
|
|
self._last_input = input_str
|
|
|
self._serialized = serialized
|
|
|
|
|
|
def on_tool_end(
|
|
|
self,
|
|
|
output: str,
|
|
|
color: Optional[str] = None,
|
|
|
observation_prefix: Optional[str] = None,
|
|
|
llm_prefix: Optional[str] = None,
|
|
|
**kwargs: Any,
|
|
|
) -> None:
|
|
|
self._require_current_thought().on_tool_end(
|
|
|
output,
|
|
|
color,
|
|
|
observation_prefix,
|
|
|
llm_prefix,
|
|
|
output_ph=self._output_placeholder,
|
|
|
input_tool=self._last_input,
|
|
|
serialized=self._serialized,
|
|
|
**kwargs,
|
|
|
)
|
|
|
self._complete_current_thought()
|
|
|
|
|
|
def on_agent_finish(
|
|
|
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
|
|
) -> None:
|
|
|
if self._current_thought is not None:
|
|
|
self._current_thought.complete(
|
|
|
self._thought_labeler.get_final_agent_thought_label()
|
|
|
.replace("[", "\[")
|
|
|
.replace("]", "\]")
|
|
|
)
|
|
|
self._current_thought = None |