|
|
import streamlit as st |
|
|
from middlewares.utils import gen_augmented_prompt_via_websearch |
|
|
from middlewares.chat_client import chat |
|
|
import json |
|
|
from pprint import pformat |
|
|
from notion_client import Client |
|
|
|
|
|
def safe_get(data, dot_chained_keys): |
|
|
''' |
|
|
{'a': {'b': [{'c': 1}]}} |
|
|
safe_get(data, 'a.b.0.c') -> 1 |
|
|
''' |
|
|
keys = dot_chained_keys.split('.') |
|
|
for key in keys: |
|
|
try: |
|
|
if isinstance(data, list): |
|
|
data = data[int(key)] |
|
|
else: |
|
|
data = data[key] |
|
|
except (KeyError, TypeError, IndexError): |
|
|
return None |
|
|
return data |
|
|
|
|
|
def get_notion_data() : |
|
|
integration_token = "secret_lTOe0q9dqqKQLRRb2KJwi7QFSl0vqoztroRFHW6MeQE" |
|
|
notion_database_id = "6c0d877b823a4e3699016fa7083f3006" |
|
|
|
|
|
client = Client(auth=integration_token) |
|
|
|
|
|
first_db_rows = client.databases.query(notion_database_id) |
|
|
rows = [] |
|
|
|
|
|
|
|
|
for row in first_db_rows['results']: |
|
|
price = safe_get(row, 'properties.($) Per Unit.number') |
|
|
store_link = safe_get(row, 'properties.Store Link.url') |
|
|
supplier_email = safe_get(row, 'properties.Supplier Email.email') |
|
|
exp_del = safe_get(row, 'properties.Expected Delivery.date') |
|
|
|
|
|
collections = safe_get(row, 'properties.Collection.multi_select') |
|
|
collection_names = [] |
|
|
for collection in collections : |
|
|
collection_names.append(collection['name']) |
|
|
|
|
|
status = safe_get(row, 'properties.Status.select.name') |
|
|
sup_phone = safe_get(row, 'properties.Supplier Phone.phone_number') |
|
|
stock_alert = safe_get(row, 'properties.Stock Alert.status.name') |
|
|
prod_name = safe_get(row, 'properties.Product .title.0.text.content') |
|
|
sku = safe_get(row, 'properties.SKU.number') |
|
|
shipped_date = safe_get(row, 'properties.Shipped On.date') |
|
|
on_order = safe_get(row, 'properties.On Order.number') |
|
|
on_hand = safe_get(row, 'properties.On Hand.number') |
|
|
size_names = [] |
|
|
sizes = safe_get(row, 'properties.Size.multi_select') |
|
|
for size in sizes : |
|
|
size_names.append(size['name']) |
|
|
|
|
|
rows.append({ |
|
|
'Price Per unit': price, |
|
|
'Store Link' : store_link, |
|
|
'Supplier Email' : supplier_email, |
|
|
'Expected Delivery' : exp_del, |
|
|
'Collection' : collection_names, |
|
|
'Status' : status, |
|
|
'Supplier Phone' : sup_phone, |
|
|
'Stock Alert' : stock_alert, |
|
|
'Product Name' : prod_name, |
|
|
'SKU' : sku, |
|
|
'Sizes' : size_names, |
|
|
'Shipped Date' : shipped_date, |
|
|
'On Order' : on_order, |
|
|
"On Hand" : on_hand, |
|
|
}) |
|
|
|
|
|
notion_data_string = pformat(rows) |
|
|
return notion_data_string |
|
|
|
|
|
def generate_chat_stream(session_state, query, config): |
|
|
|
|
|
|
|
|
|
|
|
chat_bot_dict = config["CHAT_BOTS"] |
|
|
links = [] |
|
|
if session_state.rag_enabled: |
|
|
with st.spinner("Fetching relevent documents from Web...."): |
|
|
query, links = gen_augmented_prompt_via_websearch( |
|
|
prompt=query, |
|
|
pre_context=session_state.pre_context, |
|
|
post_context=session_state.post_context, |
|
|
pre_prompt=session_state.pre_prompt, |
|
|
post_prompt=session_state.post_prompt, |
|
|
search_vendor=session_state.search_vendor, |
|
|
top_k=session_state.top_k, |
|
|
n_crawl=session_state.n_crawl, |
|
|
pass_prev=session_state.pass_prev, |
|
|
prev_output=session_state.history[-1][1], |
|
|
) |
|
|
|
|
|
notion_data = get_notion_data() |
|
|
|
|
|
with st.spinner("Generating response..."): |
|
|
chat_stream = chat(session_state, notion_data + " " + query , config) |
|
|
|
|
|
return chat_stream, links |
|
|
|