RAG-Pdf-App / app.py
eik-1's picture
Update app.py
2df6a19 verified
import os
import gradio as gr
from pypdf import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.memory import ConversationBufferMemory
# Function to extract text from PDFs
def extract_text_from_pdf(pdf_file):
try:
reader = PdfReader(pdf_file)
text = ""
for page in reader.pages:
extracted = page.extract_text()
if extracted:
text += extracted + "\n"
return text
except Exception as e:
return f"Error reading PDF: {e}"
# Function to process PDFs and create vector store
def process_pdfs(pdf_files):
documents = []
for pdf_file in pdf_files:
text = extract_text_from_pdf(pdf_file)
if text and not text.startswith("Error"):
documents.append(text)
# Chunk documents
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=150,
length_function=len
)
chunks = []
for doc in documents:
splits = text_splitter.split_text(doc)
chunks.extend(splits)
# Create embeddings and vector store
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_store = FAISS.from_texts(chunks, embeddings)
return vector_store
# Initialize LLM
def initialize_llm():
model_name = "google/flan-t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
pipe = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
max_length=512,
temperature=0.7,
device=0 if torch.cuda.is_available() else -1
)
llm = HuggingFacePipeline(pipeline=pipe)
return llm
# Create RAG chain
def create_rag_chain(vector_store, llm):
prompt_template = """Use the following pieces of context to answer the question. If you don't know the answer, say so. Do not make up information.
{context}
Question: {question}
Answer: """
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
memory = ConversationBufferMemory(
memory_key="chat_history",
input_key="question",
output_key="answer",
max_len=4
)
chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vector_store.as_retriever(search_kwargs={"k": 5}),
return_source_documents=True,
chain_type_kwargs={"prompt": prompt, "memory": memory}
)
return chain
# Gradio interface function
def rag_interface(pdf_files, question):
if not pdf_files:
return "Please upload at least one PDF file.", ""
# Process PDFs and create vector store
vector_store = process_pdfs(pdf_files)
# Initialize LLM and RAG chain
llm = initialize_llm()
rag_chain = create_rag_chain(vector_store, llm)
# Get answer
result = rag_chain({"query": question})
answer = result["result"]
chat_history = rag_chain.combine_documents_chain.memory.chat_memory.messages
# Format chat history
history_text = ""
for i in range(0, len(chat_history), 2):
if i + 1 < len(chat_history):
history_text += f"Q: {chat_history[i].content}\nA: {chat_history[i+1].content}\n\n"
return answer, history_text
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# RAG Question Answering System")
pdf_input = gr.File(label="Upload PDFs", file_count="multiple", file_types=[".pdf"])
question_input = gr.Textbox(label="Ask a question")
answer_output = gr.Textbox(label="Answer")
history_output = gr.Textbox(label="Chat History")
submit_button = gr.Button("Submit")
submit_button.click(
fn=rag_interface,
inputs=[pdf_input, question_input],
outputs=[answer_output, history_output]
)
demo.launch(share=True)