DrElaheJ's picture
Update app.py
2bb3083 verified
import gradio as gr
import PyPDF2
from sentence_transformers import SentenceTransformer
from transformers import pipeline
import numpy as np
import faiss
import pickle
import os
import re
class SimpleRAG:
def __init__(self):
# Initialize models
print("Loading models...")
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
self.qa_pipeline = pipeline(
"text2text-generation",
model="google/flan-t5-base",
max_length=512,
temperature=0.7
)
# Storage for documents and vector database
self.documents = []
self.vector_db = None # FAISS index
self.embedding_dimension = 384 # all-MiniLM-L6-v2 dimension
self.is_ready = False
# Create directory for persistent storage
self.db_path = "vector_db"
os.makedirs(self.db_path, exist_ok=True)
print("Models loaded successfully!")
#-----------------------------------
#Extract text from uploaded PDF file
def extract_text_from_pdf(self, pdf_file):
reader = PyPDF2.PdfReader(pdf_file)
text = ""
for page in reader.pages:
text += page.extract_text() + "\n"
return text
#-----------------------------------
#Chunk text into smaller overlapping chunks
def chunk_text(self, text, chunk_size=500, overlap=50):
# Clean the text
text = re.sub(r'\s+', ' ', text).strip()
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = ' '.join(words[i:i + chunk_size])
if len(chunk.strip()) > 0:
chunks.append(chunk.strip())
return chunks
#-----------------------------------
#Create FAISS vector database from embeddings
#FAISS: Facebook AI Similarity Search
def create_vector_database(self, embeddings):
# Initialize FAISS index (using Inner Product for cosine similarity)
# Normalize embeddings for cosine similarity
embeddings_normalized = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
# Create FAISS index
self.vector_db = faiss.IndexFlatIP(self.embedding_dimension)
self.vector_db.add(embeddings_normalized.astype('float32'))
print(f"Vector database created with {self.vector_db.ntotal} vectors")
#-----------------------------------
#Save vector database and documents to disk
def save_vector_database(self, filename="vector_db"):
# Save FAISS index
faiss.write_index(self.vector_db, f"{self.db_path}/{filename}.index")
# Save documents
with open(f"{self.db_path}/{filename}_docs.pkl", 'wb') as f:
pickle.dump(self.documents, f)
print("Vector database saved to disk!")
#-----------------------------------
#Load vector database and documents from disk
def load_vector_database(self, filename="vector_db"):
index_path = f"{self.db_path}/{filename}.index"
docs_path = f"{self.db_path}/{filename}_docs.pkl"
if os.path.exists(index_path) and os.path.exists(docs_path):
# Load FAISS index
self.vector_db = faiss.read_index(index_path)
# Load documents
with open(docs_path, 'rb') as f:
self.documents = pickle.load(f)
self.is_ready = True
print(f"📂 Vector database loaded: {len(self.documents)} documents")
return True
return False
#-----------------------------------
#Process PDF and create vector database
def process_pdf(self, pdf_file):
if pdf_file is None:
return "Please upload a PDF file first."
# Extract text
text = self.extract_text_from_pdf(pdf_file)
if text.startswith("Error"):
return text
# Chunk the text
self.documents = self.chunk_text(text)
if not self.documents:
return "No text could be extracted from the PDF."
# Create embeddings
print(f"Creating embeddings for {len(self.documents)} chunks...")
embeddings = self.embedding_model.encode(self.documents)
# Create vector database
self.create_vector_database(embeddings)
# Save to disk
self.save_vector_database()
self.is_ready = True
return f"PDF processed successfully!"
#-----------------------------------
#Retrieve most relevant document chunks using FAISS vector database
def retrieve_relevant_docs(self, query, top_k=3):
if not self.is_ready or self.vector_db is None:
return []
# Encode and normalize the query
query_embedding = self.embedding_model.encode([query])
query_normalized = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True)
# Search in vector database
scores, indices = self.vector_db.search(query_normalized.astype('float32'), top_k)
relevant_docs = []
for i, (score, idx) in enumerate(zip(scores[0], indices[0])):
if idx < len(self.documents): # Validate index
relevant_docs.append({
'text': self.documents[idx],
'score': float(score),
'rank': i + 1,
'doc_id': int(idx)
})
return relevant_docs
#-----------------------------------
#Generate answer using the QA model
def generate_answer(self, query, context):
# Create a prompt for the model
prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
# Generate answer
result = self.qa_pipeline(prompt)
return result[0]['generated_text']
#-----------------------------------
#Main chat function that handles the RAG pipeline
def chat(self, message, history):
if not self.is_ready:
return history + [[message, "Please upload and process a PDF file first."]]
if not message.strip():
return history + [[message, "Please enter a question."]]
# STEP 1: RETRIEVE - Find relevant documents using vector database
relevant_docs = self.retrieve_relevant_docs(message, top_k=3)
if not relevant_docs:
response = "I couldn't find relevant information in the document to answer your question."
else:
print(f"Retrieved {len(relevant_docs)} relevant chunks")
# STEP 2: AUGMENT - Combine relevant documents as context
context = "\n\n".join([doc['text'] for doc in relevant_docs])
# Limit context length to avoid model limits
if len(context) > 2000:
context = context[:2000] + "..."
# STEP 3: GENERATE - Create answer using retrieved context
print("Generating answer...")
response = self.generate_answer(message, context)
# Add source information with similarity scores
response += "\n\n **Retrieved Sources:**"
for doc in relevant_docs:
response += f"\n• Chunk #{doc['doc_id']} (similarity: {doc['score']:.3f})"
# Update history
history.append([message, response])
return history
#Create an instance of the RAG class
rag_system = SimpleRAG()
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Row():
with gr.Column(scale=1):
# PDF upload section
pdf_input = gr.File(
label="Upload PDF Document",
file_types=[".pdf"],
type="filepath"
)
process_btn = gr.Button("Process & Build Vector DB", variant="primary", size="lg")
status_output = gr.Textbox(
label="Processing Status",
interactive=False,
max_lines=10,
show_label=True
)
with gr.Column(scale=2):
# Chat section
chatbot = gr.Chatbot(
label="RAG Conversation",
height=150,
show_label=True,
bubble_full_width=False
)
with gr.Row():
msg_input = gr.Textbox(
label="Ask a question about your document...",
scale=4,
show_label=False
)
send_btn = gr.Button("Ask", variant="primary", scale=1)
with gr.Row():
clear_btn = gr.Button("Clear Chat", variant="secondary")
# Event handlers
process_btn.click(
fn=rag_system.process_pdf,
inputs=[pdf_input],
outputs=[status_output]
)
def chat_wrapper(message, history):
return rag_system.chat(message, history), ""
send_btn.click(
fn=chat_wrapper,
inputs=[msg_input, chatbot],
outputs=[chatbot, msg_input]
)
msg_input.submit(
fn=chat_wrapper,
inputs=[msg_input, chatbot],
outputs=[chatbot, msg_input]
)
clear_btn.click(
fn=lambda: [],
outputs=[chatbot]
)
demo.launch(
share=True,
debug=True,
show_error=True
)