Spaces:
Runtime error
Runtime error
update to new preview
Browse files- app.py +9 -8
- chains/arxiv_chains.py +1 -1
app.py
CHANGED
|
@@ -9,7 +9,7 @@ environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
|
|
| 9 |
from langchain.vectorstores import MyScale, MyScaleSettings
|
| 10 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
| 11 |
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
| 12 |
-
from langchain.chains.query_constructor.base import AttributeInfo
|
| 13 |
from langchain import OpenAI
|
| 14 |
from langchain.chat_models import ChatOpenAI
|
| 15 |
|
|
@@ -19,9 +19,9 @@ from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
|
|
| 19 |
from sqlalchemy import create_engine, MetaData
|
| 20 |
from langchain.chains import LLMChain
|
| 21 |
|
| 22 |
-
from
|
| 23 |
-
from langchain_experimental.retrievers.
|
| 24 |
-
from langchain_experimental.sql.
|
| 25 |
|
| 26 |
from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
|
| 27 |
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
|
@@ -82,7 +82,7 @@ def build_retriever():
|
|
| 82 |
with st.spinner("Building Self Query Retriever..."):
|
| 83 |
metadata_field_info = [
|
| 84 |
AttributeInfo(
|
| 85 |
-
name="pubdate",
|
| 86 |
description="The year the paper is published",
|
| 87 |
type="timestamp",
|
| 88 |
),
|
|
@@ -155,7 +155,7 @@ def build_retriever():
|
|
| 155 |
|
| 156 |
output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
|
| 157 |
model=embeddings)
|
| 158 |
-
sql_query_chain =
|
| 159 |
llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),
|
| 160 |
prompt=PROMPT,
|
| 161 |
top_k=10,
|
|
@@ -164,7 +164,7 @@ def build_retriever():
|
|
| 164 |
sql_cmd_parser=output_parser,
|
| 165 |
native_format=True
|
| 166 |
)
|
| 167 |
-
sql_retriever =
|
| 168 |
sql_db_chain=sql_query_chain, page_content_key="abstract")
|
| 169 |
|
| 170 |
with st.spinner('Building QA Chain with Vector SQL...'):
|
|
@@ -184,7 +184,8 @@ def build_retriever():
|
|
| 184 |
max_tokens_limit=12000,
|
| 185 |
)
|
| 186 |
|
| 187 |
-
return [{'name': m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info],
|
|
|
|
| 188 |
|
| 189 |
|
| 190 |
if 'retriever' not in st.session_state:
|
|
|
|
| 9 |
from langchain.vectorstores import MyScale, MyScaleSettings
|
| 10 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
| 11 |
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
| 12 |
+
from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
|
| 13 |
from langchain import OpenAI
|
| 14 |
from langchain.chat_models import ChatOpenAI
|
| 15 |
|
|
|
|
| 19 |
from sqlalchemy import create_engine, MetaData
|
| 20 |
from langchain.chains import LLMChain
|
| 21 |
|
| 22 |
+
from langchain.utilities.sql_database import SQLDatabase
|
| 23 |
+
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
|
| 24 |
+
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
|
| 25 |
|
| 26 |
from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
|
| 27 |
from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
|
|
|
|
| 82 |
with st.spinner("Building Self Query Retriever..."):
|
| 83 |
metadata_field_info = [
|
| 84 |
AttributeInfo(
|
| 85 |
+
name=VirtualColumnName(name="pubdate"),
|
| 86 |
description="The year the paper is published",
|
| 87 |
type="timestamp",
|
| 88 |
),
|
|
|
|
| 155 |
|
| 156 |
output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
|
| 157 |
model=embeddings)
|
| 158 |
+
sql_query_chain = VectorSQLDatabaseChain.from_llm(
|
| 159 |
llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),
|
| 160 |
prompt=PROMPT,
|
| 161 |
top_k=10,
|
|
|
|
| 164 |
sql_cmd_parser=output_parser,
|
| 165 |
native_format=True
|
| 166 |
)
|
| 167 |
+
sql_retriever = VectorSQLDatabaseChainRetriever(
|
| 168 |
sql_db_chain=sql_query_chain, page_content_key="abstract")
|
| 169 |
|
| 170 |
with st.spinner('Building QA Chain with Vector SQL...'):
|
|
|
|
| 184 |
max_tokens_limit=12000,
|
| 185 |
)
|
| 186 |
|
| 187 |
+
return [{'name': m.name.name if type(m.name) is VirtualColumnName else m.name, 'desc': m.description, 'type': m.type} for m in metadata_field_info], \
|
| 188 |
+
retriever, chain, sql_retriever, sql_chain
|
| 189 |
|
| 190 |
|
| 191 |
if 'retriever' not in st.session_state:
|
chains/arxiv_chains.py
CHANGED
|
@@ -15,7 +15,7 @@ from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesCha
|
|
| 15 |
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
| 16 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
| 17 |
|
| 18 |
-
from langchain_experimental.sql.
|
| 19 |
|
| 20 |
|
| 21 |
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
|
|
|
|
| 15 |
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
| 16 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
| 17 |
|
| 18 |
+
from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
|
| 19 |
|
| 20 |
|
| 21 |
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
|