Spaces:
Sleeping
Sleeping
| from langchain.chat_models import ChatOpenAI | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.schema import StrOutputParser | |
| from langchain.schema.runnable import Runnable | |
| from langchain.schema.runnable.config import RunnableConfig | |
| from langchain.memory import ChatMessageHistory, ConversationBufferMemory | |
| import os | |
| import chainlit as cl | |
| from chainlit.types import AskFileResponse | |
| from lxml import html | |
| from pydantic import BaseModel | |
| from typing import Any, Optional | |
| from unstructured.partition.pdf import partition_pdf | |
| from prompts import * | |
| import uuid | |
| from langchain.vectorstores import Chroma | |
| from langchain.storage import InMemoryStore | |
| from langchain.schema.document import Document | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from langchain.retrievers.multi_vector import MultiVectorRetriever | |
| from operator import itemgetter | |
| from langchain.schema.runnable import RunnablePassthrough | |
| welcome_message = """Welcome to the Semi-Structured PDF QA! To get started: | |
| 1. Upload a PDF or text file | |
| 2. Ask a question about the file | |
| 3. (Optional) Ask a question from any Table in the PDF | |
| Note: The PDF loading takes time because it uses `unstructured` to detect tables | |
| and create summaries. Please be patient. The chatbot uses `gpt-4` | |
| """ | |
| class Element(BaseModel): | |
| type: str | |
| text: Any | |
| def get_elements(raw_pdf_elements): | |
| # Categorize by type | |
| categorized_elements = [] | |
| for element in raw_pdf_elements: | |
| if "unstructured.documents.elements.Table" in str(type(element)): | |
| categorized_elements.append(Element(type="table", text=str(element))) | |
| elif "unstructured.documents.elements.CompositeElement" in str(type(element)): | |
| categorized_elements.append(Element(type="text", text=str(element))) | |
| # Tables | |
| table_elements = [e for e in categorized_elements if e.type == "table"] | |
| print(len(table_elements)) | |
| # Text | |
| text_elements = [e for e in categorized_elements if e.type == "text"] | |
| print(len(text_elements)) | |
| return table_elements, text_elements | |
| def process_docs(file: AskFileResponse): | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(mode="wb", delete=False) as tempfile: | |
| if file.type == "text/plain": | |
| tempfile.write(file.content) | |
| elif file.type == "application/pdf": | |
| with open(tempfile.name, "wb") as f: | |
| f.write(file.content) | |
| raw_pdf_elements = partition_pdf(filename=tempfile.name, | |
| # Unstructured first finds embedded image blocks | |
| extract_images_in_pdf=False, | |
| # Use layout model (YOLOX) to get bounding boxes (for tables) and find titles | |
| # Titles are any sub-section of the document | |
| infer_table_structure=True, | |
| # Post processing to aggregate text once we have the title | |
| chunking_strategy="by_title", | |
| # Chunking params to aggregate text blocks | |
| # Attempt to create a new chunk 3800 chars | |
| # Attempt to keep chunks > 2000 chars | |
| max_characters=4000, | |
| new_after_n_chars=3800, | |
| combine_text_under_n_chars=2000) | |
| table_elements, text_elements = get_elements(raw_pdf_elements) | |
| return table_elements, text_elements | |
| async def on_chat_start(): | |
| await cl.Avatar( | |
| name="QA Chatbot", | |
| url="https://avatars.githubusercontent.com/u/128686189?s=400&u=a1d1553023f8ea0921fba0debbe92a8c5f840dd9&v=4", | |
| ).send() | |
| await cl.Avatar( | |
| name="User", | |
| path="icon/avatar.png", | |
| ).send() | |
| files = None | |
| while files is None: | |
| files = await cl.AskFileMessage( | |
| content=welcome_message, | |
| accept=["text/plain", "application/pdf"], | |
| max_size_mb=20, | |
| timeout=180, | |
| disable_human_feedback=True, | |
| ).send() | |
| file = files[0] | |
| msg = cl.Message( | |
| content=f"Processing `{file.name}`...Please wait", disable_human_feedback=True | |
| ) | |
| await msg.send() | |
| table_elements, text_elements = await cl.make_async(process_docs)(file) | |
| message_history = ChatMessageHistory() | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| output_key="answer", | |
| chat_memory=message_history, | |
| return_messages=True, | |
| ) | |
| model = ChatOpenAI(streaming=True, | |
| temperature=0, | |
| model="gpt-4", | |
| openai_api_key=os.getenv("OPENAI_API_KEY")) | |
| prompt = ChatPromptTemplate.from_template(TABLE_TEXT_SUMMARY_PROMPT) | |
| summarize_chain = {"element": lambda x:x} | prompt | model | StrOutputParser() | |
| # Apply to tables | |
| tables = [i.text for i in table_elements] | |
| table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5}) | |
| # Apply to texts | |
| texts = [i.text for i in text_elements] | |
| text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5}) | |
| vectorstore = Chroma(persist_directory="./chroma_db", | |
| collection_name="summaries", | |
| embedding_function=OpenAIEmbeddings() | |
| ) | |
| # The storage layer for the parent documents | |
| store = InMemoryStore() | |
| id_key = "doc_id" | |
| # The retriever (empty to start) | |
| retriever = MultiVectorRetriever( | |
| vectorstore=vectorstore, | |
| docstore=store, | |
| id_key=id_key, | |
| ) | |
| # Add texts | |
| doc_ids = [str(uuid.uuid4()) for _ in texts] | |
| summary_texts = [Document(page_content=s,metadata={id_key: doc_ids[i]}) for i, s in enumerate(text_summaries)] | |
| retriever.vectorstore.add_documents(summary_texts) | |
| retriever.docstore.mset(list(zip(doc_ids, texts))) | |
| # Add tables | |
| table_ids = [str(uuid.uuid4()) for _ in tables] | |
| summary_tables = [Document(page_content=s,metadata={id_key: table_ids[i]}) for i, s in enumerate(table_summaries)] | |
| retriever.vectorstore.add_documents(summary_tables) | |
| retriever.docstore.mset(list(zip(table_ids, tables))) | |
| msg.content = f"`{file.name}` processed. You can now ask questions!" | |
| await msg.update() | |
| # Prompt template | |
| prompt = ChatPromptTemplate.from_template(QA_PROMPT) | |
| runnable = ( | |
| {"context": retriever, "question": RunnablePassthrough()} | |
| | prompt | |
| | model | |
| | StrOutputParser() | |
| ) | |
| cl.user_session.set("runnable", runnable) | |
| async def on_message(message: cl.Message): | |
| runnable = cl.user_session.get("runnable") # type: Runnable | |
| msg = cl.Message(content="") | |
| async for chunk in runnable.astream( | |
| message.content, | |
| config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]), | |
| ): | |
| await msg.stream_token(chunk) | |
| await msg.send() |