File size: 4,360 Bytes
d07d73c
c689c08
d07d73c
 
 
 
 
 
 
 
 
 
c689c08
 
 
d07d73c
c689c08
 
 
d07d73c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b950f71
d07d73c
c689c08
d07d73c
c689c08
b950f71
 
d07d73c
b950f71
c689c08
d07d73c
c689c08
 
 
 
d07d73c
 
 
 
 
c689c08
 
d07d73c
c689c08
 
 
d07d73c
fb3840a
 
d07d73c
 
c689c08
d07d73c
 
 
 
 
c689c08
d07d73c
 
 
 
 
 
c689c08
d07d73c
 
c689c08
d07d73c
c689c08
d07d73c
c689c08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d07d73c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from operator import itemgetter

import chainlit as cl
import tiktoken
from langchain.document_loaders import PyMuPDFLoader
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Qdrant
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings


class RetrievalAugmentedQAPipeline:
    def __init__(self, llm: ChatOpenAI, vector_db_retriever: VectorStoreRetriever) -> None:
        self.llm = llm
        self.vector_db_retriever = vector_db_retriever

        base_rag_prompt_template = """\
        Use the provided context to answer the provided user question. Only use the provided context to answer the question. If you do not know the answer, response with "I don't know"

        Context:
        {context}

        Question:
        {question}
        """
        base_rag_prompt = ChatPromptTemplate.from_template(base_rag_prompt_template)
        base_llm = ChatOpenAI(model="gpt-4o-mini", tags=["base_llm"])
        self.retrieval_augmented_qa_chain = (
            # INVOKE CHAIN WITH: {"question" : "<<SOME USER QUESTION>>"}
            # "question" : populated by getting the value of the "question" key
            # "context"  : populated by getting the value of the "question" key and chaining it into the base_retriever
                {"context": itemgetter("question") | self.vector_db_retriever, "question": itemgetter("question")}
                # "context"  : is assigned to a RunnablePassthrough object (will not be called or considered in the next step)
                #              by getting the value of the "context" key from the previous step
                | RunnablePassthrough.assign(context=itemgetter("context"))
                # "response" : the "context" and "question" values are used to format our prompt object and then piped
                #              into the LLM and stored in a key called "response"
                # "context"  : populated by getting the value of the "context" key from the previous step
                | {"response": base_rag_prompt | base_llm }
        )

    async def arun_pipeline(self, user_query: str):
        async def generate_response():
            # yield self.retrieval_augmented_qa_chain.invoke({"question": user_query})["response"].content

            async for chunk in self.retrieval_augmented_qa_chain.astream({"question": user_query}):
                yield chunk["response"].content

        return {"response": generate_response()}


@cl.on_chat_start
async def on_chat_start():
    def tiktoken_len(text):
        tokens = tiktoken.encoding_for_model("gpt-4o").encode(
            text,
        )
        return len(tokens)

    msg = cl.Message(
        content=f"Getting ready...", disable_human_feedback=True
    )
    await msg.send()

    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,
        chunk_overlap=50,
        length_function=tiktoken_len,
    )

    doc1 = PyMuPDFLoader("Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
    doc2 = PyMuPDFLoader("NIST.AI.600-1.pdf").load()
    split_chunks1 = text_splitter.split_documents(doc1)
    split_chunks2 = text_splitter.split_documents(doc2)
    embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")

    qdrant_vectorstore = Qdrant.from_documents(
        documents=split_chunks1 + split_chunks2,
        embedding=embedding_model,
        location=":memory:"
    )
    qdrant_retriever = qdrant_vectorstore.as_retriever()
    retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
        vector_db_retriever=qdrant_retriever,
        llm=ChatOpenAI(model="gpt-4o-mini", tags=["base_llm"])
    )

    # Let the user know that the system is ready
    msg.content = f"Ready. You can now ask questions!"
    await msg.update()

    cl.user_session.set("chain", retrieval_augmented_qa_pipeline)


@cl.on_message
async def main(message):
    chain = cl.user_session.get("chain")

    msg = cl.Message(content="")
    result = await chain.arun_pipeline(message.content)

    async for stream_resp in result["response"]:
        await msg.stream_token(stream_resp)

    await msg.send()