Spaces:
Sleeping
Sleeping
File size: 3,963 Bytes
01e95d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
from operator import itemgetter
from langchain_core.runnables.base import RunnableSerializable
import chainlit as cl
import aiofiles
from langchain_community.vectorstores import Qdrant
from langchain_core.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langgraph.errors import GraphRecursionError
import messages as msg
from parsers import ParsersMap
from splitters import TextSplitter
from models import EmbeddingLLM, MiniLLM
from prompts import RAG_PROMPT
from constructors import (
construct_research_graph,
construct_authoring_graph,
construct_super_graph,
construct_correctness_graph,
)
from utils import enter_chain
@cl.on_chat_start
async def init_chat():
files = None
# Wait for the user to upload a file
while files is None:
files = await cl.AskFileMessage(
content=msg.INIT_MSG,
accept=["text/plain", "application/pdf"],
).send()
file = files[0]
_, ext = os.path.splitext(file.name)
parser = ParsersMap[ext]
async with aiofiles.open(file.path, "rb") as f:
file_content = await f.read()
file_text = parser.load(file_content)
status_msg_text = """
Processing...please wait
[1/4] Read file: {read_status}
[2/4] Chunk file: {split_status}
[3/4] Load file to DB: {load_status}
[4/4] Build graph: {graph_status}
"""
status_msg = cl.Message(
status_msg_text.format(
read_status="π’", split_status="πΉ", load_status="πΉ", graph_status="πΉ"
)
)
await status_msg.send()
documents = TextSplitter.split_text(file_text)
status_msg.content = status_msg_text.format(
read_status="π’", split_status="π’", load_status="πΉ", graph_status="πΉ"
)
await status_msg.send()
db = Qdrant.from_texts(
documents,
EmbeddingLLM,
location=":memory:",
collection_name="multi-agent-chatbot",
)
status_msg.content = status_msg_text.format(
read_status="π’", split_status="π’", load_status="π’", graph_status="πΉ"
)
retriever = db.as_retriever()
await status_msg.update()
rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
chain = (
{
"context": itemgetter("question") | retriever,
"question": itemgetter("question"),
}
| rag_prompt
| MiniLLM
| StrOutputParser()
)
research_chain = (
enter_chain | construct_research_graph(research_chain=chain).compile()
)
authoring_chain = enter_chain | construct_authoring_graph().compile()
correctness_chain = enter_chain | construct_correctness_graph().compile()
super_chain = (
enter_chain
| construct_super_graph(
research_chain, authoring_chain, correctness_chain
).compile()
)
cl.user_session.set("chain", super_chain)
status_msg.content = status_msg_text.format(
read_status="π’", split_status="π’", load_status="π’", graph_status="π’"
)
await status_msg.update()
await cl.Message(msg.ASK_FOR_QUERY_MSG).send()
@cl.on_message
async def process_chat(message: cl.Message):
chain: RunnableSerializable = cl.user_session.get("chain")
try:
async for s in chain.astream(message.content, {"recursion_limit": 30}):
if "__end__" not in s:
await cl.Message(s).send()
except GraphRecursionError:
await cl.Message("Max recursion depth reached").send()
# response = cl.Message(content="")
# for token in result:
# await response.stream_token(token)
#
# await response.send()
# elements = [
# cl.File(
# name="hello.py",
# path="./hello.py",
# display="inline",
# ),
# ]
# await cl.Message(
# content="This message has a file element", elements=elements
# ).send()
|