import gradio as gr import PyPDF2 from sentence_transformers import SentenceTransformer from transformers import pipeline import numpy as np import faiss import pickle import os import re class SimpleRAG: def __init__(self): # Initialize models print("Loading models...") self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') self.qa_pipeline = pipeline( "text2text-generation", model="google/flan-t5-base", max_length=512, temperature=0.7 ) # Storage for documents and vector database self.documents = [] self.vector_db = None # FAISS index self.embedding_dimension = 384 # all-MiniLM-L6-v2 dimension self.is_ready = False # Create directory for persistent storage self.db_path = "vector_db" os.makedirs(self.db_path, exist_ok=True) print("Models loaded successfully!") #----------------------------------- #Extract text from uploaded PDF file def extract_text_from_pdf(self, pdf_file): reader = PyPDF2.PdfReader(pdf_file) text = "" for page in reader.pages: text += page.extract_text() + "\n" return text #----------------------------------- #Chunk text into smaller overlapping chunks def chunk_text(self, text, chunk_size=500, overlap=50): # Clean the text text = re.sub(r'\s+', ' ', text).strip() words = text.split() chunks = [] for i in range(0, len(words), chunk_size - overlap): chunk = ' '.join(words[i:i + chunk_size]) if len(chunk.strip()) > 0: chunks.append(chunk.strip()) return chunks #----------------------------------- #Create FAISS vector database from embeddings #FAISS: Facebook AI Similarity Search def create_vector_database(self, embeddings): # Initialize FAISS index (using Inner Product for cosine similarity) # Normalize embeddings for cosine similarity embeddings_normalized = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) # Create FAISS index self.vector_db = faiss.IndexFlatIP(self.embedding_dimension) self.vector_db.add(embeddings_normalized.astype('float32')) print(f"Vector database created with {self.vector_db.ntotal} vectors") #----------------------------------- #Save vector database and documents to disk def save_vector_database(self, filename="vector_db"): # Save FAISS index faiss.write_index(self.vector_db, f"{self.db_path}/{filename}.index") # Save documents with open(f"{self.db_path}/{filename}_docs.pkl", 'wb') as f: pickle.dump(self.documents, f) print("Vector database saved to disk!") #----------------------------------- #Load vector database and documents from disk def load_vector_database(self, filename="vector_db"): index_path = f"{self.db_path}/{filename}.index" docs_path = f"{self.db_path}/{filename}_docs.pkl" if os.path.exists(index_path) and os.path.exists(docs_path): # Load FAISS index self.vector_db = faiss.read_index(index_path) # Load documents with open(docs_path, 'rb') as f: self.documents = pickle.load(f) self.is_ready = True print(f"šŸ“‚ Vector database loaded: {len(self.documents)} documents") return True return False #----------------------------------- #Process PDF and create vector database def process_pdf(self, pdf_file): if pdf_file is None: return "Please upload a PDF file first." # Extract text text = self.extract_text_from_pdf(pdf_file) if text.startswith("Error"): return text # Chunk the text self.documents = self.chunk_text(text) if not self.documents: return "No text could be extracted from the PDF." # Create embeddings print(f"Creating embeddings for {len(self.documents)} chunks...") embeddings = self.embedding_model.encode(self.documents) # Create vector database self.create_vector_database(embeddings) # Save to disk self.save_vector_database() self.is_ready = True return f"PDF processed successfully!" #----------------------------------- #Retrieve most relevant document chunks using FAISS vector database def retrieve_relevant_docs(self, query, top_k=3): if not self.is_ready or self.vector_db is None: return [] # Encode and normalize the query query_embedding = self.embedding_model.encode([query]) query_normalized = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True) # Search in vector database scores, indices = self.vector_db.search(query_normalized.astype('float32'), top_k) relevant_docs = [] for i, (score, idx) in enumerate(zip(scores[0], indices[0])): if idx < len(self.documents): # Validate index relevant_docs.append({ 'text': self.documents[idx], 'score': float(score), 'rank': i + 1, 'doc_id': int(idx) }) return relevant_docs #----------------------------------- #Generate answer using the QA model def generate_answer(self, query, context): # Create a prompt for the model prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:" # Generate answer result = self.qa_pipeline(prompt) return result[0]['generated_text'] #----------------------------------- #Main chat function that handles the RAG pipeline def chat(self, message, history): if not self.is_ready: return history + [[message, "Please upload and process a PDF file first."]] if not message.strip(): return history + [[message, "Please enter a question."]] # STEP 1: RETRIEVE - Find relevant documents using vector database relevant_docs = self.retrieve_relevant_docs(message, top_k=3) if not relevant_docs: response = "I couldn't find relevant information in the document to answer your question." else: print(f"Retrieved {len(relevant_docs)} relevant chunks") # STEP 2: AUGMENT - Combine relevant documents as context context = "\n\n".join([doc['text'] for doc in relevant_docs]) # Limit context length to avoid model limits if len(context) > 2000: context = context[:2000] + "..." # STEP 3: GENERATE - Create answer using retrieved context print("Generating answer...") response = self.generate_answer(message, context) # Add source information with similarity scores response += "\n\n **Retrieved Sources:**" for doc in relevant_docs: response += f"\n• Chunk #{doc['doc_id']} (similarity: {doc['score']:.3f})" # Update history history.append([message, response]) return history #Create an instance of the RAG class rag_system = SimpleRAG() # Create Gradio interface with gr.Blocks(theme=gr.themes.Soft()) as demo: with gr.Row(): with gr.Column(scale=1): # PDF upload section pdf_input = gr.File( label="Upload PDF Document", file_types=[".pdf"], type="filepath" ) process_btn = gr.Button("Process & Build Vector DB", variant="primary", size="lg") status_output = gr.Textbox( label="Processing Status", interactive=False, max_lines=10, show_label=True ) with gr.Column(scale=2): # Chat section chatbot = gr.Chatbot( label="RAG Conversation", height=150, show_label=True, bubble_full_width=False ) with gr.Row(): msg_input = gr.Textbox( label="Ask a question about your document...", scale=4, show_label=False ) send_btn = gr.Button("Ask", variant="primary", scale=1) with gr.Row(): clear_btn = gr.Button("Clear Chat", variant="secondary") # Event handlers process_btn.click( fn=rag_system.process_pdf, inputs=[pdf_input], outputs=[status_output] ) def chat_wrapper(message, history): return rag_system.chat(message, history), "" send_btn.click( fn=chat_wrapper, inputs=[msg_input, chatbot], outputs=[chatbot, msg_input] ) msg_input.submit( fn=chat_wrapper, inputs=[msg_input, chatbot], outputs=[chatbot, msg_input] ) clear_btn.click( fn=lambda: [], outputs=[chatbot] ) demo.launch( share=True, debug=True, show_error=True )