Spaces:
Runtime error
Runtime error
Update chains/arxiv_chains.py
Browse files- chains/arxiv_chains.py +17 -0
chains/arxiv_chains.py
CHANGED
|
@@ -16,6 +16,23 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
| 16 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class ArXivStuffDocumentChain(StuffDocumentsChain):
|
| 20 |
"""Combine arxiv documents with PDF reference number"""
|
| 21 |
|
|
|
|
| 16 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
| 17 |
|
| 18 |
|
| 19 |
+
class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
|
| 20 |
+
"""Based on VectorSQLOutputParser
|
| 21 |
+
It also modify the SQL to get all columns
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
@property
|
| 25 |
+
def _type(self) -> str:
|
| 26 |
+
return "vector_sql_retrieve_custom"
|
| 27 |
+
|
| 28 |
+
def parse(self, text: str) -> Dict[str, Any]:
|
| 29 |
+
text = text.strip()
|
| 30 |
+
start = text.upper().find("SELECT")
|
| 31 |
+
if start >= 0:
|
| 32 |
+
end = text.upper().find("FROM")
|
| 33 |
+
text = text.replace(text[start + len("SELECT") + 1 : end - 1], "title, abstract, authors, pubdate, categories, id")
|
| 34 |
+
return super().parse(text)
|
| 35 |
+
|
| 36 |
class ArXivStuffDocumentChain(StuffDocumentsChain):
|
| 37 |
"""Combine arxiv documents with PDF reference number"""
|
| 38 |
|