Spaces:
Sleeping
Sleeping
| from operator import itemgetter | |
| import chainlit as cl | |
| import tiktoken | |
| from langchain.document_loaders import PyMuPDFLoader | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.schema.runnable import RunnablePassthrough | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Qdrant | |
| from langchain_core.vectorstores import VectorStoreRetriever | |
| from langchain_openai.chat_models import ChatOpenAI | |
| from langchain_openai.embeddings import OpenAIEmbeddings | |
| class RetrievalAugmentedQAPipeline: | |
| def __init__(self, llm: ChatOpenAI, vector_db_retriever: VectorStoreRetriever) -> None: | |
| self.llm = llm | |
| self.vector_db_retriever = vector_db_retriever | |
| base_rag_prompt_template = """\ | |
| Use the provided context to answer the provided user question. Only use the provided context to answer the question. If you do not know the answer, response with "I don't know" | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| """ | |
| base_rag_prompt = ChatPromptTemplate.from_template(base_rag_prompt_template) | |
| base_llm = ChatOpenAI(model="gpt-4o-mini", tags=["base_llm"]) | |
| self.retrieval_augmented_qa_chain = ( | |
| # INVOKE CHAIN WITH: {"question" : "<<SOME USER QUESTION>>"} | |
| # "question" : populated by getting the value of the "question" key | |
| # "context" : populated by getting the value of the "question" key and chaining it into the base_retriever | |
| {"context": itemgetter("question") | self.vector_db_retriever, "question": itemgetter("question")} | |
| # "context" : is assigned to a RunnablePassthrough object (will not be called or considered in the next step) | |
| # by getting the value of the "context" key from the previous step | |
| | RunnablePassthrough.assign(context=itemgetter("context")) | |
| # "response" : the "context" and "question" values are used to format our prompt object and then piped | |
| # into the LLM and stored in a key called "response" | |
| # "context" : populated by getting the value of the "context" key from the previous step | |
| | {"response": base_rag_prompt | base_llm } | |
| ) | |
| async def arun_pipeline(self, user_query: str): | |
| async def generate_response(): | |
| # yield self.retrieval_augmented_qa_chain.invoke({"question": user_query})["response"].content | |
| async for chunk in self.retrieval_augmented_qa_chain.astream({"question": user_query}): | |
| yield chunk["response"].content | |
| return {"response": generate_response()} | |
| async def on_chat_start(): | |
| def tiktoken_len(text): | |
| tokens = tiktoken.encoding_for_model("gpt-4o").encode( | |
| text, | |
| ) | |
| return len(tokens) | |
| msg = cl.Message( | |
| content=f"Getting ready...", disable_human_feedback=True | |
| ) | |
| await msg.send() | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, | |
| chunk_overlap=50, | |
| length_function=tiktoken_len, | |
| ) | |
| doc1 = PyMuPDFLoader("Blueprint-for-an-AI-Bill-of-Rights.pdf").load() | |
| doc2 = PyMuPDFLoader("NIST.AI.600-1.pdf").load() | |
| split_chunks1 = text_splitter.split_documents(doc1) | |
| split_chunks2 = text_splitter.split_documents(doc2) | |
| embedding_model = OpenAIEmbeddings(model="text-embedding-3-small") | |
| qdrant_vectorstore = Qdrant.from_documents( | |
| documents=split_chunks1 + split_chunks2, | |
| embedding=embedding_model, | |
| location=":memory:" | |
| ) | |
| qdrant_retriever = qdrant_vectorstore.as_retriever() | |
| retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline( | |
| vector_db_retriever=qdrant_retriever, | |
| llm=ChatOpenAI(model="gpt-4o-mini", tags=["base_llm"]) | |
| ) | |
| # Let the user know that the system is ready | |
| msg.content = f"Ready. You can now ask questions!" | |
| await msg.update() | |
| cl.user_session.set("chain", retrieval_augmented_qa_pipeline) | |
| async def main(message): | |
| chain = cl.user_session.get("chain") | |
| msg = cl.Message(content="") | |
| result = await chain.arun_pipeline(message.content) | |
| async for stream_resp in result["response"]: | |
| await msg.stream_token(stream_resp) | |
| await msg.send() | |