File size: 4,190 Bytes
2df6a19
4e2ea27
2df6a19
 
 
 
 
 
 
 
4e2ea27
2df6a19
 
 
 
 
 
 
 
 
 
 
 
4e2ea27
2df6a19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e2ea27
2df6a19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e2ea27
2df6a19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e2ea27
2df6a19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e2ea27
2df6a19
 
 
 
 
 
 
 
 
 
 
 
 
4e2ea27
2df6a19
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import gradio as gr
from pypdf import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.memory import ConversationBufferMemory

# Function to extract text from PDFs
def extract_text_from_pdf(pdf_file):
    try:
        reader = PdfReader(pdf_file)
        text = ""
        for page in reader.pages:
            extracted = page.extract_text()
            if extracted:
                text += extracted + "\n"
        return text
    except Exception as e:
        return f"Error reading PDF: {e}"

# Function to process PDFs and create vector store
def process_pdfs(pdf_files):
    documents = []
    for pdf_file in pdf_files:
        text = extract_text_from_pdf(pdf_file)
        if text and not text.startswith("Error"):
            documents.append(text)
    
    # Chunk documents
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=150,
        length_function=len
    )
    chunks = []
    for doc in documents:
        splits = text_splitter.split_text(doc)
        chunks.extend(splits)
    
    # Create embeddings and vector store
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vector_store = FAISS.from_texts(chunks, embeddings)
    return vector_store

# Initialize LLM
def initialize_llm():
    model_name = "google/flan-t5-base"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    pipe = pipeline(
        "text2text-generation",
        model=model,
        tokenizer=tokenizer,
        max_length=512,
        temperature=0.7,
        device=0 if torch.cuda.is_available() else -1
    )
    llm = HuggingFacePipeline(pipeline=pipe)
    return llm

# Create RAG chain
def create_rag_chain(vector_store, llm):
    prompt_template = """Use the following pieces of context to answer the question. If you don't know the answer, say so. Do not make up information.
{context}
Question: {question}
Answer: """
    prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        input_key="question",
        output_key="answer",
        max_len=4
    )
    chain = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=vector_store.as_retriever(search_kwargs={"k": 5}),
        return_source_documents=True,
        chain_type_kwargs={"prompt": prompt, "memory": memory}
    )
    return chain

# Gradio interface function
def rag_interface(pdf_files, question):
    if not pdf_files:
        return "Please upload at least one PDF file.", ""
    
    # Process PDFs and create vector store
    vector_store = process_pdfs(pdf_files)
    
    # Initialize LLM and RAG chain
    llm = initialize_llm()
    rag_chain = create_rag_chain(vector_store, llm)
    
    # Get answer
    result = rag_chain({"query": question})
    answer = result["result"]
    chat_history = rag_chain.combine_documents_chain.memory.chat_memory.messages
    
    # Format chat history
    history_text = ""
    for i in range(0, len(chat_history), 2):
        if i + 1 < len(chat_history):
            history_text += f"Q: {chat_history[i].content}\nA: {chat_history[i+1].content}\n\n"
    
    return answer, history_text

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# RAG Question Answering System")
    pdf_input = gr.File(label="Upload PDFs", file_count="multiple", file_types=[".pdf"])
    question_input = gr.Textbox(label="Ask a question")
    answer_output = gr.Textbox(label="Answer")
    history_output = gr.Textbox(label="Chat History")
    submit_button = gr.Button("Submit")
    submit_button.click(
        fn=rag_interface,
        inputs=[pdf_input, question_input],
        outputs=[answer_output, history_output]
    )

demo.launch(share=True)