File size: 5,762 Bytes
b27cc0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import gradio as gr
from langchain_community.vectorstores import Chroma
from langchain_community.chat_models import ChatOllama
from langchain_community.embeddings import FastEmbedEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema.output_parser import StrOutputParser
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

# Initialize embeddings model and vector store
embeddings_model = FastEmbedEmbeddings(model_name="BAAI/bge-small-en-v1.5")
vector_store = None

# Chat history (initialize with an empty list)
chat_history = []

# Store previous questions and their embeddings
question_embeddings = []

# Prompt templates for LLM
prompt_with_context_template = """Analyze the following context and answer the question based only on the following context:
{context}

Question: {question}
"""
prompt_without_context_template = """Provide an answer to the question based on general knowledge.
Question: {question}
"""
prompt_with_context = PromptTemplate.from_template(prompt_with_context_template)
prompt_without_context = PromptTemplate.from_template(prompt_without_context_template)

# Function to load, split PDFs, and store in vector store
def process_documents(uploaded_files):
    global vector_store
    all_docs = []
    for uploaded_file in uploaded_files:
        # Load each PDF using PyPDFLoader
        loader = PyPDFLoader(uploaded_file)
        pages = loader.load_and_split()

        # Split documents into chunks
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        docs = text_splitter.split_documents(pages)
        all_docs.extend(docs)

    # Create or update the vector store
    if vector_store is None:
        vector_store = Chroma.from_documents(all_docs, embeddings_model)
    else:
        vector_store.add_documents(all_docs)
    
    return f"Uploaded {len(uploaded_files)} files and indexed {len(all_docs)} chunks."

# Function to handle question answering with RAG and maintain chat history
def answer_question(question):
    global vector_store, chat_history, question_embeddings
    
    # Set up retriever and LLM
    retriever = vector_store.as_retriever() if vector_store else None
    llm = ChatOllama(model="llama3:latest", verbose=True)

    if retriever:
        # Define the RAG chain with document context
        chain = (
            {"context": retriever, "question": RunnablePassthrough()}
            | prompt_with_context
            | llm
            | StrOutputParser()
        )
        # Process user question through RAG chain with context
        answer = chain.invoke(question).capitalize()
    else:
        # Define the RAG chain without document context
        chain = (
            {"question": RunnablePassthrough()}
            | prompt_without_context
            | llm
            | StrOutputParser()
        )
        # Process user question through RAG chain without context
        answer = chain.invoke(question).capitalize()

    # Append the question and answer to the chat history
    chat_history.append((f"Q: {question}", f"A: {answer}"))
    
    # Encode the current question and store its embedding
    current_question_embedding = embeddings_model.embed_query(question)
    question_embeddings.append(current_question_embedding)
    
    # Find related questions
    related_question = "No related questions found."
    if question_embeddings:
        # Compute similarity between current question and previous questions
        similarities = cosine_similarity([current_question_embedding], question_embeddings)
        related_idx = np.argmax(similarities)
        if similarities[0][related_idx] > 0.5:
            related_question = chat_history[related_idx][0]
    
    # Format the chat history for display
    chat_display = "\n\n".join([f"{q}\n{a}" for q, a in chat_history])
    
    return answer, chat_display, related_question

# Function to clear the vector store
def clear_documents():
    global vector_store
    if vector_store is not None:
        vector_store.delete_collection()
        vector_store = None
    return "Document collection cleared.", chat_history, ""

# Gradio interface
with gr.Blocks() as demo:
    # Main layout with two columns
    with gr.Row():
        # Left column for file upload and question input
        with gr.Column(scale=1):
            file_uploader = gr.File(label="Upload PDFs", file_types=[".pdf"], file_count="multiple", type="filepath")
            upload_button = gr.Button("Upload and Process")
            clear_button = gr.Button("Clear Document Collection")
            status_display = gr.Textbox(label="Status", lines=2)

            question_input = gr.Textbox(label="Ask a question about the documents")
            ask_button = gr.Button("Ask")
        
        # Center column for answer and chat history
        with gr.Column(scale=2):
            answer_display = gr.Textbox(label="Answer", lines=4)
            chat_history_display = gr.Textbox(label="Chat History", lines=10, interactive=False)
            related_question_display = gr.Textbox(label="Related Question", lines=4, interactive=False)

    # Link buttons to functions
    upload_button.click(process_documents, inputs=[file_uploader], outputs=[status_display])
    ask_button.click(answer_question, inputs=[question_input], outputs=[answer_display, chat_history_display, related_question_display])
    clear_button.click(clear_documents, outputs=[status_display, chat_history_display, related_question_display])

# Launch the app
demo.launch(inline=False)