Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from langgraph.graph import StateGraph, START, END | |
| from typing import TypedDict, List, Union, Dict, Any, Annotated | |
| from langchain_community.document_loaders import PyMuPDFLoader | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from hybrid_retriever import build_hybrid_retriever | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_core.messages import HumanMessage, AIMessage, ToolMessage | |
| from langchain_core.documents import Document | |
| from groq import Groq | |
| import os | |
| from dotenv import load_dotenv | |
| import tempfile | |
| import time | |
| import logging | |
| from operator import add | |
| load_dotenv() | |
| # Check if GROQ_API_KEY is available | |
| if not os.getenv("GROQ_API_KEY"): | |
| print("Warning: GROQ_API_KEY not found in environment variables") | |
| def add_messages(left, right): | |
| """Helper function to add messages""" | |
| return left + right | |
| class AgentState(TypedDict): | |
| messages: Annotated[List[Union[HumanMessage, AIMessage, ToolMessage]], add_messages] | |
| query: str | |
| documents: List[str] | |
| final_answer: str | |
| needs_search: bool | |
| search_count: int | |
| metrics: Dict[str, Any] | |
| class ResponseTimeTracker: | |
| def __init__(self): | |
| self.metrics = { | |
| "retrieval_time": 0, | |
| "llm_processing_time": 0, | |
| "total_time": 0 | |
| } | |
| def update_retrieval_metrics(self, retrieval_metrics): | |
| self.metrics.update(retrieval_metrics) | |
| def get_metrics_dict(self): | |
| return self.metrics | |
| class CustomAgentExecutor: | |
| def __init__(self, retriever): | |
| self.retriever = retriever | |
| self.groq_client = Groq(api_key=os.getenv("GROQ_API_KEY")) | |
| self.response_tracker = ResponseTimeTracker() | |
| self.max_searches = 3 | |
| # Create LangGraph workflow | |
| self.workflow = self._create_workflow() | |
| def _create_workflow(self): | |
| """Create LangGraph workflow""" | |
| workflow = StateGraph(AgentState) | |
| # Add nodes | |
| workflow.add_node("search", self._search_node) | |
| workflow.add_node("generate", self._generate_node) | |
| workflow.add_node("decide", self._decide_node) | |
| # Add edges | |
| workflow.add_edge(START, "search") | |
| workflow.add_edge("search", "decide") | |
| workflow.add_conditional_edges( | |
| "decide", | |
| self._should_continue, | |
| { | |
| "search": "search", | |
| "generate": "generate", | |
| "end": END | |
| } | |
| ) | |
| workflow.add_edge("generate", END) | |
| return workflow.compile() | |
| def _search_node(self, state: AgentState) -> AgentState: | |
| """Node for document retrieval""" | |
| query = state.get("query", "") | |
| search_count = state.get("search_count", 0) | |
| # Perform retrieval | |
| retrieval_start = time.time() | |
| try: | |
| docs = self.retriever.get_relevant_documents(query) | |
| retrieval_time = time.time() - retrieval_start | |
| self.response_tracker.metrics["retrieval_time"] = retrieval_time | |
| except Exception as e: | |
| logging.error(f"Retrieval error: {e}") | |
| docs = [] | |
| retrieval_time = time.time() - retrieval_start | |
| self.response_tracker.metrics["retrieval_time"] = retrieval_time | |
| # Format documents | |
| formatted_docs = [] | |
| if docs: | |
| for i, doc in enumerate(docs, 1): | |
| ref = f"[Doc {i}]" | |
| content = doc.page_content.strip() | |
| formatted_docs.append(f"{ref} {content}") | |
| else: | |
| formatted_docs = ["No relevant information found in the knowledge base."] | |
| return { | |
| **state, | |
| "documents": formatted_docs, | |
| "search_count": search_count + 1, | |
| "needs_search": False | |
| } | |
| def _decide_node(self, state: AgentState) -> AgentState: | |
| """Node to decide next action""" | |
| documents = state.get("documents", []) | |
| search_count = state.get("search_count", 0) | |
| # Simple decision logic | |
| if not documents or documents == ["No relevant information found in the knowledge base."]: | |
| if search_count < self.max_searches: | |
| return {**state, "needs_search": True} | |
| else: | |
| return {**state, "needs_search": False, "final_answer": "I don't have the knowledge."} | |
| else: | |
| return {**state, "needs_search": False} | |
| def _generate_node(self, state: AgentState) -> AgentState: | |
| """Node for LLM response generation""" | |
| query = state.get("query", "") | |
| documents = state.get("documents", []) | |
| # Create prompt with documents | |
| doc_context = "\n\n".join(documents) | |
| system_prompt = ( | |
| "You are a helpful assistant that answers questions based only on the provided documents. " | |
| "Each passage is tagged with a source like [Doc 1], [Doc 2], etc. " | |
| "When answering, cite the relevant document(s) using these tags. " | |
| "You are prohibited from using your past knowledge. " | |
| "When the answer is not directly explained in the document(s), you MUST answer with 'I don't have the knowledge'." | |
| ) | |
| user_prompt = f"Context:\n{doc_context}\n\nQuestion: {query}\n\nAnswer:" | |
| # Generate response using Groq | |
| llm_start = time.time() | |
| try: | |
| response = self.groq_client.chat.completions.create( | |
| model="llama-3.1-8b-instant", | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| ) | |
| llm_time = time.time() - llm_start | |
| self.response_tracker.metrics["llm_processing_time"] = llm_time | |
| response_content = response.choices[0].message.content | |
| return { | |
| **state, | |
| "final_answer": response_content, | |
| "messages": state.get("messages", []) + [ | |
| HumanMessage(content=query), | |
| AIMessage(content=response_content) | |
| ] | |
| } | |
| except Exception as e: | |
| llm_time = time.time() - llm_start | |
| self.response_tracker.metrics["llm_processing_time"] = llm_time | |
| error_msg = f"LLM generation error: {str(e)}" | |
| logging.error(f"LLM error: {e}", exc_info=True) | |
| return { | |
| **state, | |
| "final_answer": error_msg, | |
| "messages": state.get("messages", []) + [ | |
| HumanMessage(content=query), | |
| AIMessage(content=error_msg) | |
| ] | |
| } | |
| def _should_continue(self, state: AgentState) -> str: | |
| """Determine next step in workflow""" | |
| if state.get("needs_search", False): | |
| return "search" | |
| elif state.get("final_answer"): | |
| return "end" | |
| else: | |
| return "generate" | |
| def get_last_response_metrics(self) -> Dict[str, Any]: | |
| """Get the metrics from the last query response""" | |
| return self.response_tracker.get_metrics_dict() | |
| def query(self, question: str) -> str: | |
| """Main query method""" | |
| initial_state = { | |
| "messages": [], | |
| "query": question, | |
| "documents": [], | |
| "final_answer": "", | |
| "needs_search": False, | |
| "search_count": 0, | |
| "metrics": {} | |
| } | |
| total_start = time.time() | |
| try: | |
| final_state = self.workflow.invoke(initial_state) | |
| total_time = time.time() - total_start | |
| self.response_tracker.metrics["total_time"] = total_time | |
| return final_state.get("final_answer", "No answer generated") | |
| except Exception as e: | |
| total_time = time.time() - total_start | |
| self.response_tracker.metrics["total_time"] = total_time | |
| logging.error(f"Query processing error: {e}") | |
| return f"Error processing query: {str(e)}" | |
| # Global variables for RAG system | |
| vector_store = None | |
| agent_executor = None | |
| def create_vector_store(pdf_path: str): | |
| """Create vector store from PDF documents""" | |
| global vector_store, agent_executor | |
| try: | |
| # Load PDF documents | |
| loader = PyMuPDFLoader(pdf_path) | |
| documents = loader.load() | |
| # Split documents into chunks | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=200, | |
| length_function=len | |
| ) | |
| chunks = text_splitter.split_documents(documents) | |
| # Create embeddings | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| # Extract texts for sparse retrieval | |
| texts = [doc.page_content for doc in chunks] | |
| # Build hybrid retriever using Elasticsearch Cloud | |
| hybrid_retriever = build_hybrid_retriever( | |
| texts=texts, | |
| index_name="try-rag", | |
| embedding=embeddings, | |
| es_url="https://my-elasticsearch-project-c8f88b.es.ap-southeast-1.aws.elastic.cloud:443", | |
| es_api_key=os.getenv("api_key_es"), | |
| top_k_dense=5, | |
| top_k_sparse=5 | |
| ) | |
| # Add documents to the hybrid retriever | |
| hybrid_retriever.add_documents(chunks) | |
| # Store the hybrid retriever | |
| vector_store = hybrid_retriever | |
| # Create agent executor | |
| agent_executor = CustomAgentExecutor(hybrid_retriever) | |
| return True | |
| except Exception as e: | |
| logging.error(f"Error creating vector store: {e}") | |
| return False | |
| def get_groq_response(prompt): | |
| """Get response from Groq API""" | |
| client = Groq(api_key=os.getenv("GROQ_API_KEY")) | |
| completion = client.chat.completions.create( | |
| model="llama-3.1-8b-instant", | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": prompt | |
| } | |
| ] | |
| ) | |
| return completion.choices[0].message.content | |
| def summarize_document(pdf_path: str) -> str: | |
| """Summarize the uploaded document""" | |
| try: | |
| loader = PyMuPDFLoader(pdf_path) | |
| documents = loader.load() | |
| # Create a summary of the document | |
| full_text = "\n\n".join([doc.page_content[:1000] for doc in documents[:5]]) # First 5 pages | |
| prompt = f"""Summarize the following document in exactly 3 sentences. Include page references where relevant. | |
| Document content: | |
| {full_text} | |
| Write 3 sentences that capture the main points of the document.""" | |
| return get_groq_response(prompt) | |
| except Exception as e: | |
| return f"Error summarizing document: {str(e)}" | |
| def process_pdf_and_chat_messages(pdf_file, message, history, system_message, max_tokens, temperature, top_p): | |
| """Process PDF and handle chat with RAG system""" | |
| global agent_executor | |
| if pdf_file is None: | |
| return "Please upload a PDF file first." | |
| try: | |
| # Handle file path | |
| if isinstance(pdf_file, str): | |
| pdf_path = pdf_file | |
| else: | |
| # For older versions where pdf_file is a file object | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: | |
| tmp_file.write(pdf_file.read()) | |
| pdf_path = tmp_file.name | |
| # Create vector store if not exists or if it's a new file | |
| if agent_executor is None: | |
| success = create_vector_store(pdf_path) | |
| if not success: | |
| return "Error processing PDF for RAG system." | |
| # Use RAG system to answer the question | |
| if agent_executor: | |
| response = agent_executor.query(message) | |
| else: | |
| response = "RAG system not initialized. Please try uploading the PDF again." | |
| return response | |
| except Exception as e: | |
| return f"Error processing PDF: {str(e)}" | |
| def respond_messages(message, history, system_message, max_tokens, temperature, top_p): | |
| """Handle chat without PDF using regular Groq response""" | |
| prompt = f"{system_message}\n\nUser: {message}" | |
| return get_groq_response(prompt) | |
| def auto_summarize_pdf(pdf_file): | |
| """Automatically summarize PDF when uploaded and create vector store""" | |
| global agent_executor | |
| if pdf_file is None: | |
| return [] | |
| try: | |
| # Handle file path | |
| if isinstance(pdf_file, str): | |
| pdf_path = pdf_file | |
| else: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file: | |
| tmp_file.write(pdf_file.read()) | |
| pdf_path = tmp_file.name | |
| # Create vector store for RAG | |
| success = create_vector_store(pdf_path) | |
| if not success: | |
| return [{"role": "assistant", "content": "Error processing PDF for RAG system."}] | |
| # Generate summary | |
| summary = summarize_document(pdf_path) | |
| return [{"role": "assistant", "content": f"**Document Summary:**\n{summary}\n\n*The document has been processed and is ready for questions using RAG system.*"}] | |
| except Exception as e: | |
| return [{"role": "assistant", "content": f"Error processing PDF: {str(e)}"}] | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Document Summarizer with RAG") | |
| gr.Markdown("Upload a PDF document to get an automatic summary and ask questions using Retrieval-Augmented Generation (RAG).") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
| system_message = gr.Textbox( | |
| value="You are a helpful assistant for summarizing and finding related information needed.", | |
| label="System message" | |
| ) | |
| max_tokens = gr.Slider(minimum=1, maximum=2000, value=512, step=1, label="Max new tokens") | |
| temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.3, step=0.1, label="Temperature") | |
| top_p = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Top-p (nucleus sampling)") | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(type='messages') | |
| msg = gr.Textbox(label="Message") | |
| clear = gr.Button("Clear") | |
| def user_input(message, history): | |
| return "", history + [{"role": "user", "content": message}] | |
| def bot_response(history, pdf_file, system_message, max_tokens, temperature, top_p): | |
| message = history[-1]["content"] | |
| if pdf_file is not None: | |
| response = process_pdf_and_chat_messages(pdf_file, message, history[:-1], system_message, max_tokens, temperature, top_p) | |
| else: | |
| response = respond_messages(message, history[:-1], system_message, max_tokens, temperature, top_p) | |
| return history[:-1] + [{"role": "user", "content": message}, {"role": "assistant", "content": response}] | |
| msg.submit(user_input, [msg, chatbot], [msg, chatbot], queue=False).then( | |
| bot_response, [chatbot, pdf_upload, system_message, max_tokens, temperature, top_p], chatbot | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| # Auto-summarize and create vector store when PDF is uploaded | |
| pdf_upload.upload(auto_summarize_pdf, [pdf_upload], [chatbot]) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |