Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from langchain.document_loaders import PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.vectorstores import FAISS | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| # Step 1: Load and Split Documents | |
| def load_documents(pdf_files): | |
| loaders = [PyPDFLoader(file.name) for file in pdf_files] | |
| docs = [] | |
| for loader in loaders: | |
| docs.extend(loader.load()) | |
| # Split documents into smaller chunks | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, | |
| chunk_overlap=50 | |
| ) | |
| return text_splitter.split_documents(docs) | |
| # Step 2: Create Vector Database | |
| def create_vector_db(splits): | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| vector_db = FAISS.from_documents(splits, embeddings) | |
| return vector_db | |
| # Step 3: Initialize Conversational Retrieval Chain | |
| def initialize_qa_chain(vector_db): | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| return_messages=True | |
| ) | |
| qa_chain = ConversationalRetrievalChain.from_chain_type( | |
| retriever=vector_db.as_retriever(), | |
| chain_type="stuff", | |
| memory=memory | |
| ) | |
| return qa_chain | |
| # Step 4: Handle Conversation | |
| def handle_conversation(qa_chain, query, history): | |
| result = qa_chain({"question": query, "chat_history": history}) | |
| response = result["answer"] | |
| history.append((query, response)) | |
| return history, history | |
| # Gradio UI | |
| def demo(): | |
| vector_db = gr.State() | |
| qa_chain = gr.State() | |
| with gr.Blocks() as interface: | |
| gr.Markdown("<h1><center>CPU-Friendly RAG Chatbot</center></h1>") | |
| with gr.Tab("Step 1: Upload PDFs"): | |
| pdf_files = gr.File(file_types=[".pdf"], label="Upload PDF Files", file_count="multiple") | |
| create_db_button = gr.Button("Create Vector Database") | |
| db_status = gr.Textbox(label="Database Status", value="Not created", interactive=False) | |
| with gr.Tab("Step 2: Chat"): | |
| chatbot = gr.Chatbot() | |
| query = gr.Textbox(label="Your Query") | |
| send_button = gr.Button("Ask") | |
| # Create database | |
| create_db_button.click( | |
| fn=lambda files: (create_vector_db(load_documents(files)), "Database created successfully!"), | |
| inputs=[pdf_files], | |
| outputs=[vector_db, db_status] | |
| ) | |
| # Initialize QA Chain | |
| create_db_button.click( | |
| fn=lambda db: initialize_qa_chain(db), | |
| inputs=[vector_db], | |
| outputs=[qa_chain] | |
| ) | |
| # Handle conversation | |
| send_button.click( | |
| fn=handle_conversation, | |
| inputs=[qa_chain, query, chatbot], | |
| outputs=[chatbot, chatbot] | |
| ) | |
| return interface | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo().launch() | |