|
|
|
|
|
import os |
|
|
import asyncio |
|
|
|
|
|
if 'OPENAI_API_KEY' not in os.environ: |
|
|
os.environ['OPENAI_API_KEY'] = 'none' |
|
|
os.environ["OPENAI_API_BASE"] = 'none' |
|
|
|
|
|
os.environ["SERP_API_KEY"] = 'none' |
|
|
os.environ["SEMANTIC_SCHOLAR_API_KEY"] = 'none' |
|
|
if os.name == 'nt': |
|
|
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) |
|
|
|
|
|
import openai |
|
|
import pandas as pd |
|
|
import streamlit as st |
|
|
|
|
|
from PIL import Image |
|
|
from agent import TeLLAgent, make_tools |
|
|
from streamlit_callback_handler import \ |
|
|
StreamlitCallbackHandlerChem |
|
|
import base64 |
|
|
import pandas as pd |
|
|
from dotenv import load_dotenv |
|
|
from langchain_openai import ChatOpenAI , OpenAI |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import tempfile |
|
|
|
|
|
|
|
|
def convert_to_base64(pil_image): |
|
|
buffered = BytesIO() |
|
|
pil_image.save(buffered, format="PNG") |
|
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
return img_str |
|
|
|
|
|
def oai_key_isvalid(api_key): |
|
|
"""Check if a given OpenAI key is valid""" |
|
|
try: |
|
|
if os.getenv("OPENAI_API_BASE"): |
|
|
llm = ChatOpenAI(openai_api_key = api_key, base_url=os.getenv("OPENAI_API_BASE")) |
|
|
out = llm.invoke("This is a test") |
|
|
else: |
|
|
llm = ChatOpenAI(openai_api_key = api_key) |
|
|
out = llm.invoke("This is a test") |
|
|
return True |
|
|
except: |
|
|
return False |
|
|
|
|
|
load_dotenv() |
|
|
ss = st.session_state |
|
|
ss.prompt = None |
|
|
if 'pending_prompt' not in st.session_state: |
|
|
st.session_state.pending_prompt = None |
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<style> |
|
|
[data-testid="stSidebar"][aria-expanded="true"]{ |
|
|
min-width: 500px; |
|
|
max-width: 500px; |
|
|
} |
|
|
""", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
|
|
|
def instantiate_agent(model1, model2, file_path = '...', image_path ='...', tools=None): |
|
|
ss.agent = TeLLAgent( tools=tools, |
|
|
model1 = model1, |
|
|
model2 = model2, |
|
|
tools_model='gpt-4o-2024-11-20', |
|
|
temp=0.1, |
|
|
openai_api_key=ss.get('api_key') , file_path = file_path, |
|
|
image_path =image_path) |
|
|
return ss.agent |
|
|
|
|
|
|
|
|
|
|
|
def on_api_key_change(): |
|
|
api_key = ss.get('api_key') or os.getenv('OPENAI_API_KEY') |
|
|
|
|
|
|
|
|
if not oai_key_isvalid(api_key): |
|
|
st.write("Please input a valid OpenAI API key.") |
|
|
|
|
|
def run_prompt(prompt, file_path = '...', image_path = '...'): |
|
|
if ss.get('domain') =='Drug discovery': |
|
|
agent = instantiate_agent(model1 = ss.get('model1_select'), model2 = ss.get('model2_select'), file_path = file_path, image_path =image_path, tools = 'drug') |
|
|
else: |
|
|
agent = instantiate_agent(model1 = ss.get('model1_select'), model2 = ss.get('model2_select'), file_path = file_path, image_path =image_path) |
|
|
st.chat_message("user").write(prompt) |
|
|
with st.chat_message("assistant") : |
|
|
try: |
|
|
|
|
|
response = agent.run(prompt) |
|
|
if ss.get('file_type') == 'CSV (.csv)': |
|
|
try: |
|
|
fx = pd.DataFrame(list(response)) |
|
|
st.markdown(":red[Prediction finished! ]") |
|
|
st.download_button( "โฌ๏ธDownload the predicted files as .csv", fx.to_csv(), "predict results.csv", use_container_width=True) |
|
|
except: |
|
|
st.write(response) |
|
|
else: |
|
|
st.write(response) |
|
|
except openai.AuthenticationError: |
|
|
st.write("Please input a valid OpenAI API key") |
|
|
except openai.APIError: |
|
|
|
|
|
print("OpenAI API error, please try again!") |
|
|
|
|
|
pre_prompts = [ |
|
|
'Generate a donor with PCE = 10% ', |
|
|
('The history and development of Y6' |
|
|
|
|
|
), |
|
|
( |
|
|
'Predict the LogP of PM6' |
|
|
), |
|
|
'Predict the PCE of Y6' |
|
|
] |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
|
|
|
st.header("๐ค :blue[TeLLAgent] ") |
|
|
|
|
|
st.text_input( |
|
|
'Input your OpenAI API key.', |
|
|
placeholder = 'Input your OpenAI API key.', |
|
|
type='password', |
|
|
key='api_key', |
|
|
on_change=on_api_key_change, |
|
|
label_visibility="collapsed" |
|
|
) |
|
|
st.text_input( |
|
|
'Input base url (optional).', |
|
|
placeholder = 'Input base url (optional)', |
|
|
key='base_url',type='password', |
|
|
label_visibility="collapsed" |
|
|
) |
|
|
|
|
|
st.text_input( |
|
|
'Input global planning model to use', |
|
|
|
|
|
key='model1_select', |
|
|
) |
|
|
st.text_input( |
|
|
'Input local execution model to use', |
|
|
|
|
|
key='model2_select', |
|
|
) |
|
|
st.text_input( |
|
|
'Input SERP API KEY (optional).', |
|
|
placeholder = 'Input SERP API KEY (optional)', |
|
|
key='serp_api',type='password', |
|
|
label_visibility="collapsed" |
|
|
) |
|
|
st.text_input( |
|
|
'Input SEMANTIC SCHOLAR API KEY (optional).', |
|
|
placeholder = 'Input SEMANTIC SCHOLAR API KEY (optional)', |
|
|
key='semantic_scholar_url',type='password', |
|
|
label_visibility="collapsed" |
|
|
) |
|
|
os.environ['OPENAI_API_KEY'] = ss.get('api_key') |
|
|
os.environ["OPENAI_API_BASE"] = ss.get('base_url') |
|
|
|
|
|
os.environ["SERP_API_KEY"] = ss.get('serp_api') |
|
|
os.environ["SEMANTIC_SCHOLAR_API_KEY"] = ss.get('semantic_scholar_url') |
|
|
|
|
|
|
|
|
st.markdown('# What can I ask?') |
|
|
cols = st.columns(2) |
|
|
with cols[0]: |
|
|
if st.button(r'๐ Generate a donor with PCE = 10% ๐งจ '): |
|
|
st.session_state.pending_prompt = pre_prompts[0] |
|
|
|
|
|
if st.button(r'๐ The history and development of Y6 '): |
|
|
st.session_state.pending_prompt = pre_prompts[1] |
|
|
|
|
|
with cols[1]: |
|
|
if st.button(r"๐Predict the LogP of PM6 "): |
|
|
st.session_state.pending_prompt = pre_prompts[2] |
|
|
|
|
|
if st.button(r'๐ Predict the PCE of Y6'): |
|
|
st.session_state.pending_prompt = pre_prompts[3] |
|
|
|
|
|
st.selectbox( |
|
|
'Select the file type ', |
|
|
['None', 'CSV (.csv)', 'Figure (.jpg, .png, .jpeg)', 'PDF (.pdf)'], |
|
|
key='file_type', |
|
|
) |
|
|
uploaded_file = None |
|
|
if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)': |
|
|
uploaded_file = st.file_uploader("Choose a Figure", type = ["jpg", "jpeg", "png"]) |
|
|
if ss.get('file_type') == 'PDF (.pdf)': |
|
|
uploaded_file = st.file_uploader("Choose a PDF file") |
|
|
if ss.get('file_type') == 'CSV (.csv)': |
|
|
uploaded_file = st.file_uploader("Choose a csv file", type = 'csv') |
|
|
st.selectbox( |
|
|
r'๐ Choose the domain ', |
|
|
['Organic solar cell', 'Drug discovery'], key='domain', |
|
|
) |
|
|
|
|
|
if ss.get('domain') == 'Drug discovery': |
|
|
instantiate_agent(model1 = 'gpt-4o-2024-11-20', model2 = 'gpt-4o-2024-11-20' ,tools = 'drug') |
|
|
else: |
|
|
instantiate_agent(model1 = 'gpt-4o-2024-11-20', model2 = 'gpt-4o-2024-11-20' ) |
|
|
tools = ss.agent.agent_executor2.tools |
|
|
|
|
|
tool_list = pd.Series( {f"โ
{t.name}": t.description for t in tools}).reset_index() |
|
|
tool_list.columns = ['Tool', 'Description'] |
|
|
st.markdown(f"# {len(tool_list)} available tools") |
|
|
st.dataframe( |
|
|
tool_list, |
|
|
width='stretch', |
|
|
hide_index=True, |
|
|
height=200 |
|
|
) |
|
|
|
|
|
if st.session_state.pending_prompt is not None: |
|
|
prompt_to_run = st.session_state.pending_prompt |
|
|
st.session_state.pending_prompt = None |
|
|
|
|
|
if not ss.get('model1_select') or not ss.get('model2_select'): |
|
|
st.error("โ ๏ธ Please input both model names in the sidebar first!") |
|
|
else: |
|
|
run_prompt(prompt_to_run) |
|
|
|
|
|
|
|
|
if prompt := st.chat_input("Say something and/or attach files"): |
|
|
|
|
|
if not ss.get('model1_select') or not ss.get('model2_select'): |
|
|
st.error("โ ๏ธ Please input both model names in the sidebar first!") |
|
|
elif uploaded_file is not None: |
|
|
|
|
|
if ss.get('file_type') == 'CSV (.csv)': |
|
|
with tempfile.NamedTemporaryFile( suffix ='.csv' ,delete=False) as f: |
|
|
f.write(uploaded_file.read()) |
|
|
run_prompt(prompt + str(' ') + str(f.name), file_path = f.name) |
|
|
f.close() |
|
|
|
|
|
if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)': |
|
|
|
|
|
st.image(uploaded_file, width = 500) |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp: |
|
|
|
|
|
mg_str = base64.b64encode(uploaded_file.getvalue()).decode("utf-8") |
|
|
temp.write(base64.b64decode(mg_str)) |
|
|
|
|
|
run_prompt(prompt+ str(' ') + str(temp.name), image_path = temp.name ) |
|
|
|
|
|
if ss.get('file_type') == 'PDF (.pdf)': |
|
|
with tempfile.NamedTemporaryFile( suffix ='.pdf' ,delete=False) as f: |
|
|
f.write(uploaded_file.read()) |
|
|
run_prompt(prompt, file_path = f.name) |
|
|
f.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
run_prompt(prompt) |
|
|
|
|
|
|