Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import PyPDF2 | |
| import io | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| from typing import List, Tuple, Dict | |
| import os | |
| from groq import Groq | |
| import json | |
| # Initialize Groq client (will use API key from environment variable) | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| if not groq_api_key: | |
| print("Warning: GROQ_API_KEY not found in environment variables. Please set it to use the chatbot.") | |
| client = None | |
| else: | |
| client = Groq(api_key=groq_api_key) | |
| # Initialize sentence transformer model | |
| print("Loading sentence transformer model...") | |
| embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| print("Model loaded!") | |
| # Global variables to store documents and embeddings | |
| documents_store = [] | |
| embeddings_store = [] | |
| metadata_store = [] # Store filename and page number for each chunk | |
| def extract_text_from_pdf(pdf_file) -> List[Tuple[str, str, int]]: | |
| """ | |
| Extract text from PDF file. | |
| Returns: List of tuples (text, filename, page_number) | |
| """ | |
| text_chunks = [] | |
| try: | |
| pdf_reader = PyPDF2.PdfReader(pdf_file) | |
| filename = pdf_file.name if hasattr(pdf_file, 'name') else "uploaded_file.pdf" | |
| for page_num, page in enumerate(pdf_reader.pages, start=1): | |
| text = page.extract_text() | |
| if text.strip(): # Only add non-empty pages | |
| text_chunks.append((text, filename, page_num)) | |
| return text_chunks | |
| except Exception as e: | |
| print(f"Error extracting text from PDF: {e}") | |
| return [] | |
| def split_into_chunks(text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]: | |
| """ | |
| Split text into overlapping chunks. | |
| """ | |
| words = text.split() | |
| chunks = [] | |
| for i in range(0, len(words), chunk_size - overlap): | |
| chunk = ' '.join(words[i:i + chunk_size]) | |
| if chunk.strip(): | |
| chunks.append(chunk) | |
| return chunks | |
| def process_pdfs(pdf_files) -> Tuple[str, str]: | |
| """ | |
| Process uploaded PDF files and create embeddings. | |
| Returns: (status_message, preview_text) | |
| """ | |
| global documents_store, embeddings_store, metadata_store | |
| if pdf_files is None or len(pdf_files) == 0: | |
| return "No files uploaded.", "" | |
| documents_store = [] | |
| embeddings_store = [] | |
| metadata_store = [] | |
| all_text_chunks = [] | |
| preview_text = "=== PDF PREVIEW ===\n\n" | |
| for pdf_file in pdf_files: | |
| extracted_chunks = extract_text_from_pdf(pdf_file) | |
| if not extracted_chunks: | |
| continue | |
| filename = extracted_chunks[0][1] | |
| preview_text += f"π File: {filename}\n" | |
| preview_text += f" Pages: {len(extracted_chunks)}\n" | |
| # Get first page preview | |
| if extracted_chunks: | |
| first_page_text = extracted_chunks[0][0][:500] # First 500 chars | |
| preview_text += f" Preview (Page 1): {first_page_text}...\n\n" | |
| # Split each page into smaller chunks | |
| for page_text, file_name, page_num in extracted_chunks: | |
| chunks = split_into_chunks(page_text) | |
| for chunk in chunks: | |
| all_text_chunks.append((chunk, file_name, page_num)) | |
| if not all_text_chunks: | |
| return "No text could be extracted from the PDFs.", preview_text | |
| # Create embeddings | |
| print(f"Creating embeddings for {len(all_text_chunks)} chunks...") | |
| texts = [chunk[0] for chunk in all_text_chunks] | |
| embeddings = embedding_model.encode(texts, show_progress_bar=True) | |
| documents_store = texts | |
| embeddings_store = embeddings | |
| metadata_store = [(chunk[1], chunk[2]) for chunk in all_text_chunks] | |
| # Generate summary | |
| total_chunks = len(all_text_chunks) | |
| unique_files = len(set(chunk[1] for chunk in all_text_chunks)) | |
| preview_text += f"\n=== SUMMARY ===\n" | |
| preview_text += f"Total documents processed: {unique_files}\n" | |
| preview_text += f"Total text chunks created: {total_chunks}\n" | |
| preview_text += f"Ready for questions!\n" | |
| return f"β Successfully processed {unique_files} PDF file(s) with {total_chunks} chunks!", preview_text | |
| def retrieve_relevant_chunks(query: str, top_k: int = 3) -> List[Tuple[str, str, int, float]]: | |
| """ | |
| Retrieve top-k most relevant chunks using cosine similarity. | |
| Returns: List of (chunk_text, filename, page_num, similarity_score) | |
| """ | |
| if len(documents_store) == 0: | |
| return [] | |
| # Encode query | |
| query_embedding = embedding_model.encode([query])[0] | |
| # Calculate cosine similarity | |
| similarities = np.dot(embeddings_store, query_embedding) / ( | |
| np.linalg.norm(embeddings_store, axis=1) * np.linalg.norm(query_embedding) | |
| ) | |
| # Get top-k indices | |
| top_indices = np.argsort(similarities)[::-1][:top_k] | |
| # Return top chunks with metadata | |
| results = [] | |
| for idx in top_indices: | |
| results.append(( | |
| documents_store[idx], | |
| metadata_store[idx][0], | |
| metadata_store[idx][1], | |
| float(similarities[idx]) | |
| )) | |
| return results | |
| def convert_history_to_gradio_format(history): | |
| """Convert history from tuple format to Gradio 6 format.""" | |
| if not history: | |
| return [] | |
| gradio_history = [] | |
| for item in history: | |
| if isinstance(item, tuple) and len(item) == 2: | |
| # Convert tuple (user_msg, assistant_msg) to dict format | |
| gradio_history.append({"role": "user", "content": item[0]}) | |
| gradio_history.append({"role": "assistant", "content": item[1]}) | |
| elif isinstance(item, dict): | |
| # Already in correct format | |
| gradio_history.append(item) | |
| return gradio_history | |
| def convert_history_from_gradio_format(history): | |
| """Convert history from Gradio 6 format to tuple format for internal use.""" | |
| if not history: | |
| return [] | |
| tuple_history = [] | |
| i = 0 | |
| while i < len(history): | |
| if isinstance(history[i], dict): | |
| if history[i].get("role") == "user" and i + 1 < len(history): | |
| if history[i + 1].get("role") == "assistant": | |
| tuple_history.append((history[i]["content"], history[i + 1]["content"])) | |
| i += 2 | |
| continue | |
| elif isinstance(history[i], tuple): | |
| tuple_history.append(history[i]) | |
| i += 1 | |
| return tuple_history | |
| def generate_answer(question: str, history: List) -> Tuple[str, List]: | |
| """ | |
| Generate answer using Groq LLM with RAG context. | |
| """ | |
| # Convert history from Gradio 6 format to internal format | |
| internal_history = convert_history_from_gradio_format(history) if history else [] | |
| if client is None: | |
| error_msg = "Error: GROQ_API_KEY not configured. Please set it as an environment variable or in Hugging Face Space secrets." | |
| internal_history.append((question, error_msg)) | |
| return "", convert_history_to_gradio_format(internal_history) | |
| if len(documents_store) == 0: | |
| error_msg = "Please upload PDF files first!" | |
| internal_history.append((question, error_msg)) | |
| return "", convert_history_to_gradio_format(internal_history) | |
| if not question.strip(): | |
| return "", convert_history_to_gradio_format(internal_history) | |
| # Retrieve relevant chunks | |
| relevant_chunks = retrieve_relevant_chunks(question, top_k=3) | |
| if not relevant_chunks: | |
| error_msg = "No relevant context found in the documents." | |
| internal_history.append((question, error_msg)) | |
| return "", convert_history_to_gradio_format(internal_history) | |
| # Build context with source references | |
| context_parts = [] | |
| sources = [] | |
| for i, (chunk, filename, page_num, score) in enumerate(relevant_chunks, 1): | |
| context_parts.append(f"[Source {i} - {filename}, Page {page_num}]\n{chunk}") | |
| sources.append(f"Source {i}: {filename}, Page {page_num} (similarity: {score:.3f})") | |
| context = "\n\n".join(context_parts) | |
| # Create prompt for Groq | |
| prompt = f"""You are a helpful assistant that answers questions based on the provided context from PDF documents. | |
| Context from documents: | |
| {context} | |
| Question: {question} | |
| Please provide a clear and accurate answer based on the context above. If the context doesn't contain enough information to answer the question, say so. At the end of your answer, mention the source references. | |
| Answer:""" | |
| try: | |
| # Call Groq API | |
| chat_completion = client.chat.completions.create( | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": prompt | |
| } | |
| ], | |
| model="llama-3.1-8b-instant", | |
| temperature=0.7, | |
| max_tokens=1024 | |
| ) | |
| answer = chat_completion.choices[0].message.content | |
| # Append sources to answer | |
| answer += "\n\nπ Sources:\n" + "\n".join(sources) | |
| # Update history | |
| internal_history.append((question, answer)) | |
| return "", convert_history_to_gradio_format(internal_history) | |
| except Exception as e: | |
| error_msg = f"Error generating answer: {str(e)}" | |
| internal_history.append((question, error_msg)) | |
| return "", convert_history_to_gradio_format(internal_history) | |
| def clear_all(): | |
| """Clear all stored data.""" | |
| global documents_store, embeddings_store, metadata_store | |
| documents_store = [] | |
| embeddings_store = [] | |
| metadata_store = [] | |
| return "", "", [] | |
| # Create Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # π RAG-Based Chatbot with PDF Support | |
| Upload multiple PDF files, and ask questions based on their content! | |
| **Features:** | |
| - π Upload multiple PDF files | |
| - ποΈ Preview PDF content before asking questions | |
| - π Semantic search using sentence transformers | |
| - π¬ Chat with your documents using Groq LLM | |
| - π Source references with page numbers | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Upload PDFs") | |
| pdf_upload = gr.File( | |
| file_count="multiple", | |
| file_types=[".pdf"], | |
| label="Upload PDF Files" | |
| ) | |
| process_btn = gr.Button("Process PDFs", variant="primary") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| gr.Markdown("### ποΈ PDF Preview & Summary") | |
| preview = gr.Textbox( | |
| label="Preview", | |
| lines=15, | |
| interactive=False, | |
| placeholder="PDF preview and summary will appear here after processing..." | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π¬ Chat with Your Documents") | |
| chatbot = gr.Chatbot( | |
| label="Chat", | |
| height=400 | |
| ) | |
| question_input = gr.Textbox( | |
| label="Ask a question", | |
| placeholder="Type your question here...", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| clear_btn = gr.Button("Clear Chat") | |
| clear_all_btn = gr.Button("Clear All", variant="stop") | |
| # Event handlers | |
| process_btn.click( | |
| fn=process_pdfs, | |
| inputs=[pdf_upload], | |
| outputs=[status, preview] | |
| ) | |
| submit_btn.click( | |
| fn=generate_answer, | |
| inputs=[question_input, chatbot], | |
| outputs=[question_input, chatbot] | |
| ) | |
| question_input.submit( | |
| fn=generate_answer, | |
| inputs=[question_input, chatbot], | |
| outputs=[question_input, chatbot] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", []), | |
| outputs=[question_input, chatbot] | |
| ) | |
| clear_all_btn.click( | |
| fn=clear_all, | |
| outputs=[status, preview, chatbot] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| theme=gr.themes.Soft() | |
| ) | |