jinysun commited on
Commit
aabfc97
·
verified ·
1 Parent(s): de41971

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +218 -239
app.py CHANGED
@@ -1,20 +1,33 @@
 
1
  import os
2
  import asyncio
 
 
 
 
 
 
 
 
 
 
3
  import openai
4
  import pandas as pd
5
  import streamlit as st
 
6
  from PIL import Image
7
- from agent import TeLLAgent, make_tools
8
- from streamlit_callback_handler import StreamlitCallbackHandlerChem
 
9
  import base64
 
10
  from dotenv import load_dotenv
11
- from langchain_openai import ChatOpenAI, OpenAI
 
12
  from io import BytesIO
 
13
  import tempfile
14
-
15
- load_dotenv()
16
-
17
- # --- Helper Functions ---
18
 
19
  def convert_to_base64(pil_image):
20
  buffered = BytesIO()
@@ -25,259 +38,225 @@ def convert_to_base64(pil_image):
25
  def oai_key_isvalid(api_key):
26
  """Check if a given OpenAI key is valid"""
27
  try:
28
- # Simple check without invoking expensive calls if possible,
29
- # or keep the invoke if you want strict validation.
30
- if not api_key or api_key == 'none': return False
31
-
32
  if os.getenv("OPENAI_API_BASE"):
33
- llm = ChatOpenAI(openai_api_key=api_key, base_url=os.getenv("OPENAI_API_BASE"))
 
34
  else:
35
- llm = ChatOpenAI(openai_api_key=api_key)
36
- llm.invoke("Hi")
37
  return True
38
  except:
39
  return False
40
-
41
- def initialize_session_state():
42
- """Initialize all session state variables"""
43
- # UI States
44
- if 'prompt' not in st.session_state: st.session_state['prompt'] = None
45
- if 'file_type' not in st.session_state: st.session_state['file_type'] = 'None'
46
- if 'domain' not in st.session_state: st.session_state['domain'] = 'Organic solar cell'
47
 
48
- # Config States (Defaults)
49
- if 'model1_select' not in st.session_state: st.session_state['model1_select'] = 'deepseek-reasoner'
50
- if 'model2_select' not in st.session_state: st.session_state['model2_select'] = 'deepseek-reasoner'
51
- if 'api_key' not in st.session_state: st.session_state['api_key'] = os.getenv('OPENAI_API_KEY', '')
52
-
53
- # Agent & Logic States
54
- if 'agent' not in st.session_state: st.session_state['agent'] = None
55
- if 'last_agent_config' not in st.session_state: st.session_state['last_agent_config'] = {}
56
-
57
- # File Paths
58
- if 'current_file_path' not in st.session_state: st.session_state['current_file_path'] = '...'
59
- if 'current_image_path' not in st.session_state: st.session_state['current_image_path'] = '...'
60
-
61
- def update_agent_if_needed(current_config):
62
- """
63
- Compare current_config with the last used config.
64
- If changed, instantiate the agent automatically.
65
  """
66
- last_config = st.session_state.get('last_agent_config', {})
67
-
68
- # Check if config changed or agent is missing
69
- if st.session_state['agent'] is None or current_config != last_config:
70
-
71
- # Validations
72
- if not current_config['api_key'] or current_config['api_key'] == 'none':
73
- # Don't try to init if no key, just return
74
- return
75
 
76
- model1 = current_config['model1'] or "deepseek-reasoner"
77
- model2 = current_config['model2'] or "deepseek-reasoner"
78
- domain = current_config['domain']
79
- file_path = current_config['file_path']
80
- image_path = current_config['image_path']
81
-
82
- tools_setting = 'drug' if domain == 'Drug discovery' else None
83
 
84
- try:
85
- # Create the agent
86
- new_agent = TeLLAgent(
87
- tools=tools_setting,
88
- model1=model1,
89
- model2=model2,
90
- tools_model='gpt-4o-2024-11-20',
91
- temp=0.1,
92
- openai_api_key=current_config['api_key'],
93
- file_path=file_path,
94
- image_path=image_path
95
- )
96
-
97
- st.session_state['agent'] = new_agent
98
- st.session_state['last_agent_config'] = current_config # Update history
99
-
100
- # Optional: Toast notification to show user it updated
101
- # st.toast(f"Agent updated! (Model: {model1})", icon="🤖")
102
-
103
- except Exception as e:
104
- st.error(f"Auto-initialization failed: {str(e)}")
105
- st.session_state['agent'] = None
106
 
107
- def run_prompt(prompt):
108
- """Execute prompt using the already initialized agent"""
109
- agent = st.session_state.get('agent')
110
-
111
- if agent is None:
112
- st.error("Agent is not ready. Please check your API Key and Model settings.")
113
- return
114
-
 
 
 
 
 
 
115
  st.chat_message("user").write(prompt)
116
-
117
- with st.chat_message("assistant"):
118
  try:
119
- response = agent.run(prompt)
120
- if st.session_state.get('file_type') == 'CSV (.csv)':
121
- try:
122
- fx = pd.DataFrame(list(response))
123
- st.markdown(":red[Prediction finished! ]")
124
- st.download_button(
125
- "⬇️Download the predicted files as .csv",
126
- fx.to_csv(),
127
- "predict results.csv",
128
- use_container_width=True
129
- )
130
- except:
131
- st.write(response)
132
- else:
133
  st.write(response)
134
  except openai.AuthenticationError:
135
  st.write("Please input a valid OpenAI API key")
136
- except openai.APIError as e:
137
- st.error(f"OpenAI API error: {str(e)}")
138
- except Exception as e:
139
- st.error(f"Error running prompt: {str(e)}")
140
 
141
- def on_api_key_change():
142
- """Callback when API key changes"""
143
- # Logic is now handled by auto-update, but we can do validation here
144
- pass
 
 
 
 
 
 
145
 
146
- def main():
147
- initialize_session_state()
148
-
149
- # Environment Defaults
150
- if 'OPENAI_API_KEY' not in os.environ:
151
- os.environ['OPENAI_API_KEY'] = 'none'
152
- os.environ["SERP_API_KEY"] = 'none'
153
- os.environ["SEMANTIC_SCHOLAR_API_KEY"] = 'none'
154
-
155
- if os.name == 'nt':
156
- asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
157
-
158
- # Sidebar Setup
159
- st.markdown(
160
- """
161
- <style>
162
- [data-testid="stSidebar"][aria-expanded="true"]{
163
- min-width: 500px; max-width: 500px;
164
- }
165
- </style>
166
- """,
167
- unsafe_allow_html=True,
168
  )
169
-
170
- # --- Sidebar Inputs ---
171
- with st.sidebar:
172
- st.header("🤖 :blue[TeLLAgent] ")
173
-
174
- # 1. Credentials
175
- api_key_input = st.text_input('Input your OpenAI API key.', type='password', key='api_key', label_visibility="collapsed")
176
- base_url_input = st.text_input('Input base url (optional).', key='base_url', type='password', label_visibility="collapsed")
177
-
178
- # 2. Model Configuration
179
- # Defaults are set in initialize_session_state, user can change them here
180
- model1_input = st.text_input('Input global planning model to use', key='model1_select')
181
- model2_input = st.text_input('Input local execution model to use', key='model2_select')
182
-
183
- # 3. External Tools Keys
184
- serp_input = st.text_input('Input SERP API KEY (optional).', key='serp_api', type='password', label_visibility="collapsed")
185
- scholar_input = st.text_input('Input SEMANTIC SCHOLAR API KEY (optional).', key='semantic_scholar_url', type='password', label_visibility="collapsed")
186
-
187
- # Update OS Environ
188
- if api_key_input: os.environ['OPENAI_API_KEY'] = api_key_input
189
- if base_url_input and base_url_input != "none": os.environ["OPENAI_API_BASE"] = base_url_input
190
- elif "OPENAI_API_BASE" in os.environ: del os.environ["OPENAI_API_BASE"]
191
- os.environ["SERP_API_KEY"] = serp_input if serp_input else 'none'
192
- os.environ["SEMANTIC_SCHOLAR_API_KEY"] = scholar_input if scholar_input else 'none'
193
-
194
- # 4. Domain & Files
195
- st.selectbox(r'📚 Choose the domain', ['Organic solar cell', 'Drug discovery'], key='domain')
196
- st.selectbox('Select the file type', ['None', 'CSV (.csv)', 'Figure (.jpg, .png, .jpeg)', 'PDF (.pdf)'], key='file_type')
197
-
198
- # --- File Processing (Immediate) ---
199
- # We process files HERE so the paths are ready for the Agent Auto-Init immediately after.
200
- uploaded_file = None
201
- current_file_path = '...'
202
- current_image_path = '...'
203
 
204
- if st.session_state['file_type'] == 'Figure (.jpg, .png, .jpeg)':
205
- uploaded_file = st.file_uploader("Choose a Figure", type=["jpg", "jpeg", "png"])
206
- if uploaded_file:
207
- st.image(uploaded_file, width=200)
208
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp:
209
- img_str = base64.b64encode(uploaded_file.getvalue()).decode("utf-8")
210
- temp.write(base64.b64decode(img_str))
211
- current_image_path = temp.name
212
- st.session_state['current_image_path'] = temp.name # Persist
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
- elif st.session_state['file_type'] == 'PDF (.pdf)':
215
- uploaded_file = st.file_uploader("Choose a PDF file")
216
- if uploaded_file:
217
- with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as f:
218
- f.write(uploaded_file.read())
219
- current_file_path = f.name
220
- st.session_state['current_file_path'] = f.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- elif st.session_state['file_type'] == 'CSV (.csv)':
223
- uploaded_file = st.file_uploader("Choose a csv file", type='csv')
224
- if uploaded_file:
225
- with tempfile.NamedTemporaryFile(suffix='.csv', delete=False) as f:
226
- f.write(uploaded_file.read())
227
- current_file_path = f.name
228
- st.session_state['current_file_path'] = f.name
229
-
230
- # --- AUTOMATIC AGENT INITIALIZATION ---
231
- # Collect current config into a dict
232
- current_config = {
233
- 'api_key': api_key_input,
234
- 'base_url': base_url_input,
235
- 'model1': model1_input,
236
- 'model2': model2_input,
237
- 'domain': st.session_state['domain'],
238
- 'file_path': current_file_path, # Uses the temp path created just now
239
- 'image_path': current_image_path # Uses the temp path created just now
240
- }
241
-
242
- # Trigger the check & init logic
243
- update_agent_if_needed(current_config)
244
 
245
- # Show Tools status
246
- agent = st.session_state.get('agent')
247
- if agent is not None and hasattr(agent, 'agent_executor2'):
248
- try:
249
- tools = agent.agent_executor2.tools
250
- tool_list = pd.Series({f"✅ {t.name}": t.description for t in tools}).reset_index()
251
- tool_list.columns = ['Tool', 'Description']
252
- st.dataframe(tool_list, use_container_width=True, hide_index=True, height=150)
253
- except: pass
254
- else:
255
- if not api_key_input:
256
- st.warning("🔑 Waiting for API Key...")
257
-
258
- # Pre-defined Prompts
259
-
260
- st.markdown('# What can I ask?')
261
-
262
- pre_prompts = [
263
- 'Generate a donor with PCE = 10% ',
264
- 'The history and development of Y6',
265
- 'Predict the LogP of PM6',
266
- 'Predict the PCE of Y6'
267
- ]
268
-
269
- cols = st.columns(2)
270
- with cols[0]:
271
- st.button(r'👑 Generate a donor...', on_click=lambda: run_prompt(pre_prompts[0]))
272
- st.button(r'📚 History of Y6', on_click=lambda: run_prompt(pre_prompts[1]))
273
- with cols[1]:
274
- st.button(r"🎄Predict LogP of PM6", on_click=lambda: run_prompt(pre_prompts[2]))
275
- st.button(r'💎 Predict PCE of Y6', on_click=lambda: run_prompt(pre_prompts[3]))
276
 
277
- # --- Main Chat Area ---
278
- if prompt := st.chat_input("Say something..."):
279
- # Append context about file if needed, though Agent handles files via init args mostly.
280
- run_prompt(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
- if __name__ == "__main__":
283
- main()
 
1
+
2
  import os
3
  import asyncio
4
+ # Init with fake key
5
+ if 'OPENAI_API_KEY' not in os.environ:
6
+ os.environ['OPENAI_API_KEY'] = 'none'
7
+ os.environ["OPENAI_API_BASE"] = 'none'
8
+
9
+ os.environ["SERP_API_KEY"] = 'none'
10
+ os.environ["SEMANTIC_SCHOLAR_API_KEY"] = 'none'
11
+ if os.name == 'nt':
12
+ asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
13
+
14
  import openai
15
  import pandas as pd
16
  import streamlit as st
17
+
18
  from PIL import Image
19
+ from agent import TeLLAgent, make_tools
20
+ from streamlit_callback_handler import \
21
+ StreamlitCallbackHandlerChem
22
  import base64
23
+ import pandas as pd
24
  from dotenv import load_dotenv
25
+ from langchain_openai import ChatOpenAI , OpenAI
26
+ import base64
27
  from io import BytesIO
28
+ from PIL import Image
29
  import tempfile
30
+
 
 
 
31
 
32
  def convert_to_base64(pil_image):
33
  buffered = BytesIO()
 
38
  def oai_key_isvalid(api_key):
39
  """Check if a given OpenAI key is valid"""
40
  try:
 
 
 
 
41
  if os.getenv("OPENAI_API_BASE"):
42
+ llm = ChatOpenAI(openai_api_key = api_key, base_url=os.getenv("OPENAI_API_BASE"))
43
+ out = llm.invoke("This is a test")
44
  else:
45
+ llm = ChatOpenAI(openai_api_key = api_key)
46
+ out = llm.invoke("This is a test")
47
  return True
48
  except:
49
  return False
 
 
 
 
 
 
 
50
 
51
+ load_dotenv()
52
+ ss = st.session_state
53
+ ss.prompt = None
54
+
55
+ # Set width of sidebar
56
+ st.markdown(
 
 
 
 
 
 
 
 
 
 
 
57
  """
58
+ <style>
59
+ [data-testid="stSidebar"][aria-expanded="true"]{
60
+ min-width: 500px;
61
+ max-width: 500px;
62
+ }
63
+ """,
64
+ unsafe_allow_html=True,
65
+ )
 
66
 
 
 
 
 
 
 
 
67
 
68
+ def instantiate_agent(model1, model2, file_path = '...', image_path ='...', tools=None):
69
+ ss.agent = TeLLAgent( tools=tools,
70
+ model1 = model1,
71
+ model2 = model2,
72
+ tools_model='gpt-4o-2024-11-20',
73
+ temp=0.1,
74
+ openai_api_key=ss.get('api_key') , file_path = file_path,
75
+ image_path =image_path)
76
+ return ss.agent
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+
79
+
80
+ def on_api_key_change():
81
+ api_key = ss.get('api_key') or os.getenv('OPENAI_API_KEY')
82
+
83
+ # Check if key is valid
84
+ if not oai_key_isvalid(api_key):
85
+ st.write("Please input a valid OpenAI API key.")
86
+
87
+ def run_prompt(prompt, file_path = '...', image_path = '...'):
88
+ if ss.get('domain') =='Drug discovery':
89
+ agent = instantiate_agent(model1 = ss.get('model1_select'), model2 = ss.get('model2_select'), file_path = file_path, image_path =image_path, tools = 'drug')
90
+ else:
91
+ agent = instantiate_agent(model1 = ss.get('model1_select'), model2 = ss.get('model2_select'), file_path = file_path, image_path =image_path)
92
  st.chat_message("user").write(prompt)
93
+ with st.chat_message("assistant") :
 
94
  try:
95
+
96
+ response = agent.run(prompt)
97
+ if ss.get('file_type') == 'CSV (.csv)':
98
+ try:
99
+ fx = pd.DataFrame(list(response))
100
+ st.markdown(":red[Prediction finished! ]")
101
+ st.download_button( "⬇️Download the predicted files as .csv", fx.to_csv(), "predict results.csv", use_container_width=True)
102
+ except:
103
+ st.write(response)
104
+ else:
 
 
 
 
105
  st.write(response)
106
  except openai.AuthenticationError:
107
  st.write("Please input a valid OpenAI API key")
108
+ except openai.APIError:
109
+ # Handle specific API errors here
110
+ print("OpenAI API error, please try again!")
111
+
112
 
113
+ pre_prompts = [
114
+ 'Generate a donor with PCE = 10% ',
115
+ ('The history and development of Y6'
116
+
117
+ ),
118
+ (
119
+ 'Predict the LogP of PM6'
120
+ ),
121
+ 'Predict the PCE of Y6'
122
+ ]
123
 
124
+ # sidebar
125
+ with st.sidebar:
126
+
127
+ st.header("🤖 :blue[TeLLAgent] ")
128
+ # Input OpenAI api key
129
+ st.text_input(
130
+ 'Input your OpenAI API key.',
131
+ placeholder = 'Input your OpenAI API key.',
132
+ type='password',
133
+ key='api_key',
134
+ on_change=on_api_key_change,
135
+ label_visibility="collapsed"
 
 
 
 
 
 
 
 
 
 
136
  )
137
+ st.text_input(
138
+ 'Input base url (optional).',
139
+ placeholder = 'Input base url (optional)',
140
+ key='base_url',type='password',
141
+ label_visibility="collapsed"
142
+ )
143
+ # Input model to use
144
+ st.text_input(
145
+ 'Input global planning model to use',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ key='model1_select',
148
+ )
149
+ st.text_input(
150
+ 'Input local execution model to use',
151
+
152
+ key='model2_select',
153
+ )
154
+ st.text_input(
155
+ 'Input SERP API KEY (optional).',
156
+ placeholder = 'Input SERP API KEY (optional)',
157
+ key='serp_api',type='password',
158
+ label_visibility="collapsed"
159
+ )
160
+ st.text_input(
161
+ 'Input SEMANTIC SCHOLAR API KEY (optional).',
162
+ placeholder = 'Input SEMANTIC SCHOLAR API KEY (optional)',
163
+ key='semantic_scholar_url',type='password',
164
+ label_visibility="collapsed"
165
+ )
166
+ os.environ['OPENAI_API_KEY'] = ss.get('api_key')
167
+ os.environ["OPENAI_API_BASE"] = ss.get('base_url')
168
+
169
+ os.environ["SERP_API_KEY"] = ss.get('serp_api')
170
+ os.environ["SEMANTIC_SCHOLAR_API_KEY"] = ss.get('semantic_scholar_url')
171
+
172
+ # Display prompt examples
173
+ st.markdown('# What can I ask?')
174
+ cols = st.columns(2)
175
+ with cols[0]:
176
+ st.button(
177
+ r'👑 Generate a donor with PCE = 10% 🧨 ',
178
+ on_click=lambda: run_prompt(pre_prompts[0]),
179
+ )
180
+ st.button(
181
+ r'📚 The history and development of Y6 ',
182
+ on_click=lambda: run_prompt(pre_prompts[1]),
183
+ )
184
+ with cols[1]:
185
+ st.button(
186
+ r"🎄Predict the LogP of PM6 ",
187
+ on_click=lambda: run_prompt(pre_prompts[2]),
188
+ )
189
+ st.button(
190
+ r'💎 Predict the PCE of Y6',
191
+ on_click=lambda: run_prompt(pre_prompts[3]),
192
+ )
193
 
194
+ st.selectbox(
195
+ 'Select the file type ',
196
+ ['None', 'CSV (.csv)', 'Figure (.jpg, .png, .jpeg)', 'PDF (.pdf)'],
197
+ key='file_type',
198
+ )
199
+ uploaded_file = None
200
+ if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
201
+ uploaded_file = st.file_uploader("Choose a Figure", type = ["jpg", "jpeg", "png"])
202
+ if ss.get('file_type') == 'PDF (.pdf)':
203
+ uploaded_file = st.file_uploader("Choose a PDF file")
204
+ if ss.get('file_type') == 'CSV (.csv)':
205
+ uploaded_file = st.file_uploader("Choose a csv file", type = 'csv')
206
+ st.selectbox(
207
+ r'📚 Choose the domain ',
208
+ ['Organic solar cell', 'Drug discovery'], key='domain',
209
+ )
210
+ # Display available tools
211
+ if ss.get('domain') == 'Drug discovery':
212
+ instantiate_agent(model1 = 'gpt-4o-2024-11-20', model2 = 'gpt-4o-2024-11-20' ,tools = 'drug')
213
+ else:
214
+ instantiate_agent(model1 = 'gpt-4o-2024-11-20', model2 = 'gpt-4o-2024-11-20' )
215
+ tools = ss.agent.agent_executor2.tools
216
 
217
+ tool_list = pd.Series( {f"✅ {t.name}": t.description for t in tools}).reset_index()
218
+ tool_list.columns = ['Tool', 'Description']
219
+ st.markdown(f"# {len(tool_list)} available tools")
220
+ st.dataframe(
221
+ tool_list,
222
+ use_container_width=True,
223
+ hide_index=True,
224
+ height=200
225
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
+ # Execute agent on user input
228
+ if prompt := st.chat_input("Say something and/or attach files"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ if uploaded_file is not None:
231
+ if ss.get('file_type') == 'CSV (.csv)':
232
+ with tempfile.NamedTemporaryFile( suffix ='.csv' ,delete=False) as f:
233
+ f.write(uploaded_file.read())
234
+ run_prompt(prompt + str(' ') + str(f.name), file_path = f.name)
235
+ f.close()
236
+
237
+ if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
238
+
239
+ st.image(uploaded_file, width = 500)
240
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp:
241
+
242
+ mg_str = base64.b64encode(uploaded_file.getvalue()).decode("utf-8")
243
+ temp.write(base64.b64decode(mg_str))
244
+
245
+ run_prompt(prompt+ str(' ') + str(temp.name), image_path = temp.name )
246
+
247
+ if ss.get('file_type') == 'PDF (.pdf)':
248
+ with tempfile.NamedTemporaryFile( suffix ='.pdf' ,delete=False) as f:
249
+ f.write(uploaded_file.read())
250
+ run_prompt(prompt, file_path = f.name)
251
+ f.close()
252
+
253
+ # with open("input.png","wb") as af:
254
+ # mg_str = base64.b64encode(files.getvalue()).decode("utf-8")
255
+ # af.write(base64.b64decode(mg_str))
256
+
257
+ # run_prompt(prompt.text+str(f.name), image_path =f.name )
258
+ # except:
259
+ # st.markdown("Please input correct files or query ")
260
+ else:
261
+ run_prompt(prompt)
262