import gradio as gr import PyPDF2 import io from sentence_transformers import SentenceTransformer import numpy as np from typing import List, Tuple, Dict import os from groq import Groq import json # Initialize Groq client (will use API key from environment variable) groq_api_key = os.getenv("GROQ_API_KEY") if not groq_api_key: print("Warning: GROQ_API_KEY not found in environment variables. Please set it to use the chatbot.") client = None else: client = Groq(api_key=groq_api_key) # Initialize sentence transformer model print("Loading sentence transformer model...") embedding_model = SentenceTransformer('all-MiniLM-L6-v2') print("Model loaded!") # Global variables to store documents and embeddings documents_store = [] embeddings_store = [] metadata_store = [] # Store filename and page number for each chunk def extract_text_from_pdf(pdf_file) -> List[Tuple[str, str, int]]: """ Extract text from PDF file. Returns: List of tuples (text, filename, page_number) """ text_chunks = [] try: pdf_reader = PyPDF2.PdfReader(pdf_file) filename = pdf_file.name if hasattr(pdf_file, 'name') else "uploaded_file.pdf" for page_num, page in enumerate(pdf_reader.pages, start=1): text = page.extract_text() if text.strip(): # Only add non-empty pages text_chunks.append((text, filename, page_num)) return text_chunks except Exception as e: print(f"Error extracting text from PDF: {e}") return [] def split_into_chunks(text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: """ Split text into overlapping chunks. """ words = text.split() chunks = [] for i in range(0, len(words), chunk_size - overlap): chunk = ' '.join(words[i:i + chunk_size]) if chunk.strip(): chunks.append(chunk) return chunks def process_pdfs(pdf_files) -> Tuple[str, str]: """ Process uploaded PDF files and create embeddings. Returns: (status_message, preview_text) """ global documents_store, embeddings_store, metadata_store if pdf_files is None or len(pdf_files) == 0: return "No files uploaded.", "" documents_store = [] embeddings_store = [] metadata_store = [] all_text_chunks = [] preview_text = "=== PDF PREVIEW ===\n\n" for pdf_file in pdf_files: extracted_chunks = extract_text_from_pdf(pdf_file) if not extracted_chunks: continue filename = extracted_chunks[0][1] preview_text += f"šŸ“„ File: {filename}\n" preview_text += f" Pages: {len(extracted_chunks)}\n" # Get first page preview if extracted_chunks: first_page_text = extracted_chunks[0][0][:500] # First 500 chars preview_text += f" Preview (Page 1): {first_page_text}...\n\n" # Split each page into smaller chunks for page_text, file_name, page_num in extracted_chunks: chunks = split_into_chunks(page_text) for chunk in chunks: all_text_chunks.append((chunk, file_name, page_num)) if not all_text_chunks: return "No text could be extracted from the PDFs.", preview_text # Create embeddings print(f"Creating embeddings for {len(all_text_chunks)} chunks...") texts = [chunk[0] for chunk in all_text_chunks] embeddings = embedding_model.encode(texts, show_progress_bar=True) documents_store = texts embeddings_store = embeddings metadata_store = [(chunk[1], chunk[2]) for chunk in all_text_chunks] # Generate summary total_chunks = len(all_text_chunks) unique_files = len(set(chunk[1] for chunk in all_text_chunks)) preview_text += f"\n=== SUMMARY ===\n" preview_text += f"Total documents processed: {unique_files}\n" preview_text += f"Total text chunks created: {total_chunks}\n" preview_text += f"Ready for questions!\n" return f"āœ… Successfully processed {unique_files} PDF file(s) with {total_chunks} chunks!", preview_text def retrieve_relevant_chunks(query: str, top_k: int = 3) -> List[Tuple[str, str, int, float]]: """ Retrieve top-k most relevant chunks using cosine similarity. Returns: List of (chunk_text, filename, page_num, similarity_score) """ if len(documents_store) == 0: return [] # Encode query query_embedding = embedding_model.encode([query])[0] # Calculate cosine similarity similarities = np.dot(embeddings_store, query_embedding) / ( np.linalg.norm(embeddings_store, axis=1) * np.linalg.norm(query_embedding) ) # Get top-k indices top_indices = np.argsort(similarities)[::-1][:top_k] # Return top chunks with metadata results = [] for idx in top_indices: results.append(( documents_store[idx], metadata_store[idx][0], metadata_store[idx][1], float(similarities[idx]) )) return results def convert_history_to_gradio_format(history): """Convert history from tuple format to Gradio 6 format.""" if not history: return [] gradio_history = [] for item in history: if isinstance(item, tuple) and len(item) == 2: # Convert tuple (user_msg, assistant_msg) to dict format gradio_history.append({"role": "user", "content": item[0]}) gradio_history.append({"role": "assistant", "content": item[1]}) elif isinstance(item, dict): # Already in correct format gradio_history.append(item) return gradio_history def convert_history_from_gradio_format(history): """Convert history from Gradio 6 format to tuple format for internal use.""" if not history: return [] tuple_history = [] i = 0 while i < len(history): if isinstance(history[i], dict): if history[i].get("role") == "user" and i + 1 < len(history): if history[i + 1].get("role") == "assistant": tuple_history.append((history[i]["content"], history[i + 1]["content"])) i += 2 continue elif isinstance(history[i], tuple): tuple_history.append(history[i]) i += 1 return tuple_history def generate_answer(question: str, history: List) -> Tuple[str, List]: """ Generate answer using Groq LLM with RAG context. """ # Convert history from Gradio 6 format to internal format internal_history = convert_history_from_gradio_format(history) if history else [] if client is None: error_msg = "Error: GROQ_API_KEY not configured. Please set it as an environment variable or in Hugging Face Space secrets." internal_history.append((question, error_msg)) return "", convert_history_to_gradio_format(internal_history) if len(documents_store) == 0: error_msg = "Please upload PDF files first!" internal_history.append((question, error_msg)) return "", convert_history_to_gradio_format(internal_history) if not question.strip(): return "", convert_history_to_gradio_format(internal_history) # Retrieve relevant chunks relevant_chunks = retrieve_relevant_chunks(question, top_k=3) if not relevant_chunks: error_msg = "No relevant context found in the documents." internal_history.append((question, error_msg)) return "", convert_history_to_gradio_format(internal_history) # Build context with source references context_parts = [] sources = [] for i, (chunk, filename, page_num, score) in enumerate(relevant_chunks, 1): context_parts.append(f"[Source {i} - {filename}, Page {page_num}]\n{chunk}") sources.append(f"Source {i}: {filename}, Page {page_num} (similarity: {score:.3f})") context = "\n\n".join(context_parts) # Create prompt for Groq prompt = f"""You are a helpful assistant that answers questions based on the provided context from PDF documents. Context from documents: {context} Question: {question} Please provide a clear and accurate answer based on the context above. If the context doesn't contain enough information to answer the question, say so. At the end of your answer, mention the source references. Answer:""" try: # Call Groq API chat_completion = client.chat.completions.create( messages=[ { "role": "user", "content": prompt } ], model="llama-3.1-8b-instant", temperature=0.7, max_tokens=1024 ) answer = chat_completion.choices[0].message.content # Append sources to answer answer += "\n\nšŸ“š Sources:\n" + "\n".join(sources) # Update history internal_history.append((question, answer)) return "", convert_history_to_gradio_format(internal_history) except Exception as e: error_msg = f"Error generating answer: {str(e)}" internal_history.append((question, error_msg)) return "", convert_history_to_gradio_format(internal_history) def clear_all(): """Clear all stored data.""" global documents_store, embeddings_store, metadata_store documents_store = [] embeddings_store = [] metadata_store = [] return "", "", [] # Create Gradio interface with gr.Blocks() as demo: gr.Markdown(""" # šŸ“š RAG-Based Chatbot with PDF Support Upload multiple PDF files, and ask questions based on their content! **Features:** - šŸ“„ Upload multiple PDF files - šŸ‘ļø Preview PDF content before asking questions - šŸ” Semantic search using sentence transformers - šŸ’¬ Chat with your documents using Groq LLM - šŸ“– Source references with page numbers """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### šŸ“¤ Upload PDFs") pdf_upload = gr.File( file_count="multiple", file_types=[".pdf"], label="Upload PDF Files" ) process_btn = gr.Button("Process PDFs", variant="primary") status = gr.Textbox(label="Status", interactive=False) gr.Markdown("### šŸ‘ļø PDF Preview & Summary") preview = gr.Textbox( label="Preview", lines=15, interactive=False, placeholder="PDF preview and summary will appear here after processing..." ) with gr.Column(scale=1): gr.Markdown("### šŸ’¬ Chat with Your Documents") chatbot = gr.Chatbot( label="Chat", height=400 ) question_input = gr.Textbox( label="Ask a question", placeholder="Type your question here...", lines=2 ) with gr.Row(): submit_btn = gr.Button("Submit", variant="primary") clear_btn = gr.Button("Clear Chat") clear_all_btn = gr.Button("Clear All", variant="stop") # Event handlers process_btn.click( fn=process_pdfs, inputs=[pdf_upload], outputs=[status, preview] ) submit_btn.click( fn=generate_answer, inputs=[question_input, chatbot], outputs=[question_input, chatbot] ) question_input.submit( fn=generate_answer, inputs=[question_input, chatbot], outputs=[question_input, chatbot] ) clear_btn.click( fn=lambda: ("", []), outputs=[question_input, chatbot] ) clear_all_btn.click( fn=clear_all, outputs=[status, preview, chatbot] ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, theme=gr.themes.Soft() )