|
|
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings |
|
|
|
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain.document_loaders import ArxivLoader |
|
|
from faiss import IndexFlatL2 |
|
|
from langchain_community.docstore.in_memory import InMemoryDocstore |
|
|
from langchain.document_transformers import LongContextReorder |
|
|
from langchain_core.runnables import RunnableLambda |
|
|
from langchain_core.runnables.passthrough import RunnableAssign |
|
|
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings |
|
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
|
|
|
import gradio as gr |
|
|
from functools import partial |
|
|
from operator import itemgetter |
|
|
|
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=1000, chunk_overlap=100, |
|
|
separators=["\n\n", "\n", ".", ";", ",", " ", ""], |
|
|
) |
|
|
|
|
|
|
|
|
print("Loading Documents....") |
|
|
docs = [ |
|
|
ArxivLoader(query="1706.03762").load(), |
|
|
ArxivLoader(query="1810.04805").load(), |
|
|
ArxivLoader(query="2005.11401").load(), |
|
|
|
|
|
|
|
|
ArxivLoader(query="2306.05685").load(), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
for doc in docs: |
|
|
content = doc[0].page_content |
|
|
if "References" in content: |
|
|
doc[0].page_content = content[:content.index("References")] |
|
|
|
|
|
|
|
|
print("Chunking Documents") |
|
|
docs_chunks = [text_splitter.split_documents(doc) for doc in docs] |
|
|
docs_chunks = [[c for c in dchunks if len(c.page_content) > 200] for dchunks in docs_chunks] |
|
|
|
|
|
|
|
|
doc_string = "Available Documents:" |
|
|
doc_metadata = [] |
|
|
for chunks in docs_chunks: |
|
|
metadata = getattr(chunks[0], 'metadata', {}) |
|
|
doc_string += "\n - " + metadata.get('Title') |
|
|
doc_metadata += [str(metadata)] |
|
|
|
|
|
extra_chunks = [doc_string] + doc_metadata |
|
|
|
|
|
embedder = NVIDIAEmbeddings(model="nvidia/embed-qa-4", model_type=None) |
|
|
|
|
|
|
|
|
print("Constructing Vector Stores") |
|
|
vecstores = [FAISS.from_texts(extra_chunks, embedder)] |
|
|
vecstores += [FAISS.from_documents(doc_chunks, embedder) for doc_chunks in docs_chunks] |
|
|
|
|
|
embed_dims = len(embedder.embed_query("test")) |
|
|
def default_FAISS(): |
|
|
'''Useful utility for making an empty FAISS vectorstore''' |
|
|
return FAISS( |
|
|
embedding_function=embedder, |
|
|
index=IndexFlatL2(embed_dims), |
|
|
docstore=InMemoryDocstore(), |
|
|
index_to_docstore_id={}, |
|
|
normalize_L2=False |
|
|
) |
|
|
|
|
|
def aggregate_vstores(vectorstores): |
|
|
|
|
|
|
|
|
agg_vstore = default_FAISS() |
|
|
for vstore in vectorstores: |
|
|
agg_vstore.merge_from(vstore) |
|
|
return agg_vstore |
|
|
|
|
|
if 'docstore' not in globals(): |
|
|
|
|
|
docstore = aggregate_vstores(vecstores) |
|
|
|
|
|
print(f"Constructed aggregate docstore with {len(docstore.docstore._dict)} chunks") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def RPrint(preface=""): |
|
|
"""Simple passthrough "prints, then returns" chain""" |
|
|
def print_and_return(x, preface): |
|
|
print(f"{preface}{x}") |
|
|
return x |
|
|
return RunnableLambda(partial(print_and_return, preface=preface)) |
|
|
|
|
|
def docs2str(docs, title="Document"): |
|
|
"""Useful utility for making chunks into context string. Optional, but useful""" |
|
|
out_str = "" |
|
|
for doc in docs: |
|
|
doc_name = getattr(doc, 'metadata', {}).get('Title', title) |
|
|
if doc_name: |
|
|
out_str += f"[Quote from {doc_name}] " |
|
|
out_str += getattr(doc, 'page_content', str(doc)) + "\n" |
|
|
return out_str |
|
|
|
|
|
|
|
|
long_reorder = RunnableLambda(LongContextReorder().transform_documents) |
|
|
|
|
|
|
|
|
llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser() |
|
|
convstore = default_FAISS() |
|
|
|
|
|
def save_memory_and_get_output(d, vstore): |
|
|
"""Accepts 'input'/'output' dictionary and saves to convstore""" |
|
|
vstore.add_texts([ |
|
|
f"User previously responded with {d.get('input')}", |
|
|
f"Agent previously responded with {d.get('output')}" |
|
|
]) |
|
|
return d.get('output') |
|
|
|
|
|
initial_msg = ( |
|
|
"Hello! I am a document chat agent here to help the user!" |
|
|
f" I have access to the following documents: {doc_string}\n\nHow can I help you?" |
|
|
) |
|
|
|
|
|
chat_prompt = ChatPromptTemplate.from_messages([("system", |
|
|
"You are a document chatbot. Help the user as they ask questions about documents." |
|
|
" User messaged just asked: {input}\n\n" |
|
|
" From this, we have retrieved the following potentially-useful info: " |
|
|
" Conversation History Retrieval:\n{history}\n\n" |
|
|
" Document Retrieval:\n{context}\n\n" |
|
|
" (Answer only from retrieval. Only cite sources that are used. Make your response conversational.)" |
|
|
), ('user', '{input}')]) |
|
|
|
|
|
retrieval_chain = ( |
|
|
{'input' : (lambda x: x)} |
|
|
| RunnableAssign({'history' : itemgetter('input') | convstore.as_retriever() | long_reorder | docs2str}) |
|
|
| RunnableAssign({'context' : itemgetter('input') | docstore.as_retriever() | long_reorder | docs2str}) |
|
|
| RPrint() |
|
|
) |
|
|
|
|
|
|
|
|
stream_chain = chat_prompt | llm |
|
|
|
|
|
def chat_gen(message, history=[], return_buffer=True): |
|
|
buffer = "" |
|
|
|
|
|
retrieval = retrieval_chain.invoke(message) |
|
|
line_buffer = "" |
|
|
|
|
|
|
|
|
for token in stream_chain.stream(retrieval): |
|
|
buffer += token |
|
|
|
|
|
if not return_buffer: |
|
|
line_buffer += token |
|
|
if "\n" in line_buffer: |
|
|
line_buffer = "" |
|
|
if ((len(line_buffer)>84 and token and token[0] == " ") or len(line_buffer)>100): |
|
|
line_buffer = "" |
|
|
yield "\n" |
|
|
token = " " + token.lstrip() |
|
|
yield buffer if return_buffer else token |
|
|
|
|
|
|
|
|
save_memory_and_get_output({'input': message, 'output': buffer}, convstore) |
|
|
|
|
|
|
|
|
chatbot = gr.Chatbot(value = [[None, initial_msg]]) |
|
|
demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue() |
|
|
|
|
|
try: |
|
|
demo.launch(debug=True, share=True, show_api=False) |
|
|
demo.close() |
|
|
except Exception as e: |
|
|
demo.close() |
|
|
print(e) |
|
|
raise e |
|
|
|