rafaaa2105 commited on
Commit
9b35070
·
verified ·
1 Parent(s): c5392bb

Upload 4 files

Browse files
Files changed (2) hide show
  1. app.py +121 -5
  2. requirements.txt +7 -0
app.py CHANGED
@@ -1,9 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import chainlit as cl
2
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  @cl.on_message
5
- async def on_message(msg: cl.Message):
6
- if cl.context.session.client_type == "copilot":
7
- fn = cl.CopilotFunction(name="test", args={"msg": msg.content})
8
- res = await fn.acall()
9
- await cl.Message(content=res).send()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from pathlib import Path
3
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
4
+ from langchain.prompts import ChatPromptTemplate
5
+ from langchain.schema import StrOutputParser
6
+ from langchain_community.document_loaders import (
7
+ PyMuPDFLoader,
8
+ )
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain.vectorstores.chroma import Chroma
11
+ from langchain.indexes import SQLRecordManager, index
12
+ from langchain.schema import Document
13
+ from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig
14
+ from langchain.callbacks.base import BaseCallbackHandler
15
+
16
  import chainlit as cl
17
 
18
 
19
+ chunk_size = 1024
20
+ chunk_overlap = 50
21
+
22
+ embeddings_model = OpenAIEmbeddings()
23
+
24
+ PDF_STORAGE_PATH = "./pdfs"
25
+
26
+
27
+ def process_pdfs(pdf_storage_path: str):
28
+ pdf_directory = Path(pdf_storage_path)
29
+ docs = [] # type: List[Document]
30
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
31
+
32
+ for pdf_path in pdf_directory.glob("*.pdf"):
33
+ loader = PyMuPDFLoader(str(pdf_path))
34
+ documents = loader.load()
35
+ docs += text_splitter.split_documents(documents)
36
+
37
+ doc_search = Chroma.from_documents(docs, embeddings_model)
38
+
39
+ namespace = "chromadb/my_documents"
40
+ record_manager = SQLRecordManager(
41
+ namespace, db_url="sqlite:///record_manager_cache.sql"
42
+ )
43
+ record_manager.create_schema()
44
+
45
+ index_result = index(
46
+ docs,
47
+ record_manager,
48
+ doc_search,
49
+ cleanup="incremental",
50
+ source_id_key="source",
51
+ )
52
+
53
+ print(f"Indexing stats: {index_result}")
54
+
55
+ return doc_search
56
+
57
+
58
+ doc_search = process_pdfs(PDF_STORAGE_PATH)
59
+ model = ChatOpenAI(model_name="gpt-4", streaming=True)
60
+
61
+
62
+ @cl.on_chat_start
63
+ async def on_chat_start():
64
+ template = """Answer the question based only on the following context:
65
+
66
+ {context}
67
+
68
+ Question: {question}
69
+ """
70
+ prompt = ChatPromptTemplate.from_template(template)
71
+
72
+ def format_docs(docs):
73
+ return "\n\n".join([d.page_content for d in docs])
74
+
75
+ retriever = doc_search.as_retriever()
76
+
77
+ runnable = (
78
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
79
+ | prompt
80
+ | model
81
+ | StrOutputParser()
82
+ )
83
+
84
+ cl.user_session.set("runnable", runnable)
85
+
86
+
87
  @cl.on_message
88
+ async def on_message(message: cl.Message):
89
+ runnable = cl.user_session.get("runnable") # type: Runnable
90
+ msg = cl.Message(content="")
91
+
92
+ class PostMessageHandler(BaseCallbackHandler):
93
+ """
94
+ Callback handler for handling the retriever and LLM processes.
95
+ Used to post the sources of the retrieved documents as a Chainlit element.
96
+ """
97
+
98
+ def __init__(self, msg: cl.Message):
99
+ BaseCallbackHandler.__init__(self)
100
+ self.msg = msg
101
+ self.sources = set() # To store unique pairs
102
+
103
+ def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
104
+ for d in documents:
105
+ source_page_pair = (d.metadata['source'], d.metadata['page'])
106
+ self.sources.add(source_page_pair) # Add unique pairs to the set
107
+
108
+ def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
109
+ if len(self.sources):
110
+ sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources])
111
+ self.msg.elements.append(
112
+ cl.Text(name="Sources", content=sources_text, display="inline")
113
+ )
114
+
115
+ async with cl.Step(type="run", name="QA Assistant"):
116
+ async for chunk in runnable.astream(
117
+ message.content,
118
+ config=RunnableConfig(callbacks=[
119
+ cl.LangchainCallbackHandler(),
120
+ PostMessageHandler(msg)
121
+ ]),
122
+ ):
123
+ await msg.stream_token(chunk)
124
+
125
+ await msg.send()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ langchain
2
+ chainlit
3
+ langchain_openai
4
+ openai
5
+ chromadb
6
+ tiktoken
7
+ pymupdf