import gradio as gr import os from groq import Groq import PyPDF2 from sentence_transformers import SentenceTransformer import numpy as np from sklearn.metrics.pairwise import cosine_similarity import json from datetime import datetime import docx # Initialize Groq client client = None try: api_key = os.environ.get("GROQ_API_KEY") if api_key: import httpx client = Groq(api_key=api_key, http_client=httpx.Client()) print("Groq client initialized successfully") except Exception as e: print(f"Error initializing Groq client: {e}") # Initialize sentence transformer model print("Loading sentence transformer model...") embedder = SentenceTransformer('all-MiniLM-L6-v2') print("Model loaded successfully!") # Global storage document_store = { 'chunks': [], 'embeddings': [], 'metadata': [], 'conversation_history': [] } def extract_text_from_pdf(pdf_file): """Extract text from PDF file""" try: if isinstance(pdf_file, str): pdf_reader = PyPDF2.PdfReader(pdf_file) filename = os.path.basename(pdf_file) else: pdf_reader = PyPDF2.PdfReader(pdf_file.name) filename = os.path.basename(pdf_file.name) text_data = [] for page_num, page in enumerate(pdf_reader.pages): text = page.extract_text() if text and text.strip(): text_data.append({ 'text': text, 'page': page_num + 1, 'filename': filename }) return text_data except Exception as e: print(f"Error reading PDF: {e}") return [] def extract_text_from_docx(docx_file): """Extract text from DOCX file""" try: if isinstance(docx_file, str): doc = docx.Document(docx_file) filename = os.path.basename(docx_file) else: doc = docx.Document(docx_file.name) filename = os.path.basename(docx_file.name) text = '\n'.join([p.text for p in doc.paragraphs if p.text.strip()]) return [{'text': text, 'page': 1, 'filename': filename}] except Exception as e: print(f"Error reading DOCX: {e}") return [] def chunk_text(text_data, chunk_size=500, overlap=50): """Split text into chunks""" chunks = [] metadata = [] for data in text_data: words = data['text'].split() for i in range(0, len(words), chunk_size - overlap): chunk = ' '.join(words[i:i + chunk_size]) if len(chunk.strip()) > 50: chunks.append(chunk) metadata.append({ 'page': data['page'], 'filename': data['filename'], 'chunk_id': len(chunks) }) return chunks, metadata def process_files(files): """Process uploaded files""" global document_store if not files: return "[ERROR] Please upload at least one file." try: document_store = {'chunks': [], 'embeddings': [], 'metadata': [], 'conversation_history': []} all_text_data = [] file_summaries = [] for file in files: file_path = file.name if hasattr(file, 'name') else file file_ext = os.path.splitext(file_path)[1].lower() print(f"Processing file: {file_path}") if file_ext == '.pdf': text_data = extract_text_from_pdf(file) elif file_ext == '.docx': text_data = extract_text_from_docx(file) else: continue all_text_data.extend(text_data) total_chars = sum(len(d['text']) for d in text_data) filename = os.path.basename(file_path) file_summaries.append(f"- **{filename}**: {len(text_data)} pages, {total_chars} characters") if not all_text_data: return "[ERROR] No valid text extracted." chunks, metadata = chunk_text(all_text_data) if not chunks: return "[ERROR] No text chunks created." embeddings = embedder.encode(chunks, show_progress_bar=False) document_store['chunks'] = chunks document_store['embeddings'] = embeddings document_store['metadata'] = metadata summary = f"**Successfully Processed {len(files)} file(s)**\n\n" summary += "\n".join(file_summaries) summary += f"\n\n**Created {len(chunks)} text chunks for retrieval.**" return summary except Exception as e: print(f"Error processing files: {e}") return f"[ERROR] {str(e)}" def retrieve_relevant_chunks(query, top_k=3): """Retrieve relevant chunks""" if not document_store['chunks']: return [], [] try: query_embedding = embedder.encode([query], show_progress_bar=False) similarities = cosine_similarity(query_embedding, document_store['embeddings'])[0] top_indices = np.argsort(similarities)[-top_k:][::-1] relevant_chunks = [document_store['chunks'][i] for i in top_indices] relevant_metadata = [document_store['metadata'][i] for i in top_indices] return relevant_chunks, relevant_metadata except Exception as e: print(f"Error retrieving chunks: {e}") return [], [] def chat(message, history): """Chat function - returns response string for ChatInterface""" global client # Reinitialize client if needed if client is None: try: api_key = os.environ.get("GROQ_API_KEY") if api_key: import httpx client = Groq(api_key=api_key, http_client=httpx.Client()) except: pass if client is None: return "[ERROR] Groq API not initialized. Set GROQ_API_KEY in Settings." if not document_store['chunks']: return "[WARNING] Please upload and process documents first." try: # Retrieve context relevant_chunks, metadata = retrieve_relevant_chunks(message, top_k=3) if not relevant_chunks: return "[ERROR] No relevant information found." # Build context context = "\n\n".join([ f"[Source: {meta['filename']}, Page {meta['page']}]\n{chunk}" for chunk, meta in zip(relevant_chunks, metadata) ]) # Build messages for Groq messages = [ {"role": "system", "content": "You are a helpful assistant that answers questions based on provided context. Be concise and accurate."} ] # Add history - convert from tuples to message format if history: for user_msg, assistant_msg in history[-3:]: # Last 3 exchanges messages.append({"role": "user", "content": user_msg}) messages.append({"role": "assistant", "content": assistant_msg}) # Add current query messages.append({ "role": "user", "content": f"Context:\n{context}\n\nQuestion: {message}" }) # Call Groq response = client.chat.completions.create( messages=messages, model="llama-3.1-8b-instant", temperature=0.3, max_tokens=1024, ) answer = response.choices[0].message.content # Add sources sources = "\n\n**Sources:**\n" + "\n".join([ f"- {m['filename']} (Page {m['page']})" for m in metadata ]) full_answer = answer + sources # Log document_store['conversation_history'].append({ 'timestamp': datetime.now().isoformat(), 'query': message, 'answer': answer }) return full_answer except Exception as e: print(f"Error: {e}") return f"[ERROR] {str(e)}" def download_history(): """Download chat history""" if not document_store['conversation_history']: return None try: with open("chat_history.json", 'w') as f: json.dump(document_store['conversation_history'], f, indent=2) return "chat_history.json" except: return None # Build interface with gr.Blocks(title="Enhanced RAG Chatbot") as demo: gr.Markdown(""" # Enhanced RAG-Based Chatbot Upload PDF/DOCX files and ask questions! **Features:** Multiple files, Semantic search, Source references, Chat history """) with gr.Row(): with gr.Column(scale=1): file_upload = gr.File( label="Upload Documents (PDF/DOCX)", file_count="multiple", file_types=[".pdf", ".docx"] ) process_btn = gr.Button("Process Documents", variant="primary") process_output = gr.Markdown() gr.Markdown("### History") download_btn = gr.Button("Download (JSON)") download_file = gr.File(label="Download") with gr.Column(scale=2): # Minimal ChatInterface compatible with Gradio 4.44.1 chat_interface = gr.ChatInterface( fn=chat ) # Process files process_btn.click(process_files, [file_upload], [process_output]) # Download download_btn.click(download_history, None, [download_file]) gr.Markdown(""" --- ### How It Works: 1. Upload PDF/DOCX files and click "Process Documents" 2. Ask questions - RAG finds relevant chunks and generates answers 3. Sources are cited with page numbers """) if __name__ == "__main__": demo.launch()