Spaces:
Runtime error
Runtime error
Synced repo using 'sync_with_huggingface' Github Action
Browse files- .streamlit/secrets.example.toml +4 -1
- app.py +7 -9
- callbacks/arxiv_callbacks.py +4 -3
- chains/arxiv_chains.py +11 -8
- chat.py +6 -6
- lib/helper.py +43 -28
- lib/json_conv.py +5 -2
- lib/private_kb.py +2 -1
- lib/schemas.py +1 -1
- lib/sessions.py +14 -11
- login.py +9 -8
- prompts/arxiv_prompt.py +1 -1
.streamlit/secrets.example.toml
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
-
MYSCALE_HOST = "msc-4a9e710a.us-east-1.aws.staging.myscale.cloud"
|
| 2 |
MYSCALE_PORT = 443
|
| 3 |
MYSCALE_USER = "chatdata"
|
| 4 |
MYSCALE_PASSWORD = "myscale_rocks"
|
| 5 |
OPENAI_API_BASE = "https://api.openai.com/v1"
|
| 6 |
OPENAI_API_KEY = "<your-openai-key>"
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MYSCALE_HOST = "msc-4a9e710a.us-east-1.aws.staging.myscale.cloud" # read-only database provided by MyScale
|
| 2 |
MYSCALE_PORT = 443
|
| 3 |
MYSCALE_USER = "chatdata"
|
| 4 |
MYSCALE_PASSWORD = "myscale_rocks"
|
| 5 |
OPENAI_API_BASE = "https://api.openai.com/v1"
|
| 6 |
OPENAI_API_KEY = "<your-openai-key>"
|
| 7 |
+
UNSTRUCTURED_API = "<your-unstructured-io-api>" # optional if you don't upload documents
|
| 8 |
+
AUTH0_DOMAIN = "<your-auth0-domain>" # optional if you don't user management
|
| 9 |
+
AUTH0_CLIENT_ID = "<your-auth0-client-id>" # optiona
|
app.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import time
|
| 3 |
import pandas as pd
|
| 4 |
from os import environ
|
| 5 |
import streamlit as st
|
|
@@ -13,10 +11,10 @@ from login import login, back_to_main
|
|
| 13 |
from lib.helper import build_tools, build_all, sel_map, display
|
| 14 |
|
| 15 |
|
| 16 |
-
|
| 17 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
| 18 |
|
| 19 |
-
st.set_page_config(page_title="ChatData",
|
|
|
|
| 20 |
st.markdown(
|
| 21 |
f"""
|
| 22 |
<style>
|
|
@@ -36,11 +34,12 @@ if login():
|
|
| 36 |
if "user_name" in st.session_state:
|
| 37 |
chat_page()
|
| 38 |
elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
|
| 39 |
-
|
| 40 |
sel = st.selectbox('Choose the knowledge base you want to ask with:',
|
| 41 |
-
|
| 42 |
sel_map[sel]['hint']()
|
| 43 |
-
tab_sql, tab_self_query = st.tabs(
|
|
|
|
| 44 |
with tab_sql:
|
| 45 |
sel_map[sel]['hint_sql']()
|
| 46 |
st.text_input("Ask a question:", key='query_sql')
|
|
@@ -85,7 +84,6 @@ if login():
|
|
| 85 |
st.write('Oops π΅ Something bad happened...')
|
| 86 |
raise e
|
| 87 |
|
| 88 |
-
|
| 89 |
with tab_self_query:
|
| 90 |
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='π‘')
|
| 91 |
st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
|
|
@@ -132,4 +130,4 @@ if login():
|
|
| 132 |
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
|
| 133 |
except Exception as e:
|
| 134 |
st.write('Oops π΅ Something bad happened...')
|
| 135 |
-
raise e
|
|
|
|
|
|
|
|
|
|
| 1 |
import pandas as pd
|
| 2 |
from os import environ
|
| 3 |
import streamlit as st
|
|
|
|
| 11 |
from lib.helper import build_tools, build_all, sel_map, display
|
| 12 |
|
| 13 |
|
|
|
|
| 14 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
| 15 |
|
| 16 |
+
st.set_page_config(page_title="ChatData",
|
| 17 |
+
page_icon="https://myscale.com/favicon.ico")
|
| 18 |
st.markdown(
|
| 19 |
f"""
|
| 20 |
<style>
|
|
|
|
| 34 |
if "user_name" in st.session_state:
|
| 35 |
chat_page()
|
| 36 |
elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
|
| 37 |
+
|
| 38 |
sel = st.selectbox('Choose the knowledge base you want to ask with:',
|
| 39 |
+
options=['ArXiv Papers', 'Wikipedia'])
|
| 40 |
sel_map[sel]['hint']()
|
| 41 |
+
tab_sql, tab_self_query = st.tabs(
|
| 42 |
+
['Vector SQL', 'Self-Query Retrievers'])
|
| 43 |
with tab_sql:
|
| 44 |
sel_map[sel]['hint_sql']()
|
| 45 |
st.text_input("Ask a question:", key='query_sql')
|
|
|
|
| 84 |
st.write('Oops π΅ Something bad happened...')
|
| 85 |
raise e
|
| 86 |
|
|
|
|
| 87 |
with tab_self_query:
|
| 88 |
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='π‘')
|
| 89 |
st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
|
|
|
|
| 130 |
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
|
| 131 |
except Exception as e:
|
| 132 |
st.write('Oops π΅ Something bad happened...')
|
| 133 |
+
raise e
|
callbacks/arxiv_callbacks.py
CHANGED
|
@@ -8,7 +8,6 @@ from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
|
| 8 |
StreamlitCallbackHandler,
|
| 9 |
)
|
| 10 |
from langchain.schema.output import LLMResult
|
| 11 |
-
from streamlit.delta_generator import DeltaGenerator
|
| 12 |
|
| 13 |
|
| 14 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
|
@@ -26,7 +25,8 @@ class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
|
| 26 |
self.progress_bar.progress(value=0.6, text="Searching in DB...")
|
| 27 |
if "repr" in outputs:
|
| 28 |
st.markdown("### Generated Filter")
|
| 29 |
-
st.markdown(
|
|
|
|
| 30 |
|
| 31 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
| 32 |
pass
|
|
@@ -88,7 +88,8 @@ class ChatDataSQLSearchCallBackHandler(StreamlitCallbackHandler):
|
|
| 88 |
st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
|
| 89 |
print(f"Vector SQL: {text}")
|
| 90 |
self.prog_value += self.prog_interval
|
| 91 |
-
self.progress_bar.progress(
|
|
|
|
| 92 |
|
| 93 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
| 94 |
cid = ".".join(serialized["id"])
|
|
|
|
| 8 |
StreamlitCallbackHandler,
|
| 9 |
)
|
| 10 |
from langchain.schema.output import LLMResult
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
|
|
|
|
| 25 |
self.progress_bar.progress(value=0.6, text="Searching in DB...")
|
| 26 |
if "repr" in outputs:
|
| 27 |
st.markdown("### Generated Filter")
|
| 28 |
+
st.markdown(
|
| 29 |
+
f"```python\n{outputs['repr']}\n```", unsafe_allow_html=True)
|
| 30 |
|
| 31 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
| 32 |
pass
|
|
|
|
| 88 |
st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
|
| 89 |
print(f"Vector SQL: {text}")
|
| 90 |
self.prog_value += self.prog_interval
|
| 91 |
+
self.progress_bar.progress(
|
| 92 |
+
value=self.prog_value, text="Searching in DB...")
|
| 93 |
|
| 94 |
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
| 95 |
cid = ".".join(serialized["id"])
|
chains/arxiv_chains.py
CHANGED
|
@@ -8,7 +8,6 @@ from langchain.callbacks.manager import (
|
|
| 8 |
CallbackManagerForChainRun,
|
| 9 |
)
|
| 10 |
from langchain.embeddings.base import Embeddings
|
| 11 |
-
from langchain.schema import BaseRetriever
|
| 12 |
from langchain.callbacks.manager import Callbacks
|
| 13 |
from langchain.schema.prompt_template import format_document
|
| 14 |
from langchain.docstore.document import Document
|
|
@@ -20,11 +19,12 @@ from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
|
|
| 20 |
|
| 21 |
logger = logging.getLogger()
|
| 22 |
|
|
|
|
| 23 |
class MyScaleWithoutMetadataJson(MyScale):
|
| 24 |
def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], **kwargs: Any) -> None:
|
| 25 |
super().__init__(embedding, config, **kwargs)
|
| 26 |
self.must_have_cols: List[str] = must_have_cols
|
| 27 |
-
|
| 28 |
def _build_qstr(
|
| 29 |
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
|
| 30 |
) -> str:
|
|
@@ -43,7 +43,7 @@ class MyScaleWithoutMetadataJson(MyScale):
|
|
| 43 |
LIMIT {topk}
|
| 44 |
"""
|
| 45 |
return q_str
|
| 46 |
-
|
| 47 |
def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any) -> List[Document]:
|
| 48 |
q_str = self._build_qstr(embedding, k, where_str)
|
| 49 |
try:
|
|
@@ -55,9 +55,11 @@ class MyScaleWithoutMetadataJson(MyScale):
|
|
| 55 |
for r in self.client.query(q_str).named_results()
|
| 56 |
]
|
| 57 |
except Exception as e:
|
| 58 |
-
logger.error(
|
|
|
|
| 59 |
return []
|
| 60 |
|
|
|
|
| 61 |
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
|
| 62 |
"""Based on VectorSQLOutputParser
|
| 63 |
It also modify the SQL to get all columns
|
|
@@ -73,9 +75,11 @@ class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
|
|
| 73 |
start = text.upper().find("SELECT")
|
| 74 |
if start >= 0:
|
| 75 |
end = text.upper().find("FROM")
|
| 76 |
-
text = text.replace(
|
|
|
|
| 77 |
return super().parse(text)
|
| 78 |
|
|
|
|
| 79 |
class ArXivStuffDocumentChain(StuffDocumentsChain):
|
| 80 |
"""Combine arxiv documents with PDF reference number"""
|
| 81 |
|
|
@@ -172,8 +176,7 @@ class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
|
|
| 172 |
answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]")
|
| 173 |
sources.append(d)
|
| 174 |
ref_cnt += 1
|
| 175 |
-
|
| 176 |
-
|
| 177 |
result: Dict[str, Any] = {
|
| 178 |
self.answer_key: answer,
|
| 179 |
self.sources_answer_key: sources,
|
|
@@ -191,4 +194,4 @@ class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
|
|
| 191 |
|
| 192 |
@property
|
| 193 |
def _chain_type(self) -> str:
|
| 194 |
-
return "arxiv_qa_with_sources_chain"
|
|
|
|
| 8 |
CallbackManagerForChainRun,
|
| 9 |
)
|
| 10 |
from langchain.embeddings.base import Embeddings
|
|
|
|
| 11 |
from langchain.callbacks.manager import Callbacks
|
| 12 |
from langchain.schema.prompt_template import format_document
|
| 13 |
from langchain.docstore.document import Document
|
|
|
|
| 19 |
|
| 20 |
logger = logging.getLogger()
|
| 21 |
|
| 22 |
+
|
| 23 |
class MyScaleWithoutMetadataJson(MyScale):
|
| 24 |
def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], **kwargs: Any) -> None:
|
| 25 |
super().__init__(embedding, config, **kwargs)
|
| 26 |
self.must_have_cols: List[str] = must_have_cols
|
| 27 |
+
|
| 28 |
def _build_qstr(
|
| 29 |
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
|
| 30 |
) -> str:
|
|
|
|
| 43 |
LIMIT {topk}
|
| 44 |
"""
|
| 45 |
return q_str
|
| 46 |
+
|
| 47 |
def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any) -> List[Document]:
|
| 48 |
q_str = self._build_qstr(embedding, k, where_str)
|
| 49 |
try:
|
|
|
|
| 55 |
for r in self.client.query(q_str).named_results()
|
| 56 |
]
|
| 57 |
except Exception as e:
|
| 58 |
+
logger.error(
|
| 59 |
+
f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
| 60 |
return []
|
| 61 |
|
| 62 |
+
|
| 63 |
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
|
| 64 |
"""Based on VectorSQLOutputParser
|
| 65 |
It also modify the SQL to get all columns
|
|
|
|
| 75 |
start = text.upper().find("SELECT")
|
| 76 |
if start >= 0:
|
| 77 |
end = text.upper().find("FROM")
|
| 78 |
+
text = text.replace(
|
| 79 |
+
text[start + len("SELECT") + 1: end - 1], ", ".join(self.must_have_columns))
|
| 80 |
return super().parse(text)
|
| 81 |
|
| 82 |
+
|
| 83 |
class ArXivStuffDocumentChain(StuffDocumentsChain):
|
| 84 |
"""Combine arxiv documents with PDF reference number"""
|
| 85 |
|
|
|
|
| 176 |
answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]")
|
| 177 |
sources.append(d)
|
| 178 |
ref_cnt += 1
|
| 179 |
+
|
|
|
|
| 180 |
result: Dict[str, Any] = {
|
| 181 |
self.answer_key: answer,
|
| 182 |
self.sources_answer_key: sources,
|
|
|
|
| 194 |
|
| 195 |
@property
|
| 196 |
def _chain_type(self) -> str:
|
| 197 |
+
return "arxiv_qa_with_sources_chain"
|
chat.py
CHANGED
|
@@ -8,9 +8,6 @@ from lib.sessions import SessionManager
|
|
| 8 |
from lib.private_kb import PrivateKnowledgeBase
|
| 9 |
from langchain.schema import HumanMessage, FunctionMessage
|
| 10 |
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
|
| 11 |
-
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
| 12 |
-
StreamlitCallbackHandler,
|
| 13 |
-
)
|
| 14 |
from lib.json_conv import CustomJSONDecoder
|
| 15 |
|
| 16 |
from lib.helper import (
|
|
@@ -313,7 +310,8 @@ def chat_page():
|
|
| 313 |
key="b_tool_files",
|
| 314 |
format_func=lambda x: x["file_name"],
|
| 315 |
)
|
| 316 |
-
st.text_input(
|
|
|
|
| 317 |
st.text_input(
|
| 318 |
"Tool Description",
|
| 319 |
"Searches among user's private files and returns related documents",
|
|
@@ -359,14 +357,16 @@ def chat_page():
|
|
| 359 |
)
|
| 360 |
st.markdown("### Uploaded Files")
|
| 361 |
st.dataframe(
|
| 362 |
-
st.session_state.private_kb.list_files(
|
|
|
|
| 363 |
use_container_width=True,
|
| 364 |
)
|
| 365 |
col_1, col_2 = st.columns(2)
|
| 366 |
with col_1:
|
| 367 |
st.button("Add Files", on_click=add_file)
|
| 368 |
with col_2:
|
| 369 |
-
st.button("Clear Files and All Tools",
|
|
|
|
| 370 |
|
| 371 |
st.button("Clear Chat History", on_click=clear_history)
|
| 372 |
st.button("Logout", on_click=back_to_main)
|
|
|
|
| 8 |
from lib.private_kb import PrivateKnowledgeBase
|
| 9 |
from langchain.schema import HumanMessage, FunctionMessage
|
| 10 |
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
|
|
|
|
|
|
|
|
|
|
| 11 |
from lib.json_conv import CustomJSONDecoder
|
| 12 |
|
| 13 |
from lib.helper import (
|
|
|
|
| 310 |
key="b_tool_files",
|
| 311 |
format_func=lambda x: x["file_name"],
|
| 312 |
)
|
| 313 |
+
st.text_input(
|
| 314 |
+
"Tool Name", "get_relevant_documents", key="b_tool_name")
|
| 315 |
st.text_input(
|
| 316 |
"Tool Description",
|
| 317 |
"Searches among user's private files and returns related documents",
|
|
|
|
| 357 |
)
|
| 358 |
st.markdown("### Uploaded Files")
|
| 359 |
st.dataframe(
|
| 360 |
+
st.session_state.private_kb.list_files(
|
| 361 |
+
st.session_state.user_name),
|
| 362 |
use_container_width=True,
|
| 363 |
)
|
| 364 |
col_1, col_2 = st.columns(2)
|
| 365 |
with col_1:
|
| 366 |
st.button("Add Files", on_click=add_file)
|
| 367 |
with col_2:
|
| 368 |
+
st.button("Clear Files and All Tools",
|
| 369 |
+
on_click=clear_files)
|
| 370 |
|
| 371 |
st.button("Clear Chat History", on_click=clear_history)
|
| 372 |
st.button("Logout", on_click=back_to_main)
|
lib/helper.py
CHANGED
|
@@ -4,10 +4,8 @@ import time
|
|
| 4 |
import hashlib
|
| 5 |
from typing import Dict, Any, List, Tuple
|
| 6 |
import re
|
| 7 |
-
import pandas as pd
|
| 8 |
from os import environ
|
| 9 |
import streamlit as st
|
| 10 |
-
import datetime
|
| 11 |
from langchain.schema import BaseRetriever
|
| 12 |
from langchain.tools import Tool
|
| 13 |
from langchain.pydantic_v1 import BaseModel, Field
|
|
@@ -20,7 +18,7 @@ except ImportError:
|
|
| 20 |
from sqlalchemy.ext.declarative import declarative_base
|
| 21 |
from sqlalchemy.orm import sessionmaker
|
| 22 |
from clickhouse_sqlalchemy import (
|
| 23 |
-
|
| 24 |
)
|
| 25 |
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
|
| 26 |
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
|
|
@@ -43,12 +41,12 @@ from langchain.prompts.prompt import PromptTemplate
|
|
| 43 |
from langchain.prompts.chat import MessagesPlaceholder
|
| 44 |
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
|
| 45 |
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
| 46 |
-
from langchain.schema.messages import BaseMessage, HumanMessage, AIMessage, FunctionMessage
|
| 47 |
SystemMessage, ChatMessage, ToolMessage
|
| 48 |
from langchain.memory import SQLChatMessageHistory
|
| 49 |
from langchain.memory.chat_message_histories.sql import \
|
| 50 |
-
|
| 51 |
-
from langchain.schema.messages import BaseMessage
|
| 52 |
# from langchain.agents.agent_toolkits import create_retriever_tool
|
| 53 |
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
|
| 54 |
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
|
@@ -73,7 +71,7 @@ UNSTRUCTURED_API = st.secrets['UNSTRUCTURED_API']
|
|
| 73 |
|
| 74 |
COMBINE_PROMPT = ChatPromptTemplate.from_strings(
|
| 75 |
string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
|
| 76 |
-
|
| 77 |
DEFAULT_SYSTEM_PROMPT = (
|
| 78 |
"Do your best to answer the questions. "
|
| 79 |
"Feel free to use any tools available to look up "
|
|
@@ -81,6 +79,7 @@ DEFAULT_SYSTEM_PROMPT = (
|
|
| 81 |
"when calling search functions."
|
| 82 |
)
|
| 83 |
|
|
|
|
| 84 |
def hint_arxiv():
|
| 85 |
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
| 86 |
"For example: \n\n"
|
|
@@ -150,7 +149,8 @@ sel_map = {
|
|
| 150 |
"hint": hint_wiki,
|
| 151 |
"hint_sql": hint_sql_wiki,
|
| 152 |
"doc_prompt": PromptTemplate(
|
| 153 |
-
input_variables=["page_content",
|
|
|
|
| 154 |
template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"),
|
| 155 |
"metadata_cols": [
|
| 156 |
AttributeInfo(
|
|
@@ -224,6 +224,7 @@ sel_map = {
|
|
| 224 |
}
|
| 225 |
}
|
| 226 |
|
|
|
|
| 227 |
def build_embedding_model(_sel):
|
| 228 |
"""Build embedding model
|
| 229 |
"""
|
|
@@ -253,7 +254,8 @@ def build_chains_retrievers(_sel: str) -> Dict[str, Any]:
|
|
| 253 |
"sql_retriever": sql_retriever,
|
| 254 |
"sql_chain": sql_chain
|
| 255 |
}
|
| 256 |
-
|
|
|
|
| 257 |
def build_self_query(_sel: str) -> SelfQueryRetriever:
|
| 258 |
"""Build self querying retriever
|
| 259 |
|
|
@@ -278,18 +280,20 @@ def build_self_query(_sel: str) -> SelfQueryRetriever:
|
|
| 278 |
"vector": sel_map[_sel]["vector_col"],
|
| 279 |
"metadata": sel_map[_sel]["metadata_col"]
|
| 280 |
})
|
| 281 |
-
doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config,
|
| 282 |
must_have_cols=sel_map[_sel]['must_have_cols'])
|
| 283 |
|
| 284 |
with st.spinner(f"Building Self Query Retriever for {_sel}..."):
|
| 285 |
metadata_field_info = sel_map[_sel]["metadata_cols"]
|
| 286 |
retriever = SelfQueryRetriever.from_llm(
|
| 287 |
-
OpenAI(model_name=query_model_name,
|
|
|
|
| 288 |
doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
|
| 289 |
use_original_query=False, structured_query_translator=MyScaleTranslator())
|
| 290 |
return retriever
|
| 291 |
|
| 292 |
-
|
|
|
|
| 293 |
"""Build Vector SQL Database Retriever
|
| 294 |
|
| 295 |
:param _sel: selected knowledge base
|
|
@@ -308,7 +312,8 @@ def build_vector_sql(_sel: str)->VectorSQLDatabaseChainRetriever:
|
|
| 308 |
output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
|
| 309 |
model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"])
|
| 310 |
sql_query_chain = VectorSQLDatabaseChain.from_llm(
|
| 311 |
-
llm=OpenAI(model_name=query_model_name,
|
|
|
|
| 312 |
prompt=PROMPT,
|
| 313 |
top_k=10,
|
| 314 |
return_direct=True,
|
|
@@ -319,8 +324,9 @@ def build_vector_sql(_sel: str)->VectorSQLDatabaseChainRetriever:
|
|
| 319 |
sql_retriever = VectorSQLDatabaseChainRetriever(
|
| 320 |
sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"])
|
| 321 |
return sql_retriever
|
| 322 |
-
|
| 323 |
-
|
|
|
|
| 324 |
"""_summary_
|
| 325 |
|
| 326 |
:param _sel: selected knowledge base
|
|
@@ -350,6 +356,7 @@ def build_qa_chain(_sel: str, retriever: BaseRetriever, name: str="Self-query")
|
|
| 350 |
)
|
| 351 |
return chain
|
| 352 |
|
|
|
|
| 353 |
@st.cache_resource
|
| 354 |
def build_all() -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 355 |
"""build all resources
|
|
@@ -365,6 +372,7 @@ def build_all() -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
| 365 |
sel_map_obj[k] = build_chains_retrievers(k)
|
| 366 |
return sel_map_obj, embeddings
|
| 367 |
|
|
|
|
| 368 |
def create_message_model(table_name, DynamicBase): # type: ignore
|
| 369 |
"""
|
| 370 |
Create a message model for a given table name.
|
|
@@ -397,6 +405,7 @@ def create_message_model(table_name, DynamicBase): # type: ignore
|
|
| 397 |
|
| 398 |
return Message
|
| 399 |
|
|
|
|
| 400 |
def _message_from_dict(message: dict) -> BaseMessage:
|
| 401 |
_type = message["type"]
|
| 402 |
if _type == "human":
|
|
@@ -417,6 +426,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
|
|
| 417 |
else:
|
| 418 |
raise ValueError(f"Got unexpected message type: {_type}")
|
| 419 |
|
|
|
|
| 420 |
class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
| 421 |
"""The default message converter for SQLChatMessageHistory."""
|
| 422 |
|
|
@@ -425,27 +435,28 @@ class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
|
| 425 |
|
| 426 |
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
|
| 427 |
tstamp = time.time()
|
| 428 |
-
msg_id = hashlib.sha256(
|
|
|
|
| 429 |
user_id, _ = session_id.split("?")
|
| 430 |
return self.model_class(
|
| 431 |
-
id=tstamp,
|
| 432 |
msg_id=msg_id,
|
| 433 |
user_id=user_id,
|
| 434 |
-
session_id=session_id,
|
| 435 |
type=message.type,
|
| 436 |
addtionals=json.dumps(message.additional_kwargs),
|
| 437 |
message=json.dumps({
|
| 438 |
-
"type": message.type,
|
| 439 |
"additional_kwargs": {"timestamp": tstamp},
|
| 440 |
"data": message.dict()})
|
| 441 |
)
|
| 442 |
-
|
| 443 |
def from_sql_model(self, sql_message: Any) -> BaseMessage:
|
| 444 |
msg_dump = json.loads(sql_message.message)
|
| 445 |
msg = _message_from_dict(msg_dump)
|
| 446 |
msg.additional_kwargs = msg_dump["additional_kwargs"]
|
| 447 |
return msg
|
| 448 |
-
|
| 449 |
def get_sql_model_class(self) -> Any:
|
| 450 |
return self.model_class
|
| 451 |
|
|
@@ -458,7 +469,7 @@ def create_agent_executor(name, session_id, llm, tools, system_prompt, **kwargs)
|
|
| 458 |
connection_string=f'{conn_str}/chat?protocol=https',
|
| 459 |
custom_message_converter=DefaultClickhouseMessageConverter(name))
|
| 460 |
memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
|
| 461 |
-
|
| 462 |
_system_message = SystemMessage(
|
| 463 |
content=system_prompt
|
| 464 |
)
|
|
@@ -475,10 +486,12 @@ def create_agent_executor(name, session_id, llm, tools, system_prompt, **kwargs)
|
|
| 475 |
return_intermediate_steps=True,
|
| 476 |
**kwargs
|
| 477 |
)
|
| 478 |
-
|
|
|
|
| 479 |
class RetrieverInput(BaseModel):
|
| 480 |
query: str = Field(description="query to look up in retriever")
|
| 481 |
|
|
|
|
| 482 |
def create_retriever_tool(
|
| 483 |
retriever: BaseRetriever, name: str, description: str
|
| 484 |
) -> Tool:
|
|
@@ -499,7 +512,7 @@ def create_retriever_tool(
|
|
| 499 |
docs: List[Document] = func(*args, **kwargs)
|
| 500 |
return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)
|
| 501 |
return wrapped_retrieve
|
| 502 |
-
|
| 503 |
return Tool(
|
| 504 |
name=name,
|
| 505 |
description=description,
|
|
@@ -507,7 +520,8 @@ def create_retriever_tool(
|
|
| 507 |
coroutine=retriever.aget_relevant_documents,
|
| 508 |
args_schema=RetrieverInput,
|
| 509 |
)
|
| 510 |
-
|
|
|
|
| 511 |
@st.cache_resource
|
| 512 |
def build_tools():
|
| 513 |
"""build all resources
|
|
@@ -531,8 +545,9 @@ def build_tools():
|
|
| 531 |
})
|
| 532 |
return sel_map_obj
|
| 533 |
|
|
|
|
| 534 |
def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
| 535 |
-
chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
|
| 536 |
openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
|
| 537 |
)
|
| 538 |
tools = st.session_state.tools if "tools_with_users" not in st.session_state else st.session_state.tools_with_users
|
|
@@ -543,7 +558,7 @@ def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temper
|
|
| 543 |
chat_llm,
|
| 544 |
tools=sel_tools,
|
| 545 |
system_prompt=system_prompt
|
| 546 |
-
|
| 547 |
return agent
|
| 548 |
|
| 549 |
|
|
@@ -556,4 +571,4 @@ def display(dataframe, columns_=None, index=None):
|
|
| 556 |
else:
|
| 557 |
st.dataframe(dataframe)
|
| 558 |
else:
|
| 559 |
-
st.write("Sorry π΅ we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
|
|
|
|
| 4 |
import hashlib
|
| 5 |
from typing import Dict, Any, List, Tuple
|
| 6 |
import re
|
|
|
|
| 7 |
from os import environ
|
| 8 |
import streamlit as st
|
|
|
|
| 9 |
from langchain.schema import BaseRetriever
|
| 10 |
from langchain.tools import Tool
|
| 11 |
from langchain.pydantic_v1 import BaseModel, Field
|
|
|
|
| 18 |
from sqlalchemy.ext.declarative import declarative_base
|
| 19 |
from sqlalchemy.orm import sessionmaker
|
| 20 |
from clickhouse_sqlalchemy import (
|
| 21 |
+
types, engines
|
| 22 |
)
|
| 23 |
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
|
| 24 |
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
|
|
|
|
| 41 |
from langchain.prompts.chat import MessagesPlaceholder
|
| 42 |
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
|
| 43 |
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
| 44 |
+
from langchain.schema.messages import BaseMessage, HumanMessage, AIMessage, FunctionMessage, \
|
| 45 |
SystemMessage, ChatMessage, ToolMessage
|
| 46 |
from langchain.memory import SQLChatMessageHistory
|
| 47 |
from langchain.memory.chat_message_histories.sql import \
|
| 48 |
+
DefaultMessageConverter
|
| 49 |
+
from langchain.schema.messages import BaseMessage
|
| 50 |
# from langchain.agents.agent_toolkits import create_retriever_tool
|
| 51 |
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
|
| 52 |
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
|
|
|
| 71 |
|
| 72 |
COMBINE_PROMPT = ChatPromptTemplate.from_strings(
|
| 73 |
string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
|
| 74 |
+
(HumanMessagePromptTemplate, '{question}')])
|
| 75 |
DEFAULT_SYSTEM_PROMPT = (
|
| 76 |
"Do your best to answer the questions. "
|
| 77 |
"Feel free to use any tools available to look up "
|
|
|
|
| 79 |
"when calling search functions."
|
| 80 |
)
|
| 81 |
|
| 82 |
+
|
| 83 |
def hint_arxiv():
|
| 84 |
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
| 85 |
"For example: \n\n"
|
|
|
|
| 149 |
"hint": hint_wiki,
|
| 150 |
"hint_sql": hint_sql_wiki,
|
| 151 |
"doc_prompt": PromptTemplate(
|
| 152 |
+
input_variables=["page_content",
|
| 153 |
+
"url", "title", "ref_id", "views"],
|
| 154 |
template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"),
|
| 155 |
"metadata_cols": [
|
| 156 |
AttributeInfo(
|
|
|
|
| 224 |
}
|
| 225 |
}
|
| 226 |
|
| 227 |
+
|
| 228 |
def build_embedding_model(_sel):
|
| 229 |
"""Build embedding model
|
| 230 |
"""
|
|
|
|
| 254 |
"sql_retriever": sql_retriever,
|
| 255 |
"sql_chain": sql_chain
|
| 256 |
}
|
| 257 |
+
|
| 258 |
+
|
| 259 |
def build_self_query(_sel: str) -> SelfQueryRetriever:
|
| 260 |
"""Build self querying retriever
|
| 261 |
|
|
|
|
| 280 |
"vector": sel_map[_sel]["vector_col"],
|
| 281 |
"metadata": sel_map[_sel]["metadata_col"]
|
| 282 |
})
|
| 283 |
+
doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config,
|
| 284 |
must_have_cols=sel_map[_sel]['must_have_cols'])
|
| 285 |
|
| 286 |
with st.spinner(f"Building Self Query Retriever for {_sel}..."):
|
| 287 |
metadata_field_info = sel_map[_sel]["metadata_cols"]
|
| 288 |
retriever = SelfQueryRetriever.from_llm(
|
| 289 |
+
OpenAI(model_name=query_model_name,
|
| 290 |
+
openai_api_key=OPENAI_API_KEY, temperature=0),
|
| 291 |
doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
|
| 292 |
use_original_query=False, structured_query_translator=MyScaleTranslator())
|
| 293 |
return retriever
|
| 294 |
|
| 295 |
+
|
| 296 |
+
def build_vector_sql(_sel: str) -> VectorSQLDatabaseChainRetriever:
|
| 297 |
"""Build Vector SQL Database Retriever
|
| 298 |
|
| 299 |
:param _sel: selected knowledge base
|
|
|
|
| 312 |
output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
|
| 313 |
model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"])
|
| 314 |
sql_query_chain = VectorSQLDatabaseChain.from_llm(
|
| 315 |
+
llm=OpenAI(model_name=query_model_name,
|
| 316 |
+
openai_api_key=OPENAI_API_KEY, temperature=0),
|
| 317 |
prompt=PROMPT,
|
| 318 |
top_k=10,
|
| 319 |
return_direct=True,
|
|
|
|
| 324 |
sql_retriever = VectorSQLDatabaseChainRetriever(
|
| 325 |
sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"])
|
| 326 |
return sql_retriever
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def build_qa_chain(_sel: str, retriever: BaseRetriever, name: str = "Self-query") -> ArXivQAwithSourcesChain:
|
| 330 |
"""_summary_
|
| 331 |
|
| 332 |
:param _sel: selected knowledge base
|
|
|
|
| 356 |
)
|
| 357 |
return chain
|
| 358 |
|
| 359 |
+
|
| 360 |
@st.cache_resource
|
| 361 |
def build_all() -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 362 |
"""build all resources
|
|
|
|
| 372 |
sel_map_obj[k] = build_chains_retrievers(k)
|
| 373 |
return sel_map_obj, embeddings
|
| 374 |
|
| 375 |
+
|
| 376 |
def create_message_model(table_name, DynamicBase): # type: ignore
|
| 377 |
"""
|
| 378 |
Create a message model for a given table name.
|
|
|
|
| 405 |
|
| 406 |
return Message
|
| 407 |
|
| 408 |
+
|
| 409 |
def _message_from_dict(message: dict) -> BaseMessage:
|
| 410 |
_type = message["type"]
|
| 411 |
if _type == "human":
|
|
|
|
| 426 |
else:
|
| 427 |
raise ValueError(f"Got unexpected message type: {_type}")
|
| 428 |
|
| 429 |
+
|
| 430 |
class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
| 431 |
"""The default message converter for SQLChatMessageHistory."""
|
| 432 |
|
|
|
|
| 435 |
|
| 436 |
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
|
| 437 |
tstamp = time.time()
|
| 438 |
+
msg_id = hashlib.sha256(
|
| 439 |
+
f"{session_id}_{message}_{tstamp}".encode('utf-8')).hexdigest()
|
| 440 |
user_id, _ = session_id.split("?")
|
| 441 |
return self.model_class(
|
| 442 |
+
id=tstamp,
|
| 443 |
msg_id=msg_id,
|
| 444 |
user_id=user_id,
|
| 445 |
+
session_id=session_id,
|
| 446 |
type=message.type,
|
| 447 |
addtionals=json.dumps(message.additional_kwargs),
|
| 448 |
message=json.dumps({
|
| 449 |
+
"type": message.type,
|
| 450 |
"additional_kwargs": {"timestamp": tstamp},
|
| 451 |
"data": message.dict()})
|
| 452 |
)
|
| 453 |
+
|
| 454 |
def from_sql_model(self, sql_message: Any) -> BaseMessage:
|
| 455 |
msg_dump = json.loads(sql_message.message)
|
| 456 |
msg = _message_from_dict(msg_dump)
|
| 457 |
msg.additional_kwargs = msg_dump["additional_kwargs"]
|
| 458 |
return msg
|
| 459 |
+
|
| 460 |
def get_sql_model_class(self) -> Any:
|
| 461 |
return self.model_class
|
| 462 |
|
|
|
|
| 469 |
connection_string=f'{conn_str}/chat?protocol=https',
|
| 470 |
custom_message_converter=DefaultClickhouseMessageConverter(name))
|
| 471 |
memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
|
| 472 |
+
|
| 473 |
_system_message = SystemMessage(
|
| 474 |
content=system_prompt
|
| 475 |
)
|
|
|
|
| 486 |
return_intermediate_steps=True,
|
| 487 |
**kwargs
|
| 488 |
)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
class RetrieverInput(BaseModel):
|
| 492 |
query: str = Field(description="query to look up in retriever")
|
| 493 |
|
| 494 |
+
|
| 495 |
def create_retriever_tool(
|
| 496 |
retriever: BaseRetriever, name: str, description: str
|
| 497 |
) -> Tool:
|
|
|
|
| 512 |
docs: List[Document] = func(*args, **kwargs)
|
| 513 |
return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)
|
| 514 |
return wrapped_retrieve
|
| 515 |
+
|
| 516 |
return Tool(
|
| 517 |
name=name,
|
| 518 |
description=description,
|
|
|
|
| 520 |
coroutine=retriever.aget_relevant_documents,
|
| 521 |
args_schema=RetrieverInput,
|
| 522 |
)
|
| 523 |
+
|
| 524 |
+
|
| 525 |
@st.cache_resource
|
| 526 |
def build_tools():
|
| 527 |
"""build all resources
|
|
|
|
| 545 |
})
|
| 546 |
return sel_map_obj
|
| 547 |
|
| 548 |
+
|
| 549 |
def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
| 550 |
+
chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
|
| 551 |
openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
|
| 552 |
)
|
| 553 |
tools = st.session_state.tools if "tools_with_users" not in st.session_state else st.session_state.tools_with_users
|
|
|
|
| 558 |
chat_llm,
|
| 559 |
tools=sel_tools,
|
| 560 |
system_prompt=system_prompt
|
| 561 |
+
)
|
| 562 |
return agent
|
| 563 |
|
| 564 |
|
|
|
|
| 571 |
else:
|
| 572 |
st.dataframe(dataframe)
|
| 573 |
else:
|
| 574 |
+
st.write("Sorry π΅ we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
|
lib/json_conv.py
CHANGED
|
@@ -1,15 +1,18 @@
|
|
| 1 |
import json
|
| 2 |
import datetime
|
| 3 |
|
|
|
|
| 4 |
class CustomJSONEncoder(json.JSONEncoder):
|
| 5 |
def default(self, obj):
|
| 6 |
if isinstance(obj, datetime.datetime):
|
| 7 |
return datetime.datetime.isoformat(obj)
|
| 8 |
return json.JSONEncoder.default(self, obj)
|
| 9 |
|
|
|
|
| 10 |
class CustomJSONDecoder(json.JSONDecoder):
|
| 11 |
def __init__(self, *args, **kwargs):
|
| 12 |
-
json.JSONDecoder.__init__(
|
|
|
|
| 13 |
|
| 14 |
def object_hook(self, source):
|
| 15 |
for k, v in source.items():
|
|
@@ -18,4 +21,4 @@ class CustomJSONDecoder(json.JSONDecoder):
|
|
| 18 |
source[k] = datetime.datetime.fromisoformat(str(v))
|
| 19 |
except:
|
| 20 |
pass
|
| 21 |
-
return source
|
|
|
|
| 1 |
import json
|
| 2 |
import datetime
|
| 3 |
|
| 4 |
+
|
| 5 |
class CustomJSONEncoder(json.JSONEncoder):
|
| 6 |
def default(self, obj):
|
| 7 |
if isinstance(obj, datetime.datetime):
|
| 8 |
return datetime.datetime.isoformat(obj)
|
| 9 |
return json.JSONEncoder.default(self, obj)
|
| 10 |
|
| 11 |
+
|
| 12 |
class CustomJSONDecoder(json.JSONDecoder):
|
| 13 |
def __init__(self, *args, **kwargs):
|
| 14 |
+
json.JSONDecoder.__init__(
|
| 15 |
+
self, object_hook=self.object_hook, *args, **kwargs)
|
| 16 |
|
| 17 |
def object_hook(self, source):
|
| 18 |
for k, v in source.items():
|
|
|
|
| 21 |
source[k] = datetime.datetime.fromisoformat(str(v))
|
| 22 |
except:
|
| 23 |
pass
|
| 24 |
+
return source
|
lib/private_kb.py
CHANGED
|
@@ -52,7 +52,8 @@ def parse_files(api_key, user_id, files: List[UploadedFile]):
|
|
| 52 |
|
| 53 |
def extract_embedding(embeddings: Embeddings, texts):
|
| 54 |
if len(texts) > 0:
|
| 55 |
-
embs = embeddings.embed_documents(
|
|
|
|
| 56 |
for i, _ in enumerate(texts):
|
| 57 |
texts[i]["vector"] = embs[i]
|
| 58 |
return texts
|
|
|
|
| 52 |
|
| 53 |
def extract_embedding(embeddings: Embeddings, texts):
|
| 54 |
if len(texts) > 0:
|
| 55 |
+
embs = embeddings.embed_documents(
|
| 56 |
+
[t["text"] for _, t in enumerate(texts)])
|
| 57 |
for i, _ in enumerate(texts):
|
| 58 |
texts[i]["vector"] = embs[i]
|
| 59 |
return texts
|
lib/schemas.py
CHANGED
|
@@ -49,4 +49,4 @@ def create_session_table(table_name, DynamicBase): # type: ignore
|
|
| 49 |
order_by=('session_id')),
|
| 50 |
{'comment': 'Store Session and Prompts'}
|
| 51 |
)
|
| 52 |
-
return Session
|
|
|
|
| 49 |
order_by=('session_id')),
|
| 50 |
{'comment': 'Store Session and Prompts'}
|
| 51 |
)
|
| 52 |
+
return Session
|
lib/sessions.py
CHANGED
|
@@ -6,9 +6,9 @@ except ImportError:
|
|
| 6 |
from langchain.schema import BaseChatMessageHistory
|
| 7 |
from datetime import datetime
|
| 8 |
from sqlalchemy import Column, Text, orm, create_engine
|
| 9 |
-
from clickhouse_sqlalchemy import types, engines
|
| 10 |
from .schemas import create_message_model, create_session_table
|
| 11 |
|
|
|
|
| 12 |
def get_sessions(engine, model_class, user_id):
|
| 13 |
with orm.sessionmaker(engine)() as session:
|
| 14 |
result = (
|
|
@@ -20,14 +20,17 @@ def get_sessions(engine, model_class, user_id):
|
|
| 20 |
)
|
| 21 |
return json.loads(result)
|
| 22 |
|
|
|
|
| 23 |
class SessionManager:
|
| 24 |
def __init__(self, session_state, host, port, username, password,
|
| 25 |
db='chat', sess_table='sessions', msg_table='chat_memory') -> None:
|
| 26 |
conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
|
| 27 |
self.engine = create_engine(conn_str, echo=False)
|
| 28 |
-
self.sess_model_class = create_session_table(
|
|
|
|
| 29 |
self.sess_model_class.metadata.create_all(self.engine)
|
| 30 |
-
self.msg_model_class = create_message_model(
|
|
|
|
| 31 |
self.msg_model_class.metadata.create_all(self.engine)
|
| 32 |
self.Session = orm.sessionmaker(self.engine)
|
| 33 |
self.session_state = session_state
|
|
@@ -46,14 +49,15 @@ class SessionManager:
|
|
| 46 |
sessions.append({
|
| 47 |
"session_id": r.session_id.split("?")[-1],
|
| 48 |
"system_prompt": r.system_prompt,
|
| 49 |
-
|
| 50 |
return sessions
|
| 51 |
-
|
| 52 |
def modify_system_prompt(self, session_id, sys_prompt):
|
| 53 |
with self.Session() as session:
|
| 54 |
-
session.update(self.sess_model_class).where(
|
|
|
|
| 55 |
session.commit()
|
| 56 |
-
|
| 57 |
def add_session(self, user_id, session_id, system_prompt, **kwargs):
|
| 58 |
with self.Session() as session:
|
| 59 |
elem = self.sess_model_class(
|
|
@@ -62,14 +66,13 @@ class SessionManager:
|
|
| 62 |
)
|
| 63 |
session.add(elem)
|
| 64 |
session.commit()
|
| 65 |
-
|
| 66 |
def remove_session(self, session_id):
|
| 67 |
with self.Session() as session:
|
| 68 |
-
session.query(self.sess_model_class).where(
|
|
|
|
| 69 |
# session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete()
|
| 70 |
if "agent" in self.session_state:
|
| 71 |
self.session_state.agent.memory.chat_memory.clear()
|
| 72 |
if "file_analyzer" in self.session_state:
|
| 73 |
self.session_state.file_analyzer.clear_files()
|
| 74 |
-
|
| 75 |
-
|
|
|
|
| 6 |
from langchain.schema import BaseChatMessageHistory
|
| 7 |
from datetime import datetime
|
| 8 |
from sqlalchemy import Column, Text, orm, create_engine
|
|
|
|
| 9 |
from .schemas import create_message_model, create_session_table
|
| 10 |
|
| 11 |
+
|
| 12 |
def get_sessions(engine, model_class, user_id):
|
| 13 |
with orm.sessionmaker(engine)() as session:
|
| 14 |
result = (
|
|
|
|
| 20 |
)
|
| 21 |
return json.loads(result)
|
| 22 |
|
| 23 |
+
|
| 24 |
class SessionManager:
|
| 25 |
def __init__(self, session_state, host, port, username, password,
|
| 26 |
db='chat', sess_table='sessions', msg_table='chat_memory') -> None:
|
| 27 |
conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
|
| 28 |
self.engine = create_engine(conn_str, echo=False)
|
| 29 |
+
self.sess_model_class = create_session_table(
|
| 30 |
+
sess_table, declarative_base())
|
| 31 |
self.sess_model_class.metadata.create_all(self.engine)
|
| 32 |
+
self.msg_model_class = create_message_model(
|
| 33 |
+
msg_table, declarative_base())
|
| 34 |
self.msg_model_class.metadata.create_all(self.engine)
|
| 35 |
self.Session = orm.sessionmaker(self.engine)
|
| 36 |
self.session_state = session_state
|
|
|
|
| 49 |
sessions.append({
|
| 50 |
"session_id": r.session_id.split("?")[-1],
|
| 51 |
"system_prompt": r.system_prompt,
|
| 52 |
+
})
|
| 53 |
return sessions
|
| 54 |
+
|
| 55 |
def modify_system_prompt(self, session_id, sys_prompt):
|
| 56 |
with self.Session() as session:
|
| 57 |
+
session.update(self.sess_model_class).where(
|
| 58 |
+
self.sess_model_class == session_id).value(system_prompt=sys_prompt)
|
| 59 |
session.commit()
|
| 60 |
+
|
| 61 |
def add_session(self, user_id, session_id, system_prompt, **kwargs):
|
| 62 |
with self.Session() as session:
|
| 63 |
elem = self.sess_model_class(
|
|
|
|
| 66 |
)
|
| 67 |
session.add(elem)
|
| 68 |
session.commit()
|
| 69 |
+
|
| 70 |
def remove_session(self, session_id):
|
| 71 |
with self.Session() as session:
|
| 72 |
+
session.query(self.sess_model_class).where(
|
| 73 |
+
self.sess_model_class.session_id == session_id).delete()
|
| 74 |
# session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete()
|
| 75 |
if "agent" in self.session_state:
|
| 76 |
self.session_state.agent.memory.chat_memory.clear()
|
| 77 |
if "file_analyzer" in self.session_state:
|
| 78 |
self.session_state.file_analyzer.clear_files()
|
|
|
|
|
|
login.py
CHANGED
|
@@ -1,21 +1,21 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import time
|
| 3 |
-
import pandas as pd
|
| 4 |
-
from os import environ
|
| 5 |
import streamlit as st
|
| 6 |
from auth0_component import login_button
|
| 7 |
|
| 8 |
AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID']
|
| 9 |
AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN']
|
| 10 |
|
|
|
|
| 11 |
def login():
|
| 12 |
if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask):
|
| 13 |
return True
|
| 14 |
-
st.subheader(
|
|
|
|
| 15 |
st.write("You can now chat with ArXiv and Wikipedia! π\n")
|
| 16 |
st.write("Built purely with streamlit π , LangChain π¦π and love β€οΈ for AI!")
|
| 17 |
-
st.write(
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
st.divider()
|
| 20 |
col1, col2 = st.columns(2, gap='large')
|
| 21 |
with col1.container():
|
|
@@ -33,7 +33,7 @@ def login():
|
|
| 33 |
st.write("- [Privacy Policy](https://myscale.com/privacy/)\n"
|
| 34 |
"- [Terms of Sevice](https://myscale.com/terms/)")
|
| 35 |
if st.session_state.auth0 is not None:
|
| 36 |
-
st.session_state.user_info = dict(st.session_state.auth0)
|
| 37 |
if 'email' in st.session_state.user_info:
|
| 38 |
email = st.session_state.user_info["email"]
|
| 39 |
else:
|
|
@@ -44,6 +44,7 @@ def login():
|
|
| 44 |
if st.session_state.jump_query_ask:
|
| 45 |
st.experimental_rerun()
|
| 46 |
|
|
|
|
| 47 |
def back_to_main():
|
| 48 |
if "user_info" in st.session_state:
|
| 49 |
del st.session_state.user_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from auth0_component import login_button
|
| 3 |
|
| 4 |
AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID']
|
| 5 |
AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN']
|
| 6 |
|
| 7 |
+
|
| 8 |
def login():
|
| 9 |
if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask):
|
| 10 |
return True
|
| 11 |
+
st.subheader(
|
| 12 |
+
"π€ Welcom to [MyScale](https://myscale.com)'s [ChatData](https://github.com/myscale/ChatData)! π€ ")
|
| 13 |
st.write("You can now chat with ArXiv and Wikipedia! π\n")
|
| 14 |
st.write("Built purely with streamlit π , LangChain π¦π and love β€οΈ for AI!")
|
| 15 |
+
st.write(
|
| 16 |
+
"Follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!")
|
| 17 |
+
st.write(
|
| 18 |
+
"For more details, please refer to [our repository on GitHub](https://github.com/myscale/ChatData)!")
|
| 19 |
st.divider()
|
| 20 |
col1, col2 = st.columns(2, gap='large')
|
| 21 |
with col1.container():
|
|
|
|
| 33 |
st.write("- [Privacy Policy](https://myscale.com/privacy/)\n"
|
| 34 |
"- [Terms of Sevice](https://myscale.com/terms/)")
|
| 35 |
if st.session_state.auth0 is not None:
|
| 36 |
+
st.session_state.user_info = dict(st.session_state.auth0)
|
| 37 |
if 'email' in st.session_state.user_info:
|
| 38 |
email = st.session_state.user_info["email"]
|
| 39 |
else:
|
|
|
|
| 44 |
if st.session_state.jump_query_ask:
|
| 45 |
st.experimental_rerun()
|
| 46 |
|
| 47 |
+
|
| 48 |
def back_to_main():
|
| 49 |
if "user_info" in st.session_state:
|
| 50 |
del st.session_state.user_info
|
prompts/arxiv_prompt.py
CHANGED
|
@@ -6,7 +6,7 @@ combine_prompt_template = (
|
|
| 6 |
+ "relevant information but still try to provide an answer based on your general knowledge. You must refer to the "
|
| 7 |
+ "corresponding section name and page that you refer to when answering. The following is the related information "
|
| 8 |
+ "about the document that will help you answer users' questions, you MUST answer it using question's language:\n\n {summaries}"
|
| 9 |
-
+ "Now you should
|
| 10 |
)
|
| 11 |
|
| 12 |
_myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
|
|
|
|
| 6 |
+ "relevant information but still try to provide an answer based on your general knowledge. You must refer to the "
|
| 7 |
+ "corresponding section name and page that you refer to when answering. The following is the related information "
|
| 8 |
+ "about the document that will help you answer users' questions, you MUST answer it using question's language:\n\n {summaries}"
|
| 9 |
+
+ "Now you should answer user's question. Remember you must use `Doc #` to refer papers:\n\n"
|
| 10 |
)
|
| 11 |
|
| 12 |
_myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
|