File size: 3,038 Bytes
424a2e7
84039f2
 
 
 
 
 
424a2e7
84039f2
 
 
 
 
 
 
 
 
 
 
 
424a2e7
84039f2
 
 
 
 
424a2e7
84039f2
 
 
 
 
 
 
 
 
 
 
 
424a2e7
84039f2
 
 
 
 
 
424a2e7
84039f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
424a2e7
84039f2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

# Step 1: Load and Split Documents
def load_documents(pdf_files):
    loaders = [PyPDFLoader(file.name) for file in pdf_files]
    docs = []
    for loader in loaders:
        docs.extend(loader.load())
    # Split documents into smaller chunks
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,
        chunk_overlap=50
    )
    return text_splitter.split_documents(docs)

# Step 2: Create Vector Database
def create_vector_db(splits):
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vector_db = FAISS.from_documents(splits, embeddings)
    return vector_db

# Step 3: Initialize Conversational Retrieval Chain
def initialize_qa_chain(vector_db):
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        return_messages=True
    )
    qa_chain = ConversationalRetrievalChain.from_chain_type(
        retriever=vector_db.as_retriever(),
        chain_type="stuff",
        memory=memory
    )
    return qa_chain

# Step 4: Handle Conversation
def handle_conversation(qa_chain, query, history):
    result = qa_chain({"question": query, "chat_history": history})
    response = result["answer"]
    history.append((query, response))
    return history, history

# Gradio UI
def demo():
    vector_db = gr.State()
    qa_chain = gr.State()
    
    with gr.Blocks() as interface:
        gr.Markdown("<h1><center>CPU-Friendly RAG Chatbot</center></h1>")
        
        with gr.Tab("Step 1: Upload PDFs"):
            pdf_files = gr.File(file_types=[".pdf"], label="Upload PDF Files", file_count="multiple")
            create_db_button = gr.Button("Create Vector Database")
            db_status = gr.Textbox(label="Database Status", value="Not created", interactive=False)
        
        with gr.Tab("Step 2: Chat"):
            chatbot = gr.Chatbot()
            query = gr.Textbox(label="Your Query")
            send_button = gr.Button("Ask")
        
        # Create database
        create_db_button.click(
            fn=lambda files: (create_vector_db(load_documents(files)), "Database created successfully!"),
            inputs=[pdf_files],
            outputs=[vector_db, db_status]
        )
        
        # Initialize QA Chain
        create_db_button.click(
            fn=lambda db: initialize_qa_chain(db),
            inputs=[vector_db],
            outputs=[qa_chain]
        )
        
        # Handle conversation
        send_button.click(
            fn=handle_conversation,
            inputs=[qa_chain, query, chatbot],
            outputs=[chatbot, chatbot]
        )
    
    return interface

# Launch the app
if __name__ == "__main__":
    demo().launch()