Hasnan Ramadhan commited on
Commit
5e3f3a0
·
1 Parent(s): 9258d67

refactor to hybrid retrieval

Browse files
Files changed (2) hide show
  1. app.py +324 -398
  2. hybrid_retriever.py +139 -0
app.py CHANGED
@@ -1,458 +1,393 @@
1
  import gradio as gr
2
- from langgraph.graph import StateGraph
3
- from typing import TypedDict
4
  from langchain_community.document_loaders import PyMuPDFLoader
5
- import requests
 
 
 
 
6
  from groq import Groq
7
  import os
8
  from dotenv import load_dotenv
9
  import tempfile
10
- from googlesearch import search
11
- from bs4 import BeautifulSoup
12
- from urllib.parse import urljoin, urlparse
13
- import re
14
-
15
-
16
 
17
  load_dotenv()
 
18
  # Check if GROQ_API_KEY is available
19
  if not os.getenv("GROQ_API_KEY"):
20
  print("Warning: GROQ_API_KEY not found in environment variables")
21
- class DocumentState(TypedDict):
22
- documents: list[dict]
23
- summaries: list[str]
24
- search_results: list[dict]
25
- search_query: str
 
 
 
 
 
26
  needs_search: bool
 
 
27
 
28
- def get_llm_response(prompt):
29
- url = "http://192.168.181.215:8081/llms"
30
- headers = {"Content-Type": "application/json"}
31
- payload = {
32
- "messages": [{"role": "user", "content": prompt}],
33
- "max_new_tokens": 2000,
34
- "do_sample": True,
35
- "temperature": 0.2,
36
- "top_k": 10,
37
- "top_p": 0.90
38
- }
39
- try:
40
- response = requests.post(url, json=payload, headers=headers)
41
- response.raise_for_status()
42
- data = response.json()
43
- return {
44
- "response": data['choices'][0]['content'],
45
- "usage": data.get('usage', {}),
46
- "generation_time": data.get('generation_time', None)
47
- }
48
- except requests.exceptions.RequestException as e:
49
- return {
50
- "response": f"Error occurred: {str(e)}",
51
- "usage": {},
52
- "generation_time": None
53
  }
 
 
 
 
 
 
54
 
55
- def get_groq_response(prompt):
56
- client = Groq(api_key=os.getenv("GROQ_API_KEY"))
57
- completion = client.chat.completions.create(
58
- model="llama-3.1-8b-instant",
59
- messages=[
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  {
61
- "role": "user",
62
- "content": prompt
 
63
  }
64
- ]
65
- )
66
- return completion.choices[0].message.content
67
-
68
- def google_search_agent(state: DocumentState) -> DocumentState:
69
- """Performs Google search and extracts content from results."""
70
- search_query = state.get('search_query')
71
- if not search_query or not isinstance(search_query, str):
72
- return state
73
-
74
- try:
75
- search_results = []
76
- # Get top 3 search results
77
- for url in search(state['search_query'], num_results=3):
78
- try:
79
- response = requests.get(url, timeout=10)
80
- response.raise_for_status()
81
-
82
- soup = BeautifulSoup(response.content, 'html.parser')
83
-
84
- # Remove script and style elements
85
- for script in soup(["script", "style"]):
86
- script.decompose()
87
-
88
- # Get text content
89
- text = soup.get_text()
90
-
91
- # Clean up text
92
- lines = (line.strip() for line in text.splitlines())
93
- chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
94
- text = ' '.join(chunk for chunk in chunks if chunk)
95
-
96
- # Limit text length
97
- if len(text) > 1000:
98
- text = text[:1000] + "..."
99
-
100
- search_results.append({
101
- 'url': url,
102
- 'content': text,
103
- 'title': soup.title.string if soup.title else "No title"
104
- })
105
- except Exception as e:
106
- print(f"Error scraping {url}: {e}")
107
- continue
108
 
109
- state['search_results'] = search_results
110
- except Exception as e:
111
- print(f"Error during search: {e}")
112
- state['search_results'] = []
113
-
114
- return state
115
-
116
- def search_analyzer_agent(state: DocumentState) -> DocumentState:
117
- """Analyzes user query to determine if web search is needed."""
118
- search_query = state.get('search_query')
119
- if not search_query or not isinstance(search_query, str):
120
- return state
121
 
122
- # Keywords that typically indicate need for current information
123
- search_indicators = [
124
- 'latest', 'recent', 'current', 'news', 'update', 'today', 'now',
125
- 'what is', 'who is', 'when did', 'where is', 'how to', 'definition',
126
- 'explain', 'information about', 'tell me about', 'research'
127
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- query_lower = search_query.lower()
130
- state['needs_search'] = any(indicator in query_lower for indicator in search_indicators)
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- return state
133
-
134
- def search_response_agent(state: DocumentState) -> DocumentState:
135
- """Generates response based on search results."""
136
- search_results = state.get('search_results')
137
- search_query = state.get('search_query')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- if not search_results or not isinstance(search_results, list):
140
- # Fallback to regular LLM response
141
- if search_query and isinstance(search_query, str):
142
- response = get_groq_response(search_query)
143
- state['summaries'] = [response]
144
- return state
 
 
145
 
146
- # Prepare search results for LLM
147
- search_context = "\n\n".join([
148
- f"Source: {result['title']} ({result['url']})\nContent: {result['content']}"
149
- for result in search_results
150
- ])
151
 
152
- prompt = f"""Based on the following search results, provide a comprehensive and accurate answer to the user's question: "{search_query}"
153
-
154
- Search Results:
155
- {search_context}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- Please provide a well-structured response that:
158
- 1. Answers the user's question directly
159
- 2. Cites the sources when relevant
160
- 3. Is accurate and informative
161
- 4. Is concise but comprehensive
162
 
163
- Response:"""
 
 
164
 
165
- response = get_groq_response(prompt)
166
- state['summaries'] = [response]
167
- return state
168
-
169
- def document_extractor_agent(state: DocumentState, pdf_path: str) -> DocumentState:
170
- """Extracts documents from a PDF file."""
171
  try:
 
172
  loader = PyMuPDFLoader(pdf_path)
173
  documents = loader.load()
174
- state['documents'] = [
175
- {
176
- 'content': doc.page_content,
177
- 'page': doc.metadata.get('page', 0) + 1,
178
- 'source': doc.metadata.get('source', 'Unknown')
179
- } for doc in documents
180
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  except Exception as e:
182
- print(f"Error loading PDF: {e}")
183
- state['documents'] = []
184
- return state
185
 
186
- def document_summarizer_agent(state: DocumentState) -> DocumentState:
187
- """Retrieves summaries of the documents."""
188
- truncated_docs = []
189
- for doc in state['documents']:
190
- content = doc['content'][:500]
191
- truncated_docs.append(f"Page {doc['page']}: {content}")
192
-
193
- prompt = f"""Summarize these documents in exactly 3 sentences. Include page citations (p. X).
 
 
 
 
 
194
 
195
- Documents:
196
- {chr(10).join(truncated_docs)}
 
 
 
 
 
 
 
 
197
 
198
- Write 3 sentences with page citations with only refer from the document don't add up and jump to the conclusion."""
199
-
200
- summary = get_groq_response(prompt)
201
- state['summaries'] = [summary]
202
- return state
203
 
204
- def create_document_graph():
205
- talking_documents = StateGraph(DocumentState)
206
- talking_documents.add_node('document_extractor', document_extractor_agent)
207
- talking_documents.add_node('document_summarizer', document_summarizer_agent)
208
- talking_documents.set_entry_point('document_extractor')
209
- talking_documents.add_edge('document_extractor', 'document_summarizer')
210
- return talking_documents.compile()
211
 
212
- def create_search_graph():
213
- search_workflow = StateGraph(DocumentState)
214
- search_workflow.add_node('search_analyzer', search_analyzer_agent)
215
- search_workflow.add_node('google_search', google_search_agent)
216
- search_workflow.add_node('search_response', search_response_agent)
217
- search_workflow.set_entry_point('search_analyzer')
218
-
219
- # Conditional edge based on search needs
220
- def should_search(state):
221
- return "search" if state.get('needs_search', False) else "response"
222
 
223
- search_workflow.add_conditional_edges(
224
- 'search_analyzer',
225
- should_search,
226
- {
227
- "search": "google_search",
228
- "response": "search_response"
229
- }
230
- )
231
- search_workflow.add_edge('google_search', 'search_response')
232
- return search_workflow.compile()
233
-
234
- def process_pdf_and_chat(pdf_file, message, history, system_message, max_tokens, temperature, top_p, enable_search=False):
235
  if pdf_file is None:
236
- return history + [(message, "Please upload a PDF file first.")]
237
 
238
  try:
239
- # Handle file path - in newer Gradio versions, pdf_file is already a path
240
  if isinstance(pdf_file, str):
241
- tmp_pdf_path = pdf_file
242
- cleanup_needed = False
243
  else:
244
  # For older versions where pdf_file is a file object
245
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
246
  tmp_file.write(pdf_file.read())
247
- tmp_pdf_path = tmp_file.name
248
- cleanup_needed = True
249
 
250
- # Check if user wants to search for additional information
251
- search_keywords = ['search', 'find more', 'additional info', 'more information', 'research']
252
- if enable_search and any(keyword in message.lower() for keyword in search_keywords):
253
- # Use search workflow for additional information
254
- search_graph = create_search_graph()
255
- search_state = {
256
- 'documents': [],
257
- 'summaries': [],
258
- 'search_results': [],
259
- 'search_query': message,
260
- 'needs_search': True
261
- }
262
-
263
- search_result = search_graph.invoke(search_state)
264
-
265
- # Also process the PDF
266
- def document_extractor_with_path(state: DocumentState) -> DocumentState:
267
- return document_extractor_agent(state, tmp_pdf_path)
268
-
269
- talking_documents = StateGraph(DocumentState)
270
- talking_documents.add_node('document_extractor', document_extractor_with_path)
271
- talking_documents.add_node('document_summarizer', document_summarizer_agent)
272
- talking_documents.set_entry_point('document_extractor')
273
- talking_documents.add_edge('document_extractor', 'document_summarizer')
274
- pdf_graph = talking_documents.compile()
275
-
276
- pdf_state = {'documents': [], 'summaries': []}
277
- pdf_result = pdf_graph.invoke(pdf_state)
278
-
279
- # Combine PDF and search results
280
- combined_response = f"**PDF Summary:**\n{pdf_result['summaries'][0] if pdf_result['summaries'] else 'No summary available'}\n\n**Additional Information from Web:**\n{search_result['summaries'][0] if search_result['summaries'] else 'No additional information found'}"
281
-
282
- response = combined_response
283
- else:
284
- # Regular PDF processing
285
- def document_extractor_with_path(state: DocumentState) -> DocumentState:
286
- return document_extractor_agent(state, tmp_pdf_path)
287
-
288
- talking_documents = StateGraph(DocumentState)
289
- talking_documents.add_node('document_extractor', document_extractor_with_path)
290
- talking_documents.add_node('document_summarizer', document_summarizer_agent)
291
- talking_documents.set_entry_point('document_extractor')
292
- talking_documents.add_edge('document_extractor', 'document_summarizer')
293
- graph = talking_documents.compile()
294
-
295
- state = {'documents': [], 'summaries': []}
296
- final_state = graph.invoke(state)
297
-
298
- if final_state['summaries']:
299
- response = final_state['summaries'][0]
300
- else:
301
- response = "Unable to process the PDF. Please check the file format."
302
 
303
- # Clean up temporary file only if we created it
304
- if cleanup_needed:
305
- os.unlink(tmp_pdf_path)
 
 
306
 
307
- return history + [(message, response)]
308
 
309
  except Exception as e:
310
- return history + [(message, f"Error processing PDF: {str(e)}")]
311
 
312
- def respond_messages(message, history, system_message, max_tokens, temperature, top_p, enable_search=False):
313
- """Enhanced chat function with optional Google search - returns just the response text"""
314
- if enable_search:
315
- # Use search workflow
316
- search_graph = create_search_graph()
317
- state = {
318
- 'documents': [],
319
- 'summaries': [],
320
- 'search_results': [],
321
- 'search_query': message,
322
- 'needs_search': False
323
- }
324
-
325
- final_state = search_graph.invoke(state)
326
-
327
- if final_state['summaries']:
328
- response = final_state['summaries'][0]
329
- else:
330
- # Fallback to regular LLM response
331
- prompt = f"{system_message}\n\nUser: {message}"
332
- response = get_groq_response(prompt)
333
- else:
334
- # Regular chat without search
335
- prompt = f"{system_message}\n\nUser: {message}"
336
- response = get_groq_response(prompt)
337
-
338
- return response
339
 
340
- def process_pdf_and_chat_messages(pdf_file, message, history, system_message, max_tokens, temperature, top_p, enable_search=False):
341
- """Enhanced PDF processing function - returns just the response text"""
 
 
342
  if pdf_file is None:
343
- return "Please upload a PDF file first."
344
 
345
  try:
346
- # Handle file path - in newer Gradio versions, pdf_file is already a path
347
  if isinstance(pdf_file, str):
348
- tmp_pdf_path = pdf_file
349
- cleanup_needed = False
350
  else:
351
- # For older versions where pdf_file is a file object
352
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
353
  tmp_file.write(pdf_file.read())
354
- tmp_pdf_path = tmp_file.name
355
- cleanup_needed = True
356
 
357
- # Check if user wants to search for additional information
358
- search_keywords = ['search', 'find more', 'additional info', 'more information', 'research']
359
- if enable_search and any(keyword in message.lower() for keyword in search_keywords):
360
- # Use search workflow for additional information
361
- search_graph = create_search_graph()
362
- search_state = {
363
- 'documents': [],
364
- 'summaries': [],
365
- 'search_results': [],
366
- 'search_query': message,
367
- 'needs_search': True
368
- }
369
-
370
- search_result = search_graph.invoke(search_state)
371
-
372
- # Also process the PDF
373
- def document_extractor_with_path(state: DocumentState) -> DocumentState:
374
- return document_extractor_agent(state, tmp_pdf_path)
375
-
376
- talking_documents = StateGraph(DocumentState)
377
- talking_documents.add_node('document_extractor', document_extractor_with_path)
378
- talking_documents.add_node('document_summarizer', document_summarizer_agent)
379
- talking_documents.set_entry_point('document_extractor')
380
- talking_documents.add_edge('document_extractor', 'document_summarizer')
381
- pdf_graph = talking_documents.compile()
382
-
383
- pdf_state = {'documents': [], 'summaries': []}
384
- pdf_result = pdf_graph.invoke(pdf_state)
385
-
386
- # Combine PDF and search results
387
- combined_response = f"**PDF Summary:**\n{pdf_result['summaries'][0] if pdf_result['summaries'] else 'No summary available'}\n\n**Additional Information from Web:**\n{search_result['summaries'][0] if search_result['summaries'] else 'No additional information found'}"
388
-
389
- response = combined_response
390
- else:
391
- # Regular PDF processing
392
- def document_extractor_with_path(state: DocumentState) -> DocumentState:
393
- return document_extractor_agent(state, tmp_pdf_path)
394
-
395
- talking_documents = StateGraph(DocumentState)
396
- talking_documents.add_node('document_extractor', document_extractor_with_path)
397
- talking_documents.add_node('document_summarizer', document_summarizer_agent)
398
- talking_documents.set_entry_point('document_extractor')
399
- talking_documents.add_edge('document_extractor', 'document_summarizer')
400
- graph = talking_documents.compile()
401
-
402
- state = {'documents': [], 'summaries': []}
403
- final_state = graph.invoke(state)
404
-
405
- if final_state['summaries']:
406
- response = final_state['summaries'][0]
407
- else:
408
- response = "Unable to process the PDF. Please check the file format."
409
 
410
- # Clean up temporary file only if we created it
411
- if cleanup_needed:
412
- os.unlink(tmp_pdf_path)
413
 
414
- return response
415
 
416
  except Exception as e:
417
- return f"Error processing PDF: {str(e)}"
418
-
419
- def respond(message, history, system_message, max_tokens, temperature, top_p, enable_search=False):
420
- """Enhanced chat function with optional Google search"""
421
- if enable_search:
422
- # Use search workflow
423
- search_graph = create_search_graph()
424
- state = {
425
- 'documents': [],
426
- 'summaries': [],
427
- 'search_results': [],
428
- 'search_query': message,
429
- 'needs_search': False
430
- }
431
-
432
- final_state = search_graph.invoke(state)
433
-
434
- if final_state['summaries']:
435
- response = final_state['summaries'][0]
436
- else:
437
- # Fallback to regular LLM response
438
- prompt = f"{system_message}\n\nUser: {message}"
439
- response = get_groq_response(prompt)
440
- else:
441
- # Regular chat without search
442
- prompt = f"{system_message}\n\nUser: {message}"
443
- response = get_groq_response(prompt)
444
-
445
- return history + [(message, response)]
446
 
447
  # Create the Gradio interface
448
  with gr.Blocks() as demo:
449
- gr.Markdown("# Document Summarizer with Web Search")
450
- gr.Markdown("Upload a PDF document and ask questions about it, or chat normally. Enable search for additional web information.")
451
 
452
  with gr.Row():
453
  with gr.Column(scale=1):
454
  pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
455
- enable_search = gr.Checkbox(label="Enable Google Search", value=False)
456
  system_message = gr.Textbox(
457
  value="You are a helpful assistant for summarizing and finding related information needed.",
458
  label="System message"
@@ -469,29 +404,20 @@ with gr.Blocks() as demo:
469
  def user_input(message, history):
470
  return "", history + [{"role": "user", "content": message}]
471
 
472
- def bot_response(history, pdf_file, enable_search, system_message, max_tokens, temperature, top_p):
473
  message = history[-1]["content"]
474
  if pdf_file is not None:
475
- response = process_pdf_and_chat_messages(pdf_file, message, history[:-1], system_message, max_tokens, temperature, top_p, enable_search)
476
  else:
477
- response = respond_messages(message, history[:-1], system_message, max_tokens, temperature, top_p, enable_search)
478
  return history[:-1] + [{"role": "user", "content": message}, {"role": "assistant", "content": response}]
479
 
480
- def auto_summarize_pdf(pdf_file):
481
- """Automatically summarize PDF when uploaded"""
482
- if pdf_file is None:
483
- return []
484
-
485
- # Trigger automatic summarization
486
- response = process_pdf_and_chat_messages(pdf_file, "Please provide a summary of this document", [], "You are a helpful assistant for summarizing documents.", 512, 0.7, 0.95, False)
487
- return [{"role": "assistant", "content": response}]
488
-
489
  msg.submit(user_input, [msg, chatbot], [msg, chatbot], queue=False).then(
490
- bot_response, [chatbot, pdf_upload, enable_search, system_message, max_tokens, temperature, top_p], chatbot
491
  )
492
  clear.click(lambda: None, None, chatbot, queue=False)
493
 
494
- # Auto-summarize when PDF is uploaded
495
  pdf_upload.upload(auto_summarize_pdf, [pdf_upload], [chatbot])
496
 
497
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ from langgraph.graph import StateGraph, START, END
3
+ from typing import TypedDict, List, Union, Dict, Any, Annotated
4
  from langchain_community.document_loaders import PyMuPDFLoader
5
+ from langchain_community.embeddings import HuggingFaceEmbeddings
6
+ from hybrid_retriever import build_hybrid_retriever
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_core.messages import HumanMessage, AIMessage, ToolMessage
9
+ from langchain_core.documents import Document
10
  from groq import Groq
11
  import os
12
  from dotenv import load_dotenv
13
  import tempfile
14
+ import time
15
+ import logging
16
+ from operator import add
 
 
 
17
 
18
  load_dotenv()
19
+
20
  # Check if GROQ_API_KEY is available
21
  if not os.getenv("GROQ_API_KEY"):
22
  print("Warning: GROQ_API_KEY not found in environment variables")
23
+
24
+ def add_messages(left, right):
25
+ """Helper function to add messages"""
26
+ return left + right
27
+
28
+ class AgentState(TypedDict):
29
+ messages: Annotated[List[Union[HumanMessage, AIMessage, ToolMessage]], add_messages]
30
+ query: str
31
+ documents: List[str]
32
+ final_answer: str
33
  needs_search: bool
34
+ search_count: int
35
+ metrics: Dict[str, Any]
36
 
37
+ class ResponseTimeTracker:
38
+ def __init__(self):
39
+ self.metrics = {
40
+ "retrieval_time": 0,
41
+ "llm_processing_time": 0,
42
+ "total_time": 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  }
44
+
45
+ def update_retrieval_metrics(self, retrieval_metrics):
46
+ self.metrics.update(retrieval_metrics)
47
+
48
+ def get_metrics_dict(self):
49
+ return self.metrics
50
 
51
+ class CustomAgentExecutor:
52
+ def __init__(self, retriever):
53
+ self.retriever = retriever
54
+ self.groq_client = Groq(api_key=os.getenv("GROQ_API_KEY"))
55
+ self.response_tracker = ResponseTimeTracker()
56
+ self.max_searches = 3
57
+
58
+ # Create LangGraph workflow
59
+ self.workflow = self._create_workflow()
60
+
61
+ def _create_workflow(self):
62
+ """Create LangGraph workflow"""
63
+ workflow = StateGraph(AgentState)
64
+
65
+ # Add nodes
66
+ workflow.add_node("search", self._search_node)
67
+ workflow.add_node("generate", self._generate_node)
68
+ workflow.add_node("decide", self._decide_node)
69
+
70
+ # Add edges
71
+ workflow.add_edge(START, "search")
72
+ workflow.add_edge("search", "decide")
73
+ workflow.add_conditional_edges(
74
+ "decide",
75
+ self._should_continue,
76
  {
77
+ "search": "search",
78
+ "generate": "generate",
79
+ "end": END
80
  }
81
+ )
82
+ workflow.add_edge("generate", END)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ return workflow.compile()
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ def _search_node(self, state: AgentState) -> AgentState:
87
+ """Node for document retrieval"""
88
+ query = state.get("query", "")
89
+ search_count = state.get("search_count", 0)
90
+
91
+ # Perform retrieval
92
+ retrieval_start = time.time()
93
+ try:
94
+ docs = self.retriever.get_relevant_documents(query)
95
+ retrieval_time = time.time() - retrieval_start
96
+ self.response_tracker.metrics["retrieval_time"] = retrieval_time
97
+ except Exception as e:
98
+ logging.error(f"Retrieval error: {e}")
99
+ docs = []
100
+ retrieval_time = time.time() - retrieval_start
101
+ self.response_tracker.metrics["retrieval_time"] = retrieval_time
102
+
103
+ # Format documents
104
+ formatted_docs = []
105
+ if docs:
106
+ for i, doc in enumerate(docs, 1):
107
+ ref = f"[Doc {i}]"
108
+ content = doc.page_content.strip()
109
+ formatted_docs.append(f"{ref} {content}")
110
+ else:
111
+ formatted_docs = ["No relevant information found in the knowledge base."]
112
+
113
+ return {
114
+ **state,
115
+ "documents": formatted_docs,
116
+ "search_count": search_count + 1,
117
+ "needs_search": False
118
+ }
119
 
120
+ def _decide_node(self, state: AgentState) -> AgentState:
121
+ """Node to decide next action"""
122
+ documents = state.get("documents", [])
123
+ search_count = state.get("search_count", 0)
124
+
125
+ # Simple decision logic
126
+ if not documents or documents == ["No relevant information found in the knowledge base."]:
127
+ if search_count < self.max_searches:
128
+ return {**state, "needs_search": True}
129
+ else:
130
+ return {**state, "needs_search": False, "final_answer": "I don't have the knowledge."}
131
+ else:
132
+ return {**state, "needs_search": False}
133
 
134
+ def _generate_node(self, state: AgentState) -> AgentState:
135
+ """Node for LLM response generation"""
136
+ query = state.get("query", "")
137
+ documents = state.get("documents", [])
138
+
139
+ # Create prompt with documents
140
+ doc_context = "\n\n".join(documents)
141
+ system_prompt = (
142
+ "You are a helpful assistant that answers questions based only on the provided documents. "
143
+ "Each passage is tagged with a source like [Doc 1], [Doc 2], etc. "
144
+ "When answering, cite the relevant document(s) using these tags. "
145
+ "You are prohibited from using your past knowledge. "
146
+ "When the answer is not directly explained in the document(s), you MUST answer with 'I don't have the knowledge'."
147
+ )
148
+
149
+ user_prompt = f"Context:\n{doc_context}\n\nQuestion: {query}\n\nAnswer:"
150
+
151
+ # Generate response using Groq
152
+ llm_start = time.time()
153
+ try:
154
+ response = self.groq_client.chat.completions.create(
155
+ model="llama-3.1-8b-instant",
156
+ messages=[
157
+ {"role": "system", "content": system_prompt},
158
+ {"role": "user", "content": user_prompt}
159
+ ]
160
+ )
161
+ llm_time = time.time() - llm_start
162
+ self.response_tracker.metrics["llm_processing_time"] = llm_time
163
+
164
+ response_content = response.choices[0].message.content
165
+
166
+ return {
167
+ **state,
168
+ "final_answer": response_content,
169
+ "messages": state.get("messages", []) + [
170
+ HumanMessage(content=query),
171
+ AIMessage(content=response_content)
172
+ ]
173
+ }
174
+ except Exception as e:
175
+ llm_time = time.time() - llm_start
176
+ self.response_tracker.metrics["llm_processing_time"] = llm_time
177
+ error_msg = f"LLM generation error: {str(e)}"
178
+ logging.error(f"LLM error: {e}", exc_info=True)
179
+ return {
180
+ **state,
181
+ "final_answer": error_msg,
182
+ "messages": state.get("messages", []) + [
183
+ HumanMessage(content=query),
184
+ AIMessage(content=error_msg)
185
+ ]
186
+ }
187
 
188
+ def _should_continue(self, state: AgentState) -> str:
189
+ """Determine next step in workflow"""
190
+ if state.get("needs_search", False):
191
+ return "search"
192
+ elif state.get("final_answer"):
193
+ return "end"
194
+ else:
195
+ return "generate"
196
 
197
+ def get_last_response_metrics(self) -> Dict[str, Any]:
198
+ """Get the metrics from the last query response"""
199
+ return self.response_tracker.get_metrics_dict()
 
 
200
 
201
+ def query(self, question: str) -> str:
202
+ """Main query method"""
203
+ initial_state = {
204
+ "messages": [],
205
+ "query": question,
206
+ "documents": [],
207
+ "final_answer": "",
208
+ "needs_search": False,
209
+ "search_count": 0,
210
+ "metrics": {}
211
+ }
212
+
213
+ total_start = time.time()
214
+ try:
215
+ final_state = self.workflow.invoke(initial_state)
216
+ total_time = time.time() - total_start
217
+ self.response_tracker.metrics["total_time"] = total_time
218
+
219
+ return final_state.get("final_answer", "No answer generated")
220
+ except Exception as e:
221
+ total_time = time.time() - total_start
222
+ self.response_tracker.metrics["total_time"] = total_time
223
+ logging.error(f"Query processing error: {e}")
224
+ return f"Error processing query: {str(e)}"
225
 
226
+ # Global variables for RAG system
227
+ vector_store = None
228
+ agent_executor = None
 
 
229
 
230
+ def create_vector_store(pdf_path: str):
231
+ """Create vector store from PDF documents"""
232
+ global vector_store, agent_executor
233
 
 
 
 
 
 
 
234
  try:
235
+ # Load PDF documents
236
  loader = PyMuPDFLoader(pdf_path)
237
  documents = loader.load()
238
+
239
+ # Split documents into chunks
240
+ text_splitter = RecursiveCharacterTextSplitter(
241
+ chunk_size=1000,
242
+ chunk_overlap=200,
243
+ length_function=len
244
+ )
245
+ chunks = text_splitter.split_documents(documents)
246
+
247
+ # Create embeddings
248
+ embeddings = HuggingFaceEmbeddings(
249
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
250
+ )
251
+
252
+ # Extract texts for sparse retrieval
253
+ texts = [doc.page_content for doc in chunks]
254
+
255
+ # Build hybrid retriever
256
+ hybrid_retriever = build_hybrid_retriever(
257
+ texts=texts,
258
+ index_name="document_index",
259
+ embedding=embeddings,
260
+ es_url="http://localhost:9200",
261
+ es_username="elastic",
262
+ es_password=os.getenv("ELASTIC_PASSWORD", ""),
263
+ top_k_dense=5,
264
+ top_k_sparse=5
265
+ )
266
+
267
+ # Add documents to the hybrid retriever
268
+ hybrid_retriever.add_documents(chunks)
269
+
270
+ # Store the hybrid retriever
271
+ vector_store = hybrid_retriever
272
+
273
+ # Create agent executor
274
+ agent_executor = CustomAgentExecutor(hybrid_retriever)
275
+
276
+ return True
277
  except Exception as e:
278
+ logging.error(f"Error creating vector store: {e}")
279
+ return False
 
280
 
281
+ def get_groq_response(prompt):
282
+ """Get response from Groq API"""
283
+ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
284
+ completion = client.chat.completions.create(
285
+ model="llama-3.1-8b-instant",
286
+ messages=[
287
+ {
288
+ "role": "user",
289
+ "content": prompt
290
+ }
291
+ ]
292
+ )
293
+ return completion.choices[0].message.content
294
 
295
+ def summarize_document(pdf_path: str) -> str:
296
+ """Summarize the uploaded document"""
297
+ try:
298
+ loader = PyMuPDFLoader(pdf_path)
299
+ documents = loader.load()
300
+
301
+ # Create a summary of the document
302
+ full_text = "\n\n".join([doc.page_content[:1000] for doc in documents[:5]]) # First 5 pages
303
+
304
+ prompt = f"""Summarize the following document in exactly 3 sentences. Include page references where relevant.
305
 
306
+ Document content:
307
+ {full_text}
 
 
 
308
 
309
+ Write 3 sentences that capture the main points of the document."""
310
+
311
+ return get_groq_response(prompt)
312
+ except Exception as e:
313
+ return f"Error summarizing document: {str(e)}"
 
 
314
 
315
+ def process_pdf_and_chat_messages(pdf_file, message, history, system_message, max_tokens, temperature, top_p):
316
+ """Process PDF and handle chat with RAG system"""
317
+ global agent_executor
 
 
 
 
 
 
 
318
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  if pdf_file is None:
320
+ return "Please upload a PDF file first."
321
 
322
  try:
323
+ # Handle file path
324
  if isinstance(pdf_file, str):
325
+ pdf_path = pdf_file
 
326
  else:
327
  # For older versions where pdf_file is a file object
328
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
329
  tmp_file.write(pdf_file.read())
330
+ pdf_path = tmp_file.name
 
331
 
332
+ # Create vector store if not exists or if it's a new file
333
+ if agent_executor is None:
334
+ success = create_vector_store(pdf_path)
335
+ if not success:
336
+ return "Error processing PDF for RAG system."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
+ # Use RAG system to answer the question
339
+ if agent_executor:
340
+ response = agent_executor.query(message)
341
+ else:
342
+ response = "RAG system not initialized. Please try uploading the PDF again."
343
 
344
+ return response
345
 
346
  except Exception as e:
347
+ return f"Error processing PDF: {str(e)}"
348
 
349
+ def respond_messages(message, history, system_message, max_tokens, temperature, top_p):
350
+ """Handle chat without PDF using regular Groq response"""
351
+ prompt = f"{system_message}\n\nUser: {message}"
352
+ return get_groq_response(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
+ def auto_summarize_pdf(pdf_file):
355
+ """Automatically summarize PDF when uploaded and create vector store"""
356
+ global agent_executor
357
+
358
  if pdf_file is None:
359
+ return []
360
 
361
  try:
362
+ # Handle file path
363
  if isinstance(pdf_file, str):
364
+ pdf_path = pdf_file
 
365
  else:
 
366
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
367
  tmp_file.write(pdf_file.read())
368
+ pdf_path = tmp_file.name
 
369
 
370
+ # Create vector store for RAG
371
+ success = create_vector_store(pdf_path)
372
+ if not success:
373
+ return [{"role": "assistant", "content": "Error processing PDF for RAG system."}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
+ # Generate summary
376
+ summary = summarize_document(pdf_path)
 
377
 
378
+ return [{"role": "assistant", "content": f"**Document Summary:**\n{summary}\n\n*The document has been processed and is ready for questions using RAG system.*"}]
379
 
380
  except Exception as e:
381
+ return [{"role": "assistant", "content": f"Error processing PDF: {str(e)}"}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
 
383
  # Create the Gradio interface
384
  with gr.Blocks() as demo:
385
+ gr.Markdown("# Document Summarizer with RAG")
386
+ gr.Markdown("Upload a PDF document to get an automatic summary and ask questions using Retrieval-Augmented Generation (RAG).")
387
 
388
  with gr.Row():
389
  with gr.Column(scale=1):
390
  pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"])
 
391
  system_message = gr.Textbox(
392
  value="You are a helpful assistant for summarizing and finding related information needed.",
393
  label="System message"
 
404
  def user_input(message, history):
405
  return "", history + [{"role": "user", "content": message}]
406
 
407
+ def bot_response(history, pdf_file, system_message, max_tokens, temperature, top_p):
408
  message = history[-1]["content"]
409
  if pdf_file is not None:
410
+ response = process_pdf_and_chat_messages(pdf_file, message, history[:-1], system_message, max_tokens, temperature, top_p)
411
  else:
412
+ response = respond_messages(message, history[:-1], system_message, max_tokens, temperature, top_p)
413
  return history[:-1] + [{"role": "user", "content": message}, {"role": "assistant", "content": response}]
414
 
 
 
 
 
 
 
 
 
 
415
  msg.submit(user_input, [msg, chatbot], [msg, chatbot], queue=False).then(
416
+ bot_response, [chatbot, pdf_upload, system_message, max_tokens, temperature, top_p], chatbot
417
  )
418
  clear.click(lambda: None, None, chatbot, queue=False)
419
 
420
+ # Auto-summarize and create vector store when PDF is uploaded
421
  pdf_upload.upload(auto_summarize_pdf, [pdf_upload], [chatbot])
422
 
423
  if __name__ == "__main__":
hybrid_retriever.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from elasticsearch import Elasticsearch
2
+ from langchain_core.documents import Document
3
+ from langchain_core.retrievers import BaseRetriever
4
+ from langchain_elasticsearch import ElasticsearchStore, BM25Strategy
5
+ from langchain_core.vectorstores import VectorStoreRetriever
6
+ from pydantic import Field
7
+ from typing import List
8
+ import logging
9
+
10
+ class HybridRetriever(BaseRetriever):
11
+ dense_db: ElasticsearchStore
12
+ dense_retriever: VectorStoreRetriever
13
+ sparse_db: ElasticsearchStore
14
+ sparse_retriever: VectorStoreRetriever
15
+ index_dense: str
16
+ index_sparse: str
17
+
18
+ top_k_dense: int = 5
19
+ top_k_sparse: int = 5
20
+ is_training: bool = False
21
+
22
+ @classmethod
23
+ def create(
24
+ cls,
25
+ dense_db,
26
+ dense_retriever,
27
+ sparse_db,
28
+ sparse_retriever,
29
+ index_dense,
30
+ index_sparse,
31
+ top_k_dense=5,
32
+ top_k_sparse=5,
33
+ is_training=False,
34
+ ):
35
+ return cls(
36
+ dense_db=dense_db,
37
+ dense_retriever=dense_retriever,
38
+ sparse_db=sparse_db,
39
+ sparse_retriever=sparse_retriever,
40
+ index_dense=index_dense,
41
+ index_sparse=index_sparse,
42
+ top_k_dense=top_k_dense,
43
+ top_k_sparse=top_k_sparse,
44
+ is_training=is_training,
45
+ )
46
+
47
+ def reset_indices(self):
48
+ result = self.dense_db.client.indices.delete(
49
+ index=self.index_dense,
50
+ ignore_unavailable=True,
51
+ allow_no_indices=True,
52
+ )
53
+ logging.info("dense_db delete: %s", result.get("acknowledged"))
54
+
55
+ result = self.sparse_db.client.indices.delete(
56
+ index=self.index_sparse,
57
+ ignore_unavailable=True,
58
+ allow_no_indices=True,
59
+ )
60
+ logging.info("sparse_db delete: %s", result.get("acknowledged"))
61
+
62
+ def add_documents(self, documents, batch_size=25):
63
+ valid_docs = []
64
+ for doc in documents:
65
+ print(f"[DOC] {repr(doc.page_content)}")
66
+
67
+ if isinstance(doc, Document) and isinstance(doc.page_content, str) and doc.page_content.strip():
68
+ valid_docs.append(doc)
69
+ else:
70
+ logging.warning(f"Skipped invalid or empty doc: {doc}")
71
+
72
+ if not valid_docs:
73
+ raise ValueError("No valid documents to add.")
74
+
75
+ for i in range(0, len(valid_docs), batch_size):
76
+ logging.info(f"Processing batch {i}")
77
+ dense_batch = valid_docs[i : i + batch_size]
78
+ sparse_batch = [doc.page_content for doc in dense_batch]
79
+
80
+ self.dense_db.add_documents(dense_batch)
81
+ self.sparse_db.add_texts(sparse_batch)
82
+
83
+ def get_relevant_documents(self, query: str) -> List[Document]:
84
+ dense_docs = self.dense_retriever.invoke(query)
85
+ sparse_docs = self.sparse_retriever.invoke(query)
86
+
87
+ print("len dense coba docs:", len(dense_docs))
88
+ print("len sparse coba docs:", len(sparse_docs))
89
+
90
+ all_docs = dense_docs + sparse_docs
91
+ seen = set()
92
+ unique_docs = []
93
+ for doc in all_docs:
94
+ if doc.page_content not in seen:
95
+ seen.add(doc.page_content)
96
+ unique_docs.append(doc)
97
+ return unique_docs
98
+
99
+ def get_elasticsearch_client(url, username=None, password=None):
100
+ if username and password:
101
+ return Elasticsearch(url, basic_auth=(username, password))
102
+ return Elasticsearch(url)
103
+
104
+ def build_hybrid_retriever(texts, index_name, embedding, es_url, es_username, es_password,
105
+ top_k_dense=5, top_k_sparse=5):
106
+
107
+ dense_index = f"{index_name}_dense"
108
+ sparse_index = f"{index_name}_sparse"
109
+
110
+ client = get_elasticsearch_client(es_url, es_username, es_password)
111
+
112
+ # Dense vector store
113
+ dense_store = ElasticsearchStore(
114
+ index_name=dense_index,
115
+ embedding=embedding,
116
+ es_connection=client,
117
+ )
118
+ dense_retriever = dense_store.as_retriever(search_kwargs={"k": top_k_dense})
119
+
120
+ # Sparse BM25 store
121
+ sparse_store = ElasticsearchStore.from_texts(
122
+ texts=[],
123
+ embedding=embedding,
124
+ index_name=sparse_index,
125
+ es_connection=client,
126
+ strategy=BM25Strategy()
127
+ )
128
+ sparse_retriever = sparse_store.as_retriever(search_kwargs={"k": top_k_sparse})
129
+
130
+ return HybridRetriever.create(
131
+ dense_db=dense_store,
132
+ dense_retriever=dense_retriever,
133
+ sparse_db=sparse_store,
134
+ sparse_retriever=sparse_retriever,
135
+ index_dense=dense_index,
136
+ index_sparse=sparse_index,
137
+ top_k_dense=top_k_dense,
138
+ top_k_sparse=top_k_sparse
139
+ )