Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import gradio as gr | |
| import numpy as np | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| import google.generativeai as genai | |
| import fitz # PyMuPDF | |
| import traceback | |
| # Initialize embedding model | |
| sbert_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| # Data storage | |
| chunks = [] | |
| faiss_index = None | |
| embedding_dimension = 384 # all-MiniLM-L6-v2 embedding dimension | |
| def extract_text_from_pdf(pdf_file_path, start_page=None, end_page=None): | |
| """Extract text from PDF file, optionally from a specific page range.""" | |
| doc = fitz.open(pdf_file_path) | |
| text = "" | |
| num_pages_in_doc = doc.page_count | |
| if start_page is not None and end_page is not None: | |
| start_idx = start_page - 1 | |
| end_idx = end_page - 1 | |
| if 0 <= start_idx <= end_idx < num_pages_in_doc: | |
| pages_to_process = range(start_idx, end_idx + 1) | |
| else: | |
| pages_to_process = range(num_pages_in_doc) | |
| else: | |
| pages_to_process = range(num_pages_in_doc) | |
| for i in pages_to_process: | |
| text += doc.load_page(i).get_text() | |
| doc.close() | |
| return text, num_pages_in_doc | |
| def chunk_text(text, chunk_size=1000, overlap=200): | |
| """Split text into overlapping chunks""" | |
| doc_chunks = [] | |
| for i in range(0, len(text), chunk_size - overlap): | |
| chunk = text[i:i + chunk_size] | |
| if len(chunk) > 100: | |
| doc_chunks.append(chunk) | |
| return doc_chunks | |
| def create_faiss_index(embeddings): | |
| """Create FAISS index for fast similarity search.""" | |
| global embedding_dimension | |
| # Normalize embeddings for cosine similarity | |
| faiss.normalize_L2(embeddings) | |
| # Create index - using IndexFlatIP for cosine similarity | |
| index = faiss.IndexFlatIP(embedding_dimension) | |
| index.add(embeddings) | |
| return index | |
| def process_pdf(pdf_file_obj, api_key): | |
| """Process PDF and create FAISS index.""" | |
| global chunks, faiss_index | |
| if not api_key: | |
| return None, [["System", "β οΈ Please set your Gemini API key first."]] | |
| if pdf_file_obj is None: | |
| return None, [["System", "π Please upload a PDF file."]] | |
| try: | |
| # Save uploaded file temporarily | |
| with open(pdf_file_obj.name, "rb") as f_in: | |
| pdf_bytes = f_in.read() | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp: | |
| tmp.write(pdf_bytes) | |
| tmp_path = tmp.name | |
| # Extract text | |
| text, total_pages = extract_text_from_pdf(tmp_path) | |
| if not text.strip(): | |
| return None, [["System", "β οΈ No text found in the PDF. Please try a different file."]] | |
| # Create chunks | |
| current_chunks = chunk_text(text) | |
| if not current_chunks: | |
| return None, [["System", "β οΈ Could not create text chunks from the PDF."]] | |
| # Generate embeddings | |
| current_embeddings = sbert_model.encode(current_chunks) | |
| current_embeddings = np.array(current_embeddings, dtype=np.float32) | |
| # Create FAISS index | |
| current_index = create_faiss_index(current_embeddings) | |
| # Update global storage | |
| chunks = current_chunks | |
| faiss_index = current_index | |
| pdf_name = os.path.basename(pdf_file_obj.name) | |
| success_msg = f"β Successfully processed '{pdf_name}' ({total_pages} pages, {len(chunks)} chunks). FAISS index created! You can now ask questions!" | |
| # Clean up | |
| if os.path.exists(tmp_path): | |
| os.unlink(tmp_path) | |
| return None, [["System", success_msg]] | |
| except Exception as e: | |
| chunks = [] | |
| faiss_index = None | |
| error_msg = f"β Error processing PDF: {str(e)}" | |
| return None, [["System", error_msg]] | |
| def retrieve_relevant_chunks(query, top_k=3): | |
| """Retrieve most relevant chunks using FAISS search.""" | |
| global chunks, faiss_index | |
| if not chunks or faiss_index is None: | |
| return [] | |
| try: | |
| # Encode query | |
| query_embedding = sbert_model.encode([query]) | |
| query_embedding = np.array(query_embedding, dtype=np.float32) | |
| # Normalize for cosine similarity | |
| faiss.normalize_L2(query_embedding) | |
| # Search using FAISS | |
| scores, indices = faiss_index.search(query_embedding, top_k) | |
| # Get top chunks | |
| top_chunks = [] | |
| for idx in indices[0]: | |
| if idx < len(chunks): # Safety check | |
| top_chunks.append(chunks[idx]) | |
| return top_chunks | |
| except Exception as e: | |
| print(f"Error in FAISS search: {str(e)}") | |
| return [] | |
| def chat_fn(message, history, api_key): | |
| """Handle chat interaction.""" | |
| if not message.strip(): | |
| return history, "" | |
| # Add user message to history | |
| history = history + [[message, None]] | |
| if not api_key: | |
| history[-1][1] = "β οΈ Please set your Gemini API key first." | |
| return history, "" | |
| if not chunks or faiss_index is None: | |
| history[-1][1] = "π Please upload and process a PDF document first." | |
| return history, "" | |
| try: | |
| # Configure Gemini | |
| genai.configure(api_key=api_key) | |
| # Get relevant context using FAISS | |
| context_chunks = retrieve_relevant_chunks(message, top_k=5) | |
| if not context_chunks: | |
| history[-1][1] = "β Could not find relevant information in the document." | |
| return history, "" | |
| # Generate response | |
| context = "\n\n".join(context_chunks) | |
| prompt = f"""Based on the following context from the document, answer the user's question. | |
| Context: | |
| {context} | |
| Question: {message} | |
| Please provide a clear, accurate answer based only on the information in the context. If the context doesn't contain enough information to answer the question, say so.""" | |
| model = genai.GenerativeModel('gemini-1.5-flash-latest') | |
| response = model.generate_content(prompt) | |
| history[-1][1] = response.text | |
| except Exception as e: | |
| history[-1][1] = f"β Error: {str(e)}" | |
| return history, "" | |
| # Custom CSS for better chat appearance | |
| css = """ | |
| .gradio-container { | |
| max-width: 800px !important; | |
| margin: auto !important; | |
| } | |
| .chat-message { | |
| padding: 10px !important; | |
| margin: 5px 0 !important; | |
| border-radius: 10px !important; | |
| } | |
| """ | |
| with gr.Blocks(css=css, title="π Chat with Your PDF") as demo: | |
| api_key_state = gr.State("") | |
| gr.Markdown(""" | |
| # π Chat with Your PDF (FAISS-Powered) | |
| Upload a PDF document and chat with it naturally. Now with FAISS for faster vector search! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| api_key_input = gr.Textbox( | |
| label="π Gemini API Key", | |
| type="password", | |
| placeholder="Enter your API key here..." | |
| ) | |
| with gr.Column(scale=1): | |
| pdf_input = gr.File( | |
| label="π Upload PDF", | |
| file_types=['.pdf'] | |
| ) | |
| # Chat interface | |
| chatbot = gr.Chatbot( | |
| label="π¬ Chat", | |
| height=500, | |
| show_label=False, | |
| bubble_full_width=False | |
| ) | |
| msg_input = gr.Textbox( | |
| label="Message", | |
| placeholder="Ask anything about your PDF...", | |
| show_label=False, | |
| container=False | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Send π¬", variant="primary") | |
| clear_btn = gr.Button("Clear Chat ποΈ") | |
| # Event handlers | |
| def update_api_key(key): | |
| return key | |
| api_key_input.change( | |
| fn=update_api_key, | |
| inputs=api_key_input, | |
| outputs=api_key_state | |
| ) | |
| pdf_input.upload( | |
| fn=process_pdf, | |
| inputs=[pdf_input, api_key_state], | |
| outputs=[msg_input, chatbot] | |
| ) | |
| submit_btn.click( | |
| fn=chat_fn, | |
| inputs=[msg_input, chatbot, api_key_state], | |
| outputs=[chatbot, msg_input] | |
| ) | |
| msg_input.submit( | |
| fn=chat_fn, | |
| inputs=[msg_input, chatbot, api_key_state], | |
| outputs=[chatbot, msg_input] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ([], ""), | |
| outputs=[chatbot, msg_input] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |