import gradio as gr from langgraph.graph import StateGraph, START, END from typing import TypedDict, List, Union, Dict, Any, Annotated from langchain_community.document_loaders import PyMuPDFLoader from langchain_community.embeddings import HuggingFaceEmbeddings from hybrid_retriever import build_hybrid_retriever from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.messages import HumanMessage, AIMessage, ToolMessage from langchain_core.documents import Document from groq import Groq import os from dotenv import load_dotenv import tempfile import time import logging from operator import add load_dotenv() # Check if GROQ_API_KEY is available if not os.getenv("GROQ_API_KEY"): print("Warning: GROQ_API_KEY not found in environment variables") def add_messages(left, right): """Helper function to add messages""" return left + right class AgentState(TypedDict): messages: Annotated[List[Union[HumanMessage, AIMessage, ToolMessage]], add_messages] query: str documents: List[str] final_answer: str needs_search: bool search_count: int metrics: Dict[str, Any] class ResponseTimeTracker: def __init__(self): self.metrics = { "retrieval_time": 0, "llm_processing_time": 0, "total_time": 0 } def update_retrieval_metrics(self, retrieval_metrics): self.metrics.update(retrieval_metrics) def get_metrics_dict(self): return self.metrics class CustomAgentExecutor: def __init__(self, retriever): self.retriever = retriever self.groq_client = Groq(api_key=os.getenv("GROQ_API_KEY")) self.response_tracker = ResponseTimeTracker() self.max_searches = 3 # Create LangGraph workflow self.workflow = self._create_workflow() def _create_workflow(self): """Create LangGraph workflow""" workflow = StateGraph(AgentState) # Add nodes workflow.add_node("search", self._search_node) workflow.add_node("generate", self._generate_node) workflow.add_node("decide", self._decide_node) # Add edges workflow.add_edge(START, "search") workflow.add_edge("search", "decide") workflow.add_conditional_edges( "decide", self._should_continue, { "search": "search", "generate": "generate", "end": END } ) workflow.add_edge("generate", END) return workflow.compile() def _search_node(self, state: AgentState) -> AgentState: """Node for document retrieval""" query = state.get("query", "") search_count = state.get("search_count", 0) # Perform retrieval retrieval_start = time.time() try: docs = self.retriever.get_relevant_documents(query) retrieval_time = time.time() - retrieval_start self.response_tracker.metrics["retrieval_time"] = retrieval_time except Exception as e: logging.error(f"Retrieval error: {e}") docs = [] retrieval_time = time.time() - retrieval_start self.response_tracker.metrics["retrieval_time"] = retrieval_time # Format documents formatted_docs = [] if docs: for i, doc in enumerate(docs, 1): ref = f"[Doc {i}]" content = doc.page_content.strip() formatted_docs.append(f"{ref} {content}") else: formatted_docs = ["No relevant information found in the knowledge base."] return { **state, "documents": formatted_docs, "search_count": search_count + 1, "needs_search": False } def _decide_node(self, state: AgentState) -> AgentState: """Node to decide next action""" documents = state.get("documents", []) search_count = state.get("search_count", 0) # Simple decision logic if not documents or documents == ["No relevant information found in the knowledge base."]: if search_count < self.max_searches: return {**state, "needs_search": True} else: return {**state, "needs_search": False, "final_answer": "I don't have the knowledge."} else: return {**state, "needs_search": False} def _generate_node(self, state: AgentState) -> AgentState: """Node for LLM response generation""" query = state.get("query", "") documents = state.get("documents", []) # Create prompt with documents doc_context = "\n\n".join(documents) system_prompt = ( "You are a helpful assistant that answers questions based only on the provided documents. " "Each passage is tagged with a source like [Doc 1], [Doc 2], etc. " "When answering, cite the relevant document(s) using these tags. " "You are prohibited from using your past knowledge. " "When the answer is not directly explained in the document(s), you MUST answer with 'I don't have the knowledge'." ) user_prompt = f"Context:\n{doc_context}\n\nQuestion: {query}\n\nAnswer:" # Generate response using Groq llm_start = time.time() try: response = self.groq_client.chat.completions.create( model="llama-3.1-8b-instant", messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] ) llm_time = time.time() - llm_start self.response_tracker.metrics["llm_processing_time"] = llm_time response_content = response.choices[0].message.content return { **state, "final_answer": response_content, "messages": state.get("messages", []) + [ HumanMessage(content=query), AIMessage(content=response_content) ] } except Exception as e: llm_time = time.time() - llm_start self.response_tracker.metrics["llm_processing_time"] = llm_time error_msg = f"LLM generation error: {str(e)}" logging.error(f"LLM error: {e}", exc_info=True) return { **state, "final_answer": error_msg, "messages": state.get("messages", []) + [ HumanMessage(content=query), AIMessage(content=error_msg) ] } def _should_continue(self, state: AgentState) -> str: """Determine next step in workflow""" if state.get("needs_search", False): return "search" elif state.get("final_answer"): return "end" else: return "generate" def get_last_response_metrics(self) -> Dict[str, Any]: """Get the metrics from the last query response""" return self.response_tracker.get_metrics_dict() def query(self, question: str) -> str: """Main query method""" initial_state = { "messages": [], "query": question, "documents": [], "final_answer": "", "needs_search": False, "search_count": 0, "metrics": {} } total_start = time.time() try: final_state = self.workflow.invoke(initial_state) total_time = time.time() - total_start self.response_tracker.metrics["total_time"] = total_time return final_state.get("final_answer", "No answer generated") except Exception as e: total_time = time.time() - total_start self.response_tracker.metrics["total_time"] = total_time logging.error(f"Query processing error: {e}") return f"Error processing query: {str(e)}" # Global variables for RAG system vector_store = None agent_executor = None def create_vector_store(pdf_path: str): """Create vector store from PDF documents""" global vector_store, agent_executor try: # Load PDF documents loader = PyMuPDFLoader(pdf_path) documents = loader.load() # Split documents into chunks text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200, length_function=len ) chunks = text_splitter.split_documents(documents) # Create embeddings embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2" ) # Extract texts for sparse retrieval texts = [doc.page_content for doc in chunks] # Build hybrid retriever using Elasticsearch Cloud hybrid_retriever = build_hybrid_retriever( texts=texts, index_name="try-rag", embedding=embeddings, es_url="https://my-elasticsearch-project-c8f88b.es.ap-southeast-1.aws.elastic.cloud:443", es_api_key=os.getenv("api_key_es"), top_k_dense=5, top_k_sparse=5 ) # Add documents to the hybrid retriever hybrid_retriever.add_documents(chunks) # Store the hybrid retriever vector_store = hybrid_retriever # Create agent executor agent_executor = CustomAgentExecutor(hybrid_retriever) return True except Exception as e: logging.error(f"Error creating vector store: {e}") return False def get_groq_response(prompt): """Get response from Groq API""" client = Groq(api_key=os.getenv("GROQ_API_KEY")) completion = client.chat.completions.create( model="llama-3.1-8b-instant", messages=[ { "role": "user", "content": prompt } ] ) return completion.choices[0].message.content def summarize_document(pdf_path: str) -> str: """Summarize the uploaded document""" try: loader = PyMuPDFLoader(pdf_path) documents = loader.load() # Create a summary of the document full_text = "\n\n".join([doc.page_content[:1000] for doc in documents[:5]]) # First 5 pages prompt = f"""Summarize the following document in exactly 3 sentences. Include page references where relevant. Document content: {full_text} Write 3 sentences that capture the main points of the document.""" return get_groq_response(prompt) except Exception as e: return f"Error summarizing document: {str(e)}" def process_pdf_and_chat_messages(pdf_file, message, history, system_message, max_tokens, temperature, top_p): """Process PDF and handle chat with RAG system""" global agent_executor if pdf_file is None: return "Please upload a PDF file first." try: # Handle file path if isinstance(pdf_file, str): pdf_path = pdf_file else: # For older versions where pdf_file is a file object with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: tmp_file.write(pdf_file.read()) pdf_path = tmp_file.name # Create vector store if not exists or if it's a new file if agent_executor is None: success = create_vector_store(pdf_path) if not success: return "Error processing PDF for RAG system." # Use RAG system to answer the question if agent_executor: response = agent_executor.query(message) else: response = "RAG system not initialized. Please try uploading the PDF again." return response except Exception as e: return f"Error processing PDF: {str(e)}" def respond_messages(message, history, system_message, max_tokens, temperature, top_p): """Handle chat without PDF using regular Groq response""" prompt = f"{system_message}\n\nUser: {message}" return get_groq_response(prompt) def auto_summarize_pdf(pdf_file): """Automatically summarize PDF when uploaded and create vector store""" global agent_executor if pdf_file is None: return [] try: # Handle file path if isinstance(pdf_file, str): pdf_path = pdf_file else: with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: tmp_file.write(pdf_file.read()) pdf_path = tmp_file.name # Create vector store for RAG success = create_vector_store(pdf_path) if not success: return [{"role": "assistant", "content": "Error processing PDF for RAG system."}] # Generate summary summary = summarize_document(pdf_path) return [{"role": "assistant", "content": f"**Document Summary:**\n{summary}\n\n*The document has been processed and is ready for questions using RAG system.*"}] except Exception as e: return [{"role": "assistant", "content": f"Error processing PDF: {str(e)}"}] # Create the Gradio interface with gr.Blocks() as demo: gr.Markdown("# Document Summarizer with RAG") gr.Markdown("Upload a PDF document to get an automatic summary and ask questions using Retrieval-Augmented Generation (RAG).") with gr.Row(): with gr.Column(scale=1): pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"]) system_message = gr.Textbox( value="You are a helpful assistant for summarizing and finding related information needed.", label="System message" ) max_tokens = gr.Slider(minimum=1, maximum=2000, value=512, step=1, label="Max new tokens") temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.3, step=0.1, label="Temperature") top_p = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Top-p (nucleus sampling)") with gr.Column(scale=2): chatbot = gr.Chatbot(type='messages') msg = gr.Textbox(label="Message") clear = gr.Button("Clear") def user_input(message, history): return "", history + [{"role": "user", "content": message}] def bot_response(history, pdf_file, system_message, max_tokens, temperature, top_p): message = history[-1]["content"] if pdf_file is not None: response = process_pdf_and_chat_messages(pdf_file, message, history[:-1], system_message, max_tokens, temperature, top_p) else: response = respond_messages(message, history[:-1], system_message, max_tokens, temperature, top_p) return history[:-1] + [{"role": "user", "content": message}, {"role": "assistant", "content": response}] msg.submit(user_input, [msg, chatbot], [msg, chatbot], queue=False).then( bot_response, [chatbot, pdf_upload, system_message, max_tokens, temperature, top_p], chatbot ) clear.click(lambda: None, None, chatbot, queue=False) # Auto-summarize and create vector store when PDF is uploaded pdf_upload.upload(auto_summarize_pdf, [pdf_upload], [chatbot]) if __name__ == "__main__": demo.launch(share=True)