|
|
import chainlit as cl |
|
|
from langchain.retrievers import BM25Retriever, EnsembleRetriever |
|
|
from langchain.vectorstores import FAISS |
|
|
from langchain.embeddings.openai import OpenAIEmbeddings |
|
|
from langchain.embeddings import CacheBackedEmbeddings |
|
|
from langchain.storage import LocalFileStore |
|
|
from langchain.agents.agent_toolkits import create_retriever_tool |
|
|
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent |
|
|
from langchain.document_loaders import WikipediaLoader, CSVLoader |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain.prompts import ChatPromptTemplate |
|
|
from langchain.agents import Tool |
|
|
from langchain.agents import ZeroShotAgent, AgentExecutor |
|
|
from langchain.chat_models import ChatOpenAI |
|
|
from langchain import LLMChain |
|
|
|
|
|
@cl.author_rename |
|
|
def rename(orig_author: str): |
|
|
rename_dict = {"RetrievalQA": "Consulting The Barbenheimer"} |
|
|
return rename_dict.get(orig_author, orig_author) |
|
|
|
|
|
@cl.on_chat_start |
|
|
async def init(): |
|
|
|
|
|
msg = cl.Message(content=f"Building Index...") |
|
|
await msg.send() |
|
|
|
|
|
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature = 0) |
|
|
|
|
|
|
|
|
wikipedia_text_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size = 1024, |
|
|
chunk_overlap = 512, |
|
|
length_function = len, |
|
|
is_separator_regex= False, |
|
|
separators = ["\n==", "\n", " "] |
|
|
) |
|
|
csv_text_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size = 1024, |
|
|
chunk_overlap = 512, |
|
|
length_function = len, |
|
|
is_separator_regex= False, |
|
|
separators = ["\n", " "] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
store = LocalFileStore("./.cache/") |
|
|
core_embeddings_model = OpenAIEmbeddings() |
|
|
embedder = CacheBackedEmbeddings.from_bytes_store(core_embeddings_model, |
|
|
store, |
|
|
namespace=core_embeddings_model.model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
barbie_wikipedia_docs = WikipediaLoader( |
|
|
query="Barbie (film)", |
|
|
load_max_docs= 1, |
|
|
doc_content_chars_max=10000000 |
|
|
).load() |
|
|
barbie_csv_docs = CSVLoader( |
|
|
file_path= "./barbie_data/barbie.csv", |
|
|
source_column="Review" |
|
|
).load() |
|
|
|
|
|
chunked_barbie_wikipedia_docs = wikipedia_text_splitter.transform_documents(barbie_wikipedia_docs) |
|
|
chunked_barbie_csv_docs = csv_text_splitter.transform_documents(barbie_csv_docs) |
|
|
|
|
|
barbie_csv_faiss_retriever = FAISS.from_documents(chunked_barbie_csv_docs, embedder) |
|
|
|
|
|
barbie_wikipedia_bm25_retriever = BM25Retriever.from_documents( |
|
|
chunked_barbie_wikipedia_docs |
|
|
) |
|
|
barbie_wikipedia_bm25_retriever.k = 1 |
|
|
|
|
|
barbie_wikipedia_faiss_store = FAISS.from_documents( |
|
|
chunked_barbie_wikipedia_docs, |
|
|
embedder |
|
|
) |
|
|
barbie_wikipedia_faiss_retriever = barbie_wikipedia_faiss_store.as_retriever(search_kwargs={"k": 1}) |
|
|
|
|
|
barbie_ensemble_retriever = EnsembleRetriever( |
|
|
retrievers=[barbie_wikipedia_bm25_retriever, barbie_wikipedia_faiss_retriever], |
|
|
weights= [0.25, 0.75] |
|
|
) |
|
|
|
|
|
barbie_wikipedia_retrieval_tool = create_retriever_tool( |
|
|
retriever=barbie_ensemble_retriever, |
|
|
name='Search_Wikipedia', |
|
|
description='Useful for when you need to answer questions about plot, cast, production, release, music, marketing, reception, themes and analysis of the Barbie movie.' |
|
|
) |
|
|
barbie_csv_retrieval_tool = create_retriever_tool( |
|
|
retriever=barbie_csv_faiss_retriever.as_retriever(), |
|
|
name='Search_Reviews', |
|
|
description='Useful for when you need to answer questions about public reviews of the Barbie movie.' |
|
|
) |
|
|
barbie_retriever_tools = [barbie_wikipedia_retrieval_tool, barbie_csv_retrieval_tool] |
|
|
|
|
|
barbie_retriever_agent_executor = create_conversational_retrieval_agent(llm=llm, tools=barbie_retriever_tools, verbose=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
oppenheimer_wikipedia_docs = WikipediaLoader( |
|
|
query="Oppenheimer", |
|
|
load_max_docs=1, |
|
|
doc_content_chars_max=10000000 |
|
|
).load() |
|
|
oppenheimer_csv_docs = CSVLoader( |
|
|
file_path="./oppenheimer_data/oppenheimer.csv", |
|
|
source_column="Review" |
|
|
).load() |
|
|
|
|
|
chunked_opp_wikipedia_docs = wikipedia_text_splitter.transform_documents(oppenheimer_wikipedia_docs) |
|
|
chunked_opp_csv_docs = csv_text_splitter.transform_documents(oppenheimer_csv_docs) |
|
|
|
|
|
opp_csv_faiss_retriever = FAISS.from_documents(chunked_opp_csv_docs, embedder).as_retriever() |
|
|
|
|
|
opp_wikipedia_bm25_retriever = BM25Retriever.from_documents(chunked_opp_wikipedia_docs) |
|
|
opp_wikipedia_bm25_retriever.k = 1 |
|
|
|
|
|
opp_wikipedia_faiss_store = FAISS.from_documents( |
|
|
chunked_opp_wikipedia_docs, |
|
|
embedder |
|
|
) |
|
|
opp_wikipedia_faiss_retriever = opp_wikipedia_faiss_store.as_retriever(search_kwargs={"k": 1}) |
|
|
|
|
|
opp_ensemble_retriever = EnsembleRetriever( |
|
|
retrievers=[opp_wikipedia_bm25_retriever, opp_wikipedia_faiss_retriever], |
|
|
weights= [0.25, 0.75] |
|
|
) |
|
|
|
|
|
system_message = """Use the information from the below two sources to answer any questions. |
|
|
|
|
|
Source 1: public user reviews about the Oppenheimer movie |
|
|
<source1> |
|
|
{source1} |
|
|
</source1> |
|
|
Source 2: the wikipedia page for the Oppenheimer movie including the plot summary, cast, and production information |
|
|
<source2> |
|
|
{source2} |
|
|
</source2> |
|
|
""" |
|
|
prompt = ChatPromptTemplate.from_messages([("system", system_message), ("human", "{question}")]) |
|
|
|
|
|
oppenheimer_multisource_chain = { |
|
|
"source1": (lambda x: x["question"]) | opp_ensemble_retriever, |
|
|
"source2": (lambda x: x["question"]) | opp_csv_faiss_retriever, |
|
|
"question": lambda x: x["question"], |
|
|
} | prompt | llm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def query_barbie(input): |
|
|
return barbie_retriever_agent_executor({"input" : input}) |
|
|
def query_oppenheimer(input): |
|
|
return oppenheimer_multisource_chain.invoke({"question" : input}) |
|
|
tools = [ |
|
|
Tool( |
|
|
name="BarbieInfo", |
|
|
func=query_barbie, |
|
|
description='Useful when you need to answer questions about the Barbie movie' |
|
|
), |
|
|
Tool( |
|
|
name="OppenheimerInfo", |
|
|
func=query_oppenheimer, |
|
|
description='Useful when you need to answer questions about the Oppenheimer movie' |
|
|
), |
|
|
] |
|
|
|
|
|
prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:""" |
|
|
suffix = """Begin!" |
|
|
Question: {input} |
|
|
{agent_scratchpad}""" |
|
|
prompt = ZeroShotAgent.create_prompt( |
|
|
tools=tools, |
|
|
prefix=prefix, |
|
|
suffix=suffix, |
|
|
input_variables=['input', 'agent_scratchpad'] |
|
|
) |
|
|
|
|
|
llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=True) |
|
|
|
|
|
barbenheimer_agent = ZeroShotAgent( |
|
|
llm_chain=llm_chain, |
|
|
tools=tools, |
|
|
verbose=True ) |
|
|
|
|
|
barbenheimer_agent_chain = AgentExecutor.from_agent_and_tools( |
|
|
agent=barbenheimer_agent, |
|
|
tools=tools, |
|
|
verbose=True ) |
|
|
|
|
|
cl.user_session.set("chain", barbenheimer_agent_chain) |
|
|
|
|
|
msg.content = f"Agent ready!" |
|
|
await msg.send() |
|
|
|
|
|
@cl.on_message |
|
|
async def main(message): |
|
|
chain: Chain = cl.user_session.get("chain") |
|
|
cb = cl.AsyncLangchainCallbackHandler( |
|
|
stream_final_answer=False, answer_prefix_tokens=["FINAL", "ANSWER"] |
|
|
) |
|
|
cb.answer_reached = True |
|
|
answer = chain.run({"input": message}) |
|
|
|
|
|
await cl.Message(content=answer).send() |
|
|
|