ez7051 commited on
Commit
82eed88
·
verified ·
1 Parent(s): 8954616

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
2
+
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.document_loaders import ArxivLoader
6
+ from faiss import IndexFlatL2
7
+ from langchain_community.docstore.in_memory import InMemoryDocstore
8
+ from langchain.document_transformers import LongContextReorder
9
+ from langchain_core.runnables import RunnableLambda
10
+ from langchain_core.runnables.passthrough import RunnableAssign
11
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
12
+
13
+ from langchain_core.prompts import ChatPromptTemplate
14
+ from langchain_core.output_parsers import StrOutputParser
15
+
16
+ import gradio as gr
17
+ from functools import partial
18
+ from operator import itemgetter
19
+
20
+
21
+ text_splitter = RecursiveCharacterTextSplitter(
22
+ chunk_size=1000, chunk_overlap=100,
23
+ separators=["\n\n", "\n", ".", ";", ",", " ", ""],
24
+ )
25
+
26
+
27
+ print("Loading Documents")
28
+ docs = [
29
+ ArxivLoader(query="1706.03762").load(), ## Attention Is All You Need Paper
30
+ ArxivLoader(query="1810.04805").load(), ## BERT Paper
31
+ ArxivLoader(query="2005.11401").load(), ## RAG Paper
32
+ ArxivLoader(query="2205.00445").load(), ## MRKL Paper
33
+ ArxivLoader(query="2310.06825").load(), ## Mistral Paper
34
+ ArxivLoader(query="2306.05685").load(), ## LLM-as-a-Judge
35
+ ## Some longer papers
36
+ # ArxivLoader(query="2210.03629").load(), ## ReAct Paper
37
+ # ArxivLoader(query="2112.10752").load(), ## Latent Stable Diffusion Paper
38
+ # ArxivLoader(query="2103.00020").load(), ## CLIP Paper
39
+ ]
40
+
41
+
42
+ for doc in docs:
43
+ content = doc[0].page_content
44
+ if "References" in content:
45
+ doc[0].page_content = content[:content.index("References")]
46
+
47
+ ## Split the documents and also filter out stubs (overly short chunks)
48
+ print("Chunking Documents")
49
+ docs_chunks = [text_splitter.split_documents(doc) for doc in docs]
50
+ docs_chunks = [[c for c in dchunks if len(c.page_content) > 200] for dchunks in docs_chunks]
51
+
52
+ ## Make some custom Chunks to give big-picture details
53
+ doc_string = "Available Documents:"
54
+ doc_metadata = []
55
+ for chunks in docs_chunks:
56
+ metadata = getattr(chunks[0], 'metadata', {})
57
+ doc_string += "\n - " + metadata.get('Title')
58
+ doc_metadata += [str(metadata)]
59
+
60
+ extra_chunks = [doc_string] + doc_metadata
61
+
62
+ embedder = NVIDIAEmbeddings(model="nvolveqa_40k", model_type=None)
63
+
64
+ ## Construct series of document vector stores
65
+ print("Constructing Vector Stores")
66
+ vecstores = [FAISS.from_texts(extra_chunks, embedder)]
67
+ vecstores += [FAISS.from_documents(doc_chunks, embedder) for doc_chunks in docs_chunks]
68
+
69
+ embed_dims = len(embedder.embed_query("test"))
70
+ def default_FAISS():
71
+ '''Useful utility for making an empty FAISS vectorstore'''
72
+ return FAISS(
73
+ embedding_function=embedder,
74
+ index=IndexFlatL2(embed_dims),
75
+ docstore=InMemoryDocstore(),
76
+ index_to_docstore_id={},
77
+ normalize_L2=False
78
+ )
79
+
80
+ def aggregate_vstores(vectorstores):
81
+ ## Initialize an empty FAISS Index and merge others into it
82
+ ## We'll use default_faiss for simplicity, though it's tied to your embedder by reference
83
+ agg_vstore = default_FAISS()
84
+ for vstore in vectorstores:
85
+ agg_vstore.merge_from(vstore)
86
+ return agg_vstore
87
+
88
+ if 'docstore' not in globals():
89
+ ## Unintuitive optimization; merge_from seems to optimize constituent vector stores away
90
+ docstore = aggregate_vstores(vecstores)
91
+
92
+ print(f"Constructed aggregate docstore with {len(docstore.docstore._dict)} chunks")
93
+
94
+ llm = ChatNVIDIA(model="mixtral_8x7b") | StrOutputParser()
95
+ convstore = default_FAISS()
96
+
97
+ def save_memory_and_get_output(d, vstore):
98
+ """Accepts 'input'/'output' dictionary and saves to convstore"""
99
+ vstore.add_texts([
100
+ f"User previously responded with {d.get('input')}",
101
+ f"Agent previously responded with {d.get('output')}"
102
+ ])
103
+ return d.get('output')
104
+
105
+ initial_msg = (
106
+ "Hello! I am a document chat agent here to help the user!"
107
+ f" I have access to the following documents: {doc_string}\n\nHow can I help you?"
108
+ )
109
+
110
+ chat_prompt = ChatPromptTemplate.from_messages([("system",
111
+ "You are a document chatbot. Help the user as they ask questions about documents."
112
+ " User messaged just asked: {input}\n\n"
113
+ " From this, we have retrieved the following potentially-useful info: "
114
+ " Conversation History Retrieval:\n{history}\n\n"
115
+ " Document Retrieval:\n{context}\n\n"
116
+ " (Answer only from retrieval. Only cite sources that are used. Make your response conversational.)"
117
+ ), ('user', '{input}')])
118
+
119
+ retrieval_chain = (
120
+ {'input' : (lambda x: x)}
121
+ | RunnableAssign({'history' : itemgetter('input') | convstore.as_retriever() | long_reorder | docs2str})
122
+ | RunnableAssign({'context' : itemgetter('input') | docstore.as_retriever() | long_reorder | docs2str})
123
+ | RPrint()
124
+ )
125
+
126
+
127
+ stream_chain = chat_prompt | llm
128
+
129
+ def chat_gen(message, history=[], return_buffer=True):
130
+ buffer = ""
131
+ ## First perform the retrieval based on the input message
132
+ retrieval = retrieval_chain.invoke(message)
133
+ line_buffer = ""
134
+
135
+ ## Then, stream the results of the stream_chain
136
+ for token in stream_chain.stream(retrieval):
137
+ buffer += token
138
+ ## If you're using standard print, keep line from getting too long
139
+ if not return_buffer:
140
+ line_buffer += token
141
+ if "\n" in line_buffer:
142
+ line_buffer = ""
143
+ if ((len(line_buffer)>84 and token and token[0] == " ") or len(line_buffer)>100):
144
+ line_buffer = ""
145
+ yield "\n"
146
+ token = " " + token.lstrip()
147
+ yield buffer if return_buffer else token
148
+
149
+ ## Lastly, save the chat exchange to the conversation memory buffer
150
+ save_memory_and_get_output({'input': message, 'output': buffer}, convstore)
151
+
152
+
153
+ chatbot = gr.Chatbot(value = [[None, initial_msg]])
154
+ demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
155
+
156
+ try:
157
+ demo.launch(debug=True, share=True, show_api=False)
158
+ demo.close()
159
+ except Exception as e:
160
+ demo.close()
161
+ print(e)
162
+ raise e
163
+