Hindi-Rag / app.py
wellwisherofindia's picture
Update app.py
3dc7d4f
raw
history blame
8.3 kB
import os
import tempfile
import gradio as gr
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
import google.generativeai as genai
import fitz # PyMuPDF
import traceback
# Initialize embedding model
sbert_model = SentenceTransformer('all-MiniLM-L6-v2')
# Data storage
chunks = []
faiss_index = None
embedding_dimension = 384 # all-MiniLM-L6-v2 embedding dimension
def extract_text_from_pdf(pdf_file_path, start_page=None, end_page=None):
"""Extract text from PDF file, optionally from a specific page range."""
doc = fitz.open(pdf_file_path)
text = ""
num_pages_in_doc = doc.page_count
if start_page is not None and end_page is not None:
start_idx = start_page - 1
end_idx = end_page - 1
if 0 <= start_idx <= end_idx < num_pages_in_doc:
pages_to_process = range(start_idx, end_idx + 1)
else:
pages_to_process = range(num_pages_in_doc)
else:
pages_to_process = range(num_pages_in_doc)
for i in pages_to_process:
text += doc.load_page(i).get_text()
doc.close()
return text, num_pages_in_doc
def chunk_text(text, chunk_size=1000, overlap=200):
"""Split text into overlapping chunks"""
doc_chunks = []
for i in range(0, len(text), chunk_size - overlap):
chunk = text[i:i + chunk_size]
if len(chunk) > 100:
doc_chunks.append(chunk)
return doc_chunks
def create_faiss_index(embeddings):
"""Create FAISS index for fast similarity search."""
global embedding_dimension
# Normalize embeddings for cosine similarity
faiss.normalize_L2(embeddings)
# Create index - using IndexFlatIP for cosine similarity
index = faiss.IndexFlatIP(embedding_dimension)
index.add(embeddings)
return index
def process_pdf(pdf_file_obj, api_key):
"""Process PDF and create FAISS index."""
global chunks, faiss_index
if not api_key:
return None, [["System", "⚠️ Please set your Gemini API key first."]]
if pdf_file_obj is None:
return None, [["System", "πŸ“„ Please upload a PDF file."]]
try:
# Save uploaded file temporarily
with open(pdf_file_obj.name, "rb") as f_in:
pdf_bytes = f_in.read()
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
tmp.write(pdf_bytes)
tmp_path = tmp.name
# Extract text
text, total_pages = extract_text_from_pdf(tmp_path)
if not text.strip():
return None, [["System", "⚠️ No text found in the PDF. Please try a different file."]]
# Create chunks
current_chunks = chunk_text(text)
if not current_chunks:
return None, [["System", "⚠️ Could not create text chunks from the PDF."]]
# Generate embeddings
current_embeddings = sbert_model.encode(current_chunks)
current_embeddings = np.array(current_embeddings, dtype=np.float32)
# Create FAISS index
current_index = create_faiss_index(current_embeddings)
# Update global storage
chunks = current_chunks
faiss_index = current_index
pdf_name = os.path.basename(pdf_file_obj.name)
success_msg = f"βœ… Successfully processed '{pdf_name}' ({total_pages} pages, {len(chunks)} chunks). FAISS index created! You can now ask questions!"
# Clean up
if os.path.exists(tmp_path):
os.unlink(tmp_path)
return None, [["System", success_msg]]
except Exception as e:
chunks = []
faiss_index = None
error_msg = f"❌ Error processing PDF: {str(e)}"
return None, [["System", error_msg]]
def retrieve_relevant_chunks(query, top_k=3):
"""Retrieve most relevant chunks using FAISS search."""
global chunks, faiss_index
if not chunks or faiss_index is None:
return []
try:
# Encode query
query_embedding = sbert_model.encode([query])
query_embedding = np.array(query_embedding, dtype=np.float32)
# Normalize for cosine similarity
faiss.normalize_L2(query_embedding)
# Search using FAISS
scores, indices = faiss_index.search(query_embedding, top_k)
# Get top chunks
top_chunks = []
for idx in indices[0]:
if idx < len(chunks): # Safety check
top_chunks.append(chunks[idx])
return top_chunks
except Exception as e:
print(f"Error in FAISS search: {str(e)}")
return []
def chat_fn(message, history, api_key):
"""Handle chat interaction."""
if not message.strip():
return history, ""
# Add user message to history
history = history + [[message, None]]
if not api_key:
history[-1][1] = "⚠️ Please set your Gemini API key first."
return history, ""
if not chunks or faiss_index is None:
history[-1][1] = "πŸ“„ Please upload and process a PDF document first."
return history, ""
try:
# Configure Gemini
genai.configure(api_key=api_key)
# Get relevant context using FAISS
context_chunks = retrieve_relevant_chunks(message, top_k=5)
if not context_chunks:
history[-1][1] = "❌ Could not find relevant information in the document."
return history, ""
# Generate response
context = "\n\n".join(context_chunks)
prompt = f"""Based on the following context from the document, answer the user's question.
Context:
{context}
Question: {message}
Please provide a clear, accurate answer based only on the information in the context. If the context doesn't contain enough information to answer the question, say so."""
model = genai.GenerativeModel('gemini-1.5-flash-latest')
response = model.generate_content(prompt)
history[-1][1] = response.text
except Exception as e:
history[-1][1] = f"❌ Error: {str(e)}"
return history, ""
# Custom CSS for better chat appearance
css = """
.gradio-container {
max-width: 800px !important;
margin: auto !important;
}
.chat-message {
padding: 10px !important;
margin: 5px 0 !important;
border-radius: 10px !important;
}
"""
with gr.Blocks(css=css, title="πŸ“š Chat with Your PDF") as demo:
api_key_state = gr.State("")
gr.Markdown("""
# πŸ“š Chat with Your PDF (FAISS-Powered)
Upload a PDF document and chat with it naturally. Now with FAISS for faster vector search!
""")
with gr.Row():
with gr.Column(scale=2):
api_key_input = gr.Textbox(
label="πŸ”‘ Gemini API Key",
type="password",
placeholder="Enter your API key here..."
)
with gr.Column(scale=1):
pdf_input = gr.File(
label="πŸ“„ Upload PDF",
file_types=['.pdf']
)
# Chat interface
chatbot = gr.Chatbot(
label="πŸ’¬ Chat",
height=500,
show_label=False,
bubble_full_width=False
)
msg_input = gr.Textbox(
label="Message",
placeholder="Ask anything about your PDF...",
show_label=False,
container=False
)
with gr.Row():
submit_btn = gr.Button("Send πŸ’¬", variant="primary")
clear_btn = gr.Button("Clear Chat πŸ—‘οΈ")
# Event handlers
def update_api_key(key):
return key
api_key_input.change(
fn=update_api_key,
inputs=api_key_input,
outputs=api_key_state
)
pdf_input.upload(
fn=process_pdf,
inputs=[pdf_input, api_key_state],
outputs=[msg_input, chatbot]
)
submit_btn.click(
fn=chat_fn,
inputs=[msg_input, chatbot, api_key_state],
outputs=[chatbot, msg_input]
)
msg_input.submit(
fn=chat_fn,
inputs=[msg_input, chatbot, api_key_state],
outputs=[chatbot, msg_input]
)
clear_btn.click(
fn=lambda: ([], ""),
outputs=[chatbot, msg_input]
)
if __name__ == "__main__":
demo.launch(share=True)