File size: 4,319 Bytes
2a5403f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
from typing import List

from langchain_chroma import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_groq import ChatGroq

from langchain_community.document_loaders import PyPDFLoader
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.prompts import PromptTemplate
from langchain import hub

import chainlit as cl
from io import BytesIO

##################################### Load the embeddings and model #####################################

groq_api_key = os.getenv("GROQ_API_KEY")
embeddings_api_key = os.getenv('GOOGLE_API_KEY')

embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
llm = ChatGroq(model="mixtral-8x7b-32768", temperature=0)




##################################### on_chat_start event handler #######################################

@cl.on_chat_start
async def on_chat_start():
    files = None

    while files is None:
        files = await cl.AskFileMessage(
            content="Please upload a text file to begin",
            accept=["application/pdf"],
            max_size_mb=20,
            timeout=300
        ).send()

    file = files[0]
    msg = cl.Message(content=f"Processing `{file.name}` ...")
    await msg.send()

    ##################################### Load the text from the file ####################################

    pdf_loader = PyPDFLoader(file.path).load()

    ##################################### Split the text into chunks #####################################

    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    chunks = text_splitter.split_documents(pdf_loader)

    ##################################### Chroma DB setup ################################################

    docsearch = await cl.make_async(Chroma.from_documents)(
        chunks, embedding_model
    )

    message_history = ChatMessageHistory()

    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key="answer",
        chat_memory=message_history,
        return_messages=True
    )

    ##################################### Chain setup ###################################################

    # Define your custom prompt template
    custom_prompt_template = """

    Based on the provided context please answer . if you don't know the answer. just say i don't know.

    {context}



    Question: {question}

    """
    custom_prompt = PromptTemplate(
    template=custom_prompt_template,
    input_variables=["context", "question"],)

    chain = ConversationalRetrievalChain.from_llm(
        llm,
        chain_type="stuff",
        retriever=docsearch.as_retriever(),
        memory=memory,
        return_source_documents=True,
        combine_docs_chain_kwargs={"prompt": custom_prompt}
        
    )

    msg.content = f"Processing `{file.name}` ... Done!✅ You can ask questions now!"
    await msg.update()

    cl.user_session.set("chain", chain)


##################################### On message event handler ###########################################

@cl.on_message
async def main(message: cl.Message):
    chain = cl.user_session.get("chain")
    cb = cl.AsyncLangchainCallbackHandler()

    res = await chain.acall(message.content, callbacks=[cb])
    answer = res['answer']

    source_documents = res["source_documents"]  # type: List[Document]

    text_elements = []  # type: List[cl.Text]

    if source_documents:
        for source_idx, source_doc in enumerate(source_documents):
            source_name = f"source_{source_idx}"
            # Create the text element referenced in the message
            text_elements.append(
                cl.Text(content=source_doc.page_content, name=source_name, display="side")
            )
        source_names = [text_el.name for text_el in text_elements]

        if source_names:
            answer += f"\nSources: {', '.join(source_names)}"
        else:
            answer += "\nNo sources found"

    await cl.Message(content=answer, elements=text_elements).send()