File size: 3,963 Bytes
01e95d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
from operator import itemgetter
from langchain_core.runnables.base import RunnableSerializable
import chainlit as cl
import aiofiles
from langchain_community.vectorstores import Qdrant
from langchain_core.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langgraph.errors import GraphRecursionError

import messages as msg
from parsers import ParsersMap
from splitters import TextSplitter
from models import EmbeddingLLM, MiniLLM
from prompts import RAG_PROMPT
from constructors import (
    construct_research_graph,
    construct_authoring_graph,
    construct_super_graph,
    construct_correctness_graph,
)
from utils import enter_chain


@cl.on_chat_start
async def init_chat():
    files = None

    # Wait for the user to upload a file
    while files is None:
        files = await cl.AskFileMessage(
            content=msg.INIT_MSG,
            accept=["text/plain", "application/pdf"],
        ).send()

    file = files[0]
    _, ext = os.path.splitext(file.name)
    parser = ParsersMap[ext]

    async with aiofiles.open(file.path, "rb") as f:
        file_content = await f.read()

    file_text = parser.load(file_content)
    status_msg_text = """
    Processing...please wait
    [1/4] Read file: {read_status}
    [2/4] Chunk file: {split_status}
    [3/4] Load file to DB: {load_status}
    [4/4] Build graph: {graph_status}
    """
    status_msg = cl.Message(
        status_msg_text.format(
            read_status="🟒", split_status="πŸ”Ή", load_status="πŸ”Ή", graph_status="πŸ”Ή"
        )
    )
    await status_msg.send()

    documents = TextSplitter.split_text(file_text)
    status_msg.content = status_msg_text.format(
        read_status="🟒", split_status="🟒", load_status="πŸ”Ή", graph_status="πŸ”Ή"
    )
    await status_msg.send()

    db = Qdrant.from_texts(
        documents,
        EmbeddingLLM,
        location=":memory:",
        collection_name="multi-agent-chatbot",
    )
    status_msg.content = status_msg_text.format(
        read_status="🟒", split_status="🟒", load_status="🟒", graph_status="πŸ”Ή"
    )
    retriever = db.as_retriever()
    await status_msg.update()

    rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
    chain = (
        {
            "context": itemgetter("question") | retriever,
            "question": itemgetter("question"),
        }
        | rag_prompt
        | MiniLLM
        | StrOutputParser()
    )
    research_chain = (
        enter_chain | construct_research_graph(research_chain=chain).compile()
    )
    authoring_chain = enter_chain | construct_authoring_graph().compile()
    correctness_chain = enter_chain | construct_correctness_graph().compile()
    super_chain = (
        enter_chain
        | construct_super_graph(
            research_chain, authoring_chain, correctness_chain
        ).compile()
    )

    cl.user_session.set("chain", super_chain)
    status_msg.content = status_msg_text.format(
        read_status="🟒", split_status="🟒", load_status="🟒", graph_status="🟒"
    )
    await status_msg.update()

    await cl.Message(msg.ASK_FOR_QUERY_MSG).send()


@cl.on_message
async def process_chat(message: cl.Message):
    chain: RunnableSerializable = cl.user_session.get("chain")
    try:
        async for s in chain.astream(message.content, {"recursion_limit": 30}):
            if "__end__" not in s:
                await cl.Message(s).send()
    except GraphRecursionError:
        await cl.Message("Max recursion depth reached").send()
    # response = cl.Message(content="")
    # for token in result:
    #     await response.stream_token(token)
    #
    # await response.send()
    # elements = [
    #     cl.File(
    #         name="hello.py",
    #         path="./hello.py",
    #         display="inline",
    #     ),
    # ]
    # await cl.Message(
    #     content="This message has a file element", elements=elements
    # ).send()