Spaces:
Runtime error
Runtime error
improve chat experience
Browse files- app.py +1 -1
- callbacks/arxiv_callbacks.py +32 -3
- chat.py +23 -4
- helper.py +66 -8
app.py
CHANGED
|
@@ -28,7 +28,7 @@ st.markdown(
|
|
| 28 |
)
|
| 29 |
st.header("ChatData")
|
| 30 |
|
| 31 |
-
if '
|
| 32 |
st.session_state["sel_map_obj"] = build_all()
|
| 33 |
st.session_state["tools"] = build_tools()
|
| 34 |
|
|
|
|
| 28 |
)
|
| 29 |
st.header("ChatData")
|
| 30 |
|
| 31 |
+
if 'sel_map_obj' not in st.session_state:
|
| 32 |
st.session_state["sel_map_obj"] = build_all()
|
| 33 |
st.session_state["tools"] = build_tools()
|
| 34 |
|
callbacks/arxiv_callbacks.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
|
|
|
|
|
|
|
| 3 |
from sql_formatter.core import format_sql
|
| 4 |
-
from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
|
| 5 |
from langchain.schema.output import LLMResult
|
|
|
|
| 6 |
|
| 7 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
| 8 |
def __init__(self) -> None:
|
|
@@ -91,4 +94,30 @@ class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
|
|
| 91 |
self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
|
| 92 |
self.status_bar = st.empty()
|
| 93 |
self.prog_value = 0
|
| 94 |
-
self.prog_interval = 0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
import json
|
| 3 |
+
import textwrap
|
| 4 |
+
from typing import Dict, Any, List
|
| 5 |
from sql_formatter.core import format_sql
|
| 6 |
+
from langchain.callbacks.streamlit.streamlit_callback_handler import LLMThought, StreamlitCallbackHandler
|
| 7 |
from langchain.schema.output import LLMResult
|
| 8 |
+
from streamlit.delta_generator import DeltaGenerator
|
| 9 |
|
| 10 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
| 11 |
def __init__(self) -> None:
|
|
|
|
| 94 |
self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
|
| 95 |
self.status_bar = st.empty()
|
| 96 |
self.prog_value = 0
|
| 97 |
+
self.prog_interval = 0.1
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class LLMThoughtWithKB(LLMThought):
|
| 101 |
+
def on_tool_end(self, output: str, color: str | None = None, observation_prefix: str | None = None, llm_prefix: str | None = None, **kwargs: Any) -> None:
|
| 102 |
+
try:
|
| 103 |
+
self._container.markdown("\n\n".join(["### Retrieved Documents:"] + \
|
| 104 |
+
[f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}"
|
| 105 |
+
for i, r in enumerate(json.loads(output))]))
|
| 106 |
+
except Exception as e:
|
| 107 |
+
super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
|
| 111 |
+
|
| 112 |
+
def on_llm_start(
|
| 113 |
+
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
| 114 |
+
) -> None:
|
| 115 |
+
if self._current_thought is None:
|
| 116 |
+
self._current_thought = LLMThoughtWithKB(
|
| 117 |
+
parent_container=self._parent_container,
|
| 118 |
+
expanded=self._expand_new_thoughts,
|
| 119 |
+
collapse_on_complete=self._collapse_completed_thoughts,
|
| 120 |
+
labeler=self._thought_labeler,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
self._current_thought.on_llm_start(serialized, prompts)
|
chat.py
CHANGED
|
@@ -5,6 +5,8 @@ import datetime
|
|
| 5 |
import streamlit as st
|
| 6 |
from lib.sessions import SessionManager
|
| 7 |
from langchain.schema import HumanMessage, FunctionMessage
|
|
|
|
|
|
|
| 8 |
|
| 9 |
from helper import (
|
| 10 |
build_agents,
|
|
@@ -25,8 +27,14 @@ TOOL_NAMES = {
|
|
| 25 |
|
| 26 |
|
| 27 |
def on_chat_submit():
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
def clear_history():
|
|
@@ -136,6 +144,12 @@ def chat_page():
|
|
| 136 |
with st.sidebar:
|
| 137 |
with st.expander("Session Management"):
|
| 138 |
refresh_sessions()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
st.data_editor(
|
| 140 |
st.session_state.current_sessions,
|
| 141 |
num_rows="dynamic",
|
|
@@ -144,6 +158,8 @@ def chat_page():
|
|
| 144 |
)
|
| 145 |
st.button("Submit Change!", on_click=on_session_change_submit)
|
| 146 |
with st.expander("Session Selection", expanded=True):
|
|
|
|
|
|
|
| 147 |
try:
|
| 148 |
dfl_indx = [
|
| 149 |
x["session_id"] for x in st.session_state.current_sessions
|
|
@@ -152,7 +168,7 @@ def chat_page():
|
|
| 152 |
print("*** ", str(e))
|
| 153 |
dfl_indx = 0
|
| 154 |
st.selectbox(
|
| 155 |
-
"Choose a session
|
| 156 |
options=st.session_state.current_sessions,
|
| 157 |
index=dfl_indx,
|
| 158 |
key="sel_sess",
|
|
@@ -161,10 +177,12 @@ def chat_page():
|
|
| 161 |
)
|
| 162 |
print(st.session_state.sel_sess)
|
| 163 |
with st.expander("Tool Settings", expanded=True):
|
|
|
|
|
|
|
| 164 |
st.multiselect(
|
| 165 |
"Knowledge Base",
|
| 166 |
st.session_state.tools.keys(),
|
| 167 |
-
default=["
|
| 168 |
key="selected_tools",
|
| 169 |
on_change=refresh_agent,
|
| 170 |
)
|
|
@@ -195,4 +213,5 @@ def chat_page():
|
|
| 195 |
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
|
| 196 |
)
|
| 197 |
st.write(f"{msg.content}")
|
|
|
|
| 198 |
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
|
|
|
|
| 5 |
import streamlit as st
|
| 6 |
from lib.sessions import SessionManager
|
| 7 |
from langchain.schema import HumanMessage, FunctionMessage
|
| 8 |
+
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
|
| 9 |
+
from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
|
| 10 |
|
| 11 |
from helper import (
|
| 12 |
build_agents,
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
def on_chat_submit():
|
| 30 |
+
with st.session_state.next_round.container():
|
| 31 |
+
with st.chat_message('user'):
|
| 32 |
+
st.write(st.session_state.chat_input)
|
| 33 |
+
with st.chat_message('assistant'):
|
| 34 |
+
container = st.container()
|
| 35 |
+
st_callback = ChatDataAgentCallBackHandler(container, collapse_completed_thoughts=False)
|
| 36 |
+
ret = st.session_state.agent({"input": st.session_state.chat_input}, callbacks=[st_callback])
|
| 37 |
+
print(ret)
|
| 38 |
|
| 39 |
|
| 40 |
def clear_history():
|
|
|
|
| 144 |
with st.sidebar:
|
| 145 |
with st.expander("Session Management"):
|
| 146 |
refresh_sessions()
|
| 147 |
+
st.info("Here you can set up your session! \n\nYou can **change your prompt** here!",
|
| 148 |
+
icon="π€")
|
| 149 |
+
st.info(("**Add columns by clicking the empty row**.\n"
|
| 150 |
+
"And **delete columns by selecting rows with a press on `DEL` Key**"),
|
| 151 |
+
icon="π‘")
|
| 152 |
+
st.info("Don't forget to **click `Submit Change` to save your change**!", icon="π")
|
| 153 |
st.data_editor(
|
| 154 |
st.session_state.current_sessions,
|
| 155 |
num_rows="dynamic",
|
|
|
|
| 158 |
)
|
| 159 |
st.button("Submit Change!", on_click=on_session_change_submit)
|
| 160 |
with st.expander("Session Selection", expanded=True):
|
| 161 |
+
st.info("Here you can select your session!", icon="π€")
|
| 162 |
+
st.info("If no session is attach to your account, then we will add a default session to you!", icon="β€οΈ")
|
| 163 |
try:
|
| 164 |
dfl_indx = [
|
| 165 |
x["session_id"] for x in st.session_state.current_sessions
|
|
|
|
| 168 |
print("*** ", str(e))
|
| 169 |
dfl_indx = 0
|
| 170 |
st.selectbox(
|
| 171 |
+
"Choose a session to chat:",
|
| 172 |
options=st.session_state.current_sessions,
|
| 173 |
index=dfl_indx,
|
| 174 |
key="sel_sess",
|
|
|
|
| 177 |
)
|
| 178 |
print(st.session_state.sel_sess)
|
| 179 |
with st.expander("Tool Settings", expanded=True):
|
| 180 |
+
st.info("Here you can select your tools.", icon="π§")
|
| 181 |
+
st.info("We provides you several knowledge base tools for you. We are building more tools!", icon="π·ββοΈ")
|
| 182 |
st.multiselect(
|
| 183 |
"Knowledge Base",
|
| 184 |
st.session_state.tools.keys(),
|
| 185 |
+
default=["Wikipedia + Self Querying"],
|
| 186 |
key="selected_tools",
|
| 187 |
on_change=refresh_agent,
|
| 188 |
)
|
|
|
|
| 213 |
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
|
| 214 |
)
|
| 215 |
st.write(f"{msg.content}")
|
| 216 |
+
st.session_state["next_round"] = st.empty()
|
| 217 |
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
|
helper.py
CHANGED
|
@@ -2,12 +2,15 @@
|
|
| 2 |
import json
|
| 3 |
import time
|
| 4 |
import hashlib
|
| 5 |
-
from typing import Dict, Any
|
| 6 |
import re
|
| 7 |
import pandas as pd
|
| 8 |
from os import environ
|
| 9 |
import streamlit as st
|
| 10 |
import datetime
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from sqlalchemy import Column, Text, create_engine, MetaData
|
| 13 |
from langchain.agents import AgentExecutor
|
|
@@ -28,7 +31,7 @@ from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
|
|
| 28 |
SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
| 29 |
from langchain.prompts.prompt import PromptTemplate
|
| 30 |
from langchain.chat_models import ChatOpenAI
|
| 31 |
-
from langchain.schema import BaseRetriever
|
| 32 |
from langchain import OpenAI
|
| 33 |
from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
|
| 34 |
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
|
@@ -36,12 +39,12 @@ from langchain.retrievers.self_query.myscale import MyScaleTranslator
|
|
| 36 |
from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings
|
| 37 |
from langchain.vectorstores import MyScaleSettings
|
| 38 |
from chains.arxiv_chains import MyScaleWithoutMetadataJson
|
| 39 |
-
from langchain.schema import Document
|
| 40 |
from langchain.prompts.prompt import PromptTemplate
|
| 41 |
from langchain.prompts.chat import MessagesPlaceholder
|
| 42 |
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
|
| 43 |
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
| 44 |
-
from langchain.schema import BaseMessage, HumanMessage, AIMessage, FunctionMessage
|
|
|
|
| 45 |
from langchain.memory import SQLChatMessageHistory
|
| 46 |
from langchain.memory.chat_message_histories.sql import \
|
| 47 |
BaseMessageConverter, DefaultMessageConverter
|
|
@@ -389,6 +392,26 @@ def create_message_model(table_name, DynamicBase): # type: ignore
|
|
| 389 |
|
| 390 |
return Message
|
| 391 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
| 393 |
"""The default message converter for SQLChatMessageHistory."""
|
| 394 |
|
|
@@ -411,9 +434,10 @@ class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
|
| 411 |
"additional_kwargs": {"timestamp": tstamp},
|
| 412 |
"data": message.dict()})
|
| 413 |
)
|
|
|
|
| 414 |
def from_sql_model(self, sql_message: Any) -> BaseMessage:
|
| 415 |
msg_dump = json.loads(sql_message.message)
|
| 416 |
-
msg =
|
| 417 |
msg.additional_kwargs = msg_dump["additional_kwargs"]
|
| 418 |
return msg
|
| 419 |
|
|
@@ -447,6 +471,38 @@ def create_agent_executor(name, session_id, llm, tools, system_prompt, **kwargs)
|
|
| 447 |
**kwargs
|
| 448 |
)
|
| 449 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
@st.cache_resource
|
| 451 |
def build_tools():
|
| 452 |
"""build all resources
|
|
@@ -465,13 +521,15 @@ def build_tools():
|
|
| 465 |
if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]:
|
| 466 |
st.session_state.sel_map_obj[k].update(build_chains_retrievers(k))
|
| 467 |
sel_map_obj.update({
|
| 468 |
-
f"
|
| 469 |
-
f"Vector SQL
|
| 470 |
})
|
| 471 |
return sel_map_obj
|
| 472 |
|
| 473 |
def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
| 474 |
-
chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
|
|
|
|
|
|
|
| 475 |
tools = [st.session_state.tools[k] for k in tool_names]
|
| 476 |
agent = create_agent_executor(
|
| 477 |
"chat_memory",
|
|
|
|
| 2 |
import json
|
| 3 |
import time
|
| 4 |
import hashlib
|
| 5 |
+
from typing import Dict, Any, List
|
| 6 |
import re
|
| 7 |
import pandas as pd
|
| 8 |
from os import environ
|
| 9 |
import streamlit as st
|
| 10 |
import datetime
|
| 11 |
+
from langchain.schema import BaseRetriever
|
| 12 |
+
from langchain.tools import Tool
|
| 13 |
+
from langchain.pydantic_v1 import BaseModel, Field
|
| 14 |
|
| 15 |
from sqlalchemy import Column, Text, create_engine, MetaData
|
| 16 |
from langchain.agents import AgentExecutor
|
|
|
|
| 31 |
SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
| 32 |
from langchain.prompts.prompt import PromptTemplate
|
| 33 |
from langchain.chat_models import ChatOpenAI
|
| 34 |
+
from langchain.schema import BaseRetriever, Document
|
| 35 |
from langchain import OpenAI
|
| 36 |
from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
|
| 37 |
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
|
|
|
| 39 |
from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings
|
| 40 |
from langchain.vectorstores import MyScaleSettings
|
| 41 |
from chains.arxiv_chains import MyScaleWithoutMetadataJson
|
|
|
|
| 42 |
from langchain.prompts.prompt import PromptTemplate
|
| 43 |
from langchain.prompts.chat import MessagesPlaceholder
|
| 44 |
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
|
| 45 |
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
| 46 |
+
from langchain.schema.messages import BaseMessage, HumanMessage, AIMessage, FunctionMessage,\
|
| 47 |
+
SystemMessage, ChatMessage, ToolMessage
|
| 48 |
from langchain.memory import SQLChatMessageHistory
|
| 49 |
from langchain.memory.chat_message_histories.sql import \
|
| 50 |
BaseMessageConverter, DefaultMessageConverter
|
|
|
|
| 392 |
|
| 393 |
return Message
|
| 394 |
|
| 395 |
+
def _message_from_dict(message: dict) -> BaseMessage:
|
| 396 |
+
_type = message["type"]
|
| 397 |
+
if _type == "human":
|
| 398 |
+
return HumanMessage(**message["data"])
|
| 399 |
+
elif _type == "ai":
|
| 400 |
+
return AIMessage(**message["data"])
|
| 401 |
+
elif _type == "system":
|
| 402 |
+
return SystemMessage(**message["data"])
|
| 403 |
+
elif _type == "chat":
|
| 404 |
+
return ChatMessage(**message["data"])
|
| 405 |
+
elif _type == "function":
|
| 406 |
+
return FunctionMessage(**message["data"])
|
| 407 |
+
elif _type == "tool":
|
| 408 |
+
return ToolMessage(**message["data"])
|
| 409 |
+
elif _type == "AIMessageChunk":
|
| 410 |
+
message["data"]["type"] = "ai"
|
| 411 |
+
return AIMessage(**message["data"])
|
| 412 |
+
else:
|
| 413 |
+
raise ValueError(f"Got unexpected message type: {_type}")
|
| 414 |
+
|
| 415 |
class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
| 416 |
"""The default message converter for SQLChatMessageHistory."""
|
| 417 |
|
|
|
|
| 434 |
"additional_kwargs": {"timestamp": tstamp},
|
| 435 |
"data": message.dict()})
|
| 436 |
)
|
| 437 |
+
|
| 438 |
def from_sql_model(self, sql_message: Any) -> BaseMessage:
|
| 439 |
msg_dump = json.loads(sql_message.message)
|
| 440 |
+
msg = _message_from_dict(msg_dump)
|
| 441 |
msg.additional_kwargs = msg_dump["additional_kwargs"]
|
| 442 |
return msg
|
| 443 |
|
|
|
|
| 471 |
**kwargs
|
| 472 |
)
|
| 473 |
|
| 474 |
+
class RetrieverInput(BaseModel):
|
| 475 |
+
query: str = Field(description="query to look up in retriever")
|
| 476 |
+
|
| 477 |
+
def create_retriever_tool(
|
| 478 |
+
retriever: BaseRetriever, name: str, description: str
|
| 479 |
+
) -> Tool:
|
| 480 |
+
"""Create a tool to do retrieval of documents.
|
| 481 |
+
|
| 482 |
+
Args:
|
| 483 |
+
retriever: The retriever to use for the retrieval
|
| 484 |
+
name: The name for the tool. This will be passed to the language model,
|
| 485 |
+
so should be unique and somewhat descriptive.
|
| 486 |
+
description: The description for the tool. This will be passed to the language
|
| 487 |
+
model, so should be descriptive.
|
| 488 |
+
|
| 489 |
+
Returns:
|
| 490 |
+
Tool class to pass to an agent
|
| 491 |
+
"""
|
| 492 |
+
def wrap(func):
|
| 493 |
+
def wrapped_retrieve(*args, **kwargs):
|
| 494 |
+
docs: List[Document] = func(*args, **kwargs)
|
| 495 |
+
return json.dumps([d.dict() for d in docs])
|
| 496 |
+
return wrapped_retrieve
|
| 497 |
+
|
| 498 |
+
return Tool(
|
| 499 |
+
name=name,
|
| 500 |
+
description=description,
|
| 501 |
+
func=wrap(retriever.get_relevant_documents),
|
| 502 |
+
coroutine=retriever.aget_relevant_documents,
|
| 503 |
+
args_schema=RetrieverInput,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
@st.cache_resource
|
| 507 |
def build_tools():
|
| 508 |
"""build all resources
|
|
|
|
| 521 |
if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]:
|
| 522 |
st.session_state.sel_map_obj[k].update(build_chains_retrievers(k))
|
| 523 |
sel_map_obj.update({
|
| 524 |
+
f"{k} + Self Querying": create_retriever_tool(st.session_state.sel_map_obj[k]["retriever"], *sel_map[k]["tool_desc"],),
|
| 525 |
+
f"{k} + Vector SQL": create_retriever_tool(st.session_state.sel_map_obj[k]["sql_retriever"], *sel_map[k]["tool_desc"],),
|
| 526 |
})
|
| 527 |
return sel_map_obj
|
| 528 |
|
| 529 |
def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
| 530 |
+
chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
|
| 531 |
+
openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
|
| 532 |
+
)
|
| 533 |
tools = [st.session_state.tools[k] for k in tool_names]
|
| 534 |
agent = create_agent_executor(
|
| 535 |
"chat_memory",
|