Spaces:
Runtime error
Runtime error
| import os | |
| from pathlib import Path | |
| from typing import List | |
| import chainlit as cl | |
| import chainlit.data as cl_data | |
| from langchain.callbacks.base import BaseCallbackHandler | |
| from langchain.indexes import SQLRecordManager, index | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.schema import Document | |
| from langchain.schema import StrOutputParser | |
| from langchain.schema.runnable import Runnable, RunnablePassthrough, RunnableConfig | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import ( | |
| PyPDFDirectoryLoader, | |
| ) | |
| from langchain_community.vectorstores import Chroma | |
| # from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| from langchain_groq import ChatGroq | |
| from langchain_huggingface import HuggingFaceEndpointEmbeddings | |
| from feedback import CustomDataLayer | |
| from rag_bot import RagBot | |
| chunk_size = 1024 | |
| chunk_overlap = 50 | |
| embeddings_model = HuggingFaceEndpointEmbeddings( | |
| huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"), | |
| model="sentence-transformers/all-MiniLM-L12-v2", | |
| ) | |
| # Feedback | |
| cl_data._data_layer = CustomDataLayer() | |
| PDF_STORAGE_PATH = "./data" | |
| def process_pdfs(pdf_storage_path: str): | |
| pdf_directory = Path(pdf_storage_path) | |
| docs = [] # type: List[Document] | |
| # text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
| loader = PyPDFDirectoryLoader(pdf_directory) | |
| documents = loader.load() | |
| recursive_text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap, | |
| length_function=len, | |
| is_separator_regex=False, | |
| ) | |
| docs = recursive_text_splitter.split_documents(documents) | |
| if not docs: | |
| raise ValueError("No documents found in the specified directory.") | |
| doc_search = Chroma.from_documents(docs, embeddings_model) | |
| namespace = "chromadb/my_documents" | |
| record_manager = SQLRecordManager( | |
| namespace, db_url="sqlite:///record_manager_cache.sql" | |
| ) | |
| record_manager.create_schema() | |
| index_result = index( | |
| docs, | |
| record_manager, | |
| doc_search, | |
| cleanup="full", | |
| source_id_key="source", | |
| ) | |
| print(f"Indexing stats: {index_result}") | |
| return doc_search | |
| doc_search = process_pdfs(PDF_STORAGE_PATH) | |
| # model = ChatOpenAI(model_name="gpt-4", streaming=True) | |
| model = ChatGroq( | |
| model='llama-3.1-70b-versatile', | |
| temperature=0, | |
| max_tokens=1024, | |
| timeout=None, | |
| max_retries=5, | |
| api_key=os.getenv("GROQ_API_KEY"), | |
| # other params... | |
| ) | |
| async def on_chat_start(): | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", | |
| """You are a helpful assistant that can answer questions about technical documents in any language. | |
| Keep your answers only in the language of the question(s). | |
| Only use the factual information from the document(s) to answer the question(s). Keep your answers concise and to the point. | |
| If you do not have have sufficient information to answer a question, politely refuse to answer and say "I don't know". | |
| \n\nRelevant documents will be retrieved below.""" | |
| "Context: {context}" | |
| ), | |
| ("human", "{question}"), | |
| ] | |
| ) | |
| def format_docs(docs): | |
| return "\n\n".join([d.page_content for d in docs]) | |
| retriever = doc_search.as_retriever(search_kwargs={"k": 5}) | |
| runnable = ( | |
| {"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| | prompt | |
| | model | |
| | StrOutputParser() | |
| ) | |
| cl.user_session.set("runnable", runnable) | |
| async def on_message(message: cl.Message): | |
| runnable = cl.user_session.get("runnable") # type: Runnable | |
| msg = cl.Message(content="") | |
| class PostMessageHandler(BaseCallbackHandler): | |
| """ | |
| Callback handler for handling the retriever and LLM processes. | |
| Used to post the sources of the retrieved documents as a Chainlit element. | |
| """ | |
| def __init__(self, msg: cl.Message): | |
| BaseCallbackHandler.__init__(self) | |
| self.msg = msg | |
| self.sources = [] # To store unique pairs | |
| def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs): | |
| for doc in documents: | |
| source = doc.metadata.get('source', 'Unknown Source') | |
| page = doc.metadata.get('page', 'N/A') | |
| page_content = doc.page_content | |
| # self.sources.add(source_page_pair) # Add unique pairs to the set | |
| if not any(s["source"] == source and s["page"] == page for s in self.sources): | |
| self.sources.append({ | |
| "source": source, | |
| "page": page, | |
| "content": page_content | |
| }) | |
| def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs): | |
| if len(self.sources): | |
| # Create a list of clickable elements for sources | |
| text_elements = [] | |
| source_references = [] | |
| for idx, src in enumerate(self.sources): | |
| source_name = f"{src['source']} p.{src['page']}" | |
| source_references.append(source_name) | |
| # Add a previewable Chainlit element | |
| text_elements.append( | |
| cl.Text( | |
| name=source_name, | |
| content=src["content"], | |
| display="side", | |
| ) | |
| ) | |
| # Generate the answer with clickable source names | |
| self.msg.content += f"\n\nSources: {", ".join( | |
| source_references | |
| )}" | |
| # Append text elements to the message | |
| self.msg.elements.extend(text_elements) | |
| async for chunk in runnable.astream( | |
| message.content, | |
| config=RunnableConfig(callbacks=[ | |
| cl.LangchainCallbackHandler(), | |
| PostMessageHandler(msg) | |
| ]), | |
| ): | |
| await msg.stream_token(chunk) | |
| await msg.send() | |