jinysun commited on
Commit
e6c172f
·
verified ·
1 Parent(s): faa20d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -139
app.py CHANGED
@@ -3,30 +3,25 @@ import asyncio
3
  # Init with fake key
4
  if 'OPENAI_API_KEY' not in os.environ:
5
  os.environ['OPENAI_API_KEY'] = 'none'
 
 
6
 
7
-
8
- os.environ["SERP_API_KEY"] = 'none'
9
- os.environ["SEMANTIC_SCHOLAR_API_KEY"] = 'none'
10
  if os.name == 'nt':
11
  asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
12
 
13
  import openai
14
  import pandas as pd
15
  import streamlit as st
16
-
17
  from PIL import Image
18
- from agent import TeLLAgent, make_tools
19
- from streamlit_callback_handler import \
20
- StreamlitCallbackHandlerChem
21
  import base64
22
- import pandas as pd
23
  from dotenv import load_dotenv
24
- from langchain_openai import ChatOpenAI , OpenAI
25
- import base64
26
  from io import BytesIO
27
- from PIL import Image
28
  import tempfile
29
-
 
30
 
31
  def convert_to_base64(pil_image):
32
  buffered = BytesIO()
@@ -38,29 +33,39 @@ def oai_key_isvalid(api_key):
38
  """Check if a given OpenAI key is valid"""
39
  try:
40
  if os.getenv("OPENAI_API_BASE"):
41
- llm = ChatOpenAI(openai_api_key = api_key, base_url=os.getenv("OPENAI_API_BASE"))
42
  out = llm.invoke("This is a test")
43
  else:
44
- llm = ChatOpenAI(openai_api_key = api_key)
45
  out = llm.invoke("This is a test")
46
  return True
47
  except:
48
  return False
49
-
50
 
51
  def initialize_session_state():
52
- """Initialize session state variables"""
53
  if 'prompt' not in st.session_state:
54
  st.session_state.prompt = None
55
- # Add other session state initializations here as needed
56
-
57
- # Call initialization function
58
- initialize_session_state()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Now you can safely use session state
61
- ss = st.session_state
62
-
63
- # Set width of sidebar
64
  st.markdown(
65
  """
66
  <style>
@@ -68,150 +73,185 @@ st.markdown(
68
  min-width: 500px;
69
  max-width: 500px;
70
  }
 
71
  """,
72
  unsafe_allow_html=True,
73
  )
74
 
 
 
75
 
76
- def instantiate_agent(model1, model2, file_path = '...', image_path ='...', tools=None):
77
-
 
 
 
78
  if not model1:
79
  model1 = "deepseek-ai/DeepSeek-R1"
80
 
81
-
82
  if not model2:
83
- model2 = "deepseek-ai/DeepSeek-v3.1"
84
 
85
-
86
- ss.agent = TeLLAgent( tools=tools,
87
- model1 = model1,
88
- model2 = model2,
89
- tools_model='gpt-4o-2024-11-20',
90
- temp=0.1,
91
- openai_api_key=ss.get('api_key'),
92
- file_path = file_path,
93
- image_path =image_path)
 
 
 
 
 
 
94
  return ss.agent
95
 
96
-
97
  def on_api_key_change():
98
  api_key = ss.get('api_key') or os.getenv('OPENAI_API_KEY')
99
-
100
  # Check if key is valid
101
  if not oai_key_isvalid(api_key):
102
  st.write("Please input a valid OpenAI API key.")
103
 
104
- def run_prompt(prompt, file_path = '...', image_path = '...'):
105
- if ss.get('domain') =='Drug discovery':
106
- agent = instantiate_agent(model1 = ss.get('model1_select'), model2 = ss.get('model2_select'), file_path = file_path, image_path =image_path, tools = 'drug')
107
- else:
108
- agent = instantiate_agent(model1 = ss.get('model1_select'), model2 = ss.get('model2_select'), file_path = file_path, image_path =image_path)
109
- st.chat_message("user").write(prompt)
110
- with st.chat_message("assistant") :
111
- try:
112
-
113
- response = agent.run(prompt)
114
- if ss.get('file_type') == 'CSV (.csv)':
115
- try:
116
- fx = pd.DataFrame(list(response))
117
- st.markdown(":red[Prediction finished! ]")
118
- st.download_button( "⬇️Download the predicted files as .csv", fx.to_csv(), "predict results.csv", use_container_width=True)
119
- except:
120
- st.write(response)
121
- else:
122
- st.write(response)
123
- except openai.AuthenticationError:
124
- st.write("Please input a valid OpenAI API key")
125
- except openai.APIError:
126
- # Handle specific API errors here
127
- print("OpenAI API error, please try again!")
128
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  pre_prompts = [
131
  'Generate a donor with PCE = 10% ',
132
- ('The history and development of Y6'
133
-
134
- ),
135
- (
136
- 'Predict the LogP of PM6'
137
- ),
138
  'Predict the PCE of Y6'
139
  ]
140
 
141
  # sidebar
142
  with st.sidebar:
143
-
144
  st.header("🤖 :blue[TeLLAgent] ")
 
145
  # Input OpenAI api key
146
  st.text_input(
147
  'Input your OpenAI API key.',
148
- placeholder = 'Input your OpenAI API key.',
149
  type='password',
150
  key='api_key',
151
  on_change=on_api_key_change,
152
  label_visibility="collapsed"
153
  )
 
154
  st.text_input(
155
- 'Input base url (optional).',
156
- placeholder = 'Input base url (optional)',
157
- key='base_url',type='password',
 
158
  label_visibility="collapsed"
159
  )
 
160
  # Input model to use
161
  st.text_input(
162
  'Input global planning model to use',
163
-
164
  key='model1_select',
165
  )
 
166
  st.text_input(
167
  'Input local execution model to use',
168
-
169
  key='model2_select',
170
  )
 
171
  st.text_input(
172
- 'Input SERP API KEY (optional).',
173
- placeholder = 'Input SERP API KEY (optional)',
174
- key='serp_api',type='password',
 
175
  label_visibility="collapsed"
176
  )
 
177
  st.text_input(
178
- 'Input SEMANTIC SCHOLAR API KEY (optional).',
179
- placeholder = 'Input SEMANTIC SCHOLAR API KEY (optional)',
180
- key='semantic_scholar_url',type='password',
 
181
  label_visibility="collapsed"
182
  )
 
183
  user_api_key = ss.get('api_key')
184
  user_base_url = ss.get('base_url')
185
 
186
-
187
  if user_api_key:
188
  os.environ['OPENAI_API_KEY'] = user_api_key
189
 
190
  if user_base_url and user_base_url.strip() != "" and user_base_url != "none":
191
  os.environ["OPENAI_API_BASE"] = user_base_url
192
  else:
193
-
194
  if "OPENAI_API_BASE" in os.environ:
195
  del os.environ["OPENAI_API_BASE"]
196
 
197
- os.environ["SERP_API_KEY"] = ss.get('serp_api')
198
- os.environ["SEMANTIC_SCHOLAR_API_KEY"] = ss.get('semantic_scholar_url')
199
 
200
  # Display prompt examples
201
  st.markdown('# What can I ask?')
202
  cols = st.columns(2)
203
  with cols[0]:
204
  st.button(
205
- r'👑 Generate a donor with PCE = 10% 🧨 ',
206
  on_click=lambda: run_prompt(pre_prompts[0]),
207
  )
208
  st.button(
209
- r'📚 The history and development of Y6 ',
210
  on_click=lambda: run_prompt(pre_prompts[1]),
211
  )
212
  with cols[1]:
213
  st.button(
214
- r"🎄Predict the LogP of PM6 ",
215
  on_click=lambda: run_prompt(pre_prompts[2]),
216
  )
217
  st.button(
@@ -220,70 +260,79 @@ with st.sidebar:
220
  )
221
 
222
  st.selectbox(
223
- 'Select the file type ',
224
- ['None', 'CSV (.csv)', 'Figure (.jpg, .png, .jpeg)', 'PDF (.pdf)'],
225
- key='file_type',
226
- )
 
227
  uploaded_file = None
228
  if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
229
- uploaded_file = st.file_uploader("Choose a Figure", type = ["jpg", "jpeg", "png"])
230
  if ss.get('file_type') == 'PDF (.pdf)':
231
  uploaded_file = st.file_uploader("Choose a PDF file")
232
  if ss.get('file_type') == 'CSV (.csv)':
233
- uploaded_file = st.file_uploader("Choose a csv file", type = 'csv')
 
234
  st.selectbox(
235
- r'📚 Choose the domain ',
236
- ['Organic solar cell', 'Drug discovery'], key='domain',
237
- )
238
- # Display available tools
239
- if ss.get('domain') == 'Drug discovery':
240
- instantiate_agent(model1 = 'gpt-4o-2024-11-20', model2 = 'gpt-4o-2024-11-20' ,tools = 'drug')
241
- else:
242
- instantiate_agent(model1 = 'gpt-4o-2024-11-20', model2 = 'gpt-4o-2024-11-20' )
243
- tools = ss.agent.agent_executor2.tools
244
-
245
- tool_list = pd.Series( {f"✅ {t.name}": t.description for t in tools}).reset_index()
246
- tool_list.columns = ['Tool', 'Description']
247
- st.markdown(f"# {len(tool_list)} available tools")
248
- st.dataframe(
249
- tool_list,
250
- use_container_width=True,
251
- hide_index=True,
252
- height=200
253
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  # Execute agent on user input
256
  if prompt := st.chat_input("Say something and/or attach files"):
257
-
258
  if uploaded_file is not None:
259
  if ss.get('file_type') == 'CSV (.csv)':
260
- with tempfile.NamedTemporaryFile( suffix ='.csv' ,delete=False) as f:
261
- f.write(uploaded_file.read())
262
- run_prompt(prompt + str(' ') + str(f.name), file_path = f.name)
263
- f.close()
264
 
265
  if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
266
-
267
- st.image(uploaded_file, width = 500)
268
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp:
269
-
270
- mg_str = base64.b64encode(uploaded_file.getvalue()).decode("utf-8")
271
- temp.write(base64.b64decode(mg_str))
272
-
273
- run_prompt(prompt+ str(' ') + str(temp.name), image_path = temp.name )
274
 
275
  if ss.get('file_type') == 'PDF (.pdf)':
276
- with tempfile.NamedTemporaryFile( suffix ='.pdf' ,delete=False) as f:
277
- f.write(uploaded_file.read())
278
- run_prompt(prompt, file_path = f.name)
279
- f.close()
280
-
281
- # with open("input.png","wb") as af:
282
- # mg_str = base64.b64encode(files.getvalue()).decode("utf-8")
283
- # af.write(base64.b64decode(mg_str))
284
-
285
- # run_prompt(prompt.text+str(f.name), image_path =f.name )
286
- # except:
287
- # st.markdown("Please input correct files or query ")
288
  else:
289
- run_prompt(prompt)
 
3
  # Init with fake key
4
  if 'OPENAI_API_KEY' not in os.environ:
5
  os.environ['OPENAI_API_KEY'] = 'none'
6
+ os.environ["SERP_API_KEY"] = 'none'
7
+ os.environ["SEMANTIC_SCHOLAR_API_KEY"] = 'none'
8
 
 
 
 
9
  if os.name == 'nt':
10
  asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
11
 
12
  import openai
13
  import pandas as pd
14
  import streamlit as st
 
15
  from PIL import Image
16
+ from agent import TeLLAgent, make_tools
17
+ from streamlit_callback_handler import StreamlitCallbackHandlerChem
 
18
  import base64
 
19
  from dotenv import load_dotenv
20
+ from langchain_openai import ChatOpenAI, OpenAI
 
21
  from io import BytesIO
 
22
  import tempfile
23
+
24
+ load_dotenv()
25
 
26
  def convert_to_base64(pil_image):
27
  buffered = BytesIO()
 
33
  """Check if a given OpenAI key is valid"""
34
  try:
35
  if os.getenv("OPENAI_API_BASE"):
36
+ llm = ChatOpenAI(openai_api_key=api_key, base_url=os.getenv("OPENAI_API_BASE"))
37
  out = llm.invoke("This is a test")
38
  else:
39
+ llm = ChatOpenAI(openai_api_key=api_key)
40
  out = llm.invoke("This is a test")
41
  return True
42
  except:
43
  return False
 
44
 
45
  def initialize_session_state():
46
+ """Initialize all session state variables"""
47
  if 'prompt' not in st.session_state:
48
  st.session_state.prompt = None
49
+ if 'agent' not in st.session_state:
50
+ st.session_state.agent = None
51
+ if 'api_key' not in st.session_state:
52
+ st.session_state.api_key = os.getenv('OPENAI_API_KEY', 'none')
53
+ if 'base_url' not in st.session_state:
54
+ st.session_state.base_url = ''
55
+ if 'model1_select' not in st.session_state:
56
+ st.session_state.model1_select = ''
57
+ if 'model2_select' not in st.session_state:
58
+ st.session_state.model2_select = ''
59
+ if 'serp_api' not in st.session_state:
60
+ st.session_state.serp_api = 'none'
61
+ if 'semantic_scholar_url' not in st.session_state:
62
+ st.session_state.semantic_scholar_url = 'none'
63
+ if 'file_type' not in st.session_state:
64
+ st.session_state.file_type = 'None'
65
+ if 'domain' not in st.session_state:
66
+ st.session_state.domain = 'Organic solar cell'
67
 
68
+ # Set width of sidebar BEFORE initializing session state
 
 
 
69
  st.markdown(
70
  """
71
  <style>
 
73
  min-width: 500px;
74
  max-width: 500px;
75
  }
76
+ </style>
77
  """,
78
  unsafe_allow_html=True,
79
  )
80
 
81
+ # Initialize session state at the very beginning
82
+ initialize_session_state()
83
 
84
+ # Now safely access session state
85
+ ss = st.session_state
86
+
87
+ def instantiate_agent(model1, model2, file_path='...', image_path='...', tools=None):
88
+ """Create or update agent instance"""
89
  if not model1:
90
  model1 = "deepseek-ai/DeepSeek-R1"
91
 
 
92
  if not model2:
93
+ model2 = "deepseek-ai/DeepSeek-v3.1"
94
 
95
+ try:
96
+ ss.agent = TeLLAgent(
97
+ tools=tools,
98
+ model1=model1,
99
+ model2=model2,
100
+ tools_model='gpt-4o-2024-11-20',
101
+ temp=0.1,
102
+ openai_api_key=ss.get('api_key'),
103
+ file_path=file_path,
104
+ image_path=image_path
105
+ )
106
+ except Exception as e:
107
+ st.error(f"Error instantiating agent: {str(e)}")
108
+ ss.agent = None
109
+
110
  return ss.agent
111
 
 
112
  def on_api_key_change():
113
  api_key = ss.get('api_key') or os.getenv('OPENAI_API_KEY')
 
114
  # Check if key is valid
115
  if not oai_key_isvalid(api_key):
116
  st.write("Please input a valid OpenAI API key.")
117
 
118
+ def run_prompt(prompt, file_path='...', image_path='...'):
119
+ """Execute prompt with agent"""
120
+ try:
121
+ if ss.get('domain') == 'Drug discovery':
122
+ agent = instantiate_agent(
123
+ model1=ss.get('model1_select'),
124
+ model2=ss.get('model2_select'),
125
+ file_path=file_path,
126
+ image_path=image_path,
127
+ tools='drug'
128
+ )
129
+ else:
130
+ agent = instantiate_agent(
131
+ model1=ss.get('model1_select'),
132
+ model2=ss.get('model2_select'),
133
+ file_path=file_path,
134
+ image_path=image_path
135
+ )
136
+
137
+ if agent is None:
138
+ st.error("Failed to create agent. Please check your API key and settings.")
139
+ return
140
+
141
+ st.chat_message("user").write(prompt)
142
+
143
+ with st.chat_message("assistant"):
144
+ try:
145
+ response = agent.run(prompt)
146
+ if ss.get('file_type') == 'CSV (.csv)':
147
+ try:
148
+ fx = pd.DataFrame(list(response))
149
+ st.markdown(":red[Prediction finished! ]")
150
+ st.download_button(
151
+ "⬇️Download the predicted files as .csv",
152
+ fx.to_csv(),
153
+ "predict results.csv",
154
+ use_container_width=True
155
+ )
156
+ except:
157
+ st.write(response)
158
+ else:
159
+ st.write(response)
160
+ except openai.AuthenticationError:
161
+ st.write("Please input a valid OpenAI API key")
162
+ except openai.APIError as e:
163
+ st.error(f"OpenAI API error: {str(e)}")
164
+ except Exception as e:
165
+ st.error(f"Error running prompt: {str(e)}")
166
+ except Exception as e:
167
+ st.error(f"Error in run_prompt: {str(e)}")
168
 
169
  pre_prompts = [
170
  'Generate a donor with PCE = 10% ',
171
+ 'The history and development of Y6',
172
+ 'Predict the LogP of PM6',
 
 
 
 
173
  'Predict the PCE of Y6'
174
  ]
175
 
176
  # sidebar
177
  with st.sidebar:
 
178
  st.header("🤖 :blue[TeLLAgent] ")
179
+
180
  # Input OpenAI api key
181
  st.text_input(
182
  'Input your OpenAI API key.',
183
+ placeholder='Input your OpenAI API key.',
184
  type='password',
185
  key='api_key',
186
  on_change=on_api_key_change,
187
  label_visibility="collapsed"
188
  )
189
+
190
  st.text_input(
191
+ 'Input base url (optional).',
192
+ placeholder='Input base url (optional)',
193
+ key='base_url',
194
+ type='password',
195
  label_visibility="collapsed"
196
  )
197
+
198
  # Input model to use
199
  st.text_input(
200
  'Input global planning model to use',
 
201
  key='model1_select',
202
  )
203
+
204
  st.text_input(
205
  'Input local execution model to use',
 
206
  key='model2_select',
207
  )
208
+
209
  st.text_input(
210
+ 'Input SERP API KEY (optional).',
211
+ placeholder='Input SERP API KEY (optional)',
212
+ key='serp_api',
213
+ type='password',
214
  label_visibility="collapsed"
215
  )
216
+
217
  st.text_input(
218
+ 'Input SEMANTIC SCHOLAR API KEY (optional).',
219
+ placeholder='Input SEMANTIC SCHOLAR API KEY (optional)',
220
+ key='semantic_scholar_url',
221
+ type='password',
222
  label_visibility="collapsed"
223
  )
224
+
225
  user_api_key = ss.get('api_key')
226
  user_base_url = ss.get('base_url')
227
 
 
228
  if user_api_key:
229
  os.environ['OPENAI_API_KEY'] = user_api_key
230
 
231
  if user_base_url and user_base_url.strip() != "" and user_base_url != "none":
232
  os.environ["OPENAI_API_BASE"] = user_base_url
233
  else:
 
234
  if "OPENAI_API_BASE" in os.environ:
235
  del os.environ["OPENAI_API_BASE"]
236
 
237
+ os.environ["SERP_API_KEY"] = ss.get('serp_api', 'none')
238
+ os.environ["SEMANTIC_SCHOLAR_API_KEY"] = ss.get('semantic_scholar_url', 'none')
239
 
240
  # Display prompt examples
241
  st.markdown('# What can I ask?')
242
  cols = st.columns(2)
243
  with cols[0]:
244
  st.button(
245
+ r'👑 Generate a donor with PCE = 10% 🧨',
246
  on_click=lambda: run_prompt(pre_prompts[0]),
247
  )
248
  st.button(
249
+ r'📚 The history and development of Y6',
250
  on_click=lambda: run_prompt(pre_prompts[1]),
251
  )
252
  with cols[1]:
253
  st.button(
254
+ r"🎄Predict the LogP of PM6",
255
  on_click=lambda: run_prompt(pre_prompts[2]),
256
  )
257
  st.button(
 
260
  )
261
 
262
  st.selectbox(
263
+ 'Select the file type',
264
+ ['None', 'CSV (.csv)', 'Figure (.jpg, .png, .jpeg)', 'PDF (.pdf)'],
265
+ key='file_type',
266
+ )
267
+
268
  uploaded_file = None
269
  if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
270
+ uploaded_file = st.file_uploader("Choose a Figure", type=["jpg", "jpeg", "png"])
271
  if ss.get('file_type') == 'PDF (.pdf)':
272
  uploaded_file = st.file_uploader("Choose a PDF file")
273
  if ss.get('file_type') == 'CSV (.csv)':
274
+ uploaded_file = st.file_uploader("Choose a csv file", type='csv')
275
+
276
  st.selectbox(
277
+ r'📚 Choose the domain',
278
+ ['Organic solar cell', 'Drug discovery'],
279
+ key='domain',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  )
281
+
282
+ # Display available tools - only initialize agent if not already done
283
+ if ss.agent is None:
284
+ if ss.get('domain') == 'Drug discovery':
285
+ instantiate_agent(
286
+ model1='gpt-4o-2024-11-20',
287
+ model2='gpt-4o-2024-11-20',
288
+ tools='drug'
289
+ )
290
+ else:
291
+ instantiate_agent(
292
+ model1='gpt-4o-2024-11-20',
293
+ model2='gpt-4o-2024-11-20'
294
+ )
295
+
296
+ # Safely display tools
297
+ if ss.agent is not None and hasattr(ss.agent, 'agent_executor2'):
298
+ try:
299
+ tools = ss.agent.agent_executor2.tools
300
+ tool_list = pd.Series(
301
+ {f"✅ {t.name}": t.description for t in tools}
302
+ ).reset_index()
303
+ tool_list.columns = ['Tool', 'Description']
304
+ st.markdown(f"# {len(tool_list)} available tools")
305
+ st.dataframe(
306
+ tool_list,
307
+ use_container_width=True,
308
+ hide_index=True,
309
+ height=200
310
+ )
311
+ except Exception as e:
312
+ st.warning(f"Could not load tools: {str(e)}")
313
+ else:
314
+ st.info("Agent not initialized. Please check your API key.")
315
 
316
  # Execute agent on user input
317
  if prompt := st.chat_input("Say something and/or attach files"):
 
318
  if uploaded_file is not None:
319
  if ss.get('file_type') == 'CSV (.csv)':
320
+ with tempfile.NamedTemporaryFile(suffix='.csv', delete=False) as f:
321
+ f.write(uploaded_file.read())
322
+ run_prompt(prompt + ' ' + str(f.name), file_path=f.name)
323
+ f.close()
324
 
325
  if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
326
+ st.image(uploaded_file, width=500)
 
327
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp:
328
+ mg_str = base64.b64encode(uploaded_file.getvalue()).decode("utf-8")
329
+ temp.write(base64.b64decode(mg_str))
330
+ run_prompt(prompt + ' ' + str(temp.name), image_path=temp.name)
 
 
331
 
332
  if ss.get('file_type') == 'PDF (.pdf)':
333
+ with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as f:
334
+ f.write(uploaded_file.read())
335
+ run_prompt(prompt, file_path=f.name)
336
+ f.close()
 
 
 
 
 
 
 
 
337
  else:
338
+ run_prompt(prompt)