jinysun commited on
Commit
f800093
·
verified ·
1 Parent(s): abd520a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -219
app.py CHANGED
@@ -1,219 +1,216 @@
1
- import os
2
- import asyncio
3
- # Init with fake key
4
- if 'OPENAI_API_KEY' not in os.environ:
5
- os.environ['OPENAI_API_KEY'] = 'none'
6
- if os.name == 'nt':
7
- asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
8
-
9
- import openai
10
- import pandas as pd
11
- import streamlit as st
12
- from IPython.core.display import HTML
13
- from PIL import Image
14
- from agent import TeLLAgent, make_tools
15
- from streamlit_callback_handler import \
16
- StreamlitCallbackHandlerChem
17
- import base64
18
- import pandas as pd
19
- from dotenv import load_dotenv
20
- from langchain_openai import ChatOpenAI , OpenAI
21
- import base64
22
- from io import BytesIO
23
- from PIL import Image
24
- import tempfile
25
-
26
-
27
- def convert_to_base64(pil_image):
28
- buffered = BytesIO()
29
- pil_image.save(buffered, format="PNG")
30
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
31
- return img_str
32
-
33
- def oai_key_isvalid(api_key):
34
- """Check if a given OpenAI key is valid"""
35
- try:
36
- llm = ChatOpenAI(openai_api_key = api_key, base_url="https://www.dmxapi.com/v1/")
37
- out = llm.invoke("This is a test")
38
- return True
39
- except:
40
- return False
41
-
42
- load_dotenv()
43
- ss = st.session_state
44
- ss.prompt = None
45
-
46
- # Set width of sidebar
47
- st.markdown(
48
- """
49
- <style>
50
- [data-testid="stSidebar"][aria-expanded="true"]{
51
- min-width: 450px;
52
- max-width: 450px;
53
- }
54
- """,
55
- unsafe_allow_html=True,
56
- )
57
-
58
-
59
- def instantiate_agent(model,file_path = '...',
60
- image_path ='...'):
61
- ss.agent = TeLLAgent(
62
- model=model,
63
- tools_model=model,
64
- temp=0.1,
65
- openai_api_key=ss.get('api_key') , file_path = file_path,
66
- image_path =image_path
67
-
68
- )
69
- return ss.agent
70
-
71
- instantiate_agent('gpt-4o-2024-11-20')
72
- tools = ss.agent.agent_executor.tools
73
-
74
- tool_list = pd.Series(
75
- {f"✅ {t.name}":t.description for t in tools}
76
- ).reset_index()
77
- tool_list.columns = ['Tool', 'Description']
78
-
79
- def on_api_key_change():
80
- api_key = ss.get('api_key') or os.getenv('OPENAI_API_KEY')
81
-
82
- # Check if key is valid
83
- if not oai_key_isvalid(api_key):
84
- st.write("Please input a valid OpenAI API key.")
85
-
86
- def run_prompt(prompt, file_path = '...', image_path = '...'):
87
- agent = instantiate_agent(ss.get('model_select'),file_path = file_path, image_path =image_path)
88
- st.chat_message("user").write(prompt)
89
- with st.chat_message("assistant") :
90
- try:
91
-
92
- response = agent.run(prompt)
93
- if ss.get('file_type') == 'CSV (.csv)':
94
- try:
95
- fx = pd.DataFrame(list(response))
96
- st.markdown(":red[Prediction finished! ]")
97
- st.download_button( "⬇️Download the predicted files as .csv", fx.to_csv(), "predict results.csv", use_container_width=True)
98
- except:
99
- st.write(response)
100
- else:
101
- st.write(response)
102
- except openai.AuthenticationError:
103
- st.write("Please input a valid OpenAI API key")
104
- except openai.APIError:
105
- # Handle specific API errors here
106
- print("OpenAI API error, please try again!")
107
-
108
-
109
- pre_prompts = [
110
- 'Who are you?',
111
- ('The history and development of Y6'
112
-
113
- ),
114
- (
115
- 'Predict the LogP of Y6'
116
- ),
117
- 'Generate a donor material with PCE = 10'
118
- ]
119
-
120
- # sidebar
121
- with st.sidebar:
122
-
123
- st.header("🤖 :blue[TeLLAgent] ")
124
- # Input OpenAI api key
125
- st.text_input(
126
- 'Input your OpenAI API key.',
127
- placeholder = 'Input your OpenAI API key.',
128
- type='password',
129
- key='api_key',
130
- on_change=on_api_key_change,
131
- label_visibility="collapsed"
132
- )
133
-
134
- # Input model to use
135
- st.selectbox(
136
- 'Select model to use',
137
- ['gpt-4o-2024-11-20', 'deepseek-v3', 'gpt-4o-mini'],
138
- key='model_select',
139
- )
140
-
141
- # Display prompt examples
142
- st.markdown('# What can I ask?')
143
- cols = st.columns(2)
144
- with cols[0]:
145
- st.button(
146
- r'👑 Who are you ? 🧨 ',
147
- on_click=lambda: run_prompt(pre_prompts[0]),
148
- )
149
- st.button(
150
- r'📚 The history and development of Y6 ',
151
- on_click=lambda: run_prompt(pre_prompts[1]),
152
- )
153
- with cols[1]:
154
- st.button(
155
- r"🎄Predict the LogP of Y6 ",
156
- on_click=lambda: run_prompt(pre_prompts[2]),
157
- )
158
- st.button(
159
- r'💎 Generate a donor material with PCE = 10',
160
- on_click=lambda: run_prompt(pre_prompts[3]),
161
- )
162
-
163
- st.selectbox(
164
- 'Select the file type ',
165
- ['None', 'CSV (.csv)', 'Figure (.jpg, .png, .jpeg)', 'PDF (.pdf)'],
166
- key='file_type',
167
- )
168
- uploaded_file = None
169
- if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
170
- uploaded_file = st.file_uploader("Choose a Figure", type = ["jpg", "jpeg", "png"])
171
- if ss.get('file_type') == 'PDF (.pdf)':
172
- uploaded_file = st.file_uploader("Choose a PDF file")
173
- if ss.get('file_type') == 'CSV (.csv)':
174
- uploaded_file = st.file_uploader("Choose a csv file", type = 'csv')
175
-
176
- # Display available tools
177
- st.markdown(f"# {len(tool_list)} available tools")
178
- st.dataframe(
179
- tool_list,
180
- use_container_width=True,
181
- hide_index=True,
182
- height=200
183
- )
184
-
185
- # Execute agent on user input
186
- if prompt := st.chat_input("Say something and/or attach files"):
187
-
188
- if uploaded_file is not None:
189
- if ss.get('file_type') == 'CSV (.csv)':
190
- with tempfile.NamedTemporaryFile( dir = 'j:/', suffix ='.csv' ,delete=False) as f:
191
- f.write(uploaded_file.read())
192
- run_prompt(prompt + str(' ') + str(f.name), file_path = f.name)
193
- f.close()
194
-
195
- if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
196
-
197
- st.image(uploaded_file, width = 500)
198
- with tempfile.NamedTemporaryFile(dir = 'j:/',delete=False, suffix=".png") as temp:
199
-
200
- mg_str = base64.b64encode(uploaded_file.getvalue()).decode("utf-8")
201
- temp.write(base64.b64decode(mg_str))
202
-
203
- run_prompt(prompt+ str(' ') + str(temp.name), image_path = temp.name )
204
-
205
- if ss.get('file_type') == 'PDF (.pdf)':
206
- with tempfile.NamedTemporaryFile( dir = 'j:/', suffix ='.pdf' ,delete=False) as f:
207
- f.write(uploaded_file.read())
208
- run_prompt(prompt, file_path = f.name)
209
- f.close()
210
-
211
- # with open("input.png","wb") as af:
212
- # mg_str = base64.b64encode(files.getvalue()).decode("utf-8")
213
- # af.write(base64.b64decode(mg_str))
214
-
215
- # run_prompt(prompt.text+str(f.name), image_path =f.name )
216
- # except:
217
- # st.markdown("Please input correct files or query ")
218
- else:
219
- run_prompt(prompt)
 
1
+ import os
2
+ import asyncio
3
+ # Init with fake key
4
+ if 'OPENAI_API_KEY' not in os.environ:
5
+ os.environ['OPENAI_API_KEY'] = 'none'
6
+ if os.name == 'nt':
7
+ asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
8
+
9
+ import pandas as pd
10
+ import streamlit as st
11
+ from IPython.core.display import HTML
12
+ from PIL import Image
13
+ from agent import TeLLAgent, make_tools
14
+ from streamlit_callback_handler import \
15
+ StreamlitCallbackHandlerChem
16
+ import base64
17
+ import pandas as pd
18
+ from dotenv import load_dotenv
19
+ from langchain_openai import ChatOpenAI , OpenAI
20
+ import base64
21
+ from io import BytesIO
22
+ from PIL import Image
23
+ import tempfile
24
+
25
+
26
+ def convert_to_base64(pil_image):
27
+ buffered = BytesIO()
28
+ pil_image.save(buffered, format="PNG")
29
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
30
+ return img_str
31
+
32
+ def oai_key_isvalid(api_key):
33
+ """Check if a given OpenAI key is valid"""
34
+ try:
35
+ llm = ChatOpenAI(openai_api_key = api_key, base_url="https://www.dmxapi.com/v1/")
36
+ out = llm.invoke("This is a test")
37
+ return True
38
+ except:
39
+ return False
40
+
41
+ load_dotenv()
42
+ ss = st.session_state
43
+ ss.prompt = None
44
+
45
+ # Set width of sidebar
46
+ st.markdown(
47
+ """
48
+ <style>
49
+ [data-testid="stSidebar"][aria-expanded="true"]{
50
+ min-width: 450px;
51
+ max-width: 450px;
52
+ }
53
+ """,
54
+ unsafe_allow_html=True,
55
+ )
56
+
57
+
58
+ def instantiate_agent(model,file_path = '...',
59
+ image_path ='...'):
60
+ ss.agent = TeLLAgent(
61
+ model=model,
62
+ tools_model=model,
63
+ temp=0.1,
64
+ openai_api_key=ss.get('api_key') , file_path = file_path,
65
+ image_path =image_path
66
+
67
+ )
68
+ return ss.agent
69
+
70
+ instantiate_agent('gpt-4o-2024-11-20')
71
+ tools = ss.agent.agent_executor.tools
72
+
73
+ tool_list = pd.Series(
74
+ {f"✅ {t.name}":t.description for t in tools}
75
+ ).reset_index()
76
+ tool_list.columns = ['Tool', 'Description']
77
+
78
+ def on_api_key_change():
79
+ api_key = ss.get('api_key') or os.getenv('OPENAI_API_KEY')
80
+
81
+ # Check if key is valid
82
+ if not oai_key_isvalid(api_key):
83
+ st.write("Please input a valid OpenAI API key.")
84
+
85
+ def run_prompt(prompt, file_path = '...', image_path = '...'):
86
+ agent = instantiate_agent(ss.get('model_select'),file_path = file_path, image_path =image_path)
87
+ st.chat_message("user").write(prompt)
88
+ with st.chat_message("assistant") :
89
+ try:
90
+
91
+ response = agent.run(prompt)
92
+ if ss.get('file_type') == 'CSV (.csv)':
93
+ try:
94
+ fx = pd.DataFrame(list(response))
95
+ st.markdown(":red[Prediction finished! ]")
96
+ st.download_button( "⬇️Download the predicted files as .csv", fx.to_csv(), "predict results.csv", use_container_width=True)
97
+ except:
98
+ st.write(response)
99
+ else:
100
+ st.write(response)
101
+ except:
102
+ st.write("Please input a valid OpenAI API key")
103
+
104
+
105
+
106
+ pre_prompts = [
107
+ 'Who are you?',
108
+ ('The history and development of Y6'
109
+
110
+ ),
111
+ (
112
+ 'Predict the LogP of Y6'
113
+ ),
114
+ 'Generate a donor material with PCE = 10'
115
+ ]
116
+
117
+ # sidebar
118
+ with st.sidebar:
119
+
120
+ st.header("🤖 :blue[TeLLAgent] ")
121
+ # Input OpenAI api key
122
+ st.text_input(
123
+ 'Input your OpenAI API key.',
124
+ placeholder = 'Input your OpenAI API key.',
125
+ type='password',
126
+ key='api_key',
127
+ on_change=on_api_key_change,
128
+ label_visibility="collapsed"
129
+ )
130
+
131
+ # Input model to use
132
+ st.selectbox(
133
+ 'Select model to use',
134
+ ['gpt-4o-2024-11-20', 'deepseek-v3', 'gpt-4o-mini'],
135
+ key='model_select',
136
+ )
137
+
138
+ # Display prompt examples
139
+ st.markdown('# What can I ask?')
140
+ cols = st.columns(2)
141
+ with cols[0]:
142
+ st.button(
143
+ r'👑 Who are you ? 🧨 ',
144
+ on_click=lambda: run_prompt(pre_prompts[0]),
145
+ )
146
+ st.button(
147
+ r'📚 The history and development of Y6 ',
148
+ on_click=lambda: run_prompt(pre_prompts[1]),
149
+ )
150
+ with cols[1]:
151
+ st.button(
152
+ r"🎄Predict the LogP of Y6 ",
153
+ on_click=lambda: run_prompt(pre_prompts[2]),
154
+ )
155
+ st.button(
156
+ r'💎 Generate a donor material with PCE = 10',
157
+ on_click=lambda: run_prompt(pre_prompts[3]),
158
+ )
159
+
160
+ st.selectbox(
161
+ 'Select the file type ',
162
+ ['None', 'CSV (.csv)', 'Figure (.jpg, .png, .jpeg)', 'PDF (.pdf)'],
163
+ key='file_type',
164
+ )
165
+ uploaded_file = None
166
+ if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
167
+ uploaded_file = st.file_uploader("Choose a Figure", type = ["jpg", "jpeg", "png"])
168
+ if ss.get('file_type') == 'PDF (.pdf)':
169
+ uploaded_file = st.file_uploader("Choose a PDF file")
170
+ if ss.get('file_type') == 'CSV (.csv)':
171
+ uploaded_file = st.file_uploader("Choose a csv file", type = 'csv')
172
+
173
+ # Display available tools
174
+ st.markdown(f"# {len(tool_list)} available tools")
175
+ st.dataframe(
176
+ tool_list,
177
+ use_container_width=True,
178
+ hide_index=True,
179
+ height=200
180
+ )
181
+
182
+ # Execute agent on user input
183
+ if prompt := st.chat_input("Say something and/or attach files"):
184
+
185
+ if uploaded_file is not None:
186
+ if ss.get('file_type') == 'CSV (.csv)':
187
+ with tempfile.NamedTemporaryFile( dir = 'j:/', suffix ='.csv' ,delete=False) as f:
188
+ f.write(uploaded_file.read())
189
+ run_prompt(prompt + str(' ') + str(f.name), file_path = f.name)
190
+ f.close()
191
+
192
+ if ss.get('file_type') == 'Figure (.jpg, .png, .jpeg)':
193
+
194
+ st.image(uploaded_file, width = 500)
195
+ with tempfile.NamedTemporaryFile(dir = 'j:/',delete=False, suffix=".png") as temp:
196
+
197
+ mg_str = base64.b64encode(uploaded_file.getvalue()).decode("utf-8")
198
+ temp.write(base64.b64decode(mg_str))
199
+
200
+ run_prompt(prompt+ str(' ') + str(temp.name), image_path = temp.name )
201
+
202
+ if ss.get('file_type') == 'PDF (.pdf)':
203
+ with tempfile.NamedTemporaryFile( dir = 'j:/', suffix ='.pdf' ,delete=False) as f:
204
+ f.write(uploaded_file.read())
205
+ run_prompt(prompt, file_path = f.name)
206
+ f.close()
207
+
208
+ # with open("input.png","wb") as af:
209
+ # mg_str = base64.b64encode(files.getvalue()).decode("utf-8")
210
+ # af.write(base64.b64decode(mg_str))
211
+
212
+ # run_prompt(prompt.text+str(f.name), image_path =f.name )
213
+ # except:
214
+ # st.markdown("Please input correct files or query ")
215
+ else:
216
+ run_prompt(prompt)