Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
|
| 2 |
+
|
| 3 |
+
from langchain_community.vectorstores import FAISS
|
| 4 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 5 |
+
from langchain.document_loaders import ArxivLoader
|
| 6 |
+
from faiss import IndexFlatL2
|
| 7 |
+
from langchain_community.docstore.in_memory import InMemoryDocstore
|
| 8 |
+
from langchain.document_transformers import LongContextReorder
|
| 9 |
+
from langchain_core.runnables import RunnableLambda
|
| 10 |
+
from langchain_core.runnables.passthrough import RunnableAssign
|
| 11 |
+
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
|
| 12 |
+
|
| 13 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 14 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 15 |
+
|
| 16 |
+
import gradio as gr
|
| 17 |
+
from functools import partial
|
| 18 |
+
from operator import itemgetter
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 22 |
+
chunk_size=1000, chunk_overlap=100,
|
| 23 |
+
separators=["\n\n", "\n", ".", ";", ",", " ", ""],
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
print("Loading Documents")
|
| 28 |
+
docs = [
|
| 29 |
+
ArxivLoader(query="1706.03762").load(), ## Attention Is All You Need Paper
|
| 30 |
+
ArxivLoader(query="1810.04805").load(), ## BERT Paper
|
| 31 |
+
ArxivLoader(query="2005.11401").load(), ## RAG Paper
|
| 32 |
+
ArxivLoader(query="2205.00445").load(), ## MRKL Paper
|
| 33 |
+
ArxivLoader(query="2310.06825").load(), ## Mistral Paper
|
| 34 |
+
ArxivLoader(query="2306.05685").load(), ## LLM-as-a-Judge
|
| 35 |
+
## Some longer papers
|
| 36 |
+
# ArxivLoader(query="2210.03629").load(), ## ReAct Paper
|
| 37 |
+
# ArxivLoader(query="2112.10752").load(), ## Latent Stable Diffusion Paper
|
| 38 |
+
# ArxivLoader(query="2103.00020").load(), ## CLIP Paper
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
for doc in docs:
|
| 43 |
+
content = doc[0].page_content
|
| 44 |
+
if "References" in content:
|
| 45 |
+
doc[0].page_content = content[:content.index("References")]
|
| 46 |
+
|
| 47 |
+
## Split the documents and also filter out stubs (overly short chunks)
|
| 48 |
+
print("Chunking Documents")
|
| 49 |
+
docs_chunks = [text_splitter.split_documents(doc) for doc in docs]
|
| 50 |
+
docs_chunks = [[c for c in dchunks if len(c.page_content) > 200] for dchunks in docs_chunks]
|
| 51 |
+
|
| 52 |
+
## Make some custom Chunks to give big-picture details
|
| 53 |
+
doc_string = "Available Documents:"
|
| 54 |
+
doc_metadata = []
|
| 55 |
+
for chunks in docs_chunks:
|
| 56 |
+
metadata = getattr(chunks[0], 'metadata', {})
|
| 57 |
+
doc_string += "\n - " + metadata.get('Title')
|
| 58 |
+
doc_metadata += [str(metadata)]
|
| 59 |
+
|
| 60 |
+
extra_chunks = [doc_string] + doc_metadata
|
| 61 |
+
|
| 62 |
+
embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None)
|
| 63 |
+
|
| 64 |
+
## Construct series of document vector stores
|
| 65 |
+
print("Constructing Vector Stores")
|
| 66 |
+
vecstores = [FAISS.from_texts(extra_chunks, embedder)]
|
| 67 |
+
vecstores += [FAISS.from_documents(doc_chunks, embedder) for doc_chunks in docs_chunks]
|
| 68 |
+
|
| 69 |
+
embed_dims = len(embedder.embed_query("test"))
|
| 70 |
+
def default_FAISS():
|
| 71 |
+
'''Useful utility for making an empty FAISS vectorstore'''
|
| 72 |
+
return FAISS(
|
| 73 |
+
embedding_function=embedder,
|
| 74 |
+
index=IndexFlatL2(embed_dims),
|
| 75 |
+
docstore=InMemoryDocstore(),
|
| 76 |
+
index_to_docstore_id={},
|
| 77 |
+
normalize_L2=False
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
def aggregate_vstores(vectorstores):
|
| 81 |
+
## Initialize an empty FAISS Index and merge others into it
|
| 82 |
+
## We'll use default_faiss for simplicity, though it's tied to your embedder by reference
|
| 83 |
+
agg_vstore = default_FAISS()
|
| 84 |
+
for vstore in vectorstores:
|
| 85 |
+
agg_vstore.merge_from(vstore)
|
| 86 |
+
return agg_vstore
|
| 87 |
+
|
| 88 |
+
if 'docstore' not in globals():
|
| 89 |
+
## Unintuitive optimization; merge_from seems to optimize constituent vector stores away
|
| 90 |
+
docstore = aggregate_vstores(vecstores)
|
| 91 |
+
|
| 92 |
+
print(f"Constructed aggregate docstore with {len(docstore.docstore._dict)} chunks")
|
| 93 |
+
|
| 94 |
+
llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser()
|
| 95 |
+
convstore = default_FAISS()
|
| 96 |
+
|
| 97 |
+
def save_memory_and_get_output(d, vstore):
|
| 98 |
+
"""Accepts 'input'/'output' dictionary and saves to convstore"""
|
| 99 |
+
vstore.add_texts([
|
| 100 |
+
f"User previously responded with {d.get('input')}",
|
| 101 |
+
f"Agent previously responded with {d.get('output')}"
|
| 102 |
+
])
|
| 103 |
+
return d.get('output')
|
| 104 |
+
|
| 105 |
+
initial_msg = (
|
| 106 |
+
"Hello! I am a document chat agent here to help the user!"
|
| 107 |
+
f" I have access to the following documents: {doc_string}\n\nHow can I help you?"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
chat_prompt = ChatPromptTemplate.from_messages([("system",
|
| 111 |
+
"You are a document chatbot. Help the user as they ask questions about documents."
|
| 112 |
+
" User messaged just asked: {input}\n\n"
|
| 113 |
+
" From this, we have retrieved the following potentially-useful info: "
|
| 114 |
+
" Conversation History Retrieval:\n{history}\n\n"
|
| 115 |
+
" Document Retrieval:\n{context}\n\n"
|
| 116 |
+
" (Answer only from retrieval. Only cite sources that are used. Make your response conversational.)"
|
| 117 |
+
), ('user', '{input}')])
|
| 118 |
+
|
| 119 |
+
retrieval_chain = (
|
| 120 |
+
{'input' : (lambda x: x)}
|
| 121 |
+
| RunnableAssign({'history' : itemgetter('input') | convstore.as_retriever() | long_reorder | docs2str})
|
| 122 |
+
| RunnableAssign({'context' : itemgetter('input') | docstore.as_retriever() | long_reorder | docs2str})
|
| 123 |
+
| RPrint()
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
stream_chain = chat_prompt | llm
|
| 128 |
+
|
| 129 |
+
def chat_gen(message, history=[], return_buffer=True):
|
| 130 |
+
buffer = ""
|
| 131 |
+
## First perform the retrieval based on the input message
|
| 132 |
+
retrieval = retrieval_chain.invoke(message)
|
| 133 |
+
line_buffer = ""
|
| 134 |
+
|
| 135 |
+
## Then, stream the results of the stream_chain
|
| 136 |
+
for token in stream_chain.stream(retrieval):
|
| 137 |
+
buffer += token
|
| 138 |
+
## If you're using standard print, keep line from getting too long
|
| 139 |
+
if not return_buffer:
|
| 140 |
+
line_buffer += token
|
| 141 |
+
if "\n" in line_buffer:
|
| 142 |
+
line_buffer = ""
|
| 143 |
+
if ((len(line_buffer)>84 and token and token[0] == " ") or len(line_buffer)>100):
|
| 144 |
+
line_buffer = ""
|
| 145 |
+
yield "\n"
|
| 146 |
+
token = " " + token.lstrip()
|
| 147 |
+
yield buffer if return_buffer else token
|
| 148 |
+
|
| 149 |
+
## Lastly, save the chat exchange to the conversation memory buffer
|
| 150 |
+
save_memory_and_get_output({'input': message, 'output': buffer}, convstore)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
chatbot = gr.Chatbot(value = [[None, initial_msg]])
|
| 154 |
+
demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
demo.launch(debug=True, share=True, show_api=False)
|
| 158 |
+
demo.close()
|
| 159 |
+
except Exception as e:
|
| 160 |
+
demo.close()
|
| 161 |
+
print(e)
|
| 162 |
+
raise e
|
| 163 |
+
|