cmagganas commited on
Commit
b8a4fc5
·
1 Parent(s): 3023464

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -136
app.py CHANGED
@@ -1,217 +1,239 @@
1
  import chainlit as cl
 
 
2
  from langchain.embeddings.openai import OpenAIEmbeddings
3
- from langchain.document_loaders import WikipediaLoader, CSVLoader
4
  from langchain.embeddings import CacheBackedEmbeddings
 
 
 
 
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain.vectorstores import FAISS
7
- from langchain.chains import RetrievalQA
 
8
  from langchain.chat_models import ChatOpenAI
9
- from langchain.storage import LocalFileStore
10
- from langchain.prompts.chat import (
11
- ChatPromptTemplate,
12
- SystemMessagePromptTemplate,
13
- HumanMessagePromptTemplate,
14
- )
15
- from langchain.retrievers import BM25Retriever, EnsembleRetriever
16
- from langchain.agents import Tool, ZeroShotAgent, AgentExecutor
17
- from langchain.agents.agent_toolkits import create_retriever_tool, create_conversational_retrieval_agent
18
  from langchain import LLMChain
19
 
20
- llm = ChatOpenAI(model="gpt-3.5-turbo", temperature = 0)
 
 
 
21
 
22
  @cl.on_chat_start
23
  async def init():
 
24
  msg = cl.Message(content=f"Building Index...")
25
  await msg.send()
26
-
27
- barbie_wikipedia_docs = WikipediaLoader(query="Barbie (film)", load_max_docs=1, doc_content_chars_max=1_000_000).load()
28
- barbie_csv_docs = CSVLoader(file_path="./barbie_data/barbie.csv", source_column="Review_Url").load()
29
- oppenheimer_wikipedia_docs = WikipediaLoader(query="Oppenheimer (film)", load_max_docs=1, doc_content_chars_max=1_000_000).load()
30
- oppenheimer_csv_docs = CSVLoader(file_path="./oppenheimer_data/oppenheimer.csv", source_column="Review_Url").load()
31
 
 
 
 
32
  wikipedia_text_splitter = RecursiveCharacterTextSplitter(
33
- chunk_size = 500,
34
- chunk_overlap = 0,
35
  length_function = len,
36
  is_separator_regex= False,
37
  separators = ["\n==", "\n", " "] # keep headings, then paragraphs, then sentences
38
  )
39
-
40
  csv_text_splitter = RecursiveCharacterTextSplitter(
41
- chunk_size = 1000,
42
- chunk_overlap = 50,
43
  length_function = len,
44
  is_separator_regex= False,
45
  separators = ["\n", " "] # keep paragraphs, then sentences
46
  )
47
 
48
- chunked_barbie_wikipedia_docs = wikipedia_text_splitter.transform_documents(barbie_wikipedia_docs)
49
- chunked_barbie_csv_docs = csv_text_splitter.transform_documents(barbie_csv_docs)
50
- chunked_opp_wikipedia_docs = wikipedia_text_splitter.transform_documents(oppenheimer_wikipedia_docs)
51
- chunked_opp_csv_docs = csv_text_splitter.transform_documents(oppenheimer_csv_docs)
52
-
53
- # #### Retrieval and Embedding Strategy
54
  # set up cached embeddings store
55
- store = LocalFileStore("./shared_cache/")
56
  core_embeddings_model = OpenAIEmbeddings()
57
- embedder = CacheBackedEmbeddings.from_bytes_store(core_embeddings_model, store, namespace=core_embeddings_model.model)
58
-
59
- # set up FAISS vector store for csv
60
- barbie_csv_faiss_async = await cl.make_async(FAISS.from_documents)(chunked_barbie_csv_docs, embedder)
61
- barbie_csv_faiss_retriever = barbie_csv_faiss_async.as_retriever()
62
- opp_csv_faiss_async = await cl.make_async(FAISS.from_documents)(chunked_opp_csv_docs, embedder)
63
- opp_csv_faiss_retriever = opp_csv_faiss_async.as_retriever()
64
-
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  # set up BM25 retriever
66
- barbie_wikipedia_bm25_retriever = await cl.make_async(BM25Retriever.from_documents)(chunked_barbie_wikipedia_docs)
 
 
67
  barbie_wikipedia_bm25_retriever.k = 1
68
- opp_wikipedia_bm25_retriever = await cl.make_async(BM25Retriever.from_documents)(chunked_opp_wikipedia_docs)
69
- opp_wikipedia_bm25_retriever.k = 1
70
-
71
- # set up FAISS vector store for Wiki
72
- barbie_wikipedia_faiss_store = await cl.make_async(FAISS.from_documents)(chunked_barbie_wikipedia_docs, embedder)
73
  barbie_wikipedia_faiss_retriever = barbie_wikipedia_faiss_store.as_retriever(search_kwargs={"k": 1})
74
- opp_wikipedia_faiss_store = await cl.make_async(FAISS.from_documents)(chunked_opp_wikipedia_docs, embedder)
75
- opp_wikipedia_faiss_retriever = opp_wikipedia_faiss_store.as_retriever(search_kwargs={"k": 1})
76
-
77
  # set up ensemble retriever
78
- barbie_ensemble_retriever = await cl.make_async(EnsembleRetriever)(
79
  retrievers=[barbie_wikipedia_bm25_retriever, barbie_wikipedia_faiss_retriever],
80
- weights=[0.25, 0.75]
81
- )
82
- opp_ensemble_retriever = await cl.make_async(EnsembleRetriever)(
83
- retrievers=[opp_wikipedia_bm25_retriever, opp_wikipedia_faiss_retriever],
84
- weights=[0.25, 0.75]
85
  )
86
-
87
- # #### Retrieval Agent
88
  barbie_wikipedia_retrieval_tool = create_retriever_tool(
89
- barbie_ensemble_retriever,
90
- "Wikipedia",
91
- "Searches and returns documents regarding the plot, history, and cast of the Barbie movie"
92
  )
93
-
94
  barbie_csv_retrieval_tool = create_retriever_tool(
95
- barbie_csv_faiss_retriever,
96
- "PublicReviews",
97
- "Searches and returns documents regarding public reviews of the Barbie movie"
98
  )
99
-
100
- barbie_retriever_tools = [barbie_wikipedia_retrieval_tool, barbie_csv_retrieval_tool]
101
-
102
- barbie_retriever_agent_executor = create_conversational_retrieval_agent(llm, barbie_retriever_tools, verbose=True)
103
-
104
- # #### Multi-source chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  system_message = """Use the information from the below two sources to answer any questions.
106
 
107
  Source 1: public user reviews about the Oppenheimer movie
108
  <source1>
109
  {source1}
110
  </source1>
111
-
112
  Source 2: the wikipedia page for the Oppenheimer movie including the plot summary, cast, and production information
113
  <source2>
114
  {source2}
115
  </source2>
116
  """
117
-
118
  prompt = ChatPromptTemplate.from_messages([("system", system_message), ("human", "{question}")])
119
-
120
  oppenheimer_multisource_chain = {
121
  "source1": (lambda x: x["question"]) | opp_ensemble_retriever,
122
  "source2": (lambda x: x["question"]) | opp_csv_faiss_retriever,
123
  "question": lambda x: x["question"],
124
  } | prompt | llm
125
 
126
- # # Agent Creation
127
-
 
 
 
128
  def query_oppenheimer(input):
129
  return oppenheimer_multisource_chain.invoke({"question" : input})
130
-
131
  tools = [
132
  Tool(
133
- name = "BarbieInfo",
134
- func=barbie_retriever_agent_executor.invoke,
135
- description="useful for when you need to answer questions about Barbie. Input should be a fully formed question."
136
  ),
137
  Tool(
138
- name = "OppenheimerInfo",
139
  func=query_oppenheimer,
140
- description="useful for when you need to answer questions about Oppenheimer. Input should be a fully formed question."
141
  ),
142
  ]
143
-
144
  prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:"""
145
  suffix = """Begin!"
146
-
147
  Question: {input}
148
  {agent_scratchpad}"""
149
-
150
  prompt = ZeroShotAgent.create_prompt(
151
- tools,
152
  prefix=prefix,
153
  suffix=suffix,
154
- input_variables=["input", "agent_scratchpad"]
155
  )
156
-
157
- llm_chain = LLMChain(llm=llm, prompt=prompt)
158
-
159
- barbenheimer_agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)
160
- barbenheimer_agent_chain = AgentExecutor.from_agent_and_tools(agent=barbenheimer_agent, tools=tools, verbose=True)
161
- # barbenheimer_agent_chain = cl.make_async(AgentExecutor.from_agent_and_tools)(
162
- # agent=barbenheimer_agent,
163
- # tools=tools,
164
- # verbose=True
165
- # )
166
-
167
- # ######################
168
- # reference code from v1
169
- # docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder)
170
-
171
- # chain = RetrievalQA.from_chain_type(
172
- # ChatOpenAI(model="gpt-4", temperature=0, streaming=True),
173
- # chain_type="stuff",
174
- # return_source_documents=True,
175
- # retriever=docsearch.as_retriever(),
176
- # chain_type_kwargs = {"prompt": prompt}
177
- # )
178
-
179
- msg.content = f"Index built!"
180
  await msg.send()
181
 
182
- cl.user_session.set("barbenheimer_agent_chain", barbenheimer_agent_chain)
183
-
184
-
185
  @cl.on_message
186
  async def main(message):
187
- chain = cl.user_session.get("barbenheimer_agent_chain")
188
- cb = cl.AsyncLangchainCallbackHandler(
 
 
 
 
189
  stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"]
190
  )
191
  cb.answer_reached = True
192
- res = await chain.acall(message, callbacks=[cb], )
193
-
194
- answer = res["result"]
 
 
195
  source_elements = []
196
- visited_sources = set()
197
-
198
- # Get the documents from the user session
199
- docs = res["source_documents"]
200
- metadatas = [doc.metadata for doc in docs]
201
- all_sources = [m["source"] for m in metadatas]
202
-
203
- for source in all_sources:
204
- if source in visited_sources:
205
- continue
206
- visited_sources.add(source)
207
- # Create the text element referenced in the message
208
- source_elements.append(
209
- cl.Text(content="https://www.imdb.com" + source, name="Review URL")
210
- )
211
-
212
- if source_elements:
213
- answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}"
214
- else:
215
- answer += "\nNo sources found"
216
 
217
  await cl.Message(content=answer, elements=source_elements).send()
 
1
  import chainlit as cl
2
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
3
+ from langchain.vectorstores import FAISS
4
  from langchain.embeddings.openai import OpenAIEmbeddings
 
5
  from langchain.embeddings import CacheBackedEmbeddings
6
+ from langchain.storage import LocalFileStore
7
+ from langchain.agents.agent_toolkits import create_retriever_tool
8
+ from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
9
+ from langchain.document_loaders import WikipediaLoader, CSVLoader
10
  from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain.prompts import ChatPromptTemplate
12
+ from langchain.agents import Tool
13
+ from langchain.agents import ZeroShotAgent, AgentExecutor
14
  from langchain.chat_models import ChatOpenAI
 
 
 
 
 
 
 
 
 
15
  from langchain import LLMChain
16
 
17
+ @cl.author_rename
18
+ def rename(orig_author: str):
19
+ rename_dict = {"RetrievalQA": "Consulting The Barbenheimer"}
20
+ return rename_dict.get(orig_author, orig_author)
21
 
22
  @cl.on_chat_start
23
  async def init():
24
+
25
  msg = cl.Message(content=f"Building Index...")
26
  await msg.send()
 
 
 
 
 
27
 
28
+ llm = ChatOpenAI(model="gpt-3.5-turbo", temperature = 0)
29
+
30
+ # set up text splitters
31
  wikipedia_text_splitter = RecursiveCharacterTextSplitter(
32
+ chunk_size = 1024,
33
+ chunk_overlap = 512,
34
  length_function = len,
35
  is_separator_regex= False,
36
  separators = ["\n==", "\n", " "] # keep headings, then paragraphs, then sentences
37
  )
 
38
  csv_text_splitter = RecursiveCharacterTextSplitter(
39
+ chunk_size = 1024,
40
+ chunk_overlap = 512,
41
  length_function = len,
42
  is_separator_regex= False,
43
  separators = ["\n", " "] # keep paragraphs, then sentences
44
  )
45
 
46
+
 
 
 
 
 
47
  # set up cached embeddings store
48
+ store = LocalFileStore("./.cache/")
49
  core_embeddings_model = OpenAIEmbeddings()
50
+ embedder = CacheBackedEmbeddings.from_bytes_store(core_embeddings_model,
51
+ store,
52
+ namespace=core_embeddings_model.model)
53
+
54
+
55
+ # Barbie retrieval system (Wikipedia, CSV)
56
+ # load the multiple source documents for Barbie and build FAISS index
57
+ barbie_wikipedia_docs = WikipediaLoader(
58
+ query="Barbie (film)",
59
+ load_max_docs= 1, # YOUR CODE HERE,
60
+ doc_content_chars_max=10000000
61
+ ).load()
62
+ barbie_csv_docs = CSVLoader(
63
+ file_path= "./barbie_data/barbie.csv",
64
+ source_column="Review"
65
+ ).load()
66
+ # chunk the loaded documents using the text splitters
67
+ chunked_barbie_wikipedia_docs = wikipedia_text_splitter.transform_documents(barbie_wikipedia_docs)
68
+ chunked_barbie_csv_docs = csv_text_splitter.transform_documents(barbie_csv_docs)
69
+ # set up FAISS vector store and create retriever for CSV docs
70
+ barbie_csv_faiss_retriever = FAISS.from_documents(chunked_barbie_csv_docs, embedder)
71
  # set up BM25 retriever
72
+ barbie_wikipedia_bm25_retriever = BM25Retriever.from_documents(
73
+ chunked_barbie_wikipedia_docs
74
+ )
75
  barbie_wikipedia_bm25_retriever.k = 1
76
+ # set up FAISS vector store and create retriever
77
+ barbie_wikipedia_faiss_store = FAISS.from_documents(
78
+ chunked_barbie_wikipedia_docs,
79
+ embedder
80
+ )
81
  barbie_wikipedia_faiss_retriever = barbie_wikipedia_faiss_store.as_retriever(search_kwargs={"k": 1})
 
 
 
82
  # set up ensemble retriever
83
+ barbie_ensemble_retriever = EnsembleRetriever(
84
  retrievers=[barbie_wikipedia_bm25_retriever, barbie_wikipedia_faiss_retriever],
85
+ weights= [0.25, 0.75] # should sum to 1
 
 
 
 
86
  )
87
+ # create retriever tools
 
88
  barbie_wikipedia_retrieval_tool = create_retriever_tool(
89
+ retriever=barbie_ensemble_retriever,
90
+ name='Search_Wikipedia',
91
+ description='Useful for when you need to answer questions about plot, cast, production, release, music, marketing, reception, themes and analysis of the Barbie movie.'
92
  )
 
93
  barbie_csv_retrieval_tool = create_retriever_tool(
94
+ retriever=barbie_csv_faiss_retriever.as_retriever(),
95
+ name='Search_Reviews',
96
+ description='Useful for when you need to answer questions about public reviews of the Barbie movie.'
97
  )
98
+ barbie_retriever_tools = [barbie_wikipedia_retrieval_tool, barbie_csv_retrieval_tool]
99
+ # retrieval agent
100
+ barbie_retriever_agent_executor = create_conversational_retrieval_agent(llm=llm, tools=barbie_retriever_tools, verbose=True)
101
+
102
+
103
+ # Oppenheimer retrieval system (Wikipedia, CSV)
104
+ # load the multiple source documents for Oppenheimer and build FAISS index
105
+ oppenheimer_wikipedia_docs = WikipediaLoader(
106
+ query="Oppenheimer",
107
+ load_max_docs=1,
108
+ doc_content_chars_max=10000000
109
+ ).load()
110
+ oppenheimer_csv_docs = CSVLoader(
111
+ file_path="./oppenheimer_data/oppenheimer.csv",
112
+ source_column="Review"
113
+ ).load()
114
+ # chunk the loaded documents using the text splitters
115
+ chunked_opp_wikipedia_docs = wikipedia_text_splitter.transform_documents(oppenheimer_wikipedia_docs)
116
+ chunked_opp_csv_docs = csv_text_splitter.transform_documents(oppenheimer_csv_docs)
117
+ # set up FAISS vector store and create retriever for CSV docs
118
+ opp_csv_faiss_retriever = FAISS.from_documents(chunked_opp_csv_docs, embedder).as_retriever()
119
+ # set up BM25 retriever
120
+ opp_wikipedia_bm25_retriever = BM25Retriever.from_documents(chunked_opp_wikipedia_docs)
121
+ opp_wikipedia_bm25_retriever.k = 1
122
+ # set up FAISS vector store and create retriever
123
+ opp_wikipedia_faiss_store = FAISS.from_documents(
124
+ chunked_opp_wikipedia_docs,
125
+ embedder
126
+ )
127
+ opp_wikipedia_faiss_retriever = opp_wikipedia_faiss_store.as_retriever(search_kwargs={"k": 1})
128
+ # set up ensemble retriever
129
+ opp_ensemble_retriever = EnsembleRetriever(
130
+ retrievers=[opp_wikipedia_bm25_retriever, opp_wikipedia_faiss_retriever],
131
+ weights= [0.25, 0.75] # should sum to 1
132
+ )
133
+ # setup prompt
134
  system_message = """Use the information from the below two sources to answer any questions.
135
 
136
  Source 1: public user reviews about the Oppenheimer movie
137
  <source1>
138
  {source1}
139
  </source1>
 
140
  Source 2: the wikipedia page for the Oppenheimer movie including the plot summary, cast, and production information
141
  <source2>
142
  {source2}
143
  </source2>
144
  """
 
145
  prompt = ChatPromptTemplate.from_messages([("system", system_message), ("human", "{question}")])
146
+ # build multi-source chain
147
  oppenheimer_multisource_chain = {
148
  "source1": (lambda x: x["question"]) | opp_ensemble_retriever,
149
  "source2": (lambda x: x["question"]) | opp_csv_faiss_retriever,
150
  "question": lambda x: x["question"],
151
  } | prompt | llm
152
 
153
+
154
+ # Agent creation
155
+ # set up tools
156
+ def query_barbie(input):
157
+ return barbie_retriever_agent_executor({"input" : input})
158
  def query_oppenheimer(input):
159
  return oppenheimer_multisource_chain.invoke({"question" : input})
 
160
  tools = [
161
  Tool(
162
+ name="BarbieInfo",
163
+ func=query_barbie,
164
+ description='Useful when you need to answer questions about the Barbie movie'
165
  ),
166
  Tool(
167
+ name="OppenheimerInfo",
168
  func=query_oppenheimer,
169
+ description='Useful when you need to answer questions about the Oppenheimer movie'
170
  ),
171
  ]
172
+ # create prompt
173
  prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:"""
174
  suffix = """Begin!"
 
175
  Question: {input}
176
  {agent_scratchpad}"""
 
177
  prompt = ZeroShotAgent.create_prompt(
178
+ tools=tools,
179
  prefix=prefix,
180
  suffix=suffix,
181
+ input_variables=['input', 'agent_scratchpad']
182
  )
183
+ # chain llm with prompt
184
+ llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=True)
185
+ # create reasoning agent
186
+ barbenheimer_agent = ZeroShotAgent(
187
+ llm_chain=llm_chain,
188
+ tools=tools,
189
+ verbose=True )
190
+ # create execution agent
191
+ barbenheimer_agent_chain = AgentExecutor.from_agent_and_tools(
192
+ agent=barbenheimer_agent,
193
+ tools=tools,
194
+ verbose=True )
195
+
196
+ cl.user_session.set("chain", barbenheimer_agent_chain)
197
+
198
+ msg.content = f"Agent ready!"
 
 
 
 
 
 
 
 
199
  await msg.send()
200
 
 
 
 
201
  @cl.on_message
202
  async def main(message):
203
+
204
+ # msg = cl.Message(content=f"Thinking...")
205
+ # await msg.send()
206
+
207
+ chain = cl.user_session.get("chain")
208
+ cb = cl.LangchainCallbackHandler(
209
  stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"]
210
  )
211
  cb.answer_reached = True
212
+ res = chain.__call__(message, callbacks=[cb], )
213
+
214
+ # print(res.keys()) # keys are "input" and "output"
215
+
216
+ answer = res["output"]
217
  source_elements = []
218
+ # visited_sources = set()
219
+
220
+ # # Get the documents from the user session
221
+ # docs = res["source_documents"]
222
+ # metadatas = [doc.metadata for doc in docs]
223
+ # all_sources = [m["source"] for m in metadatas]
224
+
225
+ # for source in all_sources:
226
+ # if source in visited_sources:
227
+ # continue
228
+ # visited_sources.add(source)
229
+ # # Create the text element referenced in the message
230
+ # source_elements.append(
231
+ # cl.Text(content="https://www.imdb.com" + source, name="Review URL")
232
+ # )
233
+
234
+ # if source_elements:
235
+ # answer += f"\nSources: {', '.join([e.content.decode('utf-8') for e in source_elements])}"
236
+ # else:
237
+ # answer += "\nNo sources found"
238
 
239
  await cl.Message(content=answer, elements=source_elements).send()