Allen Ding
Update
fb3840a
from operator import itemgetter
import chainlit as cl
import tiktoken
from langchain.document_loaders import PyMuPDFLoader
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Qdrant
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings
class RetrievalAugmentedQAPipeline:
def __init__(self, llm: ChatOpenAI, vector_db_retriever: VectorStoreRetriever) -> None:
self.llm = llm
self.vector_db_retriever = vector_db_retriever
base_rag_prompt_template = """\
Use the provided context to answer the provided user question. Only use the provided context to answer the question. If you do not know the answer, response with "I don't know"
Context:
{context}
Question:
{question}
"""
base_rag_prompt = ChatPromptTemplate.from_template(base_rag_prompt_template)
base_llm = ChatOpenAI(model="gpt-4o-mini", tags=["base_llm"])
self.retrieval_augmented_qa_chain = (
# INVOKE CHAIN WITH: {"question" : "<<SOME USER QUESTION>>"}
# "question" : populated by getting the value of the "question" key
# "context" : populated by getting the value of the "question" key and chaining it into the base_retriever
{"context": itemgetter("question") | self.vector_db_retriever, "question": itemgetter("question")}
# "context" : is assigned to a RunnablePassthrough object (will not be called or considered in the next step)
# by getting the value of the "context" key from the previous step
| RunnablePassthrough.assign(context=itemgetter("context"))
# "response" : the "context" and "question" values are used to format our prompt object and then piped
# into the LLM and stored in a key called "response"
# "context" : populated by getting the value of the "context" key from the previous step
| {"response": base_rag_prompt | base_llm }
)
async def arun_pipeline(self, user_query: str):
async def generate_response():
# yield self.retrieval_augmented_qa_chain.invoke({"question": user_query})["response"].content
async for chunk in self.retrieval_augmented_qa_chain.astream({"question": user_query}):
yield chunk["response"].content
return {"response": generate_response()}
@cl.on_chat_start
async def on_chat_start():
def tiktoken_len(text):
tokens = tiktoken.encoding_for_model("gpt-4o").encode(
text,
)
return len(tokens)
msg = cl.Message(
content=f"Getting ready...", disable_human_feedback=True
)
await msg.send()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50,
length_function=tiktoken_len,
)
doc1 = PyMuPDFLoader("Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
doc2 = PyMuPDFLoader("NIST.AI.600-1.pdf").load()
split_chunks1 = text_splitter.split_documents(doc1)
split_chunks2 = text_splitter.split_documents(doc2)
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
qdrant_vectorstore = Qdrant.from_documents(
documents=split_chunks1 + split_chunks2,
embedding=embedding_model,
location=":memory:"
)
qdrant_retriever = qdrant_vectorstore.as_retriever()
retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
vector_db_retriever=qdrant_retriever,
llm=ChatOpenAI(model="gpt-4o-mini", tags=["base_llm"])
)
# Let the user know that the system is ready
msg.content = f"Ready. You can now ask questions!"
await msg.update()
cl.user_session.set("chain", retrieval_augmented_qa_pipeline)
@cl.on_message
async def main(message):
chain = cl.user_session.get("chain")
msg = cl.Message(content="")
result = await chain.arun_pipeline(message.content)
async for stream_resp in result["response"]:
await msg.stream_token(stream_resp)
await msg.send()