Spaces:
Running
Running
| import os | |
| # Init with fake key | |
| if 'OPENAI_API_KEY' not in os.environ: | |
| os.environ['OPENAI_API_KEY'] = 'none' | |
| import openai | |
| import pandas as pd | |
| import streamlit as st | |
| from IPython.core.display import HTML | |
| from PIL import Image | |
| from langchain.callbacks import wandb_tracing_enabled | |
| from chemcrow.agents import ChemCrow, make_tools | |
| from chemcrow.frontend.streamlit_callback_handler import \ | |
| StreamlitCallbackHandlerChem | |
| from utils import oai_key_isvalid | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| ss = st.session_state | |
| ss.prompt = None | |
| icon = Image.open('assets/logo0.png') | |
| st.set_page_config( | |
| page_title="ChemCrow", | |
| page_icon = icon | |
| ) | |
| # Set width of sidebar | |
| st.markdown( | |
| """ | |
| <style> | |
| [data-testid="stSidebar"][aria-expanded="true"]{ | |
| min-width: 450px; | |
| max-width: 450px; | |
| } | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| def instantiate_agent(model): | |
| ss.agent = ChemCrow( | |
| model=model, | |
| tools_model=model, | |
| temp=0.1, | |
| openai_api_key=ss.get('api_key'), | |
| api_keys={ | |
| 'RXN4CHEM_API_KEY': st.secrets['RXN4CHEM_API_KEY'], | |
| 'CHEMSPACE_API_KEY': st.secrets['CHEMSPACE_API_KEY'] | |
| } | |
| ).agent_executor | |
| return ss.agent | |
| instantiate_agent('gpt-4-0613') | |
| tools = ss.agent.tools | |
| tool_list = pd.Series( | |
| {f"✅ {t.name}":t.description for t in tools} | |
| ).reset_index() | |
| tool_list.columns = ['Tool', 'Description'] | |
| def on_api_key_change(): | |
| api_key = ss.get('api_key') or os.getenv('OPENAI_API_KEY') | |
| # Check if key is valid | |
| if not oai_key_isvalid(api_key): | |
| st.write("Please input a valid OpenAI API key.") | |
| def run_prompt(prompt): | |
| agent = instantiate_agent(ss.get('model_select')) | |
| st.chat_message("user").write(prompt) | |
| with st.chat_message("assistant"): | |
| st_callback = StreamlitCallbackHandlerChem( | |
| st.container(), | |
| max_thought_containers = 3, | |
| collapse_completed_thoughts = False, | |
| output_placeholder=ss | |
| ) | |
| try: | |
| with wandb_tracing_enabled(): | |
| response = agent.run(prompt, callbacks=[st_callback]) | |
| st.write(response) | |
| except openai.error.AuthenticationError: | |
| st.write("Please input a valid OpenAI API key") | |
| except openai.error.APIError: | |
| # Handle specific API errors here | |
| print("OpenAI API error, please try again!") | |
| pre_prompts = [ | |
| 'How can I synthesize safinamide?', | |
| ( | |
| 'Predict the product of a mixture of Ethylidenecyclohexane and HBr. ' | |
| 'Then predict the same reaction, adding methyl peroxide into the ' | |
| 'mixture. Compare the two products and explain the reaction mechanism.' | |
| ), | |
| ( | |
| 'What is the boiling point of the reaction product between ' | |
| 'isoamyl alcohol and acetic acid?' | |
| ), | |
| 'Tell me how to synthesize vanilline, and the price of the precursors.' | |
| ] | |
| # sidebar | |
| with st.sidebar: | |
| chemcrow_logo = Image.open('assets/chemcrow-logo-bold-new.png') | |
| st.image(chemcrow_logo) | |
| # Input OpenAI api key | |
| st.text_input( | |
| 'Input your OpenAI API key.', | |
| placeholder = 'Input your OpenAI API key.', | |
| type='password', | |
| key='api_key', | |
| on_change=on_api_key_change, | |
| label_visibility="collapsed" | |
| ) | |
| # Input model to use | |
| st.selectbox( | |
| 'Select model to use', | |
| ['gpt-4-0613', 'gpt-3.5-turbo', 'gpt-4o-mini'], | |
| key='model_select', | |
| ) | |
| # Display prompt examples | |
| st.markdown('# What can I ask?') | |
| cols = st.columns(2) | |
| with cols[0]: | |
| st.button( | |
| "How can I synthesize safinamide?", | |
| on_click=lambda: run_prompt(pre_prompts[0]), | |
| ) | |
| st.button( | |
| "Explain mechanism of bromoaddition reaction", | |
| on_click=lambda: run_prompt(pre_prompts[1]), | |
| ) | |
| with cols[1]: | |
| st.button( | |
| 'Predict properties of a reaction product', | |
| on_click=lambda: run_prompt(pre_prompts[2]), | |
| ) | |
| st.button( | |
| 'Synthesize molecule with price of precursors', | |
| on_click=lambda: run_prompt(pre_prompts[3]), | |
| ) | |
| st.markdown('---') | |
| # Display available tools | |
| st.markdown(f"# {len(tool_list)} available tools") | |
| st.dataframe( | |
| tool_list, | |
| use_container_width=True, | |
| hide_index=True, | |
| height=200 | |
| ) | |
| # Execute agent on user input | |
| if user_input := st.chat_input(): | |
| run_prompt(user_input) | |