TeLLAgent / app.py
jinysun's picture
Update app.py
3b0f61d verified
import os
import asyncio
# Init with fake key
if 'OPENAI_API_KEY' not in os.environ:
os.environ['OPENAI_API_KEY'] = 'none'
os.environ["OPENAI_API_BASE"] = 'none'
os.environ["SERP_API_KEY"] = 'none'
os.environ["SEMANTIC_SCHOLAR_API_KEY"] = 'none'
if os.name == 'nt':
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
import openai
import pandas as pd
import streamlit as st
from PIL import Image
from agent import TeLLAgent, make_tools
from streamlit_callback_handler import \
StreamlitCallbackHandlerChem
import base64
import pandas as pd
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI , OpenAI
import base64
from io import BytesIO
from PIL import Image
import tempfile
def convert_to_base64(pil_image):
buffered = BytesIO()
pil_image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
def oai_key_isvalid(api_key):
"""Check if a given OpenAI key is valid"""
try:
if os.getenv("OPENAI_API_BASE"):
llm = ChatOpenAI(openai_api_key = api_key, base_url=os.getenv("OPENAI_API_BASE"))
out = llm.invoke("This is a test")
else:
llm = ChatOpenAI(openai_api_key = api_key)
out = llm.invoke("This is a test")
return True
except:
return False
load_dotenv()
ss = st.session_state
ss.prompt = None
if 'pending_prompt' not in st.session_state:
st.session_state.pending_prompt = None
# Set width of sidebar
st.markdown(
"""
<style>
[data-testid="stSidebar"][aria-expanded="true"]{
min-width: 500px;
max-width: 500px;
}
""",
unsafe_allow_html=True,
)
def instantiate_agent(model1, model2, file_path = '...', image_path ='...', tools=None):
ss.agent = TeLLAgent( tools=tools,
model1 = model1,
model2 = model2,
tools_model='gpt-4o-2024-11-20',
temp=0.1,
openai_api_key=ss.get('api_key') , file_path = file_path,
image_path =image_path)
return ss.agent
def on_api_key_change():
api_key = ss.get('api_key') or os.getenv('OPENAI_API_KEY')
# Check if key is valid
if not oai_key_isvalid(api_key):
st.write("Please input a valid OpenAI API key.")
def run_prompt(prompt, file_path = '...', image_path = '...'):
if ss.get('domain') =='Drug discovery':
agent = instantiate_agent(model1 = ss.get('model1_select'), model2 = ss.get('model2_select'), file_path = file_path, image_path =image_path, tools = 'drug')
else:
agent = instantiate_agent(model1 = ss.get('model1_select'), model2 = ss.get('model2_select'), file_path = file_path, image_path =image_path)
st.chat_message("user").write(prompt)
with st.chat_message("assistant") :
try:
response = agent.run(prompt)
if ss.get('file_type') == 'CSV (.csv)':
try:
fx = pd.DataFrame(list(response))
st.markdown(":red[Prediction finished! ]")
st.download_button( "โฌ‡๏ธDownload the predicted files as .csv", fx.to_csv(), "predict results.csv", use_container_width=True)
except:
st.write(response)
else:
st.write(response)
except openai.AuthenticationError:
st.write("Please input a valid OpenAI API key")
except openai.APIError:
# Handle specific API errors here
print("OpenAI API error, please try again!")
pre_prompts = [
'Generate a donor with PCE = 10% ',
('The history and development of Y6'
),
(
'Predict the LogP of PM6'
),
'Predict the PCE of Y6'
]
# sidebar
with st.sidebar:
st.header("๐Ÿค– :blue[TeLLAgent] ")
# Input OpenAI api key
st.text_input(
'Input your OpenAI API key.',
placeholder = 'Input your OpenAI API key.',
type='password',
key='api_key',
on_change=on_api_key_change,
label_visibility="collapsed"
)
st.text_input(
'Input base url (optional).',
placeholder = 'Input base url (optional)',
key='base_url',type='password',
label_visibility="collapsed"
)
# Input model to use
st.text_input(
'Input global planning model to use',
key='model1_select',
)
st.text_input(
'Input local execution model to use',
key='model2_select',
)
st.text_input(
'Input SERP API KEY (optional).',
placeholder = 'Input SERP API KEY (optional)',
key='serp_api',type='password',
label_visibility="collapsed"
)
st.text_input(
'Input SEMANTIC SCHOLAR API KEY (optional).',
placeholder = 'Input SEMANTIC SCHOLAR API KEY (optional)',
key='semantic_scholar_url',type='password',
label_visibility="collapsed"
)
os.environ['OPENAI_API_KEY'] = ss.get('api_key')
os.environ["OPENAI_API_BASE"] = ss.get('base_url')
os.environ["SERP_API_KEY"] = ss.get('serp_api')
os.environ["SEMANTIC_SCHOLAR_API_KEY"] = ss.get('semantic_scholar_url')
# Display prompt examples
st.markdown('# What can I ask?')
cols = st.columns(2)
with cols[0]:
if st.button(r'๐Ÿ‘‘ Generate a donor with PCE = 10% ๐Ÿงจ '):
st.session_state.pending_prompt = pre_prompts[0]
if st.button(r'๐Ÿ“š The history and development of Y6 '):
st.session_state.pending_prompt = pre_prompts[1]
with cols[1]:
if st.button(r"๐ŸŽ„Predict the LogP of PM6 "):
st.session_state.pending_prompt = pre_prompts[2]
if st.button(r'๐Ÿ’Ž Predict the PCE of Y6'):
st.session_state.pending_prompt = pre_prompts[3]
st.selectbox(
'Select the file type ',
['None', 'CSV (.csv)', 'Figure (.jpg, .png, .jpeg)', 'PDF (.pdf)'],
key='file_type',
)
uploaded_file = None
if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
uploaded_file = st.file_uploader("Choose a Figure", type = ["jpg", "jpeg", "png"])
if ss.get('file_type') == 'PDF (.pdf)':
uploaded_file = st.file_uploader("Choose a PDF file")
if ss.get('file_type') == 'CSV (.csv)':
uploaded_file = st.file_uploader("Choose a csv file", type = 'csv')
st.selectbox(
r'๐Ÿ“š Choose the domain ',
['Organic solar cell', 'Drug discovery'], key='domain',
)
# Display available tools
if ss.get('domain') == 'Drug discovery':
instantiate_agent(model1 = 'gpt-4o-2024-11-20', model2 = 'gpt-4o-2024-11-20' ,tools = 'drug')
else:
instantiate_agent(model1 = 'gpt-4o-2024-11-20', model2 = 'gpt-4o-2024-11-20' )
tools = ss.agent.agent_executor2.tools
tool_list = pd.Series( {f"โœ… {t.name}": t.description for t in tools}).reset_index()
tool_list.columns = ['Tool', 'Description']
st.markdown(f"# {len(tool_list)} available tools")
st.dataframe(
tool_list,
width='stretch',
hide_index=True,
height=200
)
if st.session_state.pending_prompt is not None:
prompt_to_run = st.session_state.pending_prompt
st.session_state.pending_prompt = None
if not ss.get('model1_select') or not ss.get('model2_select'):
st.error("โš ๏ธ Please input both model names in the sidebar first!")
else:
run_prompt(prompt_to_run)
# Execute agent on user input
if prompt := st.chat_input("Say something and/or attach files"):
if not ss.get('model1_select') or not ss.get('model2_select'):
st.error("โš ๏ธ Please input both model names in the sidebar first!")
elif uploaded_file is not None:
if ss.get('file_type') == 'CSV (.csv)':
with tempfile.NamedTemporaryFile( suffix ='.csv' ,delete=False) as f:
f.write(uploaded_file.read())
run_prompt(prompt + str(' ') + str(f.name), file_path = f.name)
f.close()
if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
st.image(uploaded_file, width = 500)
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp:
mg_str = base64.b64encode(uploaded_file.getvalue()).decode("utf-8")
temp.write(base64.b64decode(mg_str))
run_prompt(prompt+ str(' ') + str(temp.name), image_path = temp.name )
if ss.get('file_type') == 'PDF (.pdf)':
with tempfile.NamedTemporaryFile( suffix ='.pdf' ,delete=False) as f:
f.write(uploaded_file.read())
run_prompt(prompt, file_path = f.name)
f.close()
# with open("input.png","wb") as af:
# mg_str = base64.b64encode(files.getvalue()).decode("utf-8")
# af.write(base64.b64decode(mg_str))
# run_prompt(prompt.text+str(f.name), image_path =f.name )
# except:
# st.markdown("Please input correct files or query ")
else:
run_prompt(prompt)