Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import PyPDF2 | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import pipeline | |
| import numpy as np | |
| import faiss | |
| import pickle | |
| import os | |
| import re | |
| class SimpleRAG: | |
| def __init__(self): | |
| # Initialize models | |
| print("Loading models...") | |
| self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.qa_pipeline = pipeline( | |
| "text2text-generation", | |
| model="google/flan-t5-base", | |
| max_length=512, | |
| temperature=0.7 | |
| ) | |
| # Storage for documents and vector database | |
| self.documents = [] | |
| self.vector_db = None # FAISS index | |
| self.embedding_dimension = 384 # all-MiniLM-L6-v2 dimension | |
| self.is_ready = False | |
| # Create directory for persistent storage | |
| self.db_path = "vector_db" | |
| os.makedirs(self.db_path, exist_ok=True) | |
| print("Models loaded successfully!") | |
| #----------------------------------- | |
| #Extract text from uploaded PDF file | |
| def extract_text_from_pdf(self, pdf_file): | |
| reader = PyPDF2.PdfReader(pdf_file) | |
| text = "" | |
| for page in reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| #----------------------------------- | |
| #Chunk text into smaller overlapping chunks | |
| def chunk_text(self, text, chunk_size=500, overlap=50): | |
| # Clean the text | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| words = text.split() | |
| chunks = [] | |
| for i in range(0, len(words), chunk_size - overlap): | |
| chunk = ' '.join(words[i:i + chunk_size]) | |
| if len(chunk.strip()) > 0: | |
| chunks.append(chunk.strip()) | |
| return chunks | |
| #----------------------------------- | |
| #Create FAISS vector database from embeddings | |
| #FAISS: Facebook AI Similarity Search | |
| def create_vector_database(self, embeddings): | |
| # Initialize FAISS index (using Inner Product for cosine similarity) | |
| # Normalize embeddings for cosine similarity | |
| embeddings_normalized = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) | |
| # Create FAISS index | |
| self.vector_db = faiss.IndexFlatIP(self.embedding_dimension) | |
| self.vector_db.add(embeddings_normalized.astype('float32')) | |
| print(f"Vector database created with {self.vector_db.ntotal} vectors") | |
| #----------------------------------- | |
| #Save vector database and documents to disk | |
| def save_vector_database(self, filename="vector_db"): | |
| # Save FAISS index | |
| faiss.write_index(self.vector_db, f"{self.db_path}/{filename}.index") | |
| # Save documents | |
| with open(f"{self.db_path}/{filename}_docs.pkl", 'wb') as f: | |
| pickle.dump(self.documents, f) | |
| print("Vector database saved to disk!") | |
| #----------------------------------- | |
| #Load vector database and documents from disk | |
| def load_vector_database(self, filename="vector_db"): | |
| index_path = f"{self.db_path}/{filename}.index" | |
| docs_path = f"{self.db_path}/{filename}_docs.pkl" | |
| if os.path.exists(index_path) and os.path.exists(docs_path): | |
| # Load FAISS index | |
| self.vector_db = faiss.read_index(index_path) | |
| # Load documents | |
| with open(docs_path, 'rb') as f: | |
| self.documents = pickle.load(f) | |
| self.is_ready = True | |
| print(f"📂 Vector database loaded: {len(self.documents)} documents") | |
| return True | |
| return False | |
| #----------------------------------- | |
| #Process PDF and create vector database | |
| def process_pdf(self, pdf_file): | |
| if pdf_file is None: | |
| return "Please upload a PDF file first." | |
| # Extract text | |
| text = self.extract_text_from_pdf(pdf_file) | |
| if text.startswith("Error"): | |
| return text | |
| # Chunk the text | |
| self.documents = self.chunk_text(text) | |
| if not self.documents: | |
| return "No text could be extracted from the PDF." | |
| # Create embeddings | |
| print(f"Creating embeddings for {len(self.documents)} chunks...") | |
| embeddings = self.embedding_model.encode(self.documents) | |
| # Create vector database | |
| self.create_vector_database(embeddings) | |
| # Save to disk | |
| self.save_vector_database() | |
| self.is_ready = True | |
| return f"PDF processed successfully!" | |
| #----------------------------------- | |
| #Retrieve most relevant document chunks using FAISS vector database | |
| def retrieve_relevant_docs(self, query, top_k=3): | |
| if not self.is_ready or self.vector_db is None: | |
| return [] | |
| # Encode and normalize the query | |
| query_embedding = self.embedding_model.encode([query]) | |
| query_normalized = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True) | |
| # Search in vector database | |
| scores, indices = self.vector_db.search(query_normalized.astype('float32'), top_k) | |
| relevant_docs = [] | |
| for i, (score, idx) in enumerate(zip(scores[0], indices[0])): | |
| if idx < len(self.documents): # Validate index | |
| relevant_docs.append({ | |
| 'text': self.documents[idx], | |
| 'score': float(score), | |
| 'rank': i + 1, | |
| 'doc_id': int(idx) | |
| }) | |
| return relevant_docs | |
| #----------------------------------- | |
| #Generate answer using the QA model | |
| def generate_answer(self, query, context): | |
| # Create a prompt for the model | |
| prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:" | |
| # Generate answer | |
| result = self.qa_pipeline(prompt) | |
| return result[0]['generated_text'] | |
| #----------------------------------- | |
| #Main chat function that handles the RAG pipeline | |
| def chat(self, message, history): | |
| if not self.is_ready: | |
| return history + [[message, "Please upload and process a PDF file first."]] | |
| if not message.strip(): | |
| return history + [[message, "Please enter a question."]] | |
| # STEP 1: RETRIEVE - Find relevant documents using vector database | |
| relevant_docs = self.retrieve_relevant_docs(message, top_k=3) | |
| if not relevant_docs: | |
| response = "I couldn't find relevant information in the document to answer your question." | |
| else: | |
| print(f"Retrieved {len(relevant_docs)} relevant chunks") | |
| # STEP 2: AUGMENT - Combine relevant documents as context | |
| context = "\n\n".join([doc['text'] for doc in relevant_docs]) | |
| # Limit context length to avoid model limits | |
| if len(context) > 2000: | |
| context = context[:2000] + "..." | |
| # STEP 3: GENERATE - Create answer using retrieved context | |
| print("Generating answer...") | |
| response = self.generate_answer(message, context) | |
| # Add source information with similarity scores | |
| response += "\n\n **Retrieved Sources:**" | |
| for doc in relevant_docs: | |
| response += f"\n• Chunk #{doc['doc_id']} (similarity: {doc['score']:.3f})" | |
| # Update history | |
| history.append([message, response]) | |
| return history | |
| #Create an instance of the RAG class | |
| rag_system = SimpleRAG() | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # PDF upload section | |
| pdf_input = gr.File( | |
| label="Upload PDF Document", | |
| file_types=[".pdf"], | |
| type="filepath" | |
| ) | |
| process_btn = gr.Button("Process & Build Vector DB", variant="primary", size="lg") | |
| status_output = gr.Textbox( | |
| label="Processing Status", | |
| interactive=False, | |
| max_lines=10, | |
| show_label=True | |
| ) | |
| with gr.Column(scale=2): | |
| # Chat section | |
| chatbot = gr.Chatbot( | |
| label="RAG Conversation", | |
| height=150, | |
| show_label=True, | |
| bubble_full_width=False | |
| ) | |
| with gr.Row(): | |
| msg_input = gr.Textbox( | |
| label="Ask a question about your document...", | |
| scale=4, | |
| show_label=False | |
| ) | |
| send_btn = gr.Button("Ask", variant="primary", scale=1) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear Chat", variant="secondary") | |
| # Event handlers | |
| process_btn.click( | |
| fn=rag_system.process_pdf, | |
| inputs=[pdf_input], | |
| outputs=[status_output] | |
| ) | |
| def chat_wrapper(message, history): | |
| return rag_system.chat(message, history), "" | |
| send_btn.click( | |
| fn=chat_wrapper, | |
| inputs=[msg_input, chatbot], | |
| outputs=[chatbot, msg_input] | |
| ) | |
| msg_input.submit( | |
| fn=chat_wrapper, | |
| inputs=[msg_input, chatbot], | |
| outputs=[chatbot, msg_input] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: [], | |
| outputs=[chatbot] | |
| ) | |
| demo.launch( | |
| share=True, | |
| debug=True, | |
| show_error=True | |
| ) |