File size: 7,024 Bytes
82eed88 b65a947 82eed88 d9610b7 82eed88 65faae7 27a43fb 7135d32 82eed88 327db2e 82eed88 7135d32 82eed88 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import ArxivLoader
from faiss import IndexFlatL2
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain.document_transformers import LongContextReorder
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.passthrough import RunnableAssign
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
import gradio as gr
from functools import partial
from operator import itemgetter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=100,
separators=["\n\n", "\n", ".", ";", ",", " ", ""],
)
print("Loading Documents....")
docs = [
ArxivLoader(query="1706.03762").load(), ## Attention Is All You Need Paper
ArxivLoader(query="1810.04805").load(), ## BERT Paper
ArxivLoader(query="2005.11401").load(), ## RAG Paper
# ArxivLoader(query="2205.00445").load(), ## MRKL Paper
# ArxivLoader(query="2310.06825").load(), ## Mistral Paper
ArxivLoader(query="2306.05685").load(), ## LLM-as-a-Judge
## Some longer papers
# ArxivLoader(query="2210.03629").load(), ## ReAct Paper
# ArxivLoader(query="2112.10752").load(), ## Latent Stable Diffusion Paper
# ArxivLoader(query="2103.00020").load(), ## CLIP Paper
]
for doc in docs:
content = doc[0].page_content
if "References" in content:
doc[0].page_content = content[:content.index("References")]
## Split the documents and also filter out stubs (overly short chunks)
print("Chunking Documents")
docs_chunks = [text_splitter.split_documents(doc) for doc in docs]
docs_chunks = [[c for c in dchunks if len(c.page_content) > 200] for dchunks in docs_chunks]
## Make some custom Chunks to give big-picture details
doc_string = "Available Documents:"
doc_metadata = []
for chunks in docs_chunks:
metadata = getattr(chunks[0], 'metadata', {})
doc_string += "\n - " + metadata.get('Title')
doc_metadata += [str(metadata)]
extra_chunks = [doc_string] + doc_metadata
embedder = NVIDIAEmbeddings(model="nvidia/embed-qa-4", model_type=None)
## Construct series of document vector stores
print("Constructing Vector Stores")
vecstores = [FAISS.from_texts(extra_chunks, embedder)]
vecstores += [FAISS.from_documents(doc_chunks, embedder) for doc_chunks in docs_chunks]
embed_dims = len(embedder.embed_query("test"))
def default_FAISS():
'''Useful utility for making an empty FAISS vectorstore'''
return FAISS(
embedding_function=embedder,
index=IndexFlatL2(embed_dims),
docstore=InMemoryDocstore(),
index_to_docstore_id={},
normalize_L2=False
)
def aggregate_vstores(vectorstores):
## Initialize an empty FAISS Index and merge others into it
## We'll use default_faiss for simplicity, though it's tied to your embedder by reference
agg_vstore = default_FAISS()
for vstore in vectorstores:
agg_vstore.merge_from(vstore)
return agg_vstore
if 'docstore' not in globals():
## Unintuitive optimization; merge_from seems to optimize constituent vector stores away
docstore = aggregate_vstores(vecstores)
print(f"Constructed aggregate docstore with {len(docstore.docstore._dict)} chunks")
########################################################################
## Utility Runnables/Methods
def RPrint(preface=""):
"""Simple passthrough "prints, then returns" chain"""
def print_and_return(x, preface):
print(f"{preface}{x}")
return x
return RunnableLambda(partial(print_and_return, preface=preface))
def docs2str(docs, title="Document"):
"""Useful utility for making chunks into context string. Optional, but useful"""
out_str = ""
for doc in docs:
doc_name = getattr(doc, 'metadata', {}).get('Title', title)
if doc_name:
out_str += f"[Quote from {doc_name}] "
out_str += getattr(doc, 'page_content', str(doc)) + "\n"
return out_str
## Optional; Reorders longer documents to center of output text
long_reorder = RunnableLambda(LongContextReorder().transform_documents)
########################################################################
llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser()
convstore = default_FAISS()
def save_memory_and_get_output(d, vstore):
"""Accepts 'input'/'output' dictionary and saves to convstore"""
vstore.add_texts([
f"User previously responded with {d.get('input')}",
f"Agent previously responded with {d.get('output')}"
])
return d.get('output')
initial_msg = (
"Hello! I am a document chat agent here to help the user!"
f" I have access to the following documents: {doc_string}\n\nHow can I help you?"
)
chat_prompt = ChatPromptTemplate.from_messages([("system",
"You are a document chatbot. Help the user as they ask questions about documents."
" User messaged just asked: {input}\n\n"
" From this, we have retrieved the following potentially-useful info: "
" Conversation History Retrieval:\n{history}\n\n"
" Document Retrieval:\n{context}\n\n"
" (Answer only from retrieval. Only cite sources that are used. Make your response conversational.)"
), ('user', '{input}')])
retrieval_chain = (
{'input' : (lambda x: x)}
| RunnableAssign({'history' : itemgetter('input') | convstore.as_retriever() | long_reorder | docs2str})
| RunnableAssign({'context' : itemgetter('input') | docstore.as_retriever() | long_reorder | docs2str})
| RPrint()
)
stream_chain = chat_prompt | llm
def chat_gen(message, history=[], return_buffer=True):
buffer = ""
## First perform the retrieval based on the input message
retrieval = retrieval_chain.invoke(message)
line_buffer = ""
## Then, stream the results of the stream_chain
for token in stream_chain.stream(retrieval):
buffer += token
## If you're using standard print, keep line from getting too long
if not return_buffer:
line_buffer += token
if "\n" in line_buffer:
line_buffer = ""
if ((len(line_buffer)>84 and token and token[0] == " ") or len(line_buffer)>100):
line_buffer = ""
yield "\n"
token = " " + token.lstrip()
yield buffer if return_buffer else token
## Lastly, save the chat exchange to the conversation memory buffer
save_memory_and_get_output({'input': message, 'output': buffer}, convstore)
chatbot = gr.Chatbot(value = [[None, initial_msg]])
demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
try:
demo.launch(debug=True, share=True, show_api=False)
demo.close()
except Exception as e:
demo.close()
print(e)
raise e
|