Spaces:
Runtime error
Runtime error
Fangrui Liu
commited on
Commit
·
eb820e1
1
Parent(s):
1a24bbc
revised prompt
Browse files- app.py +65 -52
- chains/arxiv_chains.py +131 -0
- prompts/arxiv_prompt.py +7 -8
app.py
CHANGED
|
@@ -10,15 +10,12 @@ from langchain.vectorstores import MyScale, MyScaleSettings
|
|
| 10 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
| 11 |
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
| 12 |
from langchain.chains.query_constructor.base import AttributeInfo
|
| 13 |
-
from langchain.chains import RetrievalQAWithSourcesChain
|
| 14 |
from langchain import OpenAI
|
| 15 |
from langchain.chat_models import ChatOpenAI
|
| 16 |
|
| 17 |
-
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
|
| 18 |
-
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
|
| 19 |
-
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
|
| 20 |
-
ChatDataSQLAskCallBackHandler
|
| 21 |
from langchain.prompts.prompt import PromptTemplate
|
|
|
|
|
|
|
| 22 |
from sqlalchemy import create_engine, MetaData
|
| 23 |
from langchain.chains.sql_database.base import SQLDatabaseChain
|
| 24 |
from langchain.chains.sql_database.parser import VectorSQLRetrieveAllOutputParser
|
|
@@ -26,12 +23,17 @@ from langchain.chains import LLMChain
|
|
| 26 |
from langchain.sql_database import SQLDatabase
|
| 27 |
from langchain.retrievers import SQLDatabaseChainRetriever
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
st.set_page_config(page_title="ChatData")
|
| 31 |
|
| 32 |
st.header("ChatData")
|
| 33 |
|
| 34 |
-
columns = ['title', 'id', 'categories', 'abstract', 'authors', 'pubdate']
|
| 35 |
|
| 36 |
|
| 37 |
def try_eval(x):
|
|
@@ -41,7 +43,9 @@ def try_eval(x):
|
|
| 41 |
return x
|
| 42 |
|
| 43 |
|
| 44 |
-
def display(dataframe, columns=None):
|
|
|
|
|
|
|
| 45 |
if len(dataframe) > 0:
|
| 46 |
if columns:
|
| 47 |
st.dataframe(dataframe[columns])
|
|
@@ -108,24 +112,35 @@ def build_retriever():
|
|
| 108 |
doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
|
| 109 |
use_original_query=False)
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
| 121 |
retriever=retriever,
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
MYSCALE_USER = st.secrets['MYSCALE_USER']
|
| 130 |
MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
|
| 131 |
MYSCALE_HOST = st.secrets['MYSCALE_HOST']
|
|
@@ -141,7 +156,7 @@ def build_retriever():
|
|
| 141 |
output_parser = VectorSQLRetrieveAllOutputParser.from_embeddings(
|
| 142 |
model=embeddings)
|
| 143 |
sql_query_chain = SQLDatabaseChain.from_llm(
|
| 144 |
-
llm=OpenAI(openai_api_key=
|
| 145 |
prompt=PROMPT,
|
| 146 |
top_k=10,
|
| 147 |
return_direct=True,
|
|
@@ -151,15 +166,23 @@ def build_retriever():
|
|
| 151 |
)
|
| 152 |
sql_retriever = SQLDatabaseChainRetriever(
|
| 153 |
sql_db_chain=sql_query_chain, page_content_key="abstract")
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
retriever=sql_retriever,
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
return [{'name': m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info], retriever, chain, sql_retriever, sql_chain
|
| 165 |
|
|
@@ -220,7 +243,7 @@ ENGINE = ReplacingMergeTree ORDER BY id
|
|
| 220 |
display(docs)
|
| 221 |
except Exception as e:
|
| 222 |
st.write('Oops 😵 Something bad happened...')
|
| 223 |
-
|
| 224 |
|
| 225 |
if st.session_state.ask_sql:
|
| 226 |
plc_hldr = st.empty()
|
|
@@ -233,17 +256,12 @@ ENGINE = ReplacingMergeTree ORDER BY id
|
|
| 233 |
callback.progress_bar.progress(value=1.0, text="Done!")
|
| 234 |
st.markdown(
|
| 235 |
f"### Answer from LLM\n{ret['answer']}\n### References")
|
| 236 |
-
docs = ret['
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
ref += re.findall(
|
| 240 |
-
'(http://arxiv.org/abs/\d{4}.\d+v\d)', ret['answer'])
|
| 241 |
-
docs = pd.DataFrame([{**d.metadata, 'abstract': d.page_content}
|
| 242 |
-
for d in docs if d.metadata['id'] in set(ref)])
|
| 243 |
-
display(docs, columns)
|
| 244 |
except Exception as e:
|
| 245 |
st.write('Oops 😵 Something bad happened...')
|
| 246 |
-
|
| 247 |
|
| 248 |
|
| 249 |
with tab_self_query:
|
|
@@ -270,7 +288,7 @@ with tab_self_query:
|
|
| 270 |
display(docs, columns)
|
| 271 |
except Exception as e:
|
| 272 |
st.write('Oops 😵 Something bad happened...')
|
| 273 |
-
|
| 274 |
|
| 275 |
if st.session_state.ask_self:
|
| 276 |
plc_hldr = st.empty()
|
|
@@ -284,14 +302,9 @@ with tab_self_query:
|
|
| 284 |
callback.progress_bar.progress(value=1.0, text="Done!")
|
| 285 |
st.markdown(
|
| 286 |
f"### Answer from LLM\n{ret['answer']}\n### References")
|
| 287 |
-
docs = ret['
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
ref += re.findall(
|
| 291 |
-
'(http://arxiv.org/abs/\d{4}.\d+v\d)', ret['answer'])
|
| 292 |
-
docs = pd.DataFrame([{**d.metadata, 'abstract': d.page_content}
|
| 293 |
-
for d in docs if d.metadata['id'] in set(ref)])
|
| 294 |
-
display(docs, columns)
|
| 295 |
except Exception as e:
|
| 296 |
st.write('Oops 😵 Something bad happened...')
|
| 297 |
-
|
|
|
|
| 10 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
| 11 |
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
| 12 |
from langchain.chains.query_constructor.base import AttributeInfo
|
|
|
|
| 13 |
from langchain import OpenAI
|
| 14 |
from langchain.chat_models import ChatOpenAI
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
from langchain.prompts.prompt import PromptTemplate
|
| 17 |
+
from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
|
| 18 |
+
SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
| 19 |
from sqlalchemy import create_engine, MetaData
|
| 20 |
from langchain.chains.sql_database.base import SQLDatabaseChain
|
| 21 |
from langchain.chains.sql_database.parser import VectorSQLRetrieveAllOutputParser
|
|
|
|
| 23 |
from langchain.sql_database import SQLDatabase
|
| 24 |
from langchain.retrievers import SQLDatabaseChainRetriever
|
| 25 |
|
| 26 |
+
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
| 27 |
+
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
|
| 28 |
+
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
|
| 29 |
+
ChatDataSQLAskCallBackHandler
|
| 30 |
+
from prompts.arxiv_prompt import combine_prompt_template, _myscale_prompt
|
| 31 |
|
| 32 |
st.set_page_config(page_title="ChatData")
|
| 33 |
|
| 34 |
st.header("ChatData")
|
| 35 |
|
| 36 |
+
columns = ['ref_id', 'title', 'id', 'categories', 'abstract', 'authors', 'pubdate']
|
| 37 |
|
| 38 |
|
| 39 |
def try_eval(x):
|
|
|
|
| 43 |
return x
|
| 44 |
|
| 45 |
|
| 46 |
+
def display(dataframe, columns=None, index=None):
|
| 47 |
+
if index:
|
| 48 |
+
dataframe.set_index(index)
|
| 49 |
if len(dataframe) > 0:
|
| 50 |
if columns:
|
| 51 |
st.dataframe(dataframe[columns])
|
|
|
|
| 112 |
doc_search, "Scientific papers indexes with abstracts. All in English.", metadata_field_info,
|
| 113 |
use_original_query=False)
|
| 114 |
|
| 115 |
+
|
| 116 |
+
document_with_metadata_prompt = PromptTemplate(
|
| 117 |
+
input_variables=["page_content", "id", "title", "ref_id",
|
| 118 |
+
"authors", "pubdate", "categories"],
|
| 119 |
+
template="Title for PDF #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\tDate of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}")
|
| 120 |
+
|
| 121 |
+
COMBINE_PROMPT = ChatPromptTemplate.from_strings(
|
| 122 |
+
string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
|
| 123 |
+
(HumanMessagePromptTemplate, '{question}')])
|
| 124 |
+
OPENAI_API_KEY = st.secrets['OPENAI_API_KEY']
|
| 125 |
+
|
| 126 |
+
with st.spinner('Building QA Chain with Self-query...'):
|
| 127 |
+
chain = ArXivQAwithSourcesChain(
|
| 128 |
retriever=retriever,
|
| 129 |
+
combine_documents_chain=ArXivStuffDocumentChain(
|
| 130 |
+
llm_chain=LLMChain(
|
| 131 |
+
prompt=COMBINE_PROMPT,
|
| 132 |
+
llm=ChatOpenAI(model_name='gpt-3.5-turbo-16k',
|
| 133 |
+
openai_api_key=OPENAI_API_KEY, temperature=0.6),
|
| 134 |
+
),
|
| 135 |
+
document_prompt=document_with_metadata_prompt,
|
| 136 |
+
document_variable_name="summaries",
|
| 137 |
|
| 138 |
+
),
|
| 139 |
+
return_source_documents=True,
|
| 140 |
+
max_tokens_limit=12000,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
with st.spinner('Building Vector SQL Database Retriever'):
|
| 144 |
MYSCALE_USER = st.secrets['MYSCALE_USER']
|
| 145 |
MYSCALE_PASSWORD = st.secrets['MYSCALE_PASSWORD']
|
| 146 |
MYSCALE_HOST = st.secrets['MYSCALE_HOST']
|
|
|
|
| 156 |
output_parser = VectorSQLRetrieveAllOutputParser.from_embeddings(
|
| 157 |
model=embeddings)
|
| 158 |
sql_query_chain = SQLDatabaseChain.from_llm(
|
| 159 |
+
llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),
|
| 160 |
prompt=PROMPT,
|
| 161 |
top_k=10,
|
| 162 |
return_direct=True,
|
|
|
|
| 166 |
)
|
| 167 |
sql_retriever = SQLDatabaseChainRetriever(
|
| 168 |
sql_db_chain=sql_query_chain, page_content_key="abstract")
|
| 169 |
+
|
| 170 |
+
with st.spinner('Building QA Chain with Vector SQL...'):
|
| 171 |
+
sql_chain = ArXivQAwithSourcesChain(
|
| 172 |
retriever=sql_retriever,
|
| 173 |
+
combine_documents_chain=ArXivStuffDocumentChain(
|
| 174 |
+
llm_chain=LLMChain(
|
| 175 |
+
prompt=COMBINE_PROMPT,
|
| 176 |
+
llm=ChatOpenAI(model_name='gpt-3.5-turbo-16k',
|
| 177 |
+
openai_api_key=OPENAI_API_KEY, temperature=0.6),
|
| 178 |
+
),
|
| 179 |
+
document_prompt=document_with_metadata_prompt,
|
| 180 |
+
document_variable_name="summaries",
|
| 181 |
+
|
| 182 |
+
),
|
| 183 |
+
return_source_documents=True,
|
| 184 |
+
max_tokens_limit=12000,
|
| 185 |
+
)
|
| 186 |
|
| 187 |
return [{'name': m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info], retriever, chain, sql_retriever, sql_chain
|
| 188 |
|
|
|
|
| 243 |
display(docs)
|
| 244 |
except Exception as e:
|
| 245 |
st.write('Oops 😵 Something bad happened...')
|
| 246 |
+
raise e
|
| 247 |
|
| 248 |
if st.session_state.ask_sql:
|
| 249 |
plc_hldr = st.empty()
|
|
|
|
| 256 |
callback.progress_bar.progress(value=1.0, text="Done!")
|
| 257 |
st.markdown(
|
| 258 |
f"### Answer from LLM\n{ret['answer']}\n### References")
|
| 259 |
+
docs = ret['sources']
|
| 260 |
+
docs = pd.DataFrame([{**d.metadata, 'abstract': d.page_content} for d in docs])
|
| 261 |
+
display(docs, columns, index='ref_id')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
except Exception as e:
|
| 263 |
st.write('Oops 😵 Something bad happened...')
|
| 264 |
+
raise e
|
| 265 |
|
| 266 |
|
| 267 |
with tab_self_query:
|
|
|
|
| 288 |
display(docs, columns)
|
| 289 |
except Exception as e:
|
| 290 |
st.write('Oops 😵 Something bad happened...')
|
| 291 |
+
raise e
|
| 292 |
|
| 293 |
if st.session_state.ask_self:
|
| 294 |
plc_hldr = st.empty()
|
|
|
|
| 302 |
callback.progress_bar.progress(value=1.0, text="Done!")
|
| 303 |
st.markdown(
|
| 304 |
f"### Answer from LLM\n{ret['answer']}\n### References")
|
| 305 |
+
docs = ret['sources']
|
| 306 |
+
docs = pd.DataFrame([{**d.metadata, 'abstract': d.page_content} for d in docs])
|
| 307 |
+
display(docs, columns, index='ref_id')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
except Exception as e:
|
| 309 |
st.write('Oops 😵 Something bad happened...')
|
| 310 |
+
raise e
|
chains/arxiv_chains.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import inspect
|
| 3 |
+
from typing import Dict, Any, Optional, List, Tuple
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from langchain.callbacks.manager import (
|
| 7 |
+
AsyncCallbackManagerForChainRun,
|
| 8 |
+
CallbackManagerForChainRun,
|
| 9 |
+
)
|
| 10 |
+
from langchain.schema import BaseRetriever
|
| 11 |
+
from langchain.callbacks.manager import Callbacks
|
| 12 |
+
from langchain.schema.prompt_template import format_document
|
| 13 |
+
from langchain.docstore.document import Document
|
| 14 |
+
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
|
| 15 |
+
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
| 16 |
+
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ArXivStuffDocumentChain(StuffDocumentsChain):
|
| 20 |
+
"""Combine arxiv documents with PDF reference number"""
|
| 21 |
+
|
| 22 |
+
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
| 23 |
+
"""Construct inputs from kwargs and docs.
|
| 24 |
+
|
| 25 |
+
Format and the join all the documents together into one input with name
|
| 26 |
+
`self.document_variable_name`. The pluck any additional variables
|
| 27 |
+
from **kwargs.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
docs: List of documents to format and then join into single input
|
| 31 |
+
**kwargs: additional inputs to chain, will pluck any other required
|
| 32 |
+
arguments from here.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
dictionary of inputs to LLMChain
|
| 36 |
+
"""
|
| 37 |
+
# Format each document according to the prompt
|
| 38 |
+
doc_strings = []
|
| 39 |
+
for doc_id, doc in enumerate(docs):
|
| 40 |
+
# add temp reference number in metadata
|
| 41 |
+
doc.metadata.update({'ref_id': doc_id})
|
| 42 |
+
doc.page_content = doc.page_content.replace('\n', ' ')
|
| 43 |
+
doc_strings.append(format_document(doc, self.document_prompt))
|
| 44 |
+
# Join the documents together to put them in the prompt.
|
| 45 |
+
inputs = {
|
| 46 |
+
k: v
|
| 47 |
+
for k, v in kwargs.items()
|
| 48 |
+
if k in self.llm_chain.prompt.input_variables
|
| 49 |
+
}
|
| 50 |
+
inputs[self.document_variable_name] = self.document_separator.join(
|
| 51 |
+
doc_strings)
|
| 52 |
+
return inputs
|
| 53 |
+
|
| 54 |
+
def combine_docs(
|
| 55 |
+
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
| 56 |
+
) -> Tuple[str, dict]:
|
| 57 |
+
"""Stuff all documents into one prompt and pass to LLM.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
docs: List of documents to join together into one variable
|
| 61 |
+
callbacks: Optional callbacks to pass along
|
| 62 |
+
**kwargs: additional parameters to use to get inputs to LLMChain.
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
The first element returned is the single string output. The second
|
| 66 |
+
element returned is a dictionary of other keys to return.
|
| 67 |
+
"""
|
| 68 |
+
inputs = self._get_inputs(docs, **kwargs)
|
| 69 |
+
# Call predict on the LLM.
|
| 70 |
+
output = self.llm_chain.predict(callbacks=callbacks, **inputs)
|
| 71 |
+
return output, {}
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def _chain_type(self) -> str:
|
| 75 |
+
return "referenced_stuff_documents_chain"
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ArXivQAwithSourcesChain(RetrievalQAWithSourcesChain):
|
| 79 |
+
"""QA with source chain for Chat ArXiv app with references
|
| 80 |
+
|
| 81 |
+
This chain will automatically assign reference number to the article,
|
| 82 |
+
Then parse it back to titles or anything else.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def _call(
|
| 86 |
+
self,
|
| 87 |
+
inputs: Dict[str, Any],
|
| 88 |
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
| 89 |
+
) -> Dict[str, str]:
|
| 90 |
+
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
| 91 |
+
accepts_run_manager = (
|
| 92 |
+
"run_manager" in inspect.signature(self._get_docs).parameters
|
| 93 |
+
)
|
| 94 |
+
if accepts_run_manager:
|
| 95 |
+
docs = self._get_docs(inputs, run_manager=_run_manager)
|
| 96 |
+
else:
|
| 97 |
+
docs = self._get_docs(inputs) # type: ignore[call-arg]
|
| 98 |
+
|
| 99 |
+
answer = self.combine_documents_chain.run(
|
| 100 |
+
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
| 101 |
+
)
|
| 102 |
+
# parse source with ref_id
|
| 103 |
+
sources = []
|
| 104 |
+
ref_cnt = 1
|
| 105 |
+
for d in docs:
|
| 106 |
+
ref_id = d.metadata['ref_id']
|
| 107 |
+
if f"PDF #{ref_id}" in answer:
|
| 108 |
+
title = d.metadata['title'].replace('\n', '')
|
| 109 |
+
d.metadata['ref_id'] = ref_cnt
|
| 110 |
+
answer = answer.replace(f"PDF #{ref_id}", f"{title} [{ref_cnt}]")
|
| 111 |
+
sources.append(d)
|
| 112 |
+
ref_cnt += 1
|
| 113 |
+
|
| 114 |
+
result: Dict[str, Any] = {
|
| 115 |
+
self.answer_key: answer,
|
| 116 |
+
self.sources_answer_key: sources,
|
| 117 |
+
}
|
| 118 |
+
if self.return_source_documents:
|
| 119 |
+
result["source_documents"] = docs
|
| 120 |
+
return result
|
| 121 |
+
|
| 122 |
+
async def _acall(
|
| 123 |
+
self,
|
| 124 |
+
inputs: Dict[str, Any],
|
| 125 |
+
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
| 126 |
+
) -> Dict[str, Any]:
|
| 127 |
+
raise NotImplementedError
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def _chain_type(self) -> str:
|
| 131 |
+
return "arxiv_qa_with_sources_chain"
|
prompts/arxiv_prompt.py
CHANGED
|
@@ -1,15 +1,14 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
"
|
| 4 |
-
+ "related to PDFs given below. You should only use the abstract of the selected papers as your source of information "
|
| 5 |
+ "and try to provide concise and accurate answers to any questions asked by the user. If you are unable to find "
|
| 6 |
+ "relevant information in the given sections, you will need to let the user know that the source does not contain "
|
| 7 |
-
+ "relevant information but still try to provide an answer based on your general knowledge.
|
| 8 |
-
+ "
|
|
|
|
|
|
|
| 9 |
)
|
| 10 |
|
| 11 |
-
combine_prompt_template = combine_prompt_template_ + combine_prompt_template
|
| 12 |
-
|
| 13 |
_myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
|
| 14 |
MyScale queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance.
|
| 15 |
When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows.
|
|
|
|
| 1 |
+
combine_prompt_template = (
|
| 2 |
+
"You are a helpful PDF assistant. Your task is to provide information and answer any questions "
|
| 3 |
+
+ "related to PDFs given below. You should use the sections, title and abstract of the selected PDFs as your source of information "
|
|
|
|
| 4 |
+ "and try to provide concise and accurate answers to any questions asked by the user. If you are unable to find "
|
| 5 |
+ "relevant information in the given sections, you will need to let the user know that the source does not contain "
|
| 6 |
+
+ "relevant information but still try to provide an answer based on your general knowledge. You must refer to the "
|
| 7 |
+
+ "corresponding section name and page that you refer to when answering. The following is the related information "
|
| 8 |
+
+ "about the PDF file that will help you answer users' questions, you MUST answer it using question's language:\n\n {summaries}"
|
| 9 |
+
+ "Now you should anwser user's question. Remember you must use the PDF # to refer papers:\n\n"
|
| 10 |
)
|
| 11 |
|
|
|
|
|
|
|
| 12 |
_myscale_prompt = """You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
|
| 13 |
MyScale queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance.
|
| 14 |
When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows.
|