Hasnan Ramadhan
change slider
89afa92
import gradio as gr
from langgraph.graph import StateGraph, START, END
from typing import TypedDict, List, Union, Dict, Any, Annotated
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from hybrid_retriever import build_hybrid_retriever
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
from langchain_core.documents import Document
from groq import Groq
import os
from dotenv import load_dotenv
import tempfile
import time
import logging
from operator import add
load_dotenv()
# Check if GROQ_API_KEY is available
if not os.getenv("GROQ_API_KEY"):
print("Warning: GROQ_API_KEY not found in environment variables")
def add_messages(left, right):
"""Helper function to add messages"""
return left + right
class AgentState(TypedDict):
messages: Annotated[List[Union[HumanMessage, AIMessage, ToolMessage]], add_messages]
query: str
documents: List[str]
final_answer: str
needs_search: bool
search_count: int
metrics: Dict[str, Any]
class ResponseTimeTracker:
def __init__(self):
self.metrics = {
"retrieval_time": 0,
"llm_processing_time": 0,
"total_time": 0
}
def update_retrieval_metrics(self, retrieval_metrics):
self.metrics.update(retrieval_metrics)
def get_metrics_dict(self):
return self.metrics
class CustomAgentExecutor:
def __init__(self, retriever):
self.retriever = retriever
self.groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
self.response_tracker = ResponseTimeTracker()
self.max_searches = 3
# Create LangGraph workflow
self.workflow = self._create_workflow()
def _create_workflow(self):
"""Create LangGraph workflow"""
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("search", self._search_node)
workflow.add_node("generate", self._generate_node)
workflow.add_node("decide", self._decide_node)
# Add edges
workflow.add_edge(START, "search")
workflow.add_edge("search", "decide")
workflow.add_conditional_edges(
"decide",
self._should_continue,
{
"search": "search",
"generate": "generate",
"end": END
}
)
workflow.add_edge("generate", END)
return workflow.compile()
def _search_node(self, state: AgentState) -> AgentState:
"""Node for document retrieval"""
query = state.get("query", "")
search_count = state.get("search_count", 0)
# Perform retrieval
retrieval_start = time.time()
try:
docs = self.retriever.get_relevant_documents(query)
retrieval_time = time.time() - retrieval_start
self.response_tracker.metrics["retrieval_time"] = retrieval_time
except Exception as e:
logging.error(f"Retrieval error: {e}")
docs = []
retrieval_time = time.time() - retrieval_start
self.response_tracker.metrics["retrieval_time"] = retrieval_time
# Format documents
formatted_docs = []
if docs:
for i, doc in enumerate(docs, 1):
ref = f"[Doc {i}]"
content = doc.page_content.strip()
formatted_docs.append(f"{ref} {content}")
else:
formatted_docs = ["No relevant information found in the knowledge base."]
return {
**state,
"documents": formatted_docs,
"search_count": search_count + 1,
"needs_search": False
}
def _decide_node(self, state: AgentState) -> AgentState:
"""Node to decide next action"""
documents = state.get("documents", [])
search_count = state.get("search_count", 0)
# Simple decision logic
if not documents or documents == ["No relevant information found in the knowledge base."]:
if search_count < self.max_searches:
return {**state, "needs_search": True}
else:
return {**state, "needs_search": False, "final_answer": "I don't have the knowledge."}
else:
return {**state, "needs_search": False}
def _generate_node(self, state: AgentState) -> AgentState:
"""Node for LLM response generation"""
query = state.get("query", "")
documents = state.get("documents", [])
# Create prompt with documents
doc_context = "\n\n".join(documents)
system_prompt = (
"You are a helpful assistant that answers questions based only on the provided documents. "
"Each passage is tagged with a source like [Doc 1], [Doc 2], etc. "
"When answering, cite the relevant document(s) using these tags. "
"You are prohibited from using your past knowledge. "
"When the answer is not directly explained in the document(s), you MUST answer with 'I don't have the knowledge'."
)
user_prompt = f"Context:\n{doc_context}\n\nQuestion: {query}\n\nAnswer:"
# Generate response using Groq
llm_start = time.time()
try:
response = self.groq_client.chat.completions.create(
model="llama-3.1-8b-instant",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
)
llm_time = time.time() - llm_start
self.response_tracker.metrics["llm_processing_time"] = llm_time
response_content = response.choices[0].message.content
return {
**state,
"final_answer": response_content,
"messages": state.get("messages", []) + [
HumanMessage(content=query),
AIMessage(content=response_content)
]
}
except Exception as e:
llm_time = time.time() - llm_start
self.response_tracker.metrics["llm_processing_time"] = llm_time
error_msg = f"LLM generation error: {str(e)}"
logging.error(f"LLM error: {e}", exc_info=True)
return {
**state,
"final_answer": error_msg,
"messages": state.get("messages", []) + [
HumanMessage(content=query),
AIMessage(content=error_msg)
]
}
def _should_continue(self, state: AgentState) -> str:
"""Determine next step in workflow"""
if state.get("needs_search", False):
return "search"
elif state.get("final_answer"):
return "end"
else:
return "generate"
def get_last_response_metrics(self) -> Dict[str, Any]:
"""Get the metrics from the last query response"""
return self.response_tracker.get_metrics_dict()
def query(self, question: str) -> str:
"""Main query method"""
initial_state = {
"messages": [],
"query": question,
"documents": [],
"final_answer": "",
"needs_search": False,
"search_count": 0,
"metrics": {}
}
total_start = time.time()
try:
final_state = self.workflow.invoke(initial_state)
total_time = time.time() - total_start
self.response_tracker.metrics["total_time"] = total_time
return final_state.get("final_answer", "No answer generated")
except Exception as e:
total_time = time.time() - total_start
self.response_tracker.metrics["total_time"] = total_time
logging.error(f"Query processing error: {e}")
return f"Error processing query: {str(e)}"
# Global variables for RAG system
vector_store = None
agent_executor = None
def create_vector_store(pdf_path: str):
"""Create vector store from PDF documents"""
global vector_store, agent_executor
try:
# Load PDF documents
loader = PyMuPDFLoader(pdf_path)
documents = loader.load()
# Split documents into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len
)
chunks = text_splitter.split_documents(documents)
# Create embeddings
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
# Extract texts for sparse retrieval
texts = [doc.page_content for doc in chunks]
# Build hybrid retriever using Elasticsearch Cloud
hybrid_retriever = build_hybrid_retriever(
texts=texts,
index_name="try-rag",
embedding=embeddings,
es_url="https://my-elasticsearch-project-c8f88b.es.ap-southeast-1.aws.elastic.cloud:443",
es_api_key=os.getenv("api_key_es"),
top_k_dense=5,
top_k_sparse=5
)
# Add documents to the hybrid retriever
hybrid_retriever.add_documents(chunks)
# Store the hybrid retriever
vector_store = hybrid_retriever
# Create agent executor
agent_executor = CustomAgentExecutor(hybrid_retriever)
return True
except Exception as e:
logging.error(f"Error creating vector store: {e}")
return False
def get_groq_response(prompt):
"""Get response from Groq API"""
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
completion = client.chat.completions.create(
model="llama-3.1-8b-instant",
messages=[
{
"role": "user",
"content": prompt
}
]
)
return completion.choices[0].message.content
def summarize_document(pdf_path: str) -> str:
"""Summarize the uploaded document"""
try:
loader = PyMuPDFLoader(pdf_path)
documents = loader.load()
# Create a summary of the document
full_text = "\n\n".join([doc.page_content[:1000] for doc in documents[:5]]) # First 5 pages
prompt = f"""Summarize the following document in exactly 3 sentences. Include page references where relevant.
Document content:
{full_text}
Write 3 sentences that capture the main points of the document."""
return get_groq_response(prompt)
except Exception as e:
return f"Error summarizing document: {str(e)}"
def process_pdf_and_chat_messages(pdf_file, message, history, system_message, max_tokens, temperature, top_p):
"""Process PDF and handle chat with RAG system"""
global agent_executor
if pdf_file is None:
return "Please upload a PDF file first."
try:
# Handle file path
if isinstance(pdf_file, str):
pdf_path = pdf_file
else:
# For older versions where pdf_file is a file object
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
tmp_file.write(pdf_file.read())
pdf_path = tmp_file.name
# Create vector store if not exists or if it's a new file
if agent_executor is None:
success = create_vector_store(pdf_path)
if not success:
return "Error processing PDF for RAG system."
# Use RAG system to answer the question
if agent_executor:
response = agent_executor.query(message)
else:
response = "RAG system not initialized. Please try uploading the PDF again."
return response
except Exception as e:
return f"Error processing PDF: {str(e)}"
def respond_messages(message, history, system_message, max_tokens, temperature, top_p):
"""Handle chat without PDF using regular Groq response"""
prompt = f"{system_message}\n\nUser: {message}"
return get_groq_response(prompt)
def auto_summarize_pdf(pdf_file):
"""Automatically summarize PDF when uploaded and create vector store"""
global agent_executor
if pdf_file is None:
return []
try:
# Handle file path
if isinstance(pdf_file, str):
pdf_path = pdf_file
else:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
tmp_file.write(pdf_file.read())
pdf_path = tmp_file.name
# Create vector store for RAG
success = create_vector_store(pdf_path)
if not success:
return [{"role": "assistant", "content": "Error processing PDF for RAG system."}]
# Generate summary
summary = summarize_document(pdf_path)
return [{"role": "assistant", "content": f"**Document Summary:**\n{summary}\n\n*The document has been processed and is ready for questions using RAG system.*"}]
except Exception as e:
return [{"role": "assistant", "content": f"Error processing PDF: {str(e)}"}]
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Document Summarizer with RAG")
gr.Markdown("Upload a PDF document to get an automatic summary and ask questions using Retrieval-Augmented Generation (RAG).")
with gr.Row():
with gr.Column(scale=1):
pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
system_message = gr.Textbox(
value="You are a helpful assistant for summarizing and finding related information needed.",
label="System message"
)
max_tokens = gr.Slider(minimum=1, maximum=2000, value=512, step=1, label="Max new tokens")
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.3, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Top-p (nucleus sampling)")
with gr.Column(scale=2):
chatbot = gr.Chatbot(type='messages')
msg = gr.Textbox(label="Message")
clear = gr.Button("Clear")
def user_input(message, history):
return "", history + [{"role": "user", "content": message}]
def bot_response(history, pdf_file, system_message, max_tokens, temperature, top_p):
message = history[-1]["content"]
if pdf_file is not None:
response = process_pdf_and_chat_messages(pdf_file, message, history[:-1], system_message, max_tokens, temperature, top_p)
else:
response = respond_messages(message, history[:-1], system_message, max_tokens, temperature, top_p)
return history[:-1] + [{"role": "user", "content": message}, {"role": "assistant", "content": response}]
msg.submit(user_input, [msg, chatbot], [msg, chatbot], queue=False).then(
bot_response, [chatbot, pdf_upload, system_message, max_tokens, temperature, top_p], chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
# Auto-summarize and create vector store when PDF is uploaded
pdf_upload.upload(auto_summarize_pdf, [pdf_upload], [chatbot])
if __name__ == "__main__":
demo.launch(share=True)