Spaces:
Runtime error
Runtime error
Fangrui Liu commited on
Commit ยท
45180a0
1
Parent(s): d5a4cb4
add wikipedia
Browse files- app.py +235 -132
- callbacks/arxiv_callbacks.py +1 -1
- chains/arxiv_chains.py +49 -5
- prompts/arxiv_prompt.py +4 -4
app.py
CHANGED
|
@@ -1,3 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import re
|
| 2 |
import pandas as pd
|
| 3 |
from os import environ
|
|
@@ -6,34 +28,156 @@ import datetime
|
|
| 6 |
environ['TOKENIZERS_PARALLELISM'] = 'true'
|
| 7 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
| 8 |
|
| 9 |
-
from langchain.vectorstores import MyScale, MyScaleSettings
|
| 10 |
-
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
| 11 |
-
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
| 12 |
-
from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
|
| 13 |
-
from langchain import OpenAI
|
| 14 |
-
from langchain.chat_models import ChatOpenAI
|
| 15 |
|
| 16 |
-
|
| 17 |
-
from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
|
| 18 |
-
SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
| 19 |
-
from sqlalchemy import create_engine, MetaData
|
| 20 |
-
from langchain.chains import LLMChain
|
| 21 |
|
| 22 |
-
|
| 23 |
-
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
|
| 24 |
-
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
|
| 30 |
-
ChatDataSQLAskCallBackHandler
|
| 31 |
-
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
|
| 32 |
|
| 33 |
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
def try_eval(x):
|
|
@@ -55,14 +199,14 @@ def display(dataframe, columns_=None, index=None):
|
|
| 55 |
st.write("Sorry ๐ต we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
|
| 56 |
|
| 57 |
|
| 58 |
-
|
| 59 |
-
def build_retriever():
|
| 60 |
with st.spinner("Loading Model..."):
|
| 61 |
-
embeddings =
|
| 62 |
-
|
| 63 |
-
embed_instruction="Represent the question for retrieving supporting scientific papers: ")
|
| 64 |
|
| 65 |
-
|
|
|
|
|
|
|
| 66 |
myscale_connection = {
|
| 67 |
"host": st.secrets['MYSCALE_HOST'],
|
| 68 |
"port": st.secrets['MYSCALE_PORT'],
|
|
@@ -70,69 +214,40 @@ def build_retriever():
|
|
| 70 |
"password": st.secrets['MYSCALE_PASSWORD'],
|
| 71 |
}
|
| 72 |
|
| 73 |
-
config = MyScaleSettings(**myscale_connection,
|
|
|
|
|
|
|
| 74 |
column_map={
|
| 75 |
"id": "id",
|
| 76 |
-
"text": "
|
| 77 |
-
"vector": "
|
| 78 |
-
"metadata": "
|
| 79 |
})
|
| 80 |
-
doc_search =
|
|
|
|
| 81 |
|
| 82 |
-
with st.spinner("Building Self Query Retriever..."):
|
| 83 |
-
metadata_field_info = [
|
| 84 |
-
AttributeInfo(
|
| 85 |
-
name=VirtualColumnName(name="pubdate"),
|
| 86 |
-
description="The year the paper is published",
|
| 87 |
-
type="timestamp",
|
| 88 |
-
),
|
| 89 |
-
AttributeInfo(
|
| 90 |
-
name="authors",
|
| 91 |
-
description="List of author names",
|
| 92 |
-
type="list[string]",
|
| 93 |
-
),
|
| 94 |
-
AttributeInfo(
|
| 95 |
-
name="title",
|
| 96 |
-
description="Title of the paper",
|
| 97 |
-
type="string",
|
| 98 |
-
),
|
| 99 |
-
AttributeInfo(
|
| 100 |
-
name="categories",
|
| 101 |
-
description="arxiv categories to this paper",
|
| 102 |
-
type="list[string]"
|
| 103 |
-
),
|
| 104 |
-
AttributeInfo(
|
| 105 |
-
name="length(categories)",
|
| 106 |
-
description="length of arxiv categories to this paper",
|
| 107 |
-
type="int"
|
| 108 |
-
),
|
| 109 |
-
]
|
| 110 |
retriever = SelfQueryRetriever.from_llm(
|
| 111 |
-
OpenAI(openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0),
|
| 112 |
doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
|
| 113 |
-
use_original_query=False)
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
document_with_metadata_prompt = PromptTemplate(
|
| 117 |
-
input_variables=["page_content", "id", "title", "ref_id",
|
| 118 |
-
"authors", "pubdate", "categories"],
|
| 119 |
-
template="Title for PDF #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}")
|
| 120 |
|
| 121 |
COMBINE_PROMPT = ChatPromptTemplate.from_strings(
|
| 122 |
string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
|
| 123 |
-
|
| 124 |
OPENAI_API_KEY = st.secrets['OPENAI_API_KEY']
|
| 125 |
|
| 126 |
-
with st.spinner('Building QA Chain with Self-query...'):
|
| 127 |
chain = ArXivQAwithSourcesChain(
|
| 128 |
retriever=retriever,
|
| 129 |
combine_documents_chain=ArXivStuffDocumentChain(
|
| 130 |
llm_chain=LLMChain(
|
| 131 |
prompt=COMBINE_PROMPT,
|
| 132 |
-
llm=ChatOpenAI(model_name=
|
| 133 |
-
|
| 134 |
),
|
| 135 |
-
document_prompt=
|
| 136 |
document_variable_name="summaries",
|
| 137 |
|
| 138 |
),
|
|
@@ -140,23 +255,22 @@ def build_retriever():
|
|
| 140 |
max_tokens_limit=12000,
|
| 141 |
)
|
| 142 |
|
| 143 |
-
with st.spinner('Building Vector SQL Database Retriever'):
|
| 144 |
MYSCALE_USER = st.secrets['MYSCALE_USER']
|
| 145 |
MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
|
| 146 |
MYSCALE_HOST = st.secrets['MYSCALE_HOST']
|
| 147 |
MYSCALE_PORT = st.secrets['MYSCALE_PORT']
|
| 148 |
engine = create_engine(
|
| 149 |
-
f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/
|
| 150 |
metadata = MetaData(bind=engine)
|
| 151 |
PROMPT = PromptTemplate(
|
| 152 |
input_variables=["input", "table_info", "top_k"],
|
| 153 |
template=_myscale_prompt,
|
| 154 |
)
|
| 155 |
-
|
| 156 |
output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
|
| 157 |
-
model=
|
| 158 |
sql_query_chain = VectorSQLDatabaseChain.from_llm(
|
| 159 |
-
llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),
|
| 160 |
prompt=PROMPT,
|
| 161 |
top_k=10,
|
| 162 |
return_direct=True,
|
|
@@ -165,18 +279,18 @@ def build_retriever():
|
|
| 165 |
native_format=True
|
| 166 |
)
|
| 167 |
sql_retriever = VectorSQLDatabaseChainRetriever(
|
| 168 |
-
sql_db_chain=sql_query_chain, page_content_key="
|
| 169 |
|
| 170 |
-
with st.spinner('Building QA Chain with Vector SQL...'):
|
| 171 |
sql_chain = ArXivQAwithSourcesChain(
|
| 172 |
retriever=sql_retriever,
|
| 173 |
combine_documents_chain=ArXivStuffDocumentChain(
|
| 174 |
llm_chain=LLMChain(
|
| 175 |
prompt=COMBINE_PROMPT,
|
| 176 |
-
llm=ChatOpenAI(model_name=
|
| 177 |
-
|
| 178 |
),
|
| 179 |
-
document_prompt=
|
| 180 |
document_variable_name="summaries",
|
| 181 |
|
| 182 |
),
|
|
@@ -184,48 +298,33 @@ def build_retriever():
|
|
| 184 |
max_tokens_limit=12000,
|
| 185 |
)
|
| 186 |
|
| 187 |
-
return
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
|
| 191 |
if 'retriever' not in st.session_state:
|
| 192 |
-
st.session_state[
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
| 199 |
-
"For example: \n\n"
|
| 200 |
-
"*If you want to search papers with complex filters*:\n\n"
|
| 201 |
-
"- What is a Bayesian network? Please use articles published later than Feb 2018 and with more than 2 categories and whose title like `computer` and must have `cs.CV` in its category.\n\n"
|
| 202 |
-
"*If you want to ask questions based on papers in database*:\n\n"
|
| 203 |
-
"- What is PageRank?\n"
|
| 204 |
-
"- Did Geoffrey Hinton wrote paper about Capsule Neural Networks?\n"
|
| 205 |
-
"- Introduce some applications of GANs published around 2019.\n"
|
| 206 |
-
"- ่ฏทๆ นๆฎ 2019 ๅนดๅทฆๅณ็ๆ็ซ ไป็ปไธไธ GAN ็ๅบ็จ้ฝๆๅชไบ\n"
|
| 207 |
-
"- Veuillez prรฉsenter les applications du GAN sur la base des articles autour de 2019 ?\n"
|
| 208 |
-
"- Is it possible to synthesize room temperature super conductive material?")
|
| 209 |
tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
|
| 210 |
with tab_sql:
|
| 211 |
-
|
| 212 |
-
st.markdown('''```sql
|
| 213 |
-
CREATE TABLE default.ChatArXiv (
|
| 214 |
-
`abstract` String,
|
| 215 |
-
`id` String,
|
| 216 |
-
`vector` Array(Float32),
|
| 217 |
-
`metadata` Object('JSON'),
|
| 218 |
-
`pubdate` DateTime,
|
| 219 |
-
`title` String,
|
| 220 |
-
`categories` Array(String),
|
| 221 |
-
`authors` Array(String),
|
| 222 |
-
`comment` String,
|
| 223 |
-
`primary_category` String,
|
| 224 |
-
VECTOR INDEX vec_idx vector TYPE MSTG('metric_type=Cosine'),
|
| 225 |
-
CONSTRAINT vec_len CHECK length(vector) = 768)
|
| 226 |
-
ENGINE = ReplacingMergeTree ORDER BY id
|
| 227 |
-
```''')
|
| 228 |
-
|
| 229 |
st.text_input("Ask a question:", key='query_sql')
|
| 230 |
cols = st.columns([1, 1, 7])
|
| 231 |
cols[0].button("Query", key='search_sql')
|
|
@@ -237,7 +336,7 @@ ENGINE = ReplacingMergeTree ORDER BY id
|
|
| 237 |
with plc_hldr.expander('Query Log', expanded=True):
|
| 238 |
callback = ChatDataSQLSearchCallBackHandler()
|
| 239 |
try:
|
| 240 |
-
docs = st.session_state.sql_retriever.get_relevant_documents(
|
| 241 |
st.session_state.query_sql, callbacks=[callback])
|
| 242 |
callback.progress_bar.progress(value=1.0, text="Done!")
|
| 243 |
docs = pd.DataFrame(
|
|
@@ -253,14 +352,16 @@ ENGINE = ReplacingMergeTree ORDER BY id
|
|
| 253 |
with plc_hldr.expander('Chat Log', expanded=True):
|
| 254 |
callback = ChatDataSQLAskCallBackHandler()
|
| 255 |
try:
|
| 256 |
-
ret = st.session_state.sql_chain(
|
| 257 |
st.session_state.query_sql, callbacks=[callback])
|
| 258 |
callback.progress_bar.progress(value=1.0, text="Done!")
|
| 259 |
st.markdown(
|
| 260 |
f"### Answer from LLM\n{ret['answer']}\n### References")
|
| 261 |
docs = ret['sources']
|
| 262 |
-
docs = pd.DataFrame(
|
| 263 |
-
|
|
|
|
|
|
|
| 264 |
except Exception as e:
|
| 265 |
st.write('Oops ๐ต Something bad happened...')
|
| 266 |
raise e
|
|
@@ -268,7 +369,7 @@ ENGINE = ReplacingMergeTree ORDER BY id
|
|
| 268 |
|
| 269 |
with tab_self_query:
|
| 270 |
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='๐ก')
|
| 271 |
-
st.dataframe(st.session_state.metadata_columns)
|
| 272 |
st.text_input("Ask a question:", key='query_self')
|
| 273 |
cols = st.columns([1, 1, 7])
|
| 274 |
cols[0].button("Query", key='search_self')
|
|
@@ -281,13 +382,13 @@ with tab_self_query:
|
|
| 281 |
call_back = None
|
| 282 |
callback = ChatDataSelfSearchCallBackHandler()
|
| 283 |
try:
|
| 284 |
-
docs = st.session_state.retriever.get_relevant_documents(
|
| 285 |
st.session_state.query_self, callbacks=[callback])
|
|
|
|
| 286 |
callback.progress_bar.progress(value=1.0, text="Done!")
|
| 287 |
docs = pd.DataFrame(
|
| 288 |
[{**d.metadata, 'abstract': d.page_content} for d in docs])
|
| 289 |
-
|
| 290 |
-
display(docs, ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'])
|
| 291 |
except Exception as e:
|
| 292 |
st.write('Oops ๐ต Something bad happened...')
|
| 293 |
raise e
|
|
@@ -299,14 +400,16 @@ with tab_self_query:
|
|
| 299 |
call_back = None
|
| 300 |
callback = ChatDataSelfAskCallBackHandler()
|
| 301 |
try:
|
| 302 |
-
ret = st.session_state.chain(
|
| 303 |
st.session_state.query_self, callbacks=[callback])
|
| 304 |
callback.progress_bar.progress(value=1.0, text="Done!")
|
| 305 |
st.markdown(
|
| 306 |
f"### Answer from LLM\n{ret['answer']}\n### References")
|
| 307 |
docs = ret['sources']
|
| 308 |
-
docs = pd.DataFrame(
|
| 309 |
-
|
|
|
|
|
|
|
| 310 |
except Exception as e:
|
| 311 |
st.write('Oops ๐ต Something bad happened...')
|
| 312 |
raise e
|
|
|
|
| 1 |
+
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
|
| 2 |
+
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
|
| 3 |
+
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
|
| 4 |
+
ChatDataSQLAskCallBackHandler
|
| 5 |
+
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
| 6 |
+
from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
|
| 7 |
+
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
|
| 8 |
+
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
|
| 9 |
+
from langchain.utilities.sql_database import SQLDatabase
|
| 10 |
+
from langchain.chains import LLMChain
|
| 11 |
+
from sqlalchemy import create_engine, MetaData
|
| 12 |
+
from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
|
| 13 |
+
SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
| 14 |
+
from langchain.prompts.prompt import PromptTemplate
|
| 15 |
+
from langchain.chat_models import ChatOpenAI
|
| 16 |
+
from langchain import OpenAI
|
| 17 |
+
from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
|
| 18 |
+
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
| 19 |
+
from langchain.retrievers.self_query.myscale import MyScaleTranslator
|
| 20 |
+
from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings
|
| 21 |
+
from langchain.vectorstores import MyScaleSettings
|
| 22 |
+
from chains.arxiv_chains import MyScaleWithoutMetadataJson
|
| 23 |
import re
|
| 24 |
import pandas as pd
|
| 25 |
from os import environ
|
|
|
|
| 28 |
environ['TOKENIZERS_PARALLELISM'] = 'true'
|
| 29 |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
st.set_page_config(page_title="ChatData")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
st.header("ChatData")
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
# query_model_name = "gpt-3.5-turbo-instruct"
|
| 37 |
+
query_model_name = "text-davinci-003"
|
| 38 |
+
chat_model_name = "gpt-3.5-turbo-16k"
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
+
def hint_arxiv():
|
| 42 |
+
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
| 43 |
+
"For example: \n\n"
|
| 44 |
+
"*If you want to search papers with complex filters*:\n\n"
|
| 45 |
+
"- What is a Bayesian network? Please use articles published later than Feb 2018 and with more than 2 categories and whose title like `computer` and must have `cs.CV` in its category.\n\n"
|
| 46 |
+
"*If you want to ask questions based on papers in database*:\n\n"
|
| 47 |
+
"- What is PageRank?\n"
|
| 48 |
+
"- Did Geoffrey Hinton wrote paper about Capsule Neural Networks?\n"
|
| 49 |
+
"- Introduce some applications of GANs published around 2019.\n"
|
| 50 |
+
"- ่ฏทๆ นๆฎ 2019 ๅนดๅทฆๅณ็ๆ็ซ ไป็ปไธไธ GAN ็ๅบ็จ้ฝๆๅชไบ\n"
|
| 51 |
+
"- Veuillez prรฉsenter les applications du GAN sur la base des articles autour de 2019 ?\n"
|
| 52 |
+
"- Is it possible to synthesize room temperature super conductive material?")
|
| 53 |
|
| 54 |
+
|
| 55 |
+
def hint_sql_arxiv():
|
| 56 |
+
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='๐ก')
|
| 57 |
+
st.markdown('''```sql
|
| 58 |
+
CREATE TABLE default.ChatArXiv (
|
| 59 |
+
`abstract` String,
|
| 60 |
+
`id` String,
|
| 61 |
+
`vector` Array(Float32),
|
| 62 |
+
`metadata` Object('JSON'),
|
| 63 |
+
`pubdate` DateTime,
|
| 64 |
+
`title` String,
|
| 65 |
+
`categories` Array(String),
|
| 66 |
+
`authors` Array(String),
|
| 67 |
+
`comment` String,
|
| 68 |
+
`primary_category` String,
|
| 69 |
+
VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
|
| 70 |
+
CONSTRAINT vec_len CHECK length(vector) = 768)
|
| 71 |
+
ENGINE = ReplacingMergeTree ORDER BY id
|
| 72 |
+
```''')
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def hint_wiki():
|
| 76 |
+
st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
|
| 77 |
+
"For example: \n\n"
|
| 78 |
+
"- Which company did Elon Musk found?\n"
|
| 79 |
+
"- What is Iron Gwazi?\n"
|
| 80 |
+
"- What is a Ring in mathematics?\n"
|
| 81 |
+
"- ่นๆ็ๅๆบๅฐๆฏ้ฃ้๏ผ\n")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def hint_sql_wiki():
|
| 85 |
+
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='๐ก')
|
| 86 |
+
st.markdown('''```sql
|
| 87 |
+
CREATE TABLE wiki.Wikipedia (
|
| 88 |
+
`id` String,
|
| 89 |
+
`title` String,
|
| 90 |
+
`text` String,
|
| 91 |
+
`url` String,
|
| 92 |
+
`wiki_id` UInt64,
|
| 93 |
+
`views` Float32,
|
| 94 |
+
`paragraph_id` UInt64,
|
| 95 |
+
`langs` UInt32,
|
| 96 |
+
`emb` Array(Float32),
|
| 97 |
+
VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
|
| 98 |
+
CONSTRAINT emb_len CHECK length(emb) = 768)
|
| 99 |
+
ENGINE = ReplacingMergeTree ORDER BY id
|
| 100 |
+
```''')
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
sel_map = {
|
| 104 |
+
'Wikipedia': {
|
| 105 |
+
"database": "wiki",
|
| 106 |
+
"table": "Wikipedia",
|
| 107 |
+
"hint": hint_wiki,
|
| 108 |
+
"hint_sql": hint_sql_wiki,
|
| 109 |
+
"doc_prompt": PromptTemplate(
|
| 110 |
+
input_variables=["page_content", "url", "title", "ref_id", "views"],
|
| 111 |
+
template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"),
|
| 112 |
+
"metadata_cols": [
|
| 113 |
+
AttributeInfo(
|
| 114 |
+
name="title",
|
| 115 |
+
description="title of the wikipedia page",
|
| 116 |
+
type="string",
|
| 117 |
+
),
|
| 118 |
+
AttributeInfo(
|
| 119 |
+
name="text",
|
| 120 |
+
description="paragraph from this wiki page",
|
| 121 |
+
type="string",
|
| 122 |
+
),
|
| 123 |
+
AttributeInfo(
|
| 124 |
+
name="views",
|
| 125 |
+
description="number of views",
|
| 126 |
+
type="float"
|
| 127 |
+
),
|
| 128 |
+
],
|
| 129 |
+
"must_have_cols": ['id', 'title', 'url', 'text', 'views'],
|
| 130 |
+
"vector_col": "emb",
|
| 131 |
+
"text_col": "text",
|
| 132 |
+
"metadata_col": "metadata",
|
| 133 |
+
"emb_model": lambda: SentenceTransformerEmbeddings(
|
| 134 |
+
model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2',)
|
| 135 |
+
},
|
| 136 |
+
'ArXiv Papers': {
|
| 137 |
+
"database": "default",
|
| 138 |
+
"table": "ChatArXiv",
|
| 139 |
+
"hint": hint_arxiv,
|
| 140 |
+
"hint_sql": hint_sql_arxiv,
|
| 141 |
+
"doc_prompt": PromptTemplate(
|
| 142 |
+
input_variables=["page_content", "id", "title", "ref_id",
|
| 143 |
+
"authors", "pubdate", "categories"],
|
| 144 |
+
template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}"),
|
| 145 |
+
"metadata_cols": [
|
| 146 |
+
AttributeInfo(
|
| 147 |
+
name=VirtualColumnName(name="pubdate"),
|
| 148 |
+
description="The year the paper is published",
|
| 149 |
+
type="timestamp",
|
| 150 |
+
),
|
| 151 |
+
AttributeInfo(
|
| 152 |
+
name="authors",
|
| 153 |
+
description="List of author names",
|
| 154 |
+
type="list[string]",
|
| 155 |
+
),
|
| 156 |
+
AttributeInfo(
|
| 157 |
+
name="title",
|
| 158 |
+
description="Title of the paper",
|
| 159 |
+
type="string",
|
| 160 |
+
),
|
| 161 |
+
AttributeInfo(
|
| 162 |
+
name="categories",
|
| 163 |
+
description="arxiv categories to this paper",
|
| 164 |
+
type="list[string]"
|
| 165 |
+
),
|
| 166 |
+
AttributeInfo(
|
| 167 |
+
name="length(categories)",
|
| 168 |
+
description="length of arxiv categories to this paper",
|
| 169 |
+
type="int"
|
| 170 |
+
),
|
| 171 |
+
],
|
| 172 |
+
"must_have_cols": ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'],
|
| 173 |
+
"vector_col": "vector",
|
| 174 |
+
"text_col": "abstract",
|
| 175 |
+
"metadata_col": "metadata",
|
| 176 |
+
"emb_model": lambda: HuggingFaceInstructEmbeddings(
|
| 177 |
+
model_name='hkunlp/instructor-xl',
|
| 178 |
+
embed_instruction="Represent the question for retrieving supporting scientific papers: ")
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
|
| 182 |
|
| 183 |
def try_eval(x):
|
|
|
|
| 199 |
st.write("Sorry ๐ต we didn't find any articles related to your query.\n\nMaybe the LLM is too naughty that does not follow our instruction... \n\nPlease try again and use verbs that may match the datatype.", unsafe_allow_html=True)
|
| 200 |
|
| 201 |
|
| 202 |
+
def build_embedding_model(_sel):
|
|
|
|
| 203 |
with st.spinner("Loading Model..."):
|
| 204 |
+
embeddings = sel_map[_sel]["emb_model"]()
|
| 205 |
+
return embeddings
|
|
|
|
| 206 |
|
| 207 |
+
|
| 208 |
+
def build_retriever(_sel):
|
| 209 |
+
with st.spinner(f"Connecting DB for {_sel}..."):
|
| 210 |
myscale_connection = {
|
| 211 |
"host": st.secrets['MYSCALE_HOST'],
|
| 212 |
"port": st.secrets['MYSCALE_PORT'],
|
|
|
|
| 214 |
"password": st.secrets['MYSCALE_PASSWORD'],
|
| 215 |
}
|
| 216 |
|
| 217 |
+
config = MyScaleSettings(**myscale_connection,
|
| 218 |
+
database=sel_map[_sel]["database"],
|
| 219 |
+
table=sel_map[_sel]["table"],
|
| 220 |
column_map={
|
| 221 |
"id": "id",
|
| 222 |
+
"text": sel_map[_sel]["text_col"],
|
| 223 |
+
"vector": sel_map[_sel]["vector_col"],
|
| 224 |
+
"metadata": sel_map[_sel]["metadata_col"]
|
| 225 |
})
|
| 226 |
+
doc_search = MyScaleWithoutMetadataJson(st.session_state[f"emb_model_{_sel}"], config,
|
| 227 |
+
must_have_cols=sel_map[_sel]['must_have_cols'])
|
| 228 |
|
| 229 |
+
with st.spinner(f"Building Self Query Retriever for {_sel}..."):
|
| 230 |
+
metadata_field_info = sel_map[_sel]["metadata_cols"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
retriever = SelfQueryRetriever.from_llm(
|
| 232 |
+
OpenAI(model_name=query_model_name, openai_api_key=st.secrets['OPENAI_API_KEY'], temperature=0),
|
| 233 |
doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
|
| 234 |
+
use_original_query=False, structured_query_translator=MyScaleTranslator())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
COMBINE_PROMPT = ChatPromptTemplate.from_strings(
|
| 237 |
string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
|
| 238 |
+
(HumanMessagePromptTemplate, '{question}')])
|
| 239 |
OPENAI_API_KEY = st.secrets['OPENAI_API_KEY']
|
| 240 |
|
| 241 |
+
with st.spinner(f'Building QA Chain with Self-query for {_sel}...'):
|
| 242 |
chain = ArXivQAwithSourcesChain(
|
| 243 |
retriever=retriever,
|
| 244 |
combine_documents_chain=ArXivStuffDocumentChain(
|
| 245 |
llm_chain=LLMChain(
|
| 246 |
prompt=COMBINE_PROMPT,
|
| 247 |
+
llm=ChatOpenAI(model_name=chat_model_name,
|
| 248 |
+
openai_api_key=OPENAI_API_KEY, temperature=0.6),
|
| 249 |
),
|
| 250 |
+
document_prompt=sel_map[_sel]["doc_prompt"],
|
| 251 |
document_variable_name="summaries",
|
| 252 |
|
| 253 |
),
|
|
|
|
| 255 |
max_tokens_limit=12000,
|
| 256 |
)
|
| 257 |
|
| 258 |
+
with st.spinner(f'Building Vector SQL Database Retriever for {_sel}...'):
|
| 259 |
MYSCALE_USER = st.secrets['MYSCALE_USER']
|
| 260 |
MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
|
| 261 |
MYSCALE_HOST = st.secrets['MYSCALE_HOST']
|
| 262 |
MYSCALE_PORT = st.secrets['MYSCALE_PORT']
|
| 263 |
engine = create_engine(
|
| 264 |
+
f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/{sel_map[_sel]["database"]}?protocol=https')
|
| 265 |
metadata = MetaData(bind=engine)
|
| 266 |
PROMPT = PromptTemplate(
|
| 267 |
input_variables=["input", "table_info", "top_k"],
|
| 268 |
template=_myscale_prompt,
|
| 269 |
)
|
|
|
|
| 270 |
output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
|
| 271 |
+
model=st.session_state[f'emb_model_{_sel}'], must_have_columns=sel_map[_sel]["must_have_cols"])
|
| 272 |
sql_query_chain = VectorSQLDatabaseChain.from_llm(
|
| 273 |
+
llm=OpenAI(model_name=query_model_name, openai_api_key=OPENAI_API_KEY, temperature=0),
|
| 274 |
prompt=PROMPT,
|
| 275 |
top_k=10,
|
| 276 |
return_direct=True,
|
|
|
|
| 279 |
native_format=True
|
| 280 |
)
|
| 281 |
sql_retriever = VectorSQLDatabaseChainRetriever(
|
| 282 |
+
sql_db_chain=sql_query_chain, page_content_key=sel_map[_sel]["text_col"])
|
| 283 |
|
| 284 |
+
with st.spinner(f'Building QA Chain with Vector SQL for {_sel}...'):
|
| 285 |
sql_chain = ArXivQAwithSourcesChain(
|
| 286 |
retriever=sql_retriever,
|
| 287 |
combine_documents_chain=ArXivStuffDocumentChain(
|
| 288 |
llm_chain=LLMChain(
|
| 289 |
prompt=COMBINE_PROMPT,
|
| 290 |
+
llm=ChatOpenAI(model_name=chat_model_name,
|
| 291 |
+
openai_api_key=OPENAI_API_KEY, temperature=0.6),
|
| 292 |
),
|
| 293 |
+
document_prompt=sel_map[_sel]["doc_prompt"],
|
| 294 |
document_variable_name="summaries",
|
| 295 |
|
| 296 |
),
|
|
|
|
| 298 |
max_tokens_limit=12000,
|
| 299 |
)
|
| 300 |
|
| 301 |
+
return {
|
| 302 |
+
"metadata_columns": [{'name': m.name.name if type(m.name) is VirtualColumnName else m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info],
|
| 303 |
+
"retriever": retriever,
|
| 304 |
+
"chain": chain,
|
| 305 |
+
"sql_retriever": sql_retriever,
|
| 306 |
+
"sql_chain": sql_chain
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
@st.cache_resource
|
| 311 |
+
def build_all():
|
| 312 |
+
sel_map_obj = {}
|
| 313 |
+
for k in sel_map:
|
| 314 |
+
st.session_state[f'emb_model_{k}'] = build_embedding_model(k)
|
| 315 |
+
sel_map_obj[k] = build_retriever(k)
|
| 316 |
+
return sel_map_obj
|
| 317 |
|
| 318 |
|
| 319 |
if 'retriever' not in st.session_state:
|
| 320 |
+
st.session_state["sel_map_obj"] = build_all()
|
| 321 |
+
|
| 322 |
+
sel = st.selectbox('Choose the knowledge base you want to ask with:',
|
| 323 |
+
options=['ArXiv Papers', 'Wikipedia'])
|
| 324 |
+
sel_map[sel]['hint']()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
|
| 326 |
with tab_sql:
|
| 327 |
+
sel_map[sel]['hint_sql']()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
st.text_input("Ask a question:", key='query_sql')
|
| 329 |
cols = st.columns([1, 1, 7])
|
| 330 |
cols[0].button("Query", key='search_sql')
|
|
|
|
| 336 |
with plc_hldr.expander('Query Log', expanded=True):
|
| 337 |
callback = ChatDataSQLSearchCallBackHandler()
|
| 338 |
try:
|
| 339 |
+
docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
|
| 340 |
st.session_state.query_sql, callbacks=[callback])
|
| 341 |
callback.progress_bar.progress(value=1.0, text="Done!")
|
| 342 |
docs = pd.DataFrame(
|
|
|
|
| 352 |
with plc_hldr.expander('Chat Log', expanded=True):
|
| 353 |
callback = ChatDataSQLAskCallBackHandler()
|
| 354 |
try:
|
| 355 |
+
ret = st.session_state.sel_map_obj[sel]["sql_chain"](
|
| 356 |
st.session_state.query_sql, callbacks=[callback])
|
| 357 |
callback.progress_bar.progress(value=1.0, text="Done!")
|
| 358 |
st.markdown(
|
| 359 |
f"### Answer from LLM\n{ret['answer']}\n### References")
|
| 360 |
docs = ret['sources']
|
| 361 |
+
docs = pd.DataFrame(
|
| 362 |
+
[{**d.metadata, 'abstract': d.page_content} for d in docs])
|
| 363 |
+
display(
|
| 364 |
+
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
|
| 365 |
except Exception as e:
|
| 366 |
st.write('Oops ๐ต Something bad happened...')
|
| 367 |
raise e
|
|
|
|
| 369 |
|
| 370 |
with tab_self_query:
|
| 371 |
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='๐ก')
|
| 372 |
+
st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
|
| 373 |
st.text_input("Ask a question:", key='query_self')
|
| 374 |
cols = st.columns([1, 1, 7])
|
| 375 |
cols[0].button("Query", key='search_self')
|
|
|
|
| 382 |
call_back = None
|
| 383 |
callback = ChatDataSelfSearchCallBackHandler()
|
| 384 |
try:
|
| 385 |
+
docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
|
| 386 |
st.session_state.query_self, callbacks=[callback])
|
| 387 |
+
print(docs)
|
| 388 |
callback.progress_bar.progress(value=1.0, text="Done!")
|
| 389 |
docs = pd.DataFrame(
|
| 390 |
[{**d.metadata, 'abstract': d.page_content} for d in docs])
|
| 391 |
+
display(docs, sel_map[sel]["must_have_cols"])
|
|
|
|
| 392 |
except Exception as e:
|
| 393 |
st.write('Oops ๐ต Something bad happened...')
|
| 394 |
raise e
|
|
|
|
| 400 |
call_back = None
|
| 401 |
callback = ChatDataSelfAskCallBackHandler()
|
| 402 |
try:
|
| 403 |
+
ret = st.session_state.sel_map_obj[sel]["chain"](
|
| 404 |
st.session_state.query_self, callbacks=[callback])
|
| 405 |
callback.progress_bar.progress(value=1.0, text="Done!")
|
| 406 |
st.markdown(
|
| 407 |
f"### Answer from LLM\n{ret['answer']}\n### References")
|
| 408 |
docs = ret['sources']
|
| 409 |
+
docs = pd.DataFrame(
|
| 410 |
+
[{**d.metadata, 'abstract': d.page_content} for d in docs])
|
| 411 |
+
display(
|
| 412 |
+
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
|
| 413 |
except Exception as e:
|
| 414 |
st.write('Oops ๐ต Something bad happened...')
|
| 415 |
raise e
|
callbacks/arxiv_callbacks.py
CHANGED
|
@@ -90,4 +90,4 @@ class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
|
|
| 90 |
self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
|
| 91 |
self.status_bar = st.empty()
|
| 92 |
self.prog_value = 0
|
| 93 |
-
self.prog_interval = 0.1
|
|
|
|
| 90 |
self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
|
| 91 |
self.status_bar = st.empty()
|
| 92 |
self.prog_value = 0
|
| 93 |
+
self.prog_interval = 0.1
|
chains/arxiv_chains.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import
|
| 2 |
import inspect
|
| 3 |
from typing import Dict, Any, Optional, List, Tuple
|
| 4 |
|
|
@@ -7,21 +7,62 @@ from langchain.callbacks.manager import (
|
|
| 7 |
AsyncCallbackManagerForChainRun,
|
| 8 |
CallbackManagerForChainRun,
|
| 9 |
)
|
|
|
|
| 10 |
from langchain.schema import BaseRetriever
|
| 11 |
from langchain.callbacks.manager import Callbacks
|
| 12 |
from langchain.schema.prompt_template import format_document
|
| 13 |
from langchain.docstore.document import Document
|
| 14 |
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
|
| 15 |
-
from langchain.
|
| 16 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
| 17 |
|
| 18 |
from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
|
| 22 |
"""Based on VectorSQLOutputParser
|
| 23 |
It also modify the SQL to get all columns
|
| 24 |
"""
|
|
|
|
| 25 |
|
| 26 |
@property
|
| 27 |
def _type(self) -> str:
|
|
@@ -123,12 +164,15 @@ class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
|
|
| 123 |
ref_cnt = 1
|
| 124 |
for d in docs:
|
| 125 |
ref_id = d.metadata['ref_id']
|
| 126 |
-
if f"
|
|
|
|
|
|
|
| 127 |
title = d.metadata['title'].replace('\n', '')
|
| 128 |
d.metadata['ref_id'] = ref_cnt
|
| 129 |
-
answer = answer.replace(f"
|
| 130 |
sources.append(d)
|
| 131 |
ref_cnt += 1
|
|
|
|
| 132 |
|
| 133 |
result: Dict[str, Any] = {
|
| 134 |
self.answer_key: answer,
|
|
@@ -147,4 +191,4 @@ class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
|
|
| 147 |
|
| 148 |
@property
|
| 149 |
def _chain_type(self) -> str:
|
| 150 |
-
return "arxiv_qa_with_sources_chain"
|
|
|
|
| 1 |
+
import logging
|
| 2 |
import inspect
|
| 3 |
from typing import Dict, Any, Optional, List, Tuple
|
| 4 |
|
|
|
|
| 7 |
AsyncCallbackManagerForChainRun,
|
| 8 |
CallbackManagerForChainRun,
|
| 9 |
)
|
| 10 |
+
from langchain.embeddings.base import Embeddings
|
| 11 |
from langchain.schema import BaseRetriever
|
| 12 |
from langchain.callbacks.manager import Callbacks
|
| 13 |
from langchain.schema.prompt_template import format_document
|
| 14 |
from langchain.docstore.document import Document
|
| 15 |
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
|
| 16 |
+
from langchain.vectorstores.myscale import MyScale, MyScaleSettings
|
| 17 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
| 18 |
|
| 19 |
from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
|
| 20 |
|
| 21 |
+
logger = logging.getLogger()
|
| 22 |
+
|
| 23 |
+
class MyScaleWithoutMetadataJson(MyScale):
|
| 24 |
+
def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [], **kwargs: Any) -> None:
|
| 25 |
+
super().__init__(embedding, config, **kwargs)
|
| 26 |
+
self.must_have_cols: List[str] = must_have_cols
|
| 27 |
+
|
| 28 |
+
def _build_qstr(
|
| 29 |
+
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
|
| 30 |
+
) -> str:
|
| 31 |
+
q_emb_str = ",".join(map(str, q_emb))
|
| 32 |
+
if where_str:
|
| 33 |
+
where_str = f"PREWHERE {where_str}"
|
| 34 |
+
else:
|
| 35 |
+
where_str = ""
|
| 36 |
+
|
| 37 |
+
q_str = f"""
|
| 38 |
+
SELECT {self.config.column_map['text']}, dist, {','.join(self.must_have_cols)}
|
| 39 |
+
FROM {self.config.database}.{self.config.table}
|
| 40 |
+
{where_str}
|
| 41 |
+
ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}])
|
| 42 |
+
AS dist {self.dist_order}
|
| 43 |
+
LIMIT {topk}
|
| 44 |
+
"""
|
| 45 |
+
return q_str
|
| 46 |
+
|
| 47 |
+
def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None, **kwargs: Any) -> List[Document]:
|
| 48 |
+
q_str = self._build_qstr(embedding, k, where_str)
|
| 49 |
+
try:
|
| 50 |
+
return [
|
| 51 |
+
Document(
|
| 52 |
+
page_content=r[self.config.column_map["text"]],
|
| 53 |
+
metadata={k: r[k] for k in self.must_have_cols},
|
| 54 |
+
)
|
| 55 |
+
for r in self.client.query(q_str).named_results()
|
| 56 |
+
]
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error(f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
| 59 |
+
return []
|
| 60 |
|
| 61 |
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
|
| 62 |
"""Based on VectorSQLOutputParser
|
| 63 |
It also modify the SQL to get all columns
|
| 64 |
"""
|
| 65 |
+
must_have_columns: List[str]
|
| 66 |
|
| 67 |
@property
|
| 68 |
def _type(self) -> str:
|
|
|
|
| 164 |
ref_cnt = 1
|
| 165 |
for d in docs:
|
| 166 |
ref_id = d.metadata['ref_id']
|
| 167 |
+
if f"Doc #{ref_id}" in answer:
|
| 168 |
+
answer = answer.replace(f"Doc #{ref_id}", f"#{ref_id}")
|
| 169 |
+
if f"#{ref_id}" in answer:
|
| 170 |
title = d.metadata['title'].replace('\n', '')
|
| 171 |
d.metadata['ref_id'] = ref_cnt
|
| 172 |
+
answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]")
|
| 173 |
sources.append(d)
|
| 174 |
ref_cnt += 1
|
| 175 |
+
|
| 176 |
|
| 177 |
result: Dict[str, Any] = {
|
| 178 |
self.answer_key: answer,
|
|
|
|
| 191 |
|
| 192 |
@property
|
| 193 |
def _chain_type(self) -> str:
|
| 194 |
+
return "arxiv_qa_with_sources_chain"
|
prompts/arxiv_prompt.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
combine_prompt_template = (
|
| 2 |
-
"You are a helpful
|
| 3 |
-
+ "related to
|
| 4 |
+ "and try to provide concise and accurate answers to any questions asked by the user. If you are unable to find "
|
| 5 |
+ "relevant information in the given sections, you will need to let the user know that the source does not contain "
|
| 6 |
+ "relevant information but still try to provide an answer based on your general knowledge. You must refer to the "
|
| 7 |
+ "corresponding section name and page that you refer to when answering. The following is the related information "
|
| 8 |
-
+ "about the
|
| 9 |
-
+ "Now you should anwser user's question. Remember you must use
|
| 10 |
)
|
| 11 |
|
| 12 |
_myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
|
|
|
|
| 1 |
combine_prompt_template = (
|
| 2 |
+
"You are a helpful document assistant. Your task is to provide information and answer any questions "
|
| 3 |
+
+ "related to documents given below. You should use the sections, title and abstract of the selected documents as your source of information "
|
| 4 |
+ "and try to provide concise and accurate answers to any questions asked by the user. If you are unable to find "
|
| 5 |
+ "relevant information in the given sections, you will need to let the user know that the source does not contain "
|
| 6 |
+ "relevant information but still try to provide an answer based on your general knowledge. You must refer to the "
|
| 7 |
+ "corresponding section name and page that you refer to when answering. The following is the related information "
|
| 8 |
+
+ "about the document that will help you answer users' questions, you MUST answer it using question's language:\n\n {summaries}"
|
| 9 |
+
+ "Now you should anwser user's question. Remember you must use `Doc #` to refer papers:\n\n"
|
| 10 |
)
|
| 11 |
|
| 12 |
_myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
|