Spaces:
Build error
Build error
| import os | |
| import asyncio | |
| from typing import List | |
| from chainlit.types import AskFileResponse | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import OpenAIEmbeddings | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.schema import SystemMessage, HumanMessage | |
| from PyPDF2 import PdfReader | |
| import chainlit as cl | |
| # Check if the API key is set | |
| if not os.getenv("OPENAI_API_KEY"): | |
| raise ValueError("OPENAI_API_KEY environment variable is not set") | |
| # Set up prompts | |
| system_template = "Use the following context to answer a user's question. If you cannot find the answer in the context, say you don't know the answer." | |
| system_message_prompt = SystemMessagePromptTemplate.from_template(system_template) | |
| human_template = "Context:\n{context}\n\nQuestion:\n{question}" | |
| human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) | |
| chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) | |
| class RetrievalAugmentedQAPipeline: | |
| def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None: | |
| self.llm = llm | |
| self.vector_db = vector_db | |
| async def arun_pipeline(self, user_query: str): | |
| context_docs = self.vector_db.similarity_search(user_query, k=2) | |
| context_list = [doc.page_content for doc in context_docs] | |
| context_prompt = "\n".join(context_list) | |
| max_context_length = 12000 | |
| if len(context_prompt) > max_context_length: | |
| context_prompt = context_prompt[:max_context_length] | |
| messages = chat_prompt.format_prompt(context=context_prompt, question=user_query).to_messages() | |
| async for chunk in self.llm.astream(messages): | |
| yield chunk.content | |
| def process_pdf(file: AskFileResponse) -> List[str]: | |
| pdf_reader = PdfReader(file.content) | |
| text = "\n".join([page.extract_text() for page in pdf_reader.pages]) | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40) | |
| return text_splitter.split_text(text) | |
| async def on_chat_start(): | |
| files = await cl.AskFileMessage( | |
| content="Please upload a PDF file to begin!", | |
| accept=["application/pdf"], | |
| max_size_mb=20, | |
| ).send() | |
| if not files: | |
| await cl.Message(content="No file was uploaded. Please try again.").send() | |
| return | |
| file = files[0] | |
| msg = cl.Message(content=f"Processing `{file.name}`...") | |
| await msg.send() | |
| texts = process_pdf(file) | |
| embeddings = OpenAIEmbeddings() | |
| vector_db = Chroma.from_texts(texts, embeddings) | |
| chat_openai = ChatOpenAI() | |
| retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai) | |
| cl.user_session.set("pipeline", retrieval_augmented_qa_pipeline) | |
| msg.content = f"Processing `{file.name}` done. You can now ask questions!" | |
| await msg.update() | |
| async def main(message: cl.Message): | |
| pipeline = cl.user_session.get("pipeline") | |
| if not pipeline: | |
| await cl.Message(content="Please upload a PDF file first.").send() | |
| return | |
| msg = cl.Message(content="") | |
| try: | |
| async for chunk in pipeline.arun_pipeline(message.content): | |
| await msg.stream_token(chunk) | |
| except Exception as e: | |
| await cl.Message(content=f"An error occurred: {str(e)}").send() | |
| return | |
| await msg.send() |