RAG / app.py
ez7051's picture
Update app.py
b65a947 verified
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import ArxivLoader
from faiss import IndexFlatL2
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain.document_transformers import LongContextReorder
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.passthrough import RunnableAssign
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
import gradio as gr
from functools import partial
from operator import itemgetter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=100,
separators=["\n\n", "\n", ".", ";", ",", " ", ""],
)
print("Loading Documents....")
docs = [
ArxivLoader(query="1706.03762").load(), ## Attention Is All You Need Paper
ArxivLoader(query="1810.04805").load(), ## BERT Paper
ArxivLoader(query="2005.11401").load(), ## RAG Paper
# ArxivLoader(query="2205.00445").load(), ## MRKL Paper
# ArxivLoader(query="2310.06825").load(), ## Mistral Paper
ArxivLoader(query="2306.05685").load(), ## LLM-as-a-Judge
## Some longer papers
# ArxivLoader(query="2210.03629").load(), ## ReAct Paper
# ArxivLoader(query="2112.10752").load(), ## Latent Stable Diffusion Paper
# ArxivLoader(query="2103.00020").load(), ## CLIP Paper
]
for doc in docs:
content = doc[0].page_content
if "References" in content:
doc[0].page_content = content[:content.index("References")]
## Split the documents and also filter out stubs (overly short chunks)
print("Chunking Documents")
docs_chunks = [text_splitter.split_documents(doc) for doc in docs]
docs_chunks = [[c for c in dchunks if len(c.page_content) > 200] for dchunks in docs_chunks]
## Make some custom Chunks to give big-picture details
doc_string = "Available Documents:"
doc_metadata = []
for chunks in docs_chunks:
metadata = getattr(chunks[0], 'metadata', {})
doc_string += "\n - " + metadata.get('Title')
doc_metadata += [str(metadata)]
extra_chunks = [doc_string] + doc_metadata
embedder = NVIDIAEmbeddings(model="nvidia/embed-qa-4", model_type=None)
## Construct series of document vector stores
print("Constructing Vector Stores")
vecstores = [FAISS.from_texts(extra_chunks, embedder)]
vecstores += [FAISS.from_documents(doc_chunks, embedder) for doc_chunks in docs_chunks]
embed_dims = len(embedder.embed_query("test"))
def default_FAISS():
'''Useful utility for making an empty FAISS vectorstore'''
return FAISS(
embedding_function=embedder,
index=IndexFlatL2(embed_dims),
docstore=InMemoryDocstore(),
index_to_docstore_id={},
normalize_L2=False
)
def aggregate_vstores(vectorstores):
## Initialize an empty FAISS Index and merge others into it
## We'll use default_faiss for simplicity, though it's tied to your embedder by reference
agg_vstore = default_FAISS()
for vstore in vectorstores:
agg_vstore.merge_from(vstore)
return agg_vstore
if 'docstore' not in globals():
## Unintuitive optimization; merge_from seems to optimize constituent vector stores away
docstore = aggregate_vstores(vecstores)
print(f"Constructed aggregate docstore with {len(docstore.docstore._dict)} chunks")
########################################################################
## Utility Runnables/Methods
def RPrint(preface=""):
"""Simple passthrough "prints, then returns" chain"""
def print_and_return(x, preface):
print(f"{preface}{x}")
return x
return RunnableLambda(partial(print_and_return, preface=preface))
def docs2str(docs, title="Document"):
"""Useful utility for making chunks into context string. Optional, but useful"""
out_str = ""
for doc in docs:
doc_name = getattr(doc, 'metadata', {}).get('Title', title)
if doc_name:
out_str += f"[Quote from {doc_name}] "
out_str += getattr(doc, 'page_content', str(doc)) + "\n"
return out_str
## Optional; Reorders longer documents to center of output text
long_reorder = RunnableLambda(LongContextReorder().transform_documents)
########################################################################
llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser()
convstore = default_FAISS()
def save_memory_and_get_output(d, vstore):
"""Accepts 'input'/'output' dictionary and saves to convstore"""
vstore.add_texts([
f"User previously responded with {d.get('input')}",
f"Agent previously responded with {d.get('output')}"
])
return d.get('output')
initial_msg = (
"Hello! I am a document chat agent here to help the user!"
f" I have access to the following documents: {doc_string}\n\nHow can I help you?"
)
chat_prompt = ChatPromptTemplate.from_messages([("system",
"You are a document chatbot. Help the user as they ask questions about documents."
" User messaged just asked: {input}\n\n"
" From this, we have retrieved the following potentially-useful info: "
" Conversation History Retrieval:\n{history}\n\n"
" Document Retrieval:\n{context}\n\n"
" (Answer only from retrieval. Only cite sources that are used. Make your response conversational.)"
), ('user', '{input}')])
retrieval_chain = (
{'input' : (lambda x: x)}
| RunnableAssign({'history' : itemgetter('input') | convstore.as_retriever() | long_reorder | docs2str})
| RunnableAssign({'context' : itemgetter('input') | docstore.as_retriever() | long_reorder | docs2str})
| RPrint()
)
stream_chain = chat_prompt | llm
def chat_gen(message, history=[], return_buffer=True):
buffer = ""
## First perform the retrieval based on the input message
retrieval = retrieval_chain.invoke(message)
line_buffer = ""
## Then, stream the results of the stream_chain
for token in stream_chain.stream(retrieval):
buffer += token
## If you're using standard print, keep line from getting too long
if not return_buffer:
line_buffer += token
if "\n" in line_buffer:
line_buffer = ""
if ((len(line_buffer)>84 and token and token[0] == " ") or len(line_buffer)>100):
line_buffer = ""
yield "\n"
token = " " + token.lstrip()
yield buffer if return_buffer else token
## Lastly, save the chat exchange to the conversation memory buffer
save_memory_and_get_output({'input': message, 'output': buffer}, convstore)
chatbot = gr.Chatbot(value = [[None, initial_msg]])
demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
try:
demo.launch(debug=True, share=True, show_api=False)
demo.close()
except Exception as e:
demo.close()
print(e)
raise e