cmagganas commited on
Commit
4539200
·
1 Parent(s): c885f5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -23
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import chainlit as cl
2
  from langchain.embeddings.openai import OpenAIEmbeddings
3
- from langchain.document_loaders.csv_loader import CSVLoader
4
  from langchain.embeddings import CacheBackedEmbeddings
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.vectorstores import FAISS
@@ -12,7 +12,10 @@ from langchain.prompts.chat import (
12
  SystemMessagePromptTemplate,
13
  HumanMessagePromptTemplate,
14
  )
15
- import chainlit as cl
 
 
 
16
 
17
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
18
 
@@ -39,40 +42,162 @@ messages = [
39
  prompt = ChatPromptTemplate(messages=messages)
40
  chain_type_kwargs = {"prompt": prompt}
41
 
42
- @cl.author_rename
43
- def rename(orig_author: str):
44
- rename_dict = {"RetrievalQA": "Consulting The Kens"}
45
- return rename_dict.get(orig_author, orig_author)
46
 
47
  @cl.on_chat_start
48
  async def init():
49
  msg = cl.Message(content=f"Building Index...")
50
  await msg.send()
51
 
52
- # build FAISS index from csv
53
- loader = CSVLoader(file_path="./data/barbie.csv", source_column="Review_Url")
54
- data = loader.load()
55
- documents = text_splitter.transform_documents(data)
56
- store = LocalFileStore("./cache/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  core_embeddings_model = OpenAIEmbeddings()
58
- embedder = CacheBackedEmbeddings.from_bytes_store(
59
- core_embeddings_model, store, namespace=core_embeddings_model.model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
61
- # make async docsearch
62
- docsearch = await cl.make_async(FAISS.from_documents)(documents, embedder)
63
-
64
- chain = RetrievalQA.from_chain_type(
65
- ChatOpenAI(model="gpt-4", temperature=0, streaming=True),
66
- chain_type="stuff",
67
- return_source_documents=True,
68
- retriever=docsearch.as_retriever(),
69
- chain_type_kwargs = {"prompt": prompt}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  msg.content = f"Index built!"
73
  await msg.send()
74
 
75
- cl.user_session.set("chain", chain)
76
 
77
 
78
  @cl.on_message
 
1
  import chainlit as cl
2
  from langchain.embeddings.openai import OpenAIEmbeddings
3
+ from langchain.document_loaders import WikipediaLoader, CSVLoader
4
  from langchain.embeddings import CacheBackedEmbeddings
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain.vectorstores import FAISS
 
12
  SystemMessagePromptTemplate,
13
  HumanMessagePromptTemplate,
14
  )
15
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
16
+ from langchain.agents import Tool, ZeroShotAgent, AgentExecutor
17
+ from langchain.agents.agent_toolkits import create_retriever_tool, create_conversational_retrieval_agent
18
+ from langchain import LLMChain
19
 
20
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
21
 
 
42
  prompt = ChatPromptTemplate(messages=messages)
43
  chain_type_kwargs = {"prompt": prompt}
44
 
45
+ # @cl.author_rename
46
+ # def rename(orig_author: str):
47
+ # rename_dict = {"RetrievalQA": "Consulting The Kens"}
48
+ # return rename_dict.get(orig_author, orig_author)
49
 
50
  @cl.on_chat_start
51
  async def init():
52
  msg = cl.Message(content=f"Building Index...")
53
  await msg.send()
54
 
55
+ ### start building retrievers, stores and agents
56
+ llm = ChatOpenAI(model="gpt-3.5-turbo", temperature = 0)
57
+
58
+ barbie_wikipedia_docs = WikipediaLoader(query="Barbie (film)", load_max_docs=1, doc_content_chars_max=1_000_000).load()
59
+ barbie_csv_docs = CSVLoader(file_path="./barbie_data/barbie.csv", source_column="Review_Url").load()
60
+ oppenheimer_wikipedia_docs = WikipediaLoader(query="Oppenheimer (film)", load_max_docs=1, doc_content_chars_max=1_000_000).load()
61
+ oppenheimer_csv_docs = CSVLoader(file_path="./oppenheimer_data/oppenheimer.csv", source_column="Review_Url").load()
62
+
63
+ wikipedia_text_splitter = RecursiveCharacterTextSplitter(
64
+ chunk_size = 500,
65
+ chunk_overlap = 0,
66
+ length_function = len,
67
+ is_separator_regex= False,
68
+ separators = ["\n==", "\n", " "] # keep headings, then paragraphs, then sentences
69
+ )
70
+
71
+ csv_text_splitter = RecursiveCharacterTextSplitter(
72
+ chunk_size = 1000,
73
+ chunk_overlap = 50,
74
+ length_function = len,
75
+ is_separator_regex= False,
76
+ separators = ["\n", " "] # keep paragraphs, then sentences
77
+ )
78
+
79
+ chunked_barbie_wikipedia_docs = wikipedia_text_splitter.transform_documents(barbie_wikipedia_docs)
80
+ chunked_barbie_csv_docs = csv_text_splitter.transform_documents(barbie_csv_docs)
81
+ chunked_opp_wikipedia_docs = wikipedia_text_splitter.transform_documents(oppenheimer_wikipedia_docs)
82
+ chunked_opp_csv_docs = csv_text_splitter.transform_documents(oppenheimer_csv_docs)
83
+
84
+ # #### Retrieval and Embedding Strategy
85
+ # set up cached embeddings store
86
+ store = LocalFileStore("./shared_cache/")
87
  core_embeddings_model = OpenAIEmbeddings()
88
+ embedder = CacheBackedEmbeddings.from_bytes_store(core_embeddings_model, store, namespace=core_embeddings_model.model)
89
+
90
+ # We'll implement a `FAISS` vectorstore, and create a retriever from it.
91
+ barbie_csv_faiss_retriever = await cl.make_async(FAISS.from_documents)(chunked_barbie_csv_docs, embedder).as_retriever()
92
+ opp_csv_faiss_retriever = await cl.make_async(FAISS.from_documents)(chunked_opp_csv_docs, embedder).as_retriever()
93
+ opp_wikipedia_faiss_store = await cl.make_async(FAISS.from_documents)(chunked_opp_wikipedia_docs, embedder)
94
+ opp_wikipedia_faiss_retriever = opp_wikipedia_faiss_store.as_retriever(search_kwargs={"k": 1})
95
+
96
+ # set up BM25 retriever
97
+ barbie_wikipedia_bm25_retriever = BM25Retriever.from_documents(chunked_barbie_wikipedia_docs)
98
+ barbie_wikipedia_bm25_retriever.k = 1
99
+ opp_wikipedia_bm25_retriever = BM25Retriever.from_documents(chunked_opp_wikipedia_docs)
100
+ opp_wikipedia_bm25_retriever.k = 1
101
+
102
+ # set up FAISS vector store
103
+ barbie_wikipedia_faiss_store = await cl.make_async(FAISS.from_documents)(chunked_barbie_wikipedia_docs, embedder)
104
+ barbie_wikipedia_faiss_retriever = barbie_wikipedia_faiss_store.as_retriever(search_kwargs={"k": 1})
105
+ opp_wikipedia_faiss_store = await cl.make_async(FAISS.from_documents)(chunked_opp_wikipedia_docs, embedder)
106
+ opp_wikipedia_faiss_retriever = opp_wikipedia_faiss_store.as_retriever(search_kwargs={"k": 1})
107
+
108
+ # set up ensemble retriever
109
+ barbie_ensemble_retriever = EnsembleRetriever(retrievers=[barbie_wikipedia_bm25_retriever, barbie_wikipedia_faiss_retriever],weights=[0.25, 0.75])
110
+ opp_ensemble_retriever = EnsembleRetriever(retrievers=[opp_wikipedia_bm25_retriever, opp_wikipedia_faiss_retriever],weights=[0.25, 0.75])
111
+
112
+ # #### Retrieval Agent
113
+ barbie_wikipedia_retrieval_tool = create_retriever_tool(
114
+ barbie_ensemble_retriever,
115
+ "Wikipedia",
116
+ "Searches and returns documents regarding the plot, history, and cast of the Barbie movie"
117
+ )
118
+
119
+ barbie_csv_retrieval_tool = create_retriever_tool(
120
+ barbie_csv_faiss_retriever,
121
+ "PublicReviews",
122
+ "Searches and returns documents regarding public reviews of the Barbie movie"
123
  )
124
+
125
+ barbie_retriever_tools = [barbie_wikipedia_retrieval_tool, barbie_csv_retrieval_tool]
126
+
127
+ barbie_retriever_agent_executor = create_conversational_retrieval_agent(llm, barbie_retriever_tools, verbose=True)
128
+
129
+ # #### Multi-source chain
130
+ system_message = """Use the information from the below two sources to answer any questions.
131
+
132
+ Source 1: public user reviews about the Oppenheimer movie
133
+ <source1>
134
+ {source1}
135
+ </source1>
136
+
137
+ Source 2: the wikipedia page for the Oppenheimer movie including the plot summary, cast, and production information
138
+ <source2>
139
+ {source2}
140
+ </source2>
141
+ """
142
+
143
+ prompt = ChatPromptTemplate.from_messages([("system", system_message), ("human", "{question}")])
144
+
145
+ oppenheimer_multisource_chain = {
146
+ "source1": (lambda x: x["question"]) | opp_ensemble_retriever,
147
+ "source2": (lambda x: x["question"]) | opp_csv_faiss_retriever,
148
+ "question": lambda x: x["question"],
149
+ } | prompt | llm
150
+
151
+ # # Agent Creation
152
+
153
+ def query_oppenheimer(input):
154
+ return oppenheimer_multisource_chain.invoke({"question" : input})
155
+
156
+ tools = [
157
+ Tool(
158
+ name = "BarbieInfo",
159
+ func=barbie_retriever_agent_executor.invoke,
160
+ description="useful for when you need to answer questions about Barbie. Input should be a fully formed question."
161
+ ),
162
+ Tool(
163
+ name = "OppenheimerInfo",
164
+ func=query_oppenheimer,
165
+ description="useful for when you need to answer questions about Oppenheimer. Input should be a fully formed question."
166
+ ),
167
+ ]
168
+
169
+ prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:"""
170
+ suffix = """Begin!"
171
+
172
+ Question: {input}
173
+ {agent_scratchpad}"""
174
+
175
+ prompt = ZeroShotAgent.create_prompt(
176
+ tools,
177
+ prefix=prefix,
178
+ suffix=suffix,
179
+ input_variables=["input", "agent_scratchpad"]
180
  )
181
+
182
+ llm_chain = LLMChain(llm=llm, prompt=prompt)
183
+
184
+ barbenheimer_agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)
185
+ barbenheimer_agent_chain = AgentExecutor.from_agent_and_tools(agent=barbenheimer_agent, tools=tools, verbose=True)
186
+
187
+ ########################################################################################################################################
188
+
189
+ # chain = RetrievalQA.from_chain_type(
190
+ # ChatOpenAI(model="gpt-4", temperature=0, streaming=True),
191
+ # chain_type="stuff",
192
+ # return_source_documents=True,
193
+ # retriever=docsearch.as_retriever(),
194
+ # chain_type_kwargs = {"prompt": prompt}
195
+ # )
196
 
197
  msg.content = f"Index built!"
198
  await msg.send()
199
 
200
+ cl.user_session.set("chain", barbenheimer_agent_chain)
201
 
202
 
203
  @cl.on_message