Spaces:
Runtime error
Runtime error
Fangrui Liu
commited on
Commit
Β·
042a946
1
Parent(s):
e1383d0
update session model
Browse files- app.py +10 -1
- chat.py +158 -14
- helper.py +23 -37
- lib/schemas.py +52 -0
- lib/sessions.py +68 -0
app.py
CHANGED
|
@@ -10,13 +10,22 @@ from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
|
|
| 10 |
|
| 11 |
from chat import chat_page
|
| 12 |
from login import login, back_to_main
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
-
from helper import build_tools, build_agents, build_all, sel_map, display
|
| 16 |
|
| 17 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
| 18 |
|
| 19 |
st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
st.header("ChatData")
|
| 21 |
|
| 22 |
if 'retriever' not in st.session_state:
|
|
|
|
| 10 |
|
| 11 |
from chat import chat_page
|
| 12 |
from login import login, back_to_main
|
| 13 |
+
from helper import build_tools, build_agents, build_all, sel_map, display
|
| 14 |
|
| 15 |
|
|
|
|
| 16 |
|
| 17 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
| 18 |
|
| 19 |
st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
|
| 20 |
+
st.markdown(
|
| 21 |
+
f"""
|
| 22 |
+
<style>
|
| 23 |
+
.st-e4 {{
|
| 24 |
+
max-width: 500px
|
| 25 |
+
}}
|
| 26 |
+
</style>""",
|
| 27 |
+
unsafe_allow_html=True,
|
| 28 |
+
)
|
| 29 |
st.header("ChatData")
|
| 30 |
|
| 31 |
if 'retriever' not in st.session_state:
|
chat.py
CHANGED
|
@@ -1,20 +1,37 @@
|
|
| 1 |
import pandas as pd
|
| 2 |
from os import environ
|
|
|
|
| 3 |
import datetime
|
| 4 |
import streamlit as st
|
|
|
|
| 5 |
from langchain.schema import HumanMessage, FunctionMessage
|
| 6 |
|
| 7 |
-
from helper import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from login import back_to_main
|
| 9 |
|
| 10 |
-
environ[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def on_chat_submit():
|
| 13 |
-
ret = st.session_state.
|
| 14 |
print(ret)
|
| 15 |
-
|
|
|
|
| 16 |
def clear_history():
|
| 17 |
-
st.session_state
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
def back_to_main():
|
|
@@ -25,29 +42,156 @@ def back_to_main():
|
|
| 25 |
if "jump_query_ask" in st.session_state:
|
| 26 |
del st.session_state.jump_query_ask
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def chat_page():
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
with st.sidebar:
|
| 31 |
-
st.
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
st.button("Clear Chat History", on_click=clear_history)
|
| 34 |
st.button("Logout", on_click=back_to_main)
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
| 36 |
speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
|
| 37 |
if isinstance(msg, FunctionMessage):
|
| 38 |
with st.chat_message("Knowledge Base", avatar="π"):
|
| 39 |
-
|
| 40 |
-
|
|
|
|
| 41 |
st.write("Retrieved from knowledge base:")
|
| 42 |
try:
|
| 43 |
-
st.dataframe(
|
|
|
|
|
|
|
| 44 |
except:
|
| 45 |
st.write(msg.content)
|
| 46 |
else:
|
| 47 |
if len(msg.content) > 0:
|
| 48 |
with st.chat_message(speaker):
|
| 49 |
print(type(msg), msg.dict())
|
| 50 |
-
st.write(
|
|
|
|
|
|
|
| 51 |
st.write(f"{msg.content}")
|
| 52 |
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
|
| 53 |
-
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
from os import environ
|
| 3 |
+
from time import sleep
|
| 4 |
import datetime
|
| 5 |
import streamlit as st
|
| 6 |
+
from lib.sessions import SessionManager
|
| 7 |
from langchain.schema import HumanMessage, FunctionMessage
|
| 8 |
|
| 9 |
+
from helper import (
|
| 10 |
+
build_agents,
|
| 11 |
+
MYSCALE_HOST,
|
| 12 |
+
MYSCALE_PASSWORD,
|
| 13 |
+
MYSCALE_PORT,
|
| 14 |
+
MYSCALE_USER,
|
| 15 |
+
DEFAULT_SYSTEM_PROMPT,
|
| 16 |
+
)
|
| 17 |
from login import back_to_main
|
| 18 |
|
| 19 |
+
environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"]
|
| 20 |
+
|
| 21 |
+
TOOL_NAMES = {
|
| 22 |
+
"langchain_retriever_tool": "Self-querying retriever",
|
| 23 |
+
"vecsql_retriever_tool": "Vector SQL",
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
|
| 27 |
def on_chat_submit():
|
| 28 |
+
ret = st.session_state.agent({"input": st.session_state.chat_input})
|
| 29 |
print(ret)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
def clear_history():
|
| 33 |
+
if "agent" in st.session_state:
|
| 34 |
+
st.session_state.agent.memory.clear()
|
| 35 |
|
| 36 |
|
| 37 |
def back_to_main():
|
|
|
|
| 42 |
if "jump_query_ask" in st.session_state:
|
| 43 |
del st.session_state.jump_query_ask
|
| 44 |
|
| 45 |
+
|
| 46 |
+
def on_session_change_submit():
|
| 47 |
+
if "session_manager" in st.session_state and "session_editor" in st.session_state:
|
| 48 |
+
print(st.session_state.session_editor)
|
| 49 |
+
try:
|
| 50 |
+
for elem in st.session_state.session_editor["added_rows"]:
|
| 51 |
+
if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem:
|
| 52 |
+
if elem["session_id"] != "" and "?" not in elem["session_id"]:
|
| 53 |
+
st.session_state.session_manager.add_session(
|
| 54 |
+
user_id=st.session_state.user_name,
|
| 55 |
+
session_id=f"{st.session_state.user_name}?{elem['session_id']}",
|
| 56 |
+
system_prompt=elem["system_prompt"],
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
raise KeyError(
|
| 60 |
+
"`session_id` should NOT be neither empty nor contain question marks."
|
| 61 |
+
)
|
| 62 |
+
else:
|
| 63 |
+
raise KeyError(
|
| 64 |
+
"You should fill both `session_id` and `system_prompt` to add a column!"
|
| 65 |
+
)
|
| 66 |
+
for elem in st.session_state.session_editor["deleted_rows"]:
|
| 67 |
+
st.session_state.session_manager.remove_session(
|
| 68 |
+
session_id=f"{st.session_state.user_name}?{st.session_state.current_sessions[elem]['session_id']}",
|
| 69 |
+
)
|
| 70 |
+
refresh_sessions()
|
| 71 |
+
if len(st.session_state.session_editor["deleted_rows"]) > 0:
|
| 72 |
+
try:
|
| 73 |
+
dfl_indx = [
|
| 74 |
+
x["session_id"] for x in st.session_state.current_sessions
|
| 75 |
+
].index("default")
|
| 76 |
+
except ValueError:
|
| 77 |
+
dfl_indx = 0
|
| 78 |
+
st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx]
|
| 79 |
+
except Exception as e:
|
| 80 |
+
sleep(2)
|
| 81 |
+
st.error(f"{type(e)}: {str(e)}")
|
| 82 |
+
finally:
|
| 83 |
+
st.session_state.session_editor["added_rows"] = []
|
| 84 |
+
st.session_state.session_editor["deleted_rows"] = []
|
| 85 |
+
refresh_agent()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def build_session_manager():
|
| 89 |
+
return SessionManager(
|
| 90 |
+
host=MYSCALE_HOST,
|
| 91 |
+
port=MYSCALE_PORT,
|
| 92 |
+
username=MYSCALE_USER,
|
| 93 |
+
password=MYSCALE_PASSWORD,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def refresh_sessions():
|
| 98 |
+
st.session_state[
|
| 99 |
+
"current_sessions"
|
| 100 |
+
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
|
| 101 |
+
if type(st.session_state.current_sessions) is not dict and len(st.session_state.current_sessions) <= 0:
|
| 102 |
+
st.session_state.session_manager.add_session(
|
| 103 |
+
st.session_state.user_name,
|
| 104 |
+
f"{st.session_state.user_name}?default",
|
| 105 |
+
DEFAULT_SYSTEM_PROMPT,
|
| 106 |
+
)
|
| 107 |
+
st.session_state[
|
| 108 |
+
"current_sessions"
|
| 109 |
+
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def refresh_agent():
|
| 113 |
+
with st.spinner("Initializing session..."):
|
| 114 |
+
print(
|
| 115 |
+
f"??? Changed to ",
|
| 116 |
+
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
|
| 117 |
+
)
|
| 118 |
+
st.session_state["agent"] = build_agents(
|
| 119 |
+
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
|
| 120 |
+
["LangChain Self Query Retriever For Wikipedia"]
|
| 121 |
+
if "selected_tools" not in st.session_state
|
| 122 |
+
else st.session_state.selected_tools,
|
| 123 |
+
system_prompt=DEFAULT_SYSTEM_PROMPT
|
| 124 |
+
if "sel_sess" not in st.session_state
|
| 125 |
+
else st.session_state.sel_sess["system_prompt"],
|
| 126 |
+
)
|
| 127 |
+
st.session_state["session_manager"] = build_session_manager()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
def chat_page():
|
| 131 |
+
if "sel_sess" not in st.session_state:
|
| 132 |
+
st.session_state["sel_sess"] = {
|
| 133 |
+
"session_id": "default",
|
| 134 |
+
"system_prompt": DEFAULT_SYSTEM_PROMPT,
|
| 135 |
+
}
|
| 136 |
with st.sidebar:
|
| 137 |
+
with st.expander("Session Management"):
|
| 138 |
+
refresh_sessions()
|
| 139 |
+
st.data_editor(
|
| 140 |
+
st.session_state.current_sessions,
|
| 141 |
+
num_rows="dynamic",
|
| 142 |
+
key="session_editor",
|
| 143 |
+
use_container_width=True,
|
| 144 |
+
)
|
| 145 |
+
st.button("Submit Change!", on_click=on_session_change_submit)
|
| 146 |
+
with st.expander("Session Selection", expanded=True):
|
| 147 |
+
try:
|
| 148 |
+
dfl_indx = [
|
| 149 |
+
x["session_id"] for x in st.session_state.current_sessions
|
| 150 |
+
].index("default")
|
| 151 |
+
except ValueError:
|
| 152 |
+
dfl_indx = 0
|
| 153 |
+
st.selectbox(
|
| 154 |
+
"Choose a session be chat:",
|
| 155 |
+
options=st.session_state.current_sessions,
|
| 156 |
+
index=dfl_indx,
|
| 157 |
+
key="sel_sess",
|
| 158 |
+
format_func=lambda x: x["session_id"],
|
| 159 |
+
on_change=refresh_agent,
|
| 160 |
+
)
|
| 161 |
+
print(st.session_state.sel_sess)
|
| 162 |
+
with st.expander("Tool Settings", expanded=True):
|
| 163 |
+
st.multiselect(
|
| 164 |
+
"Knowledge Base",
|
| 165 |
+
st.session_state.tools.keys(),
|
| 166 |
+
default=["LangChain Self Query Retriever For Wikipedia"],
|
| 167 |
+
key="selected_tools",
|
| 168 |
+
on_change=refresh_agent,
|
| 169 |
+
)
|
| 170 |
st.button("Clear Chat History", on_click=clear_history)
|
| 171 |
st.button("Logout", on_click=back_to_main)
|
| 172 |
+
if 'agent' not in st.session_state:
|
| 173 |
+
refresh_agent()
|
| 174 |
+
print("!!! ", st.session_state.agent.memory.chat_memory.session_id)
|
| 175 |
+
for msg in st.session_state.agent.memory.chat_memory.messages:
|
| 176 |
speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
|
| 177 |
if isinstance(msg, FunctionMessage):
|
| 178 |
with st.chat_message("Knowledge Base", avatar="π"):
|
| 179 |
+
st.write(
|
| 180 |
+
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
|
| 181 |
+
)
|
| 182 |
st.write("Retrieved from knowledge base:")
|
| 183 |
try:
|
| 184 |
+
st.dataframe(
|
| 185 |
+
pd.DataFrame.from_records(map(dict, eval(msg.content)))
|
| 186 |
+
)
|
| 187 |
except:
|
| 188 |
st.write(msg.content)
|
| 189 |
else:
|
| 190 |
if len(msg.content) > 0:
|
| 191 |
with st.chat_message(speaker):
|
| 192 |
print(type(msg), msg.dict())
|
| 193 |
+
st.write(
|
| 194 |
+
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
|
| 195 |
+
)
|
| 196 |
st.write(f"{msg.content}")
|
| 197 |
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
|
|
|
helper.py
CHANGED
|
@@ -68,6 +68,12 @@ MYSCALE_PORT = st.secrets['MYSCALE_PORT']
|
|
| 68 |
COMBINE_PROMPT = ChatPromptTemplate.from_strings(
|
| 69 |
string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
|
| 70 |
(HumanMessagePromptTemplate, '{question}')])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
def hint_arxiv():
|
| 73 |
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
|
@@ -415,7 +421,7 @@ class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
|
| 415 |
return self.model_class
|
| 416 |
|
| 417 |
|
| 418 |
-
def create_agent_executor(name, session_id, llm, tools, **kwargs):
|
| 419 |
name = name.replace(" ", "_")
|
| 420 |
conn_str = f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}'
|
| 421 |
chat_memory = SQLChatMessageHistory(
|
|
@@ -425,12 +431,7 @@ def create_agent_executor(name, session_id, llm, tools, **kwargs):
|
|
| 425 |
memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
|
| 426 |
|
| 427 |
_system_message = SystemMessage(
|
| 428 |
-
content=
|
| 429 |
-
"Do your best to answer the questions. "
|
| 430 |
-
"Feel free to use any tools available to look up "
|
| 431 |
-
"relevant information. Please keep all details in query "
|
| 432 |
-
"when calling search functions."
|
| 433 |
-
)
|
| 434 |
)
|
| 435 |
prompt = OpenAIFunctionsAgent.create_prompt(
|
| 436 |
system_message=_system_message,
|
|
@@ -463,38 +464,23 @@ def build_tools():
|
|
| 463 |
st.session_state["sel_map_obj"][k] = {}
|
| 464 |
if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]:
|
| 465 |
st.session_state.sel_map_obj[k].update(build_chains_retrievers(k))
|
| 466 |
-
sel_map_obj
|
| 467 |
-
"
|
| 468 |
-
"
|
| 469 |
-
}
|
| 470 |
return sel_map_obj
|
| 471 |
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
tools = []
|
| 484 |
-
else:
|
| 485 |
-
tools = [st.session_state.tools[k][m]]
|
| 486 |
-
if k not in agents:
|
| 487 |
-
agents[k] = {}
|
| 488 |
-
agents[k][n] = create_agent_executor(
|
| 489 |
-
"chat_memory",
|
| 490 |
-
session_id,
|
| 491 |
-
chat_llm,
|
| 492 |
-
tools=tools,
|
| 493 |
-
)
|
| 494 |
-
cnt += 1/6
|
| 495 |
-
p.progress(cnt, f"Building with Knowledge Base {k} via Retriever {n}...")
|
| 496 |
-
p.empty()
|
| 497 |
-
return agents
|
| 498 |
|
| 499 |
|
| 500 |
def display(dataframe, columns_=None, index=None):
|
|
|
|
| 68 |
COMBINE_PROMPT = ChatPromptTemplate.from_strings(
|
| 69 |
string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
|
| 70 |
(HumanMessagePromptTemplate, '{question}')])
|
| 71 |
+
DEFAULT_SYSTEM_PROMPT = (
|
| 72 |
+
"Do your best to answer the questions. "
|
| 73 |
+
"Feel free to use any tools available to look up "
|
| 74 |
+
"relevant information. Please keep all details in query "
|
| 75 |
+
"when calling search functions."
|
| 76 |
+
)
|
| 77 |
|
| 78 |
def hint_arxiv():
|
| 79 |
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
|
|
|
| 421 |
return self.model_class
|
| 422 |
|
| 423 |
|
| 424 |
+
def create_agent_executor(name, session_id, llm, tools, system_prompt, **kwargs):
|
| 425 |
name = name.replace(" ", "_")
|
| 426 |
conn_str = f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}'
|
| 427 |
chat_memory = SQLChatMessageHistory(
|
|
|
|
| 431 |
memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
|
| 432 |
|
| 433 |
_system_message = SystemMessage(
|
| 434 |
+
content=system_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
)
|
| 436 |
prompt = OpenAIFunctionsAgent.create_prompt(
|
| 437 |
system_message=_system_message,
|
|
|
|
| 464 |
st.session_state["sel_map_obj"][k] = {}
|
| 465 |
if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]:
|
| 466 |
st.session_state.sel_map_obj[k].update(build_chains_retrievers(k))
|
| 467 |
+
sel_map_obj.update({
|
| 468 |
+
f"LangChain Self Query Retriever For {k}": create_retriever_tool(st.session_state.sel_map_obj[k]["retriever"], *sel_map[k]["tool_desc"],),
|
| 469 |
+
f"Vector SQL Retriever For {k}": create_retriever_tool(st.session_state.sel_map_obj[k]["sql_retriever"], *sel_map[k]["tool_desc"],),
|
| 470 |
+
})
|
| 471 |
return sel_map_obj
|
| 472 |
|
| 473 |
+
def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
| 474 |
+
chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature, openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY)
|
| 475 |
+
tools = [st.session_state.tools[k] for k in tool_names]
|
| 476 |
+
agent = create_agent_executor(
|
| 477 |
+
"chat_memory",
|
| 478 |
+
session_id,
|
| 479 |
+
chat_llm,
|
| 480 |
+
tools=tools,
|
| 481 |
+
system_prompt=system_prompt
|
| 482 |
+
)
|
| 483 |
+
return agent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
|
| 485 |
|
| 486 |
def display(dataframe, columns_=None, index=None):
|
lib/schemas.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sqlalchemy import Column, Text
|
| 2 |
+
from clickhouse_sqlalchemy import types, engines
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def create_message_model(table_name, DynamicBase): # type: ignore
|
| 6 |
+
"""
|
| 7 |
+
Create a message model for a given table name.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
table_name: The name of the table to use.
|
| 11 |
+
DynamicBase: The base class to use for the model.
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
The model class.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
# Model decleared inside a function to have a dynamic table name
|
| 19 |
+
class Message(DynamicBase):
|
| 20 |
+
__tablename__ = table_name
|
| 21 |
+
id = Column(types.Float64)
|
| 22 |
+
session_id = Column(Text)
|
| 23 |
+
user_id = Column(Text)
|
| 24 |
+
msg_id = Column(Text, primary_key=True)
|
| 25 |
+
type = Column(Text)
|
| 26 |
+
addtionals = Column(Text)
|
| 27 |
+
message = Column(Text)
|
| 28 |
+
__table_args__ = (
|
| 29 |
+
engines.ReplacingMergeTree(
|
| 30 |
+
partition_by='session_id',
|
| 31 |
+
order_by=('id', 'msg_id')),
|
| 32 |
+
{'comment': 'Store Chat History'}
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
return Message
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def create_session_table(table_name, DynamicBase): # type: ignore
|
| 39 |
+
# Model decleared inside a function to have a dynamic table name
|
| 40 |
+
class Session(DynamicBase):
|
| 41 |
+
__tablename__ = table_name
|
| 42 |
+
user_id = Column(Text)
|
| 43 |
+
session_id = Column(Text, primary_key=True)
|
| 44 |
+
system_prompt = Column(Text)
|
| 45 |
+
create_by = Column(types.DateTime)
|
| 46 |
+
additionals = Column(Text)
|
| 47 |
+
__table_args__ = (
|
| 48 |
+
engines.ReplacingMergeTree(
|
| 49 |
+
order_by=('session_id')),
|
| 50 |
+
{'comment': 'Store Session and Prompts'}
|
| 51 |
+
)
|
| 52 |
+
return Session
|
lib/sessions.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
try:
|
| 3 |
+
from sqlalchemy.orm import declarative_base
|
| 4 |
+
except ImportError:
|
| 5 |
+
from sqlalchemy.ext.declarative import declarative_base
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from sqlalchemy import Column, Text, orm, create_engine
|
| 8 |
+
from clickhouse_sqlalchemy import types, engines
|
| 9 |
+
from .schemas import create_message_model, create_session_table
|
| 10 |
+
|
| 11 |
+
def get_sessions(engine, model_class, user_id):
|
| 12 |
+
with orm.sessionmaker(engine)() as session:
|
| 13 |
+
result = (
|
| 14 |
+
session.query(model_class)
|
| 15 |
+
.where(
|
| 16 |
+
model_class.session_id == user_id
|
| 17 |
+
)
|
| 18 |
+
.order_by(model_class.create_by.desc())
|
| 19 |
+
)
|
| 20 |
+
return json.loads(result)
|
| 21 |
+
|
| 22 |
+
class SessionManager:
|
| 23 |
+
def __init__(self, host, port, username, password, db='chat', sess_table='sessions', msg_table='chat_memory') -> None:
|
| 24 |
+
conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
|
| 25 |
+
self.engine = create_engine(conn_str, echo=False)
|
| 26 |
+
self.sess_model_class = create_session_table(sess_table, declarative_base())
|
| 27 |
+
self.sess_model_class.metadata.create_all(self.engine)
|
| 28 |
+
self.msg_model_class = create_message_model(msg_table, declarative_base())
|
| 29 |
+
self.msg_model_class.metadata.create_all(self.engine)
|
| 30 |
+
self.Session = orm.sessionmaker(self.engine)
|
| 31 |
+
|
| 32 |
+
def list_sessions(self, user_id):
|
| 33 |
+
with self.Session() as session:
|
| 34 |
+
result = (
|
| 35 |
+
session.query(self.sess_model_class)
|
| 36 |
+
.where(
|
| 37 |
+
self.sess_model_class.user_id == user_id
|
| 38 |
+
)
|
| 39 |
+
.order_by(self.sess_model_class.create_by.desc())
|
| 40 |
+
)
|
| 41 |
+
sessions = []
|
| 42 |
+
for r in result:
|
| 43 |
+
sessions.append({
|
| 44 |
+
"session_id": r.session_id.split("?")[-1],
|
| 45 |
+
"system_prompt": r.system_prompt,
|
| 46 |
+
})
|
| 47 |
+
return sessions
|
| 48 |
+
|
| 49 |
+
def modify_system_prompt(self, session_id, sys_prompt):
|
| 50 |
+
with self.Session() as session:
|
| 51 |
+
session.update(self.sess_model_class).where(self.sess_model_class==session_id).value(system_prompt=sys_prompt)
|
| 52 |
+
session.commit()
|
| 53 |
+
|
| 54 |
+
def add_session(self, user_id, session_id, system_prompt, **kwargs):
|
| 55 |
+
with self.Session() as session:
|
| 56 |
+
elem = self.sess_model_class(
|
| 57 |
+
user_id=user_id, session_id=session_id, system_prompt=system_prompt,
|
| 58 |
+
create_by=datetime.now(), additionals=json.dumps(kwargs)
|
| 59 |
+
)
|
| 60 |
+
session.add(elem)
|
| 61 |
+
session.commit()
|
| 62 |
+
|
| 63 |
+
def remove_session(self, session_id):
|
| 64 |
+
with self.Session() as session:
|
| 65 |
+
session.query(self.sess_model_class).where(self.sess_model_class.session_id==session_id).delete()
|
| 66 |
+
session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete()
|
| 67 |
+
|
| 68 |
+
|