Spaces:
Runtime error
Runtime error
Fangrui Liu
commited on
Commit
Β·
c6f6149
1
Parent(s):
04f0bde
add knowledge base management
Browse files- app.py +1 -1
- callbacks/arxiv_callbacks.py +64 -40
- chat.py +173 -42
- helper.py β lib/helper.py +7 -4
- lib/json_conv.py +21 -0
- lib/private_kb.py +95 -21
- lib/sessions.py +0 -1
app.py
CHANGED
|
@@ -10,7 +10,7 @@ from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
|
|
| 10 |
|
| 11 |
from chat import chat_page
|
| 12 |
from login import login, back_to_main
|
| 13 |
-
from helper import build_tools,
|
| 14 |
|
| 15 |
|
| 16 |
|
|
|
|
| 10 |
|
| 11 |
from chat import chat_page
|
| 12 |
from login import login, back_to_main
|
| 13 |
+
from lib.helper import build_tools, build_all, sel_map, display
|
| 14 |
|
| 15 |
|
| 16 |
|
callbacks/arxiv_callbacks.py
CHANGED
|
@@ -3,70 +3,79 @@ import json
|
|
| 3 |
import textwrap
|
| 4 |
from typing import Dict, Any, List
|
| 5 |
from sql_formatter.core import format_sql
|
| 6 |
-
from langchain.callbacks.streamlit.streamlit_callback_handler import
|
|
|
|
|
|
|
|
|
|
| 7 |
from langchain.schema.output import LLMResult
|
| 8 |
from streamlit.delta_generator import DeltaGenerator
|
| 9 |
|
|
|
|
| 10 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
| 11 |
def __init__(self) -> None:
|
| 12 |
self.progress_bar = st.progress(value=0.0, text="Working...")
|
| 13 |
self.tokens_stream = ""
|
| 14 |
-
|
| 15 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
| 16 |
pass
|
| 17 |
-
|
| 18 |
def on_text(self, text: str, **kwargs) -> None:
|
| 19 |
self.progress_bar.progress(value=0.2, text="Asking LLM...")
|
| 20 |
-
|
| 21 |
def on_chain_end(self, outputs, **kwargs) -> None:
|
| 22 |
-
self.progress_bar.progress(value=0.6, text=
|
| 23 |
-
if
|
| 24 |
-
st.markdown(
|
| 25 |
st.markdown(f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)
|
| 26 |
-
|
| 27 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
| 28 |
pass
|
| 29 |
|
|
|
|
| 30 |
class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
|
| 31 |
def __init__(self) -> None:
|
| 32 |
-
self.progress_bar = st.progress(value=0.0, text=
|
| 33 |
self.status_bar = st.empty()
|
| 34 |
self.prog_value = 0.0
|
| 35 |
self.prog_map = {
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
}
|
| 40 |
|
| 41 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
| 42 |
pass
|
| 43 |
-
|
| 44 |
def on_text(self, text: str, **kwargs) -> None:
|
| 45 |
pass
|
| 46 |
-
|
| 47 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
| 48 |
-
cid =
|
| 49 |
-
if cid !=
|
| 50 |
-
self.progress_bar.progress(
|
|
|
|
|
|
|
| 51 |
self.prog_value = self.prog_map[cid]
|
| 52 |
else:
|
| 53 |
self.prog_value += 0.1
|
| 54 |
-
self.progress_bar.progress(
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def on_chain_end(self, outputs, **kwargs) -> None:
|
| 57 |
pass
|
| 58 |
-
|
| 59 |
|
| 60 |
class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
|
| 61 |
def __init__(self) -> None:
|
| 62 |
-
self.progress_bar = st.progress(value=0.0, text=
|
| 63 |
self.status_bar = st.empty()
|
| 64 |
self.prog_value = 0
|
| 65 |
self.prog_interval = 0.2
|
| 66 |
|
| 67 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
| 68 |
pass
|
| 69 |
-
|
| 70 |
def on_llm_end(
|
| 71 |
self,
|
| 72 |
response: LLMResult,
|
|
@@ -74,41 +83,56 @@ class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
|
|
| 74 |
**kwargs,
|
| 75 |
):
|
| 76 |
text = response.generations[0][0].text
|
| 77 |
-
if text.replace(
|
| 78 |
-
st.write(
|
| 79 |
-
st.markdown(f
|
| 80 |
print(f"Vector SQL: {text}")
|
| 81 |
self.prog_value += self.prog_interval
|
| 82 |
self.progress_bar.progress(value=self.prog_value, text="Searching in DB...")
|
| 83 |
-
|
| 84 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
| 85 |
-
cid =
|
| 86 |
self.prog_value += self.prog_interval
|
| 87 |
-
self.progress_bar.progress(
|
| 88 |
-
|
|
|
|
|
|
|
| 89 |
def on_chain_end(self, outputs, **kwargs) -> None:
|
| 90 |
pass
|
| 91 |
-
|
|
|
|
| 92 |
class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
|
| 93 |
def __init__(self) -> None:
|
| 94 |
-
self.progress_bar = st.progress(value=0.0, text=
|
| 95 |
self.status_bar = st.empty()
|
| 96 |
self.prog_value = 0
|
| 97 |
self.prog_interval = 0.1
|
| 98 |
-
|
| 99 |
-
|
| 100 |
class LLMThoughtWithKB(LLMThought):
|
| 101 |
-
def on_tool_end(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
try:
|
| 103 |
-
self._container.markdown(
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
except Exception as e:
|
| 107 |
super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
|
| 108 |
-
|
| 109 |
-
|
| 110 |
class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
|
| 111 |
-
|
| 112 |
def on_llm_start(
|
| 113 |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
| 114 |
) -> None:
|
|
@@ -120,4 +144,4 @@ class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
|
|
| 120 |
labeler=self._thought_labeler,
|
| 121 |
)
|
| 122 |
|
| 123 |
-
self._current_thought.on_llm_start(serialized, prompts)
|
|
|
|
| 3 |
import textwrap
|
| 4 |
from typing import Dict, Any, List
|
| 5 |
from sql_formatter.core import format_sql
|
| 6 |
+
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
| 7 |
+
LLMThought,
|
| 8 |
+
StreamlitCallbackHandler,
|
| 9 |
+
)
|
| 10 |
from langchain.schema.output import LLMResult
|
| 11 |
from streamlit.delta_generator import DeltaGenerator
|
| 12 |
|
| 13 |
+
|
| 14 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
| 15 |
def __init__(self) -> None:
|
| 16 |
self.progress_bar = st.progress(value=0.0, text="Working...")
|
| 17 |
self.tokens_stream = ""
|
| 18 |
+
|
| 19 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
| 20 |
pass
|
| 21 |
+
|
| 22 |
def on_text(self, text: str, **kwargs) -> None:
|
| 23 |
self.progress_bar.progress(value=0.2, text="Asking LLM...")
|
| 24 |
+
|
| 25 |
def on_chain_end(self, outputs, **kwargs) -> None:
|
| 26 |
+
self.progress_bar.progress(value=0.6, text="Searching in DB...")
|
| 27 |
+
if "repr" in outputs:
|
| 28 |
+
st.markdown("### Generated Filter")
|
| 29 |
st.markdown(f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)
|
| 30 |
+
|
| 31 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
| 32 |
pass
|
| 33 |
|
| 34 |
+
|
| 35 |
class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
|
| 36 |
def __init__(self) -> None:
|
| 37 |
+
self.progress_bar = st.progress(value=0.0, text="Searching DB...")
|
| 38 |
self.status_bar = st.empty()
|
| 39 |
self.prog_value = 0.0
|
| 40 |
self.prog_map = {
|
| 41 |
+
"langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain": 0.2,
|
| 42 |
+
"langchain.chains.combine_documents.map_reduce.MapReduceDocumentsChain": 0.4,
|
| 43 |
+
"langchain.chains.combine_documents.stuff.StuffDocumentsChain": 0.8,
|
| 44 |
}
|
| 45 |
|
| 46 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
| 47 |
pass
|
| 48 |
+
|
| 49 |
def on_text(self, text: str, **kwargs) -> None:
|
| 50 |
pass
|
| 51 |
+
|
| 52 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
| 53 |
+
cid = ".".join(serialized["id"])
|
| 54 |
+
if cid != "langchain.chains.llm.LLMChain":
|
| 55 |
+
self.progress_bar.progress(
|
| 56 |
+
value=self.prog_map[cid], text=f"Running Chain `{cid}`..."
|
| 57 |
+
)
|
| 58 |
self.prog_value = self.prog_map[cid]
|
| 59 |
else:
|
| 60 |
self.prog_value += 0.1
|
| 61 |
+
self.progress_bar.progress(
|
| 62 |
+
value=self.prog_value, text=f"Running Chain `{cid}`..."
|
| 63 |
+
)
|
| 64 |
|
| 65 |
def on_chain_end(self, outputs, **kwargs) -> None:
|
| 66 |
pass
|
| 67 |
+
|
| 68 |
|
| 69 |
class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
|
| 70 |
def __init__(self) -> None:
|
| 71 |
+
self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
|
| 72 |
self.status_bar = st.empty()
|
| 73 |
self.prog_value = 0
|
| 74 |
self.prog_interval = 0.2
|
| 75 |
|
| 76 |
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
| 77 |
pass
|
| 78 |
+
|
| 79 |
def on_llm_end(
|
| 80 |
self,
|
| 81 |
response: LLMResult,
|
|
|
|
| 83 |
**kwargs,
|
| 84 |
):
|
| 85 |
text = response.generations[0][0].text
|
| 86 |
+
if text.replace(" ", "").upper().startswith("SELECT"):
|
| 87 |
+
st.write("We generated Vector SQL for you:")
|
| 88 |
+
st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
|
| 89 |
print(f"Vector SQL: {text}")
|
| 90 |
self.prog_value += self.prog_interval
|
| 91 |
self.progress_bar.progress(value=self.prog_value, text="Searching in DB...")
|
| 92 |
+
|
| 93 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
| 94 |
+
cid = ".".join(serialized["id"])
|
| 95 |
self.prog_value += self.prog_interval
|
| 96 |
+
self.progress_bar.progress(
|
| 97 |
+
value=self.prog_value, text=f"Running Chain `{cid}`..."
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
def on_chain_end(self, outputs, **kwargs) -> None:
|
| 101 |
pass
|
| 102 |
+
|
| 103 |
+
|
| 104 |
class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
|
| 105 |
def __init__(self) -> None:
|
| 106 |
+
self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
|
| 107 |
self.status_bar = st.empty()
|
| 108 |
self.prog_value = 0
|
| 109 |
self.prog_interval = 0.1
|
| 110 |
+
|
| 111 |
+
|
| 112 |
class LLMThoughtWithKB(LLMThought):
|
| 113 |
+
def on_tool_end(
|
| 114 |
+
self,
|
| 115 |
+
output: str,
|
| 116 |
+
color=None,
|
| 117 |
+
observation_prefix=None,
|
| 118 |
+
llm_prefix=None,
|
| 119 |
+
**kwargs: Any,
|
| 120 |
+
) -> None:
|
| 121 |
try:
|
| 122 |
+
self._container.markdown(
|
| 123 |
+
"\n\n".join(
|
| 124 |
+
["### Retrieved Documents:"]
|
| 125 |
+
+ [
|
| 126 |
+
f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}"
|
| 127 |
+
for i, r in enumerate(json.loads(output))
|
| 128 |
+
]
|
| 129 |
+
)
|
| 130 |
+
)
|
| 131 |
except Exception as e:
|
| 132 |
super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
|
|
|
|
| 136 |
def on_llm_start(
|
| 137 |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
| 138 |
) -> None:
|
|
|
|
| 144 |
labeler=self._thought_labeler,
|
| 145 |
)
|
| 146 |
|
| 147 |
+
self._current_thought.on_llm_start(serialized, prompts)
|
chat.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
from os import environ
|
| 3 |
from time import sleep
|
|
@@ -7,9 +8,12 @@ from lib.sessions import SessionManager
|
|
| 7 |
from lib.private_kb import PrivateKnowledgeBase
|
| 8 |
from langchain.schema import HumanMessage, FunctionMessage
|
| 9 |
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
|
| 10 |
-
from langchain.callbacks.streamlit.streamlit_callback_handler import
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
from helper import (
|
| 13 |
build_agents,
|
| 14 |
MYSCALE_HOST,
|
| 15 |
MYSCALE_PASSWORD,
|
|
@@ -30,12 +34,16 @@ TOOL_NAMES = {
|
|
| 30 |
|
| 31 |
def on_chat_submit():
|
| 32 |
with st.session_state.next_round.container():
|
| 33 |
-
with st.chat_message(
|
| 34 |
st.write(st.session_state.chat_input)
|
| 35 |
-
with st.chat_message(
|
| 36 |
container = st.container()
|
| 37 |
-
st_callback = ChatDataAgentCallBackHandler(
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
print(ret)
|
| 40 |
|
| 41 |
|
|
@@ -105,7 +113,10 @@ def refresh_sessions():
|
|
| 105 |
st.session_state[
|
| 106 |
"current_sessions"
|
| 107 |
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
|
| 108 |
-
if
|
|
|
|
|
|
|
|
|
|
| 109 |
st.session_state.session_manager.add_session(
|
| 110 |
st.session_state.user_name,
|
| 111 |
f"{st.session_state.user_name}?default",
|
|
@@ -114,14 +125,64 @@ def refresh_sessions():
|
|
| 114 |
st.session_state[
|
| 115 |
"current_sessions"
|
| 116 |
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
try:
|
| 119 |
-
dfl_indx = [x["session_id"] for x in st.session_state.current_sessions].index(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
except ValueError:
|
| 121 |
dfl_indx = 0
|
| 122 |
st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx]
|
| 123 |
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
def refresh_agent():
|
| 126 |
with st.spinner("Initializing session..."):
|
| 127 |
print(
|
|
@@ -138,22 +199,29 @@ def refresh_agent():
|
|
| 138 |
else st.session_state.sel_sess["system_prompt"],
|
| 139 |
)
|
| 140 |
|
|
|
|
| 141 |
def add_file():
|
| 142 |
-
if
|
|
|
|
|
|
|
|
|
|
| 143 |
st.session_state.tool_status.error("Please upload files!", icon="β οΈ")
|
| 144 |
sleep(2)
|
| 145 |
return
|
| 146 |
try:
|
| 147 |
st.session_state.tool_status.info("Uploading...")
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
| 151 |
except ValueError as e:
|
| 152 |
st.session_state.tool_status.error("Failed to upload! " + str(e))
|
| 153 |
sleep(2)
|
| 154 |
-
|
|
|
|
| 155 |
def clear_files():
|
| 156 |
st.session_state.private_kb.clear(st.session_state.user_name)
|
|
|
|
| 157 |
|
| 158 |
|
| 159 |
def chat_page():
|
|
@@ -168,7 +236,7 @@ def chat_page():
|
|
| 168 |
port=MYSCALE_PORT,
|
| 169 |
username=MYSCALE_USER,
|
| 170 |
password=MYSCALE_PASSWORD,
|
| 171 |
-
embedding=st.session_state.embeddings[
|
| 172 |
parser_api_key=UNSTRUCTURED_API,
|
| 173 |
)
|
| 174 |
if "session_manager" not in st.session_state:
|
|
@@ -177,12 +245,21 @@ def chat_page():
|
|
| 177 |
with st.expander("Session Management"):
|
| 178 |
if "current_sessions" not in st.session_state:
|
| 179 |
refresh_sessions()
|
| 180 |
-
st.info(
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
st.data_editor(
|
| 187 |
st.session_state.current_sessions,
|
| 188 |
num_rows="dynamic",
|
|
@@ -191,12 +268,18 @@ def chat_page():
|
|
| 191 |
)
|
| 192 |
st.button("Submit Change!", on_click=on_session_change_submit)
|
| 193 |
with st.expander("Session Selection", expanded=True):
|
| 194 |
-
st.info(
|
| 195 |
-
|
|
|
|
|
|
|
| 196 |
try:
|
| 197 |
dfl_indx = [
|
| 198 |
x["session_id"] for x in st.session_state.current_sessions
|
| 199 |
-
].index(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
except Exception as e:
|
| 201 |
print("*** ", str(e))
|
| 202 |
dfl_indx = 0
|
|
@@ -210,39 +293,84 @@ def chat_page():
|
|
| 210 |
)
|
| 211 |
print(st.session_state.sel_sess)
|
| 212 |
with st.expander("Tool Settings", expanded=True):
|
| 213 |
-
st.info(
|
| 214 |
-
|
|
|
|
|
|
|
| 215 |
st.session_state["tool_status"] = st.empty()
|
| 216 |
-
tab_kb, tab_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
with tab_kb:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
st.multiselect(
|
| 219 |
"Select a Knowledge Base Tool",
|
| 220 |
-
st.session_state.tools.keys()
|
|
|
|
|
|
|
| 221 |
default=["Wikipedia + Self Querying"],
|
| 222 |
key="selected_tools",
|
| 223 |
on_change=refresh_agent,
|
| 224 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
with tab_file:
|
| 226 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
st.markdown("### Uploaded Files")
|
| 228 |
-
st.dataframe(
|
|
|
|
|
|
|
|
|
|
| 229 |
col_1, col_2 = st.columns(2)
|
| 230 |
with col_1:
|
| 231 |
st.button("Add Files", on_click=add_file)
|
| 232 |
with col_2:
|
| 233 |
-
st.button("Clear Files", on_click=clear_files)
|
| 234 |
-
|
| 235 |
-
# st.text_input("Give this knowledge base a description:")
|
| 236 |
-
# col_3, col_4 = st.columns(2)
|
| 237 |
-
# with col_3:
|
| 238 |
-
# st.button("Build Your KB!")
|
| 239 |
-
# with col_4:
|
| 240 |
-
# st.button("Delete Your KB")
|
| 241 |
-
|
| 242 |
-
|
| 243 |
st.button("Clear Chat History", on_click=clear_history)
|
| 244 |
st.button("Logout", on_click=back_to_main)
|
| 245 |
-
if
|
| 246 |
refresh_agent()
|
| 247 |
print("!!! ", st.session_state.agent.memory.chat_memory.session_id)
|
| 248 |
for msg in st.session_state.agent.memory.chat_memory.messages:
|
|
@@ -255,7 +383,10 @@ def chat_page():
|
|
| 255 |
st.write("Retrieved from knowledge base:")
|
| 256 |
try:
|
| 257 |
st.dataframe(
|
| 258 |
-
pd.DataFrame.from_records(
|
|
|
|
|
|
|
|
|
|
| 259 |
)
|
| 260 |
except:
|
| 261 |
st.write(msg.content)
|
|
|
|
| 1 |
+
import json
|
| 2 |
import pandas as pd
|
| 3 |
from os import environ
|
| 4 |
from time import sleep
|
|
|
|
| 8 |
from lib.private_kb import PrivateKnowledgeBase
|
| 9 |
from langchain.schema import HumanMessage, FunctionMessage
|
| 10 |
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
|
| 11 |
+
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
| 12 |
+
StreamlitCallbackHandler,
|
| 13 |
+
)
|
| 14 |
+
from lib.json_conv import CustomJSONDecoder
|
| 15 |
|
| 16 |
+
from lib.helper import (
|
| 17 |
build_agents,
|
| 18 |
MYSCALE_HOST,
|
| 19 |
MYSCALE_PASSWORD,
|
|
|
|
| 34 |
|
| 35 |
def on_chat_submit():
|
| 36 |
with st.session_state.next_round.container():
|
| 37 |
+
with st.chat_message("user"):
|
| 38 |
st.write(st.session_state.chat_input)
|
| 39 |
+
with st.chat_message("assistant"):
|
| 40 |
container = st.container()
|
| 41 |
+
st_callback = ChatDataAgentCallBackHandler(
|
| 42 |
+
container, collapse_completed_thoughts=False
|
| 43 |
+
)
|
| 44 |
+
ret = st.session_state.agent(
|
| 45 |
+
{"input": st.session_state.chat_input}, callbacks=[st_callback]
|
| 46 |
+
)
|
| 47 |
print(ret)
|
| 48 |
|
| 49 |
|
|
|
|
| 113 |
st.session_state[
|
| 114 |
"current_sessions"
|
| 115 |
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
|
| 116 |
+
if (
|
| 117 |
+
type(st.session_state.current_sessions) is not dict
|
| 118 |
+
and len(st.session_state.current_sessions) <= 0
|
| 119 |
+
):
|
| 120 |
st.session_state.session_manager.add_session(
|
| 121 |
st.session_state.user_name,
|
| 122 |
f"{st.session_state.user_name}?default",
|
|
|
|
| 125 |
st.session_state[
|
| 126 |
"current_sessions"
|
| 127 |
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
|
| 128 |
+
st.session_state["user_files"] = st.session_state.private_kb.list_files(
|
| 129 |
+
st.session_state.user_name
|
| 130 |
+
)
|
| 131 |
+
st.session_state["user_tools"] = st.session_state.private_kb.list_tools(
|
| 132 |
+
st.session_state.user_name
|
| 133 |
+
)
|
| 134 |
+
st.session_state["tools_with_users"] = {
|
| 135 |
+
**st.session_state.tools,
|
| 136 |
+
**st.session_state.private_kb.as_tools(st.session_state.user_name),
|
| 137 |
+
}
|
| 138 |
try:
|
| 139 |
+
dfl_indx = [x["session_id"] for x in st.session_state.current_sessions].index(
|
| 140 |
+
"default"
|
| 141 |
+
if "" not in st.session_state
|
| 142 |
+
else st.session_state.sel_session["session_id"]
|
| 143 |
+
)
|
| 144 |
except ValueError:
|
| 145 |
dfl_indx = 0
|
| 146 |
st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx]
|
| 147 |
|
| 148 |
|
| 149 |
+
def build_kb_as_tool():
|
| 150 |
+
if (
|
| 151 |
+
"b_tool_name" in st.session_state
|
| 152 |
+
and "b_tool_desc" in st.session_state
|
| 153 |
+
and "b_tool_files" in st.session_state
|
| 154 |
+
and len(st.session_state.b_tool_name) > 0
|
| 155 |
+
and len(st.session_state.b_tool_desc) > 0
|
| 156 |
+
and len(st.session_state.b_tool_files) > 0
|
| 157 |
+
):
|
| 158 |
+
st.session_state.private_kb.create_tool(
|
| 159 |
+
st.session_state.user_name,
|
| 160 |
+
st.session_state.b_tool_name,
|
| 161 |
+
st.session_state.b_tool_desc,
|
| 162 |
+
[f["file_name"] for f in st.session_state.b_tool_files],
|
| 163 |
+
)
|
| 164 |
+
refresh_sessions()
|
| 165 |
+
else:
|
| 166 |
+
st.session_state.tool_status.error(
|
| 167 |
+
"You should fill all fields to build up a tool!"
|
| 168 |
+
)
|
| 169 |
+
sleep(2)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def remove_kb():
|
| 173 |
+
if "r_tool_names" in st.session_state and len(st.session_state.r_tool_names) > 0:
|
| 174 |
+
st.session_state.private_kb.remove_tools(
|
| 175 |
+
st.session_state.user_name,
|
| 176 |
+
[f["tool_name"] for f in st.session_state.r_tool_names],
|
| 177 |
+
)
|
| 178 |
+
refresh_sessions()
|
| 179 |
+
else:
|
| 180 |
+
st.session_state.tool_status.error(
|
| 181 |
+
"You should specify at least one tool to delete!"
|
| 182 |
+
)
|
| 183 |
+
sleep(2)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
def refresh_agent():
|
| 187 |
with st.spinner("Initializing session..."):
|
| 188 |
print(
|
|
|
|
| 199 |
else st.session_state.sel_sess["system_prompt"],
|
| 200 |
)
|
| 201 |
|
| 202 |
+
|
| 203 |
def add_file():
|
| 204 |
+
if (
|
| 205 |
+
"uploaded_files" not in st.session_state
|
| 206 |
+
or len(st.session_state.uploaded_files) == 0
|
| 207 |
+
):
|
| 208 |
st.session_state.tool_status.error("Please upload files!", icon="β οΈ")
|
| 209 |
sleep(2)
|
| 210 |
return
|
| 211 |
try:
|
| 212 |
st.session_state.tool_status.info("Uploading...")
|
| 213 |
+
st.session_state.private_kb.add_by_file(
|
| 214 |
+
st.session_state.user_name, st.session_state.uploaded_files
|
| 215 |
+
)
|
| 216 |
+
refresh_sessions()
|
| 217 |
except ValueError as e:
|
| 218 |
st.session_state.tool_status.error("Failed to upload! " + str(e))
|
| 219 |
sleep(2)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
def clear_files():
|
| 223 |
st.session_state.private_kb.clear(st.session_state.user_name)
|
| 224 |
+
refresh_sessions()
|
| 225 |
|
| 226 |
|
| 227 |
def chat_page():
|
|
|
|
| 236 |
port=MYSCALE_PORT,
|
| 237 |
username=MYSCALE_USER,
|
| 238 |
password=MYSCALE_PASSWORD,
|
| 239 |
+
embedding=st.session_state.embeddings["Wikipedia"],
|
| 240 |
parser_api_key=UNSTRUCTURED_API,
|
| 241 |
)
|
| 242 |
if "session_manager" not in st.session_state:
|
|
|
|
| 245 |
with st.expander("Session Management"):
|
| 246 |
if "current_sessions" not in st.session_state:
|
| 247 |
refresh_sessions()
|
| 248 |
+
st.info(
|
| 249 |
+
"Here you can set up your session! \n\nYou can **change your prompt** here!",
|
| 250 |
+
icon="π€",
|
| 251 |
+
)
|
| 252 |
+
st.info(
|
| 253 |
+
(
|
| 254 |
+
"**Add columns by clicking the empty row**.\n"
|
| 255 |
+
"And **delete columns by selecting rows with a press on `DEL` Key**"
|
| 256 |
+
),
|
| 257 |
+
icon="π‘",
|
| 258 |
+
)
|
| 259 |
+
st.info(
|
| 260 |
+
"Don't forget to **click `Submit Change` to save your change**!",
|
| 261 |
+
icon="π",
|
| 262 |
+
)
|
| 263 |
st.data_editor(
|
| 264 |
st.session_state.current_sessions,
|
| 265 |
num_rows="dynamic",
|
|
|
|
| 268 |
)
|
| 269 |
st.button("Submit Change!", on_click=on_session_change_submit)
|
| 270 |
with st.expander("Session Selection", expanded=True):
|
| 271 |
+
st.info(
|
| 272 |
+
"If no session is attach to your account, then we will add a default session to you!",
|
| 273 |
+
icon="β€οΈ",
|
| 274 |
+
)
|
| 275 |
try:
|
| 276 |
dfl_indx = [
|
| 277 |
x["session_id"] for x in st.session_state.current_sessions
|
| 278 |
+
].index(
|
| 279 |
+
"default"
|
| 280 |
+
if "" not in st.session_state
|
| 281 |
+
else st.session_state.sel_session["session_id"]
|
| 282 |
+
)
|
| 283 |
except Exception as e:
|
| 284 |
print("*** ", str(e))
|
| 285 |
dfl_indx = 0
|
|
|
|
| 293 |
)
|
| 294 |
print(st.session_state.sel_sess)
|
| 295 |
with st.expander("Tool Settings", expanded=True):
|
| 296 |
+
st.info(
|
| 297 |
+
"We provides you several knowledge base tools for you. We are building more tools!",
|
| 298 |
+
icon="π§",
|
| 299 |
+
)
|
| 300 |
st.session_state["tool_status"] = st.empty()
|
| 301 |
+
tab_kb, tab_file = st.tabs(
|
| 302 |
+
[
|
| 303 |
+
"Knowledge Bases",
|
| 304 |
+
"File Upload",
|
| 305 |
+
]
|
| 306 |
+
)
|
| 307 |
with tab_kb:
|
| 308 |
+
st.markdown("#### Build You Own Knowledge")
|
| 309 |
+
st.multiselect(
|
| 310 |
+
"Select Files to Build up",
|
| 311 |
+
st.session_state.user_files,
|
| 312 |
+
placeholder="You should upload files first",
|
| 313 |
+
key="b_tool_files",
|
| 314 |
+
format_func=lambda x: x["file_name"],
|
| 315 |
+
)
|
| 316 |
+
st.text_input("Tool Name", "get_relevant_documents", key="b_tool_name")
|
| 317 |
+
st.text_input(
|
| 318 |
+
"Tool Description",
|
| 319 |
+
"Searches among user's private files and returns related documents",
|
| 320 |
+
key="b_tool_desc",
|
| 321 |
+
)
|
| 322 |
+
st.button("Build!", on_click=build_kb_as_tool)
|
| 323 |
+
st.markdown("### Knowledge Base Selection")
|
| 324 |
+
if (
|
| 325 |
+
"user_tools" in st.session_state
|
| 326 |
+
and len(st.session_state.user_tools) > 0
|
| 327 |
+
):
|
| 328 |
+
st.markdown("***User Created Knowledge Bases***")
|
| 329 |
+
st.dataframe(st.session_state.user_tools)
|
| 330 |
st.multiselect(
|
| 331 |
"Select a Knowledge Base Tool",
|
| 332 |
+
st.session_state.tools.keys()
|
| 333 |
+
if "tools_with_users" not in st.session_state
|
| 334 |
+
else st.session_state.tools_with_users,
|
| 335 |
default=["Wikipedia + Self Querying"],
|
| 336 |
key="selected_tools",
|
| 337 |
on_change=refresh_agent,
|
| 338 |
)
|
| 339 |
+
st.markdown("### Delete Knowledge Base")
|
| 340 |
+
st.multiselect(
|
| 341 |
+
"Choose Knowledge Base to Remove",
|
| 342 |
+
st.session_state.user_tools,
|
| 343 |
+
format_func=lambda x: x["tool_name"],
|
| 344 |
+
key="r_tool_names",
|
| 345 |
+
)
|
| 346 |
+
st.button("Delete", on_click=remove_kb)
|
| 347 |
with tab_file:
|
| 348 |
+
st.info(
|
| 349 |
+
(
|
| 350 |
+
"We adopted [Unstructured API](https://unstructured.io/api-key) "
|
| 351 |
+
"here and we only store the processed texts from your documents. "
|
| 352 |
+
"For privacy concerns, please refer to "
|
| 353 |
+
"[our policy issue](https://myscale.com/privacy/)."
|
| 354 |
+
),
|
| 355 |
+
icon="π",
|
| 356 |
+
)
|
| 357 |
+
st.file_uploader(
|
| 358 |
+
"Upload files", key="uploaded_files", accept_multiple_files=True
|
| 359 |
+
)
|
| 360 |
st.markdown("### Uploaded Files")
|
| 361 |
+
st.dataframe(
|
| 362 |
+
st.session_state.private_kb.list_files(st.session_state.user_name),
|
| 363 |
+
use_container_width=True,
|
| 364 |
+
)
|
| 365 |
col_1, col_2 = st.columns(2)
|
| 366 |
with col_1:
|
| 367 |
st.button("Add Files", on_click=add_file)
|
| 368 |
with col_2:
|
| 369 |
+
st.button("Clear Files and All Tools", on_click=clear_files)
|
| 370 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
st.button("Clear Chat History", on_click=clear_history)
|
| 372 |
st.button("Logout", on_click=back_to_main)
|
| 373 |
+
if "agent" not in st.session_state:
|
| 374 |
refresh_agent()
|
| 375 |
print("!!! ", st.session_state.agent.memory.chat_memory.session_id)
|
| 376 |
for msg in st.session_state.agent.memory.chat_memory.messages:
|
|
|
|
| 383 |
st.write("Retrieved from knowledge base:")
|
| 384 |
try:
|
| 385 |
st.dataframe(
|
| 386 |
+
pd.DataFrame.from_records(
|
| 387 |
+
json.loads(msg.content, cls=CustomJSONDecoder)
|
| 388 |
+
),
|
| 389 |
+
use_container_width=True,
|
| 390 |
)
|
| 391 |
except:
|
| 392 |
st.write(msg.content)
|
helper.py β lib/helper.py
RENAMED
|
@@ -49,10 +49,12 @@ from langchain.memory import SQLChatMessageHistory
|
|
| 49 |
from langchain.memory.chat_message_histories.sql import \
|
| 50 |
BaseMessageConverter, DefaultMessageConverter
|
| 51 |
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
|
| 52 |
-
from langchain.agents.agent_toolkits import create_retriever_tool
|
| 53 |
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
|
| 54 |
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
| 55 |
from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
|
|
|
|
|
|
|
| 56 |
environ['TOKENIZERS_PARALLELISM'] = 'true'
|
| 57 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
| 58 |
|
|
@@ -495,7 +497,7 @@ def create_retriever_tool(
|
|
| 495 |
def wrap(func):
|
| 496 |
def wrapped_retrieve(*args, **kwargs):
|
| 497 |
docs: List[Document] = func(*args, **kwargs)
|
| 498 |
-
return json.dumps([d.dict() for d in docs])
|
| 499 |
return wrapped_retrieve
|
| 500 |
|
| 501 |
return Tool(
|
|
@@ -533,12 +535,13 @@ def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temper
|
|
| 533 |
chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
|
| 534 |
openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
|
| 535 |
)
|
| 536 |
-
tools =
|
|
|
|
| 537 |
agent = create_agent_executor(
|
| 538 |
"chat_memory",
|
| 539 |
session_id,
|
| 540 |
chat_llm,
|
| 541 |
-
tools=
|
| 542 |
system_prompt=system_prompt
|
| 543 |
)
|
| 544 |
return agent
|
|
|
|
| 49 |
from langchain.memory.chat_message_histories.sql import \
|
| 50 |
BaseMessageConverter, DefaultMessageConverter
|
| 51 |
from langchain.schema.messages import BaseMessage, _message_to_dict, messages_from_dict
|
| 52 |
+
# from langchain.agents.agent_toolkits import create_retriever_tool
|
| 53 |
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
|
| 54 |
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
| 55 |
from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
|
| 56 |
+
from .json_conv import CustomJSONEncoder
|
| 57 |
+
|
| 58 |
environ['TOKENIZERS_PARALLELISM'] = 'true'
|
| 59 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
| 60 |
|
|
|
|
| 497 |
def wrap(func):
|
| 498 |
def wrapped_retrieve(*args, **kwargs):
|
| 499 |
docs: List[Document] = func(*args, **kwargs)
|
| 500 |
+
return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)
|
| 501 |
return wrapped_retrieve
|
| 502 |
|
| 503 |
return Tool(
|
|
|
|
| 535 |
chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
|
| 536 |
openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
|
| 537 |
)
|
| 538 |
+
tools = st.session_state.tools if "tools_with_users" not in st.session_state else st.session_state.tools_with_users
|
| 539 |
+
sel_tools = [tools[k] for k in tool_names]
|
| 540 |
agent = create_agent_executor(
|
| 541 |
"chat_memory",
|
| 542 |
session_id,
|
| 543 |
chat_llm,
|
| 544 |
+
tools=sel_tools,
|
| 545 |
system_prompt=system_prompt
|
| 546 |
)
|
| 547 |
return agent
|
lib/json_conv.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import datetime
|
| 3 |
+
|
| 4 |
+
class CustomJSONEncoder(json.JSONEncoder):
|
| 5 |
+
def default(self, obj):
|
| 6 |
+
if isinstance(obj, datetime.datetime):
|
| 7 |
+
return datetime.datetime.isoformat(obj)
|
| 8 |
+
return json.JSONEncoder.default(self, obj)
|
| 9 |
+
|
| 10 |
+
class CustomJSONDecoder(json.JSONDecoder):
|
| 11 |
+
def __init__(self, *args, **kwargs):
|
| 12 |
+
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
|
| 13 |
+
|
| 14 |
+
def object_hook(self, source):
|
| 15 |
+
for k, v in source.items():
|
| 16 |
+
if isinstance(v, str):
|
| 17 |
+
try:
|
| 18 |
+
source[k] = datetime.datetime.fromisoformat(str(v))
|
| 19 |
+
except:
|
| 20 |
+
pass
|
| 21 |
+
return source
|
lib/private_kb.py
CHANGED
|
@@ -1,18 +1,19 @@
|
|
| 1 |
import pandas as pd
|
| 2 |
import hashlib
|
| 3 |
import requests
|
| 4 |
-
from typing import List
|
| 5 |
from datetime import datetime
|
| 6 |
from langchain.schema.embeddings import Embeddings
|
| 7 |
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
| 8 |
from clickhouse_connect import get_client
|
| 9 |
from multiprocessing.pool import ThreadPool
|
| 10 |
from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
|
|
|
|
| 11 |
|
| 12 |
parser_url = "https://api.unstructured.io/general/v0/general"
|
| 13 |
|
| 14 |
|
| 15 |
-
def parse_files(api_key, user_id, files: List[UploadedFile]
|
| 16 |
def parse_file(file: UploadedFile):
|
| 17 |
headers = {
|
| 18 |
"accept": "application/json",
|
|
@@ -31,9 +32,10 @@ def parse_files(api_key, user_id, files: List[UploadedFile], collection="default
|
|
| 31 |
{
|
| 32 |
"text": t["text"],
|
| 33 |
"file_name": t["metadata"]["filename"],
|
| 34 |
-
"entity_id": hashlib.sha256(
|
|
|
|
|
|
|
| 35 |
"user_id": user_id,
|
| 36 |
-
"collection_id": collection,
|
| 37 |
"created_by": datetime.now(),
|
| 38 |
}
|
| 39 |
for t in json_response
|
|
@@ -43,7 +45,7 @@ def parse_files(api_key, user_id, files: List[UploadedFile], collection="default
|
|
| 43 |
|
| 44 |
with ThreadPool(8) as p:
|
| 45 |
rows = []
|
| 46 |
-
for r in
|
| 47 |
rows.extend(r)
|
| 48 |
return rows
|
| 49 |
|
|
@@ -68,21 +70,33 @@ class PrivateKnowledgeBase:
|
|
| 68 |
parser_api_key,
|
| 69 |
db="chat",
|
| 70 |
kb_table="private_kb",
|
|
|
|
| 71 |
) -> None:
|
| 72 |
super().__init__()
|
| 73 |
-
|
| 74 |
CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
|
| 75 |
entity_id String,
|
| 76 |
file_name String,
|
| 77 |
text String,
|
| 78 |
user_id String,
|
| 79 |
-
collection_id String,
|
| 80 |
created_by DateTime,
|
| 81 |
vector Array(Float32),
|
| 82 |
CONSTRAINT cons_vec_len CHECK length(vector) = 768,
|
| 83 |
VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
|
| 84 |
) ENGINE = ReplacingMergeTree ORDER BY entity_id
|
| 85 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
config = MyScaleSettings(
|
| 87 |
host=host,
|
| 88 |
port=port,
|
|
@@ -98,41 +112,101 @@ class PrivateKnowledgeBase:
|
|
| 98 |
password=config.password,
|
| 99 |
)
|
| 100 |
client.command("SET allow_experimental_object_type=1")
|
| 101 |
-
client.command(
|
|
|
|
| 102 |
self.parser_api_key = parser_api_key
|
| 103 |
self.vstore = MyScaleWithoutJSON(
|
| 104 |
embedding=embedding,
|
| 105 |
config=config,
|
| 106 |
-
must_have_cols=["file_name", "text", "
|
| 107 |
)
|
| 108 |
-
self.retriever = self.vstore.as_retriever()
|
| 109 |
|
| 110 |
-
def list_files(self, user_id):
|
| 111 |
query = f"""
|
| 112 |
-
SELECT DISTINCT file_name
|
| 113 |
-
|
|
|
|
|
|
|
| 114 |
"""
|
| 115 |
return [r for r in self.vstore.client.query(query).named_results()]
|
| 116 |
|
| 117 |
def add_by_file(
|
| 118 |
-
self, user_id, files: List[UploadedFile],
|
| 119 |
):
|
| 120 |
-
data = parse_files(self.parser_api_key, user_id, files
|
| 121 |
data = extract_embedding(self.vstore.embeddings, data)
|
| 122 |
self.vstore.client.insert_df(
|
| 123 |
-
self.
|
| 124 |
pd.DataFrame(data),
|
| 125 |
database=self.vstore.config.database,
|
| 126 |
)
|
| 127 |
|
| 128 |
def clear(self, user_id):
|
| 129 |
self.vstore.client.command(
|
| 130 |
-
f"DELETE FROM {self.vstore.config.database}.{self.
|
| 131 |
f"WHERE user_id='{user_id}'"
|
| 132 |
)
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
-
def
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
import hashlib
|
| 3 |
import requests
|
| 4 |
+
from typing import List, Optional
|
| 5 |
from datetime import datetime
|
| 6 |
from langchain.schema.embeddings import Embeddings
|
| 7 |
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
| 8 |
from clickhouse_connect import get_client
|
| 9 |
from multiprocessing.pool import ThreadPool
|
| 10 |
from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
|
| 11 |
+
from .helper import create_retriever_tool
|
| 12 |
|
| 13 |
parser_url = "https://api.unstructured.io/general/v0/general"
|
| 14 |
|
| 15 |
|
| 16 |
+
def parse_files(api_key, user_id, files: List[UploadedFile]):
|
| 17 |
def parse_file(file: UploadedFile):
|
| 18 |
headers = {
|
| 19 |
"accept": "application/json",
|
|
|
|
| 32 |
{
|
| 33 |
"text": t["text"],
|
| 34 |
"file_name": t["metadata"]["filename"],
|
| 35 |
+
"entity_id": hashlib.sha256(
|
| 36 |
+
(file_hash + t["text"]).encode()
|
| 37 |
+
).hexdigest(),
|
| 38 |
"user_id": user_id,
|
|
|
|
| 39 |
"created_by": datetime.now(),
|
| 40 |
}
|
| 41 |
for t in json_response
|
|
|
|
| 45 |
|
| 46 |
with ThreadPool(8) as p:
|
| 47 |
rows = []
|
| 48 |
+
for r in p.imap_unordered(parse_file, files):
|
| 49 |
rows.extend(r)
|
| 50 |
return rows
|
| 51 |
|
|
|
|
| 70 |
parser_api_key,
|
| 71 |
db="chat",
|
| 72 |
kb_table="private_kb",
|
| 73 |
+
tool_table="private_tool",
|
| 74 |
) -> None:
|
| 75 |
super().__init__()
|
| 76 |
+
kb_schema_ = f"""
|
| 77 |
CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
|
| 78 |
entity_id String,
|
| 79 |
file_name String,
|
| 80 |
text String,
|
| 81 |
user_id String,
|
|
|
|
| 82 |
created_by DateTime,
|
| 83 |
vector Array(Float32),
|
| 84 |
CONSTRAINT cons_vec_len CHECK length(vector) = 768,
|
| 85 |
VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
|
| 86 |
) ENGINE = ReplacingMergeTree ORDER BY entity_id
|
| 87 |
"""
|
| 88 |
+
tool_schema_ = f"""
|
| 89 |
+
CREATE TABLE IF NOT EXISTS {db}.{tool_table}(
|
| 90 |
+
tool_id String,
|
| 91 |
+
tool_name String,
|
| 92 |
+
file_names Array(String),
|
| 93 |
+
user_id String,
|
| 94 |
+
created_by DateTime,
|
| 95 |
+
tool_description String
|
| 96 |
+
) ENGINE = ReplacingMergeTree ORDER BY tool_id
|
| 97 |
+
"""
|
| 98 |
+
self.kb_table = kb_table
|
| 99 |
+
self.tool_table = tool_table
|
| 100 |
config = MyScaleSettings(
|
| 101 |
host=host,
|
| 102 |
port=port,
|
|
|
|
| 112 |
password=config.password,
|
| 113 |
)
|
| 114 |
client.command("SET allow_experimental_object_type=1")
|
| 115 |
+
client.command(kb_schema_)
|
| 116 |
+
client.command(tool_schema_)
|
| 117 |
self.parser_api_key = parser_api_key
|
| 118 |
self.vstore = MyScaleWithoutJSON(
|
| 119 |
embedding=embedding,
|
| 120 |
config=config,
|
| 121 |
+
must_have_cols=["file_name", "text", "created_by"],
|
| 122 |
)
|
|
|
|
| 123 |
|
| 124 |
+
def list_files(self, user_id, tool_name=None):
|
| 125 |
query = f"""
|
| 126 |
+
SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph,
|
| 127 |
+
arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars
|
| 128 |
+
FROM {self.vstore.config.database}.{self.kb_table}
|
| 129 |
+
WHERE user_id = '{user_id}' GROUP BY file_name
|
| 130 |
"""
|
| 131 |
return [r for r in self.vstore.client.query(query).named_results()]
|
| 132 |
|
| 133 |
def add_by_file(
|
| 134 |
+
self, user_id, files: List[UploadedFile], **kwargs
|
| 135 |
):
|
| 136 |
+
data = parse_files(self.parser_api_key, user_id, files)
|
| 137 |
data = extract_embedding(self.vstore.embeddings, data)
|
| 138 |
self.vstore.client.insert_df(
|
| 139 |
+
self.kb_table,
|
| 140 |
pd.DataFrame(data),
|
| 141 |
database=self.vstore.config.database,
|
| 142 |
)
|
| 143 |
|
| 144 |
def clear(self, user_id):
|
| 145 |
self.vstore.client.command(
|
| 146 |
+
f"DELETE FROM {self.vstore.config.database}.{self.kb_table} "
|
| 147 |
f"WHERE user_id='{user_id}'"
|
| 148 |
)
|
| 149 |
+
query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table}
|
| 150 |
+
WHERE user_id = '{user_id}'"""
|
| 151 |
+
self.vstore.client.command(query)
|
| 152 |
|
| 153 |
+
def create_tool(
|
| 154 |
+
self, user_id, tool_name, tool_description, files: Optional[List[str]] = None
|
| 155 |
+
):
|
| 156 |
+
self.vstore.client.insert_df(
|
| 157 |
+
self.tool_table,
|
| 158 |
+
pd.DataFrame(
|
| 159 |
+
[
|
| 160 |
+
{
|
| 161 |
+
"tool_id": hashlib.sha256(
|
| 162 |
+
(user_id + tool_name).encode("utf-8")
|
| 163 |
+
).hexdigest(),
|
| 164 |
+
"tool_name": tool_name,
|
| 165 |
+
"file_names": files,
|
| 166 |
+
"user_id": user_id,
|
| 167 |
+
"created_by": datetime.now(),
|
| 168 |
+
"tool_description": tool_description,
|
| 169 |
+
}
|
| 170 |
+
]
|
| 171 |
+
),
|
| 172 |
+
database=self.vstore.config.database,
|
| 173 |
+
)
|
| 174 |
|
| 175 |
+
def list_tools(self, user_id, tool_name=None):
|
| 176 |
+
extended_where = f"AND tool_name = '{tool_name}'" if tool_name else ""
|
| 177 |
+
query = f"""
|
| 178 |
+
SELECT tool_name, tool_description, length(file_names)
|
| 179 |
+
FROM {self.vstore.config.database}.{self.tool_table}
|
| 180 |
+
WHERE user_id = '{user_id}' {extended_where}
|
| 181 |
+
"""
|
| 182 |
+
return [r for r in self.vstore.client.query(query).named_results()]
|
| 183 |
+
|
| 184 |
+
def remove_tools(self, user_id, tool_names):
|
| 185 |
+
tool_names = ",".join([f"'{t}'" for t in tool_names])
|
| 186 |
+
query = f"""DELETE FROM {self.vstore.config.database}.{self.tool_table}
|
| 187 |
+
WHERE user_id = '{user_id}' AND tool_name IN [{tool_names}]"""
|
| 188 |
+
self.vstore.client.command(query)
|
| 189 |
+
|
| 190 |
+
def as_tools(self, user_id, tool_name=None):
|
| 191 |
+
tools = self.list_tools(user_id=user_id, tool_name=tool_name)
|
| 192 |
+
retrievers = {
|
| 193 |
+
t["tool_name"]: create_retriever_tool(
|
| 194 |
+
self.vstore.as_retriever(
|
| 195 |
+
search_kwargs={
|
| 196 |
+
"where_str": (
|
| 197 |
+
f"user_id='{user_id}' "
|
| 198 |
+
f"""AND file_name IN (
|
| 199 |
+
SELECT arrayJoin(file_names) FROM (
|
| 200 |
+
SELECT file_names
|
| 201 |
+
FROM {self.vstore.config.database}.{self.tool_table}
|
| 202 |
+
WHERE user_id = '{user_id}' AND tool_name = '{t['tool_name']}')
|
| 203 |
+
)"""
|
| 204 |
+
)
|
| 205 |
+
},
|
| 206 |
+
),
|
| 207 |
+
name=t["tool_name"],
|
| 208 |
+
description=t["tool_description"],
|
| 209 |
+
)
|
| 210 |
+
for t in tools
|
| 211 |
+
}
|
| 212 |
+
return retrievers
|
lib/sessions.py
CHANGED
|
@@ -8,7 +8,6 @@ from datetime import datetime
|
|
| 8 |
from sqlalchemy import Column, Text, orm, create_engine
|
| 9 |
from clickhouse_sqlalchemy import types, engines
|
| 10 |
from .schemas import create_message_model, create_session_table
|
| 11 |
-
from .private_kb import PrivateKnowledgeBase
|
| 12 |
|
| 13 |
def get_sessions(engine, model_class, user_id):
|
| 14 |
with orm.sessionmaker(engine)() as session:
|
|
|
|
| 8 |
from sqlalchemy import Column, Text, orm, create_engine
|
| 9 |
from clickhouse_sqlalchemy import types, engines
|
| 10 |
from .schemas import create_message_model, create_session_table
|
|
|
|
| 11 |
|
| 12 |
def get_sessions(engine, model_class, user_id):
|
| 13 |
with orm.sessionmaker(engine)() as session:
|