rag-pdf-chatbot / app.py
Ansnaeem's picture
Update app.py
a3382b8 verified
import gradio as gr
import PyPDF2
import io
from sentence_transformers import SentenceTransformer
import numpy as np
from typing import List, Tuple, Dict
import os
from groq import Groq
import json
# Initialize Groq client (will use API key from environment variable)
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
print("Warning: GROQ_API_KEY not found in environment variables. Please set it to use the chatbot.")
client = None
else:
client = Groq(api_key=groq_api_key)
# Initialize sentence transformer model
print("Loading sentence transformer model...")
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
print("Model loaded!")
# Global variables to store documents and embeddings
documents_store = []
embeddings_store = []
metadata_store = [] # Store filename and page number for each chunk
def extract_text_from_pdf(pdf_file) -> List[Tuple[str, str, int]]:
"""
Extract text from PDF file.
Returns: List of tuples (text, filename, page_number)
"""
text_chunks = []
try:
pdf_reader = PyPDF2.PdfReader(pdf_file)
filename = pdf_file.name if hasattr(pdf_file, 'name') else "uploaded_file.pdf"
for page_num, page in enumerate(pdf_reader.pages, start=1):
text = page.extract_text()
if text.strip(): # Only add non-empty pages
text_chunks.append((text, filename, page_num))
return text_chunks
except Exception as e:
print(f"Error extracting text from PDF: {e}")
return []
def split_into_chunks(text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
"""
Split text into overlapping chunks.
"""
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = ' '.join(words[i:i + chunk_size])
if chunk.strip():
chunks.append(chunk)
return chunks
def process_pdfs(pdf_files) -> Tuple[str, str]:
"""
Process uploaded PDF files and create embeddings.
Returns: (status_message, preview_text)
"""
global documents_store, embeddings_store, metadata_store
if pdf_files is None or len(pdf_files) == 0:
return "No files uploaded.", ""
documents_store = []
embeddings_store = []
metadata_store = []
all_text_chunks = []
preview_text = "=== PDF PREVIEW ===\n\n"
for pdf_file in pdf_files:
extracted_chunks = extract_text_from_pdf(pdf_file)
if not extracted_chunks:
continue
filename = extracted_chunks[0][1]
preview_text += f"πŸ“„ File: {filename}\n"
preview_text += f" Pages: {len(extracted_chunks)}\n"
# Get first page preview
if extracted_chunks:
first_page_text = extracted_chunks[0][0][:500] # First 500 chars
preview_text += f" Preview (Page 1): {first_page_text}...\n\n"
# Split each page into smaller chunks
for page_text, file_name, page_num in extracted_chunks:
chunks = split_into_chunks(page_text)
for chunk in chunks:
all_text_chunks.append((chunk, file_name, page_num))
if not all_text_chunks:
return "No text could be extracted from the PDFs.", preview_text
# Create embeddings
print(f"Creating embeddings for {len(all_text_chunks)} chunks...")
texts = [chunk[0] for chunk in all_text_chunks]
embeddings = embedding_model.encode(texts, show_progress_bar=True)
documents_store = texts
embeddings_store = embeddings
metadata_store = [(chunk[1], chunk[2]) for chunk in all_text_chunks]
# Generate summary
total_chunks = len(all_text_chunks)
unique_files = len(set(chunk[1] for chunk in all_text_chunks))
preview_text += f"\n=== SUMMARY ===\n"
preview_text += f"Total documents processed: {unique_files}\n"
preview_text += f"Total text chunks created: {total_chunks}\n"
preview_text += f"Ready for questions!\n"
return f"βœ… Successfully processed {unique_files} PDF file(s) with {total_chunks} chunks!", preview_text
def retrieve_relevant_chunks(query: str, top_k: int = 3) -> List[Tuple[str, str, int, float]]:
"""
Retrieve top-k most relevant chunks using cosine similarity.
Returns: List of (chunk_text, filename, page_num, similarity_score)
"""
if len(documents_store) == 0:
return []
# Encode query
query_embedding = embedding_model.encode([query])[0]
# Calculate cosine similarity
similarities = np.dot(embeddings_store, query_embedding) / (
np.linalg.norm(embeddings_store, axis=1) * np.linalg.norm(query_embedding)
)
# Get top-k indices
top_indices = np.argsort(similarities)[::-1][:top_k]
# Return top chunks with metadata
results = []
for idx in top_indices:
results.append((
documents_store[idx],
metadata_store[idx][0],
metadata_store[idx][1],
float(similarities[idx])
))
return results
def convert_history_to_gradio_format(history):
"""Convert history from tuple format to Gradio 6 format."""
if not history:
return []
gradio_history = []
for item in history:
if isinstance(item, tuple) and len(item) == 2:
# Convert tuple (user_msg, assistant_msg) to dict format
gradio_history.append({"role": "user", "content": item[0]})
gradio_history.append({"role": "assistant", "content": item[1]})
elif isinstance(item, dict):
# Already in correct format
gradio_history.append(item)
return gradio_history
def convert_history_from_gradio_format(history):
"""Convert history from Gradio 6 format to tuple format for internal use."""
if not history:
return []
tuple_history = []
i = 0
while i < len(history):
if isinstance(history[i], dict):
if history[i].get("role") == "user" and i + 1 < len(history):
if history[i + 1].get("role") == "assistant":
tuple_history.append((history[i]["content"], history[i + 1]["content"]))
i += 2
continue
elif isinstance(history[i], tuple):
tuple_history.append(history[i])
i += 1
return tuple_history
def generate_answer(question: str, history: List) -> Tuple[str, List]:
"""
Generate answer using Groq LLM with RAG context.
"""
# Convert history from Gradio 6 format to internal format
internal_history = convert_history_from_gradio_format(history) if history else []
if client is None:
error_msg = "Error: GROQ_API_KEY not configured. Please set it as an environment variable or in Hugging Face Space secrets."
internal_history.append((question, error_msg))
return "", convert_history_to_gradio_format(internal_history)
if len(documents_store) == 0:
error_msg = "Please upload PDF files first!"
internal_history.append((question, error_msg))
return "", convert_history_to_gradio_format(internal_history)
if not question.strip():
return "", convert_history_to_gradio_format(internal_history)
# Retrieve relevant chunks
relevant_chunks = retrieve_relevant_chunks(question, top_k=3)
if not relevant_chunks:
error_msg = "No relevant context found in the documents."
internal_history.append((question, error_msg))
return "", convert_history_to_gradio_format(internal_history)
# Build context with source references
context_parts = []
sources = []
for i, (chunk, filename, page_num, score) in enumerate(relevant_chunks, 1):
context_parts.append(f"[Source {i} - {filename}, Page {page_num}]\n{chunk}")
sources.append(f"Source {i}: {filename}, Page {page_num} (similarity: {score:.3f})")
context = "\n\n".join(context_parts)
# Create prompt for Groq
prompt = f"""You are a helpful assistant that answers questions based on the provided context from PDF documents.
Context from documents:
{context}
Question: {question}
Please provide a clear and accurate answer based on the context above. If the context doesn't contain enough information to answer the question, say so. At the end of your answer, mention the source references.
Answer:"""
try:
# Call Groq API
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt
}
],
model="llama-3.1-8b-instant",
temperature=0.7,
max_tokens=1024
)
answer = chat_completion.choices[0].message.content
# Append sources to answer
answer += "\n\nπŸ“š Sources:\n" + "\n".join(sources)
# Update history
internal_history.append((question, answer))
return "", convert_history_to_gradio_format(internal_history)
except Exception as e:
error_msg = f"Error generating answer: {str(e)}"
internal_history.append((question, error_msg))
return "", convert_history_to_gradio_format(internal_history)
def clear_all():
"""Clear all stored data."""
global documents_store, embeddings_store, metadata_store
documents_store = []
embeddings_store = []
metadata_store = []
return "", "", []
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("""
# πŸ“š RAG-Based Chatbot with PDF Support
Upload multiple PDF files, and ask questions based on their content!
**Features:**
- πŸ“„ Upload multiple PDF files
- πŸ‘οΈ Preview PDF content before asking questions
- πŸ” Semantic search using sentence transformers
- πŸ’¬ Chat with your documents using Groq LLM
- πŸ“– Source references with page numbers
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“€ Upload PDFs")
pdf_upload = gr.File(
file_count="multiple",
file_types=[".pdf"],
label="Upload PDF Files"
)
process_btn = gr.Button("Process PDFs", variant="primary")
status = gr.Textbox(label="Status", interactive=False)
gr.Markdown("### πŸ‘οΈ PDF Preview & Summary")
preview = gr.Textbox(
label="Preview",
lines=15,
interactive=False,
placeholder="PDF preview and summary will appear here after processing..."
)
with gr.Column(scale=1):
gr.Markdown("### πŸ’¬ Chat with Your Documents")
chatbot = gr.Chatbot(
label="Chat",
height=400
)
question_input = gr.Textbox(
label="Ask a question",
placeholder="Type your question here...",
lines=2
)
with gr.Row():
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear Chat")
clear_all_btn = gr.Button("Clear All", variant="stop")
# Event handlers
process_btn.click(
fn=process_pdfs,
inputs=[pdf_upload],
outputs=[status, preview]
)
submit_btn.click(
fn=generate_answer,
inputs=[question_input, chatbot],
outputs=[question_input, chatbot]
)
question_input.submit(
fn=generate_answer,
inputs=[question_input, chatbot],
outputs=[question_input, chatbot]
)
clear_btn.click(
fn=lambda: ("", []),
outputs=[question_input, chatbot]
)
clear_all_btn.click(
fn=clear_all,
outputs=[status, preview, chatbot]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
theme=gr.themes.Soft()
)