jinysun commited on
Commit
caccbee
·
verified ·
1 Parent(s): a8e1f9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -204
app.py CHANGED
@@ -1,14 +1,7 @@
1
  import os
2
  import asyncio
3
- # Init with fake key
4
- if 'OPENAI_API_KEY' not in os.environ:
5
- os.environ['OPENAI_API_KEY'] = 'none'
6
- os.environ["SERP_API_KEY"] = 'none'
7
- os.environ["SEMANTIC_SCHOLAR_API_KEY"] = 'none'
8
-
9
- if os.name == 'nt':
10
- asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
11
-
12
  import openai
13
  import pandas as pd
14
  import streamlit as st
@@ -45,86 +38,71 @@ def oai_key_isvalid(api_key):
45
  def initialize_session_state():
46
  """Initialize all session state variables"""
47
  if 'prompt' not in st.session_state:
48
- st.session_state.prompt = None
49
  if 'agent' not in st.session_state:
50
- st.session_state.agent = None
51
  if 'api_key' not in st.session_state:
52
- st.session_state.api_key = os.getenv('OPENAI_API_KEY', 'none')
53
  if 'base_url' not in st.session_state:
54
- st.session_state.base_url = ''
55
  if 'model1_select' not in st.session_state:
56
- st.session_state.model1_select = ''
57
  if 'model2_select' not in st.session_state:
58
- st.session_state.model2_select = ''
59
  if 'serp_api' not in st.session_state:
60
- st.session_state.serp_api = 'none'
61
  if 'semantic_scholar_url' not in st.session_state:
62
- st.session_state.semantic_scholar_url = 'none'
63
  if 'file_type' not in st.session_state:
64
- st.session_state.file_type = 'None'
65
  if 'domain' not in st.session_state:
66
- st.session_state.domain = 'Organic solar cell'
67
-
68
- # Set width of sidebar BEFORE initializing session state
69
- st.markdown(
70
- """
71
- <style>
72
- [data-testid="stSidebar"][aria-expanded="true"]{
73
- min-width: 500px;
74
- max-width: 500px;
75
- }
76
- </style>
77
- """,
78
- unsafe_allow_html=True,
79
- )
80
-
81
- # Initialize session state at the very beginning
82
- initialize_session_state()
83
-
84
- # Now safely access session state
85
- ss = st.session_state
86
 
87
  def instantiate_agent(model1, model2, file_path='...', image_path='...', tools=None):
88
  """Create or update agent instance"""
 
 
 
 
 
89
 
90
  try:
91
- ss.agent = TeLLAgent(
92
  tools=tools,
93
  model1=model1,
94
  model2=model2,
95
  tools_model='gpt-4o-2024-11-20',
96
  temp=0.1,
97
- openai_api_key=ss.get('api_key'),
98
  file_path=file_path,
99
  image_path=image_path
100
  )
101
  except Exception as e:
102
  st.error(f"Error instantiating agent: {str(e)}")
103
- ss.agent = None
104
 
105
- return ss.agent
106
 
107
  def on_api_key_change():
108
- api_key = ss.get('api_key') or os.getenv('OPENAI_API_KEY')
109
- # Check if key is valid
110
  if not oai_key_isvalid(api_key):
111
  st.write("Please input a valid OpenAI API key.")
112
 
113
  def run_prompt(prompt, file_path='...', image_path='...'):
114
  """Execute prompt with agent"""
115
  try:
116
- if ss.get('domain') == 'Drug discovery':
117
  agent = instantiate_agent(
118
- model1=ss.get('model1_select'),
119
- model2=ss.get('model2_select'),
120
  file_path=file_path,
121
  image_path=image_path,
122
  tools='drug'
123
  )
124
  else:
125
  agent = instantiate_agent(
126
- model1=ss.get('model1_select'),
127
- model2=ss.get('model2_select'),
128
  file_path=file_path,
129
  image_path=image_path
130
  )
@@ -138,7 +116,7 @@ def run_prompt(prompt, file_path='...', image_path='...'):
138
  with st.chat_message("assistant"):
139
  try:
140
  response = agent.run(prompt)
141
- if ss.get('file_type') == 'CSV (.csv)':
142
  try:
143
  fx = pd.DataFrame(list(response))
144
  st.markdown(":red[Prediction finished! ]")
@@ -161,173 +139,198 @@ def run_prompt(prompt, file_path='...', image_path='...'):
161
  except Exception as e:
162
  st.error(f"Error in run_prompt: {str(e)}")
163
 
164
- pre_prompts = [
165
- 'Generate a donor with PCE = 10% ',
166
- 'The history and development of Y6',
167
- 'Predict the LogP of PM6',
168
- 'Predict the PCE of Y6'
169
- ]
170
-
171
- # sidebar
172
- with st.sidebar:
173
- st.header("🤖 :blue[TeLLAgent] ")
174
-
175
- # Input OpenAI api key
176
- st.text_input(
177
- 'Input your OpenAI API key.',
178
- placeholder='Input your OpenAI API key.',
179
- type='password',
180
- key='api_key',
181
- on_change=on_api_key_change,
182
- label_visibility="collapsed"
183
- )
184
-
185
- st.text_input(
186
- 'Input base url (optional).',
187
- placeholder='Input base url (optional)',
188
- key='base_url',
189
- type='password',
190
- label_visibility="collapsed"
191
- )
192
 
193
- # Input model to use
194
- st.text_input(
195
- 'Input global planning model to use',
196
- key='model1_select',
197
- )
198
 
199
- st.text_input(
200
- 'Input local execution model to use',
201
- key='model2_select',
 
 
 
 
 
 
 
 
202
  )
203
 
204
- st.text_input(
205
- 'Input SERP API KEY (optional).',
206
- placeholder='Input SERP API KEY (optional)',
207
- key='serp_api',
208
- type='password',
209
- label_visibility="collapsed"
210
- )
211
 
212
- st.text_input(
213
- 'Input SEMANTIC SCHOLAR API KEY (optional).',
214
- placeholder='Input SEMANTIC SCHOLAR API KEY (optional)',
215
- key='semantic_scholar_url',
216
- type='password',
217
- label_visibility="collapsed"
218
- )
219
 
220
- user_api_key = ss.get('api_key')
221
- user_base_url = ss.get('base_url')
222
-
223
- if user_api_key:
224
- os.environ['OPENAI_API_KEY'] = user_api_key
 
225
 
226
- if user_base_url and user_base_url.strip() != "" and user_base_url != "none":
227
- os.environ["OPENAI_API_BASE"] = user_base_url
228
- else:
229
- if "OPENAI_API_BASE" in os.environ:
230
- del os.environ["OPENAI_API_BASE"]
231
-
232
- os.environ["SERP_API_KEY"] = ss.get('serp_api', 'none')
233
- os.environ["SEMANTIC_SCHOLAR_API_KEY"] = ss.get('semantic_scholar_url', 'none')
234
-
235
- # Display prompt examples
236
- st.markdown('# What can I ask?')
237
- cols = st.columns(2)
238
- with cols[0]:
239
- st.button(
240
- r'👑 Generate a donor with PCE = 10% 🧨',
241
- on_click=lambda: run_prompt(pre_prompts[0]),
242
- )
243
- st.button(
244
- r'📚 The history and development of Y6',
245
- on_click=lambda: run_prompt(pre_prompts[1]),
246
  )
247
- with cols[1]:
248
- st.button(
249
- r"🎄Predict the LogP of PM6",
250
- on_click=lambda: run_prompt(pre_prompts[2]),
 
 
 
 
 
 
 
 
 
251
  )
252
- st.button(
253
- r'💎 Predict the PCE of Y6',
254
- on_click=lambda: run_prompt(pre_prompts[3]),
 
255
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
- st.selectbox(
258
- 'Select the file type',
259
- ['None', 'CSV (.csv)', 'Figure (.jpg, .png, .jpeg)', 'PDF (.pdf)'],
260
- key='file_type',
261
- )
262
-
263
- uploaded_file = None
264
- if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
265
- uploaded_file = st.file_uploader("Choose a Figure", type=["jpg", "jpeg", "png"])
266
- if ss.get('file_type') == 'PDF (.pdf)':
267
- uploaded_file = st.file_uploader("Choose a PDF file")
268
- if ss.get('file_type') == 'CSV (.csv)':
269
- uploaded_file = st.file_uploader("Choose a csv file", type='csv')
270
 
271
- st.selectbox(
272
- r'📚 Choose the domain',
273
- ['Organic solar cell', 'Drug discovery'],
274
- key='domain',
275
- )
276
-
277
- # Display available tools - only initialize agent if not already done
278
- if ss.agent is None:
279
- if ss.get('domain') == 'Drug discovery':
280
- instantiate_agent(
281
- model1='gpt-4o-2024-11-20',
282
- model2='gpt-4o-2024-11-20',
283
- tools='drug'
284
  )
285
- else:
286
- instantiate_agent(
287
- model1='gpt-4o-2024-11-20',
288
- model2='gpt-4o-2024-11-20'
289
  )
290
-
291
- # Safely display tools
292
- if ss.agent is not None and hasattr(ss.agent, 'agent_executor2'):
293
- try:
294
- tools = ss.agent.agent_executor2.tools
295
- tool_list = pd.Series(
296
- {f"✅ {t.name}": t.description for t in tools}
297
- ).reset_index()
298
- tool_list.columns = ['Tool', 'Description']
299
- st.markdown(f"# {len(tool_list)} available tools")
300
- st.dataframe(
301
- tool_list,
302
- use_container_width=True,
303
- hide_index=True,
304
- height=200
305
  )
306
- except Exception as e:
307
- st.warning(f"Could not load tools: {str(e)}")
308
- else:
309
- st.info("Agent not initialized. Please check your API key.")
310
 
311
- # Execute agent on user input
312
- if prompt := st.chat_input("Say something and/or attach files"):
313
- if uploaded_file is not None:
314
- if ss.get('file_type') == 'CSV (.csv)':
315
- with tempfile.NamedTemporaryFile(suffix='.csv', delete=False) as f:
316
- f.write(uploaded_file.read())
317
- run_prompt(prompt + ' ' + str(f.name), file_path=f.name)
318
- f.close()
319
-
320
- if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
321
- st.image(uploaded_file, width=500)
322
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp:
323
- mg_str = base64.b64encode(uploaded_file.getvalue()).decode("utf-8")
324
- temp.write(base64.b64decode(mg_str))
325
- run_prompt(prompt + ' ' + str(temp.name), image_path=temp.name)
326
-
327
- if ss.get('file_type') == 'PDF (.pdf)':
328
- with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as f:
329
- f.write(uploaded_file.read())
330
- run_prompt(prompt, file_path=f.name)
331
- f.close()
332
- else:
333
- run_prompt(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import asyncio
3
+
4
+
 
 
 
 
 
 
 
5
  import openai
6
  import pandas as pd
7
  import streamlit as st
 
38
  def initialize_session_state():
39
  """Initialize all session state variables"""
40
  if 'prompt' not in st.session_state:
41
+ st.session_state['prompt'] = None
42
  if 'agent' not in st.session_state:
43
+ st.session_state['agent'] = None
44
  if 'api_key' not in st.session_state:
45
+ st.session_state['api_key'] = os.getenv('OPENAI_API_KEY', 'none')
46
  if 'base_url' not in st.session_state:
47
+ st.session_state['base_url'] = ''
48
  if 'model1_select' not in st.session_state:
49
+ st.session_state['model1_select'] = ''
50
  if 'model2_select' not in st.session_state:
51
+ st.session_state['model2_select'] = ''
52
  if 'serp_api' not in st.session_state:
53
+ st.session_state['serp_api'] = 'none'
54
  if 'semantic_scholar_url' not in st.session_state:
55
+ st.session_state['semantic_scholar_url'] = 'none'
56
  if 'file_type' not in st.session_state:
57
+ st.session_state['file_type'] = 'None'
58
  if 'domain' not in st.session_state:
59
+ st.session_state['domain'] = 'Organic solar cell'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  def instantiate_agent(model1, model2, file_path='...', image_path='...', tools=None):
62
  """Create or update agent instance"""
63
+ if not model1:
64
+ model1 = "deepseek-ai/DeepSeek-R1"
65
+
66
+ if not model2:
67
+ model2 = "deepseek-ai/DeepSeek-v3.1"
68
 
69
  try:
70
+ st.session_state['agent'] = TeLLAgent(
71
  tools=tools,
72
  model1=model1,
73
  model2=model2,
74
  tools_model='gpt-4o-2024-11-20',
75
  temp=0.1,
76
+ openai_api_key=st.session_state.get('api_key'),
77
  file_path=file_path,
78
  image_path=image_path
79
  )
80
  except Exception as e:
81
  st.error(f"Error instantiating agent: {str(e)}")
82
+ st.session_state['agent'] = None
83
 
84
+ return st.session_state.get('agent')
85
 
86
  def on_api_key_change():
87
+ api_key = st.session_state.get('api_key') or os.getenv('OPENAI_API_KEY')
 
88
  if not oai_key_isvalid(api_key):
89
  st.write("Please input a valid OpenAI API key.")
90
 
91
  def run_prompt(prompt, file_path='...', image_path='...'):
92
  """Execute prompt with agent"""
93
  try:
94
+ if st.session_state.get('domain') == 'Drug discovery':
95
  agent = instantiate_agent(
96
+ model1=st.session_state.get('model1_select'),
97
+ model2=st.session_state.get('model2_select'),
98
  file_path=file_path,
99
  image_path=image_path,
100
  tools='drug'
101
  )
102
  else:
103
  agent = instantiate_agent(
104
+ model1=st.session_state.get('model1_select'),
105
+ model2=st.session_state.get('model2_select'),
106
  file_path=file_path,
107
  image_path=image_path
108
  )
 
116
  with st.chat_message("assistant"):
117
  try:
118
  response = agent.run(prompt)
119
+ if st.session_state.get('file_type') == 'CSV (.csv)':
120
  try:
121
  fx = pd.DataFrame(list(response))
122
  st.markdown(":red[Prediction finished! ]")
 
139
  except Exception as e:
140
  st.error(f"Error in run_prompt: {str(e)}")
141
 
142
+ def main():
143
+ """Main application function"""
144
+ initialize_session_state()
145
+ # Init with fake key
146
+ if 'OPENAI_API_KEY' not in os.environ:
147
+ os.environ['OPENAI_API_KEY'] = 'none'
148
+ os.environ["SERP_API_KEY"] = 'none'
149
+ os.environ["SEMANTIC_SCHOLAR_API_KEY"] = 'none'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ if os.name == 'nt':
152
+ asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
 
 
 
153
 
154
+ # Set width of sidebar
155
+ st.markdown(
156
+ """
157
+ <style>
158
+ [data-testid="stSidebar"][aria-expanded="true"]{
159
+ min-width: 500px;
160
+ max-width: 500px;
161
+ }
162
+ </style>
163
+ """,
164
+ unsafe_allow_html=True,
165
  )
166
 
167
+ # Initialize session state
 
 
 
 
 
 
168
 
 
 
 
 
 
 
 
169
 
170
+ pre_prompts = [
171
+ 'Generate a donor with PCE = 10% ',
172
+ 'The history and development of Y6',
173
+ 'Predict the LogP of PM6',
174
+ 'Predict the PCE of Y6'
175
+ ]
176
 
177
+ # Sidebar
178
+ with st.sidebar:
179
+ st.header("🤖 :blue[TeLLAgent] ")
180
+
181
+ # Input OpenAI api key
182
+ st.text_input(
183
+ 'Input your OpenAI API key.',
184
+ placeholder='Input your OpenAI API key.',
185
+ type='password',
186
+ key='api_key',
187
+ on_change=on_api_key_change,
188
+ label_visibility="collapsed"
 
 
 
 
 
 
 
 
189
  )
190
+
191
+ st.text_input(
192
+ 'Input base url (optional).',
193
+ placeholder='Input base url (optional)',
194
+ key='base_url',
195
+ type='password',
196
+ label_visibility="collapsed"
197
+ )
198
+
199
+ # Input model to use
200
+ st.text_input(
201
+ 'Input global planning model to use',
202
+ key='model1_select',
203
  )
204
+
205
+ st.text_input(
206
+ 'Input local execution model to use',
207
+ key='model2_select',
208
  )
209
+
210
+ st.text_input(
211
+ 'Input SERP API KEY (optional).',
212
+ placeholder='Input SERP API KEY (optional)',
213
+ key='serp_api',
214
+ type='password',
215
+ label_visibility="collapsed"
216
+ )
217
+
218
+ st.text_input(
219
+ 'Input SEMANTIC SCHOLAR API KEY (optional).',
220
+ placeholder='Input SEMANTIC SCHOLAR API KEY (optional)',
221
+ key='semantic_scholar_url',
222
+ type='password',
223
+ label_visibility="collapsed"
224
+ )
225
+
226
+ user_api_key = st.session_state.get('api_key')
227
+ user_base_url = st.session_state.get('base_url')
228
 
229
+ if user_api_key:
230
+ os.environ['OPENAI_API_KEY'] = user_api_key
231
+
232
+ if user_base_url and user_base_url.strip() != "" and user_base_url != "none":
233
+ os.environ["OPENAI_API_BASE"] = user_base_url
234
+ else:
235
+ if "OPENAI_API_BASE" in os.environ:
236
+ del os.environ["OPENAI_API_BASE"]
237
+
238
+ os.environ["SERP_API_KEY"] = st.session_state.get('serp_api', 'none')
239
+ os.environ["SEMANTIC_SCHOLAR_API_KEY"] = st.session_state.get('semantic_scholar_url', 'none')
 
 
240
 
241
+ # Display prompt examples
242
+ st.markdown('# What can I ask?')
243
+ cols = st.columns(2)
244
+ with cols[0]:
245
+ st.button(
246
+ r'👑 Generate a donor with PCE = 10% 🧨',
247
+ on_click=lambda: run_prompt(pre_prompts[0]),
 
 
 
 
 
 
248
  )
249
+ st.button(
250
+ r'📚 The history and development of Y6',
251
+ on_click=lambda: run_prompt(pre_prompts[1]),
 
252
  )
253
+ with cols[1]:
254
+ st.button(
255
+ r"🎄Predict the LogP of PM6",
256
+ on_click=lambda: run_prompt(pre_prompts[2]),
257
+ )
258
+ st.button(
259
+ r'💎 Predict the PCE of Y6',
260
+ on_click=lambda: run_prompt(pre_prompts[3]),
 
 
 
 
 
 
 
261
  )
 
 
 
 
262
 
263
+ st.selectbox(
264
+ 'Select the file type',
265
+ ['None', 'CSV (.csv)', 'Figure (.jpg, .png, .jpeg)', 'PDF (.pdf)'],
266
+ key='file_type',
267
+ )
268
+
269
+ uploaded_file = None
270
+ if st.session_state.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
271
+ uploaded_file = st.file_uploader("Choose a Figure", type=["jpg", "jpeg", "png"])
272
+ if st.session_state.get('file_type') == 'PDF (.pdf)':
273
+ uploaded_file = st.file_uploader("Choose a PDF file")
274
+ if st.session_state.get('file_type') == 'CSV (.csv)':
275
+ uploaded_file = st.file_uploader("Choose a csv file", type='csv')
276
+
277
+ st.selectbox(
278
+ r'📚 Choose the domain',
279
+ ['Organic solar cell', 'Drug discovery'],
280
+ key='domain',
281
+ )
282
+
283
+ # Display available tools - only initialize agent if not already done
284
+ if st.session_state.get('agent') is None:
285
+ if st.session_state.get('domain') == 'Drug discovery':
286
+ instantiate_agent(
287
+ model1='gpt-4o-2024-11-20',
288
+ model2='gpt-4o-2024-11-20',
289
+ tools='drug'
290
+ )
291
+ else:
292
+ instantiate_agent(
293
+ model1='gpt-4o-2024-11-20',
294
+ model2='gpt-4o-2024-11-20'
295
+ )
296
+
297
+ # Safely display tools
298
+ agent = st.session_state.get('agent')
299
+ if agent is not None and hasattr(agent, 'agent_executor2'):
300
+ try:
301
+ tools = agent.agent_executor2.tools
302
+ tool_list = pd.Series(
303
+ {f"✅ {t.name}": t.description for t in tools}
304
+ ).reset_index()
305
+ tool_list.columns = ['Tool', 'Description']
306
+ st.markdown(f"# {len(tool_list)} available tools")
307
+ st.dataframe(
308
+ tool_list,
309
+ use_container_width=True,
310
+ hide_index=True,
311
+ height=200
312
+ )
313
+ except Exception as e:
314
+ st.warning(f"Could not load tools: {str(e)}")
315
+ else:
316
+ st.info("Agent not initialized. Please check your API key.")
317
+
318
+ # Execute agent on user input
319
+ if prompt := st.chat_input("Say something and/or attach files"):
320
+ uploaded_file = None
321
+ if st.session_state.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
322
+ # Re-get uploaded file for main area
323
+ pass
324
+ elif st.session_state.get('file_type') == 'PDF (.pdf)':
325
+ pass
326
+ elif st.session_state.get('file_type') == 'CSV (.csv)':
327
+ pass
328
+
329
+ # Check if file was uploaded in sidebar
330
+ # Note: uploaded_file needs to be handled differently in Streamlit
331
+ # For now, just run the prompt
332
+ run_prompt(prompt)
333
+
334
+ # Run the main function
335
+ if __name__ == "__main__":
336
+ main()