Spaces:
Runtime error
Runtime error
File size: 3,950 Bytes
9b35070 8d1d51f 9b35070 01c5d61 9b35070 8b9c0c5 9b35070 8b9c0c5 9b35070 01c5d61 9b35070 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from typing import List
from pathlib import Path
from langchain_community.chat_models.huggingface import ChatHuggingFace
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.prompts import ChatPromptTemplate
from langchain.schema import StrOutputParser
from langchain_community.document_loaders import (
PyMuPDFLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores.chroma import Chroma
from langchain.indexes import SQLRecordManager, index
from langchain.schema import Document
from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig
from langchain.callbacks.base import BaseCallbackHandler
import chainlit as cl
chunk_size = 1024
chunk_overlap = 50
embeddings_model = HuggingFaceEmbeddings()
PDF_STORAGE_PATH = "./pdfs"
def process_pdfs(pdf_storage_path: str):
pdf_directory = Path(pdf_storage_path)
docs = [] # type: List[Document]
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
for pdf_path in pdf_directory.glob("*.pdf"):
loader = PyMuPDFLoader(str(pdf_path))
documents = loader.load()
docs += text_splitter.split_documents(documents)
doc_search = Chroma.from_documents(docs, embeddings_model)
namespace = "chromadb/my_documents"
record_manager = SQLRecordManager(
namespace, db_url="sqlite:///record_manager_cache.sql"
)
record_manager.create_schema()
index_result = index(
docs,
record_manager,
doc_search,
cleanup="incremental",
source_id_key="source",
)
print(f"Indexing stats: {index_result}")
return doc_search
doc_search = process_pdfs(PDF_STORAGE_PATH)
model = ChatHuggingFace(model_name="Mixtral-8x7b-Instruct", streaming=True)
@cl.on_chat_start
async def on_chat_start():
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
def format_docs(docs):
return "\n\n".join([d.page_content for d in docs])
retriever = doc_search.as_retriever()
runnable = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| model
| StrOutputParser()
)
cl.user_session.set("runnable", runnable)
@cl.on_message
async def on_message(message: cl.Message):
runnable = cl.user_session.get("runnable") # type: Runnable
msg = cl.Message(content="")
class PostMessageHandler(BaseCallbackHandler):
"""
Callback handler for handling the retriever and LLM processes.
Used to post the sources of the retrieved documents as a Chainlit element.
"""
def __init__(self, msg: cl.Message):
BaseCallbackHandler.__init__(self)
self.msg = msg
self.sources = set() # To store unique pairs
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
for d in documents:
source_page_pair = (d.metadata['source'], d.metadata['page'])
self.sources.add(source_page_pair) # Add unique pairs to the set
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
if len(self.sources):
sources_text = "\n".join([f"{source}#page={page}" for source, page in self.sources])
self.msg.elements.append(
cl.Text(name="Sources", content=sources_text, display="inline")
)
async with cl.Step(type="run", name="QA Assistant"):
async for chunk in runnable.astream(
message.content,
config=RunnableConfig(callbacks=[
cl.LangchainCallbackHandler(),
PostMessageHandler(msg)
]),
):
await msg.stream_token(chunk)
await msg.send()
|