hanoch@raized.ai commited on
Commit
b505cc3
·
1 Parent(s): f4483df

working version

Browse files
Files changed (5) hide show
  1. .streamlit/config.toml +3 -0
  2. app.py +177 -163
  3. googleai.py +8 -8
  4. openai_utils.py +5 -1
  5. semsearch.pyproj +1 -0
.streamlit/config.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [global]
2
+ exclude = ["env/Lib/site-packages/torch"]
3
+ disableWatchdog = true
app.py CHANGED
@@ -10,7 +10,7 @@ logger.setLevel(logging.DEBUG)
10
 
11
  import streamlit as st
12
 
13
- from googleai import send_message as google_send_message, init_googleai
14
 
15
  from langchain.chains import RetrievalQA
16
  from langchain_community.embeddings import OpenAIEmbeddings
@@ -64,6 +64,7 @@ carddict = {
64
 
65
  @st.cache_resource
66
  def init_models():
 
67
  retriever = SentenceTransformer("msmarco-distilbert-base-v4")
68
  #model_name = "sentence-transformers/all-MiniLM-L6-v2"
69
  model_name = "sentence-transformers/msmarco-distilbert-base-v4"
@@ -76,6 +77,7 @@ def init_models():
76
 
77
  @st.cache_resource
78
  def init_openai():
 
79
  st.session_state.openai_client = oai.get_client()
80
  assistants = st.session_state.openai_client.beta.assistants.list(
81
  order="desc",
@@ -147,6 +149,25 @@ def card(company_id, name, description, score, data_type, region, country, metad
147
  #print(f" markdown for {company_id}\n{markdown}")
148
  return markdown
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  def run_query(query, report_type, top_k , regions, countries, is_debug, index_namespace, openai_model, default_prompt):
151
 
152
  #Summarize the results
@@ -156,176 +177,161 @@ def run_query(query, report_type, top_k , regions, countries, is_debug, index_na
156
  # For every company find its uniqueness over the other companies. Use only information from the descriptions.
157
  # """
158
  content_container = st.container() #, col_sidepanel = st.columns([4, 1], gap="small")
159
- if report_type == "gemini":
160
- try:
161
- logger.debug(f"User: {query}")
162
- response = google_send_message(query)
163
- response = response['output']
164
- logger.debug(f"Agent: {response }")
165
- with content_container:
166
- with st.chat_message(name = 'User'):
167
- st.write(query)
168
- with st.chat_message(name = 'Agent', avatar = assistant_avatar):
169
- st.write(response)
170
- except Exception as e:
171
- logger.exception(f"Error processing user message", exc_info=e)
172
-
173
- else:
174
- if report_type=="guided":
175
- prompt_txt = utils.query_finetune_prompt + """
176
- User query: {query}
177
- """
178
- prompt_template = PromptTemplate(template=prompt_txt, input_variables=["query"])
179
- prompt = prompt_template.format(query = query)
180
- m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0, max_tokens=20, log_message = False)
181
 
182
- print(f"Keywords: {m_text}")
183
- results = utils.search_index(m_text, top_k, regions, countries, retriever, index_namespace)
184
 
185
- descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
186
- ntokens = len(descriptions.split(" "))
187
-
188
- print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}")
189
-
190
- prompt_txt = utils.summarization_prompt + """
191
- User query: {query}
192
- Company descriptions: {descriptions}
193
- """
194
- prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
195
- prompt = prompt_template.format(descriptions = descriptions, query = query)
196
- print(f"==============================\nPrompt:\n{prompt}\n==============================\n")
197
-
198
- m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
199
- m_text
200
- elif report_type=="company_list": # or st.session_state.new_conversation:
201
- results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
202
- descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
203
- elif report_type=="assistant":
204
- #results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
205
- #descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
206
- messages = oai.call_assistant(query, engine=openai_model)
207
- st.session_state.messages = messages
208
- results = st.session_state.db_search_results
209
- if not messages is None:
210
- with content_container:
211
- for message in list(messages)[::-1]:
212
- if hasattr(message, 'role'):
213
- # print(f"\n-----\nMessage: {message}\n")
214
- # with st.chat_message(name = message.role):
215
- # st.write(message.content[0].text.value)
216
- if message.role == "assistant":
217
- with st.chat_message(name = message.role, avatar = assistant_avatar):
218
- st.write(message.content[0].text.value)
219
- else:
220
- with st.chat_message(name = message.role):
221
- st.write(message.content[0].text.value)
222
- # st.session_state.messages.append({"role": "user", "content": query})
223
- # st.session_state.messages.append({"role": "system", "content": m_text})
224
 
225
- else:
226
- st.session_state.new_conversation = False
227
-
228
- results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
229
- descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
230
- ntokens = len(descriptions.split(" "))
231
-
232
- print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}")
233
- prompt = utils.clustering_prompt if report_type=="clustered" else utils.default_prompt
234
- prompt_txt = prompt + """
235
- User query: {query}
236
- Company descriptions: {descriptions}
237
- """
238
- prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
239
- prompt = prompt_template.format(descriptions = descriptions, query = query)
240
- print(f"==============================\nPrompt:\n{prompt[:1000]}\n==============================\n")
241
-
242
- m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
243
- m_text
244
- st.session_state.messages.append({"role": "user", "content": query})
245
- i = m_text.find("-----")
246
- i = 0 if i<0 else i
247
- st.session_state.messages.append({"role": "system", "content": m_text[:i]})
248
-
249
-
250
-
251
- #render_history()
252
- # for message in st.session_state.messages:
253
- # with st.chat_message(message["role"]):
254
- # st.markdown(message["content"])
255
- # print(f"History: \n {st.session_state.messages}")
256
 
257
- sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
258
-
259
- names = []
260
- # list_html = """
261
- # <h2>Companies list</h2>
262
- # <div class="container-fluid">
263
- # <div class="row align-items-start" style="padding-bottom:10px;">
264
- # <div class="col-md-8 col-sm-8">
265
- # <span>Company</span>
266
- # </div>
267
- # <div class="col-md-1 col-sm-1">
268
- # <span>Country</span>
269
- # </div>
270
- # <div class="col-md-1 col-sm-1">
271
- # <span>Customer Problem</span>
272
- # </div>
273
- # <div class="col-md-1 col-sm-1">
274
- # <span>Business Model</span>
275
- # </div>
276
- # <div class="col-md-1 col-sm-1">
277
- # Actions
278
- # </div>
279
- # </div>
280
- # """
281
- list_html = "<div class='container-fluid'>"
282
-
283
- locations = set()
284
- for r in sorted_results:
285
- company_name = r["name"]
286
- if company_name in names:
287
- continue
288
- else:
289
- names.append(company_name)
290
- description = r["description"] #.replace(company_name, f"<mark>{company_name}</mark>")
291
- if description is None or len(description.strip())<10:
292
- continue
293
 
294
- score = round(r["score"], 4)
295
- data_type = r["metadata"]["type"] if "type" in r["metadata"] else ""
296
- region = r["metadata"]["region"]
297
- country = r["metadata"]["country"]
298
- company_id = r["metadata"]["company_id"]
299
 
300
- locations.add(country)
301
- list_html = list_html + card(company_id, company_name, description, score, data_type, region, country, r['data'], is_debug)
302
 
303
- list_html = list_html + '</div>'
304
 
305
- pins = country_geo[country_geo['name'].isin(locations)].loc[:, ['latitude', 'longitude']]
306
 
307
- if len(pins)>0:
308
- with st.expander("Map view"):
309
- st.map(pins)
310
- #st.markdown(list_html, unsafe_allow_html=True)
311
 
312
- df = pd.DataFrame.from_dict(carddict, orient="columns")
313
 
314
- if len(df)>0:
315
- df.index += 1
316
- with content_container:
317
- st.dataframe(df,
318
- hide_index=False,
319
- column_config ={
320
- "name": st.column_config.TextColumn("Name"),
321
- "company_id": st.column_config.LinkColumn("Link"),
322
- "description": st.column_config.TextColumn("Description"),
323
- "country": st.column_config.TextColumn("Country", width="small"),
324
- "customer_problem": st.column_config.TextColumn("Customer problem"),
325
- "target_customer": st.column_config.TextColumn(label="Target customer", width="small"),
326
- "business_model": st.column_config.TextColumn(label="Business model")
327
- },
328
- use_container_width=True)
329
  st.session_state.last_user_query = query
330
 
331
 
@@ -449,6 +455,8 @@ if utils.check_password():
449
 
450
  tab_advanced = st.sidebar.expander("Settings")
451
  with tab_advanced:
 
 
452
  #prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", )
453
  #prompt_title_editable = st.text_input("Title", key="prompt_title_editable")
454
  report_type = st.selectbox(label="Response Type", options=["gemini", "assistant", "standard", "guided", "company_list", "clustered"], index=0)
@@ -464,10 +472,9 @@ if utils.check_password():
464
  index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0)
465
  liked_companies = st.text_input(label="liked companies", key='liked_companies')
466
  disliked_companies = st.text_input(label="disliked companies", key='disliked_companies')
467
- default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content")
468
  clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content")
469
 
470
- if not "assistant_thread" in st.session_state:
471
  st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
472
 
473
 
@@ -488,7 +495,14 @@ if utils.check_password():
488
  st.session_state.index_namespace = index_namespace
489
  st.session_state.region = region_selectbox
490
  st.session_state.country = countries_selectbox
491
- run_query(query, report_type, top_k, region_selectbox, countries_selectbox, is_debug, index_namespace, openai_model, default_prompt)
 
 
 
 
 
 
 
492
  else:
493
  st.session_state.new_conversation = False
494
 
 
10
 
11
  import streamlit as st
12
 
13
+ from googleai import send_message as google_send_message, init_googleai, DEFAULT_INSTRUCTIONS as google_default_instructions
14
 
15
  from langchain.chains import RetrievalQA
16
  from langchain_community.embeddings import OpenAIEmbeddings
 
64
 
65
  @st.cache_resource
66
  def init_models():
67
+ logger.debug("init_models")
68
  retriever = SentenceTransformer("msmarco-distilbert-base-v4")
69
  #model_name = "sentence-transformers/all-MiniLM-L6-v2"
70
  model_name = "sentence-transformers/msmarco-distilbert-base-v4"
 
77
 
78
  @st.cache_resource
79
  def init_openai():
80
+ logger.debug("init_openai")
81
  st.session_state.openai_client = oai.get_client()
82
  assistants = st.session_state.openai_client.beta.assistants.list(
83
  order="desc",
 
149
  #print(f" markdown for {company_id}\n{markdown}")
150
  return markdown
151
 
152
+ def run_googleai(query, prompt):
153
+ try:
154
+ logger.debug(f"User: {query}")
155
+ response = google_send_message(query, prompt)
156
+ response = response['output']
157
+ logger.debug(f"Agent: {response }")
158
+ content_container = st.container() #, col_sidepanel = st.columns([4, 1], gap="small")
159
+ with content_container:
160
+ with st.chat_message(name = 'User'):
161
+ st.write(query)
162
+ with st.chat_message(name = 'Agent', avatar = assistant_avatar):
163
+ st.write(response)
164
+ st.session_state.messages.append({"role": "user", "content": query})
165
+ st.session_state.messages.append({"role": "system", "content": response})
166
+ render_history()
167
+ except Exception as e:
168
+ logger.exception(f"Error processing user message", exc_info=e)
169
+ st.session_state.last_user_query = query
170
+
171
  def run_query(query, report_type, top_k , regions, countries, is_debug, index_namespace, openai_model, default_prompt):
172
 
173
  #Summarize the results
 
177
  # For every company find its uniqueness over the other companies. Use only information from the descriptions.
178
  # """
179
  content_container = st.container() #, col_sidepanel = st.columns([4, 1], gap="small")
180
+ if report_type=="guided":
181
+ prompt_txt = utils.query_finetune_prompt + """
182
+ User query: {query}
183
+ """
184
+ prompt_template = PromptTemplate(template=prompt_txt, input_variables=["query"])
185
+ prompt = prompt_template.format(query = query)
186
+ m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0, max_tokens=20, log_message = False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
+ print(f"Keywords: {m_text}")
189
+ results = utils.search_index(m_text, top_k, regions, countries, retriever, index_namespace)
190
 
191
+ descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
192
+ ntokens = len(descriptions.split(" "))
193
+
194
+ print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}")
195
+
196
+ prompt_txt = utils.summarization_prompt + """
197
+ User query: {query}
198
+ Company descriptions: {descriptions}
199
+ """
200
+ prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
201
+ prompt = prompt_template.format(descriptions = descriptions, query = query)
202
+ print(f"==============================\nPrompt:\n{prompt}\n==============================\n")
203
+
204
+ m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
205
+ m_text
206
+ elif report_type=="company_list": # or st.session_state.new_conversation:
207
+ results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
208
+ descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
209
+ elif report_type=="assistant":
210
+ #results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
211
+ #descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
212
+ messages = oai.call_assistant(query, engine=openai_model)
213
+ st.session_state.messages = messages
214
+ results = st.session_state.db_search_results
215
+ if not messages is None:
216
+ with content_container:
217
+ for message in list(messages)[::-1]:
218
+ if hasattr(message, 'role'):
219
+ # print(f"\n-----\nMessage: {message}\n")
220
+ # with st.chat_message(name = message.role):
221
+ # st.write(message.content[0].text.value)
222
+ if message.role == "assistant":
223
+ with st.chat_message(name = message.role, avatar = assistant_avatar):
224
+ st.write(message.content[0].text.value)
225
+ else:
226
+ with st.chat_message(name = message.role):
227
+ st.write(message.content[0].text.value)
228
+ # st.session_state.messages.append({"role": "user", "content": query})
229
+ # st.session_state.messages.append({"role": "system", "content": m_text})
230
 
231
+ else:
232
+ st.session_state.new_conversation = False
233
+
234
+ results = utils.search_index(query, top_k, regions, countries, retriever, index_namespace)
235
+ descriptions = "\n".join([f"Description of company \"{res['name']}\": {res['data']['Summary']}.\n" for res in results[:20] if 'Summary' in res['data']])
236
+ ntokens = len(descriptions.split(" "))
237
+
238
+ print(f"Descriptions ({ntokens} tokens):\n {descriptions[:1000]}")
239
+ prompt = utils.clustering_prompt if report_type=="clustered" else utils.default_prompt
240
+ prompt_txt = prompt + """
241
+ User query: {query}
242
+ Company descriptions: {descriptions}
243
+ """
244
+ prompt_template = PromptTemplate(template=prompt_txt, input_variables=["descriptions", "query"])
245
+ prompt = prompt_template.format(descriptions = descriptions, query = query)
246
+ print(f"==============================\nPrompt:\n{prompt[:1000]}\n==============================\n")
247
+
248
+ m_text = oai.call_openai(prompt, engine=openai_model, temp=0, top_p=1.0)
249
+ m_text
250
+ st.session_state.messages.append({"role": "user", "content": query})
251
+ i = m_text.find("-----")
252
+ i = 0 if i<0 else i
253
+ st.session_state.messages.append({"role": "system", "content": m_text[:i]})
254
+
255
+
256
+
257
+ #render_history()
258
+ # for message in st.session_state.messages:
259
+ # with st.chat_message(message["role"]):
260
+ # st.markdown(message["content"])
261
+ # print(f"History: \n {st.session_state.messages}")
262
 
263
+ sorted_results = sorted(results, key=lambda x: x['score'], reverse=True)
264
+
265
+ names = []
266
+ # list_html = """
267
+ # <h2>Companies list</h2>
268
+ # <div class="container-fluid">
269
+ # <div class="row align-items-start" style="padding-bottom:10px;">
270
+ # <div class="col-md-8 col-sm-8">
271
+ # <span>Company</span>
272
+ # </div>
273
+ # <div class="col-md-1 col-sm-1">
274
+ # <span>Country</span>
275
+ # </div>
276
+ # <div class="col-md-1 col-sm-1">
277
+ # <span>Customer Problem</span>
278
+ # </div>
279
+ # <div class="col-md-1 col-sm-1">
280
+ # <span>Business Model</span>
281
+ # </div>
282
+ # <div class="col-md-1 col-sm-1">
283
+ # Actions
284
+ # </div>
285
+ # </div>
286
+ # """
287
+ list_html = "<div class='container-fluid'>"
288
+
289
+ locations = set()
290
+ for r in sorted_results:
291
+ company_name = r["name"]
292
+ if company_name in names:
293
+ continue
294
+ else:
295
+ names.append(company_name)
296
+ description = r["description"] #.replace(company_name, f"<mark>{company_name}</mark>")
297
+ if description is None or len(description.strip())<10:
298
+ continue
299
 
300
+ score = round(r["score"], 4)
301
+ data_type = r["metadata"]["type"] if "type" in r["metadata"] else ""
302
+ region = r["metadata"]["region"]
303
+ country = r["metadata"]["country"]
304
+ company_id = r["metadata"]["company_id"]
305
 
306
+ locations.add(country)
307
+ list_html = list_html + card(company_id, company_name, description, score, data_type, region, country, r['data'], is_debug)
308
 
309
+ list_html = list_html + '</div>'
310
 
311
+ #pins = country_geo[country_geo['name'].isin(locations)].loc[:, ['latitude', 'longitude']]
312
 
313
+ # if len(pins)>0:
314
+ # with st.expander("Map view"):
315
+ # st.map(pins)
316
+ #st.markdown(list_html, unsafe_allow_html=True)
317
 
318
+ df = pd.DataFrame.from_dict(carddict, orient="columns")
319
 
320
+ if len(df)>0:
321
+ df.index += 1
322
+ with content_container:
323
+ st.dataframe(df,
324
+ hide_index=False,
325
+ column_config ={
326
+ "name": st.column_config.TextColumn("Name"),
327
+ "company_id": st.column_config.LinkColumn("Link"),
328
+ "description": st.column_config.TextColumn("Description"),
329
+ "country": st.column_config.TextColumn("Country", width="small"),
330
+ "customer_problem": st.column_config.TextColumn("Customer problem"),
331
+ "target_customer": st.column_config.TextColumn(label="Target customer", width="small"),
332
+ "business_model": st.column_config.TextColumn(label="Business model")
333
+ },
334
+ use_container_width=True)
335
  st.session_state.last_user_query = query
336
 
337
 
 
455
 
456
  tab_advanced = st.sidebar.expander("Settings")
457
  with tab_advanced:
458
+ gemini_prompt = st.text_area("Gemini Prompt", value = google_default_instructions, height=400, key="advanced_gemini_prompt_content")
459
+ default_prompt = st.text_area("Default Prompt", value = utils.default_prompt, height=400, key="advanced_default_prompt_content")
460
  #prompt_title = st.selectbox("Report Type", index = 0, options = utils.get_prompts(), on_change=on_prompt_selected, key="advanced_prompts_select", )
461
  #prompt_title_editable = st.text_input("Title", key="prompt_title_editable")
462
  report_type = st.selectbox(label="Response Type", options=["gemini", "assistant", "standard", "guided", "company_list", "clustered"], index=0)
 
472
  index_namespace = st.selectbox(label="Data Type", options=["websummarized", "web", "cbli", "all"], index=0)
473
  liked_companies = st.text_input(label="liked companies", key='liked_companies')
474
  disliked_companies = st.text_input(label="disliked companies", key='disliked_companies')
 
475
  clustering_prompt = st.text_area("Clustering Prompt", value = utils.clustering_prompt, height=400, key="advanced_clustering_prompt_content")
476
 
477
+ if report_type == "assistant" and not "assistant_thread" in st.session_state:
478
  st.session_state.assistant_thread = st.session_state.openai_client.beta.threads.create()
479
 
480
 
 
495
  st.session_state.index_namespace = index_namespace
496
  st.session_state.region = region_selectbox
497
  st.session_state.country = countries_selectbox
498
+ if report_type=="gemini":
499
+ run_googleai(query, gemini_prompt)
500
+ else:
501
+ run_query(query, report_type, top_k,
502
+ region_selectbox, countries_selectbox, is_debug,
503
+ index_namespace, openai_model,
504
+ default_prompt,
505
+ gemini_prompt)
506
  else:
507
  st.session_state.new_conversation = False
508
 
googleai.py CHANGED
@@ -73,7 +73,7 @@ In order to query the database you have a semantic search tool called 'query_pin
73
  def search_index(query):
74
  return pc_search(query, top_k=1000, countries=[], regions = [], retriever = st.session_state.retriever)
75
 
76
- def init_googleai(instructions=DEFAULT_INSTRUCTIONS):
77
  logger.debug("Initiailizing google ai")
78
  pinecone_tool = Tool(
79
  name="query_pinecone",
@@ -84,15 +84,14 @@ def init_googleai(instructions=DEFAULT_INSTRUCTIONS):
84
  )
85
 
86
  llm = ChatGoogleGenerativeAI(
87
- # model="gemini-1.5-pro",
88
- model="gemini-1.5-flash",
89
  temperature=0.1,
90
  google_api_key=GOOGLE_API_KEY
91
  )
92
 
93
  tools = [pinecone_tool]
94
 
95
- st.session_state.agent_chain = initialize_agent(
96
  tools=tools,
97
  llm=llm,
98
  agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
@@ -101,7 +100,8 @@ def init_googleai(instructions=DEFAULT_INSTRUCTIONS):
101
  )
102
 
103
 
104
- def send_message(user_message:str):
105
- if not 'agent_chain' in st.session_state:
106
- init_googleai()
107
- return st.session_state.agent_chain.invoke(user_message)
 
 
73
  def search_index(query):
74
  return pc_search(query, top_k=1000, countries=[], regions = [], retriever = st.session_state.retriever)
75
 
76
+ def init_googleai(instructions=DEFAULT_INSTRUCTIONS, model = "gemini-1.5-flash"): # model="gemini-1.5-pro",
77
  logger.debug("Initiailizing google ai")
78
  pinecone_tool = Tool(
79
  name="query_pinecone",
 
84
  )
85
 
86
  llm = ChatGoogleGenerativeAI(
87
+ model=model,
 
88
  temperature=0.1,
89
  google_api_key=GOOGLE_API_KEY
90
  )
91
 
92
  tools = [pinecone_tool]
93
 
94
+ st.session_state.googleai_agent_chain = initialize_agent(
95
  tools=tools,
96
  llm=llm,
97
  agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
 
100
  )
101
 
102
 
103
+ def send_message(user_message:str, prompt):
104
+ if not 'googleai_agent_chain' in st.session_state or st.session_state.googleai_default_instructions != prompt:
105
+ st.session_state.googleai_default_instructions = prompt
106
+ init_googleai(prompt)
107
+ return st.session_state.googleai_agent_chain.invoke(user_message)
openai_utils.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import json
2
  import time
3
  import traceback
@@ -21,7 +25,7 @@ def getListOfCompanies(query, filters = {}):
21
  return descriptions
22
 
23
  def report_error(txt):
24
- print(f"\nEEEEEEEEEEEEE\n{txt}")
25
 
26
  def wait_for_response(thread, run):
27
  timeout = 60 #timeout in seconds
 
1
+ import logging
2
+ logger = logging.getLogger(__name__)
3
+ logger.setLevel(logging.DEBUG)
4
+
5
  import json
6
  import time
7
  import traceback
 
25
  return descriptions
26
 
27
  def report_error(txt):
28
+ logger.debug(f"\nError: \n{txt}")
29
 
30
  def wait_for_response(thread, run):
31
  timeout = 60 #timeout in seconds
semsearch.pyproj CHANGED
@@ -36,6 +36,7 @@
36
  <Content Include=".gitignore" />
37
  <Content Include=".streamlit\config.toml" />
38
  <Content Include=".streamlit\secrets.toml" />
 
39
  <Content Include="Dockerfile" />
40
  <Content Include="README.md" />
41
  <Content Include="requirements.txt" />
 
36
  <Content Include=".gitignore" />
37
  <Content Include=".streamlit\config.toml" />
38
  <Content Include=".streamlit\secrets.toml" />
39
+ <Content Include="cloudrun.yaml" />
40
  <Content Include="Dockerfile" />
41
  <Content Include="README.md" />
42
  <Content Include="requirements.txt" />