mabelwang21 commited on
Commit
88b5dd6
·
1 Parent(s): 0b51467

update rag_search in build_retrievier

Browse files
Files changed (1) hide show
  1. agent.py +50 -19
agent.py CHANGED
@@ -8,6 +8,8 @@ from typing import List, TypedDict, Annotated, Optional
8
  import requests
9
  from urllib.parse import urlparse
10
  import shutil
 
 
11
 
12
  from langchain.tools import tool, StructuredTool
13
  from langchain_community.document_loaders import (
@@ -21,6 +23,7 @@ from langchain.schema import BaseMessage, SystemMessage, HumanMessage
21
  from langgraph.graph.message import add_messages
22
  from langgraph.graph import START, StateGraph
23
  from langgraph.prebuilt import ToolNode, tools_condition
 
24
 
25
  from youtube_transcript_api import YouTubeTranscriptApi
26
  from PIL import Image
@@ -328,19 +331,32 @@ class MyAgent:
328
  """
329
  if not self.docs:
330
  return
331
- self.retriever = BM25Retriever.from_documents(self.docs)
332
-
333
- @tool(name="rag_search")
334
- def rag_search(query: str) -> str:
335
- """Retrieve top-3 relevant document chunks via BM25."""
336
- res = self.retriever.get_relevant_documents(query)
337
- if res:
338
- return "\n\n".join([doc.page_content for doc in res[:3]])
339
- return ""
340
-
341
- # Register RAG tool
342
- self.tools.append(rag_search)
343
-
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
  def __call__(
346
  self,
@@ -382,8 +398,11 @@ class MyAgent:
382
  builder.add_edge(START, "assistant")
383
 
384
  # Always allow the assistant to hand off to the tools node
385
- builder.add_edge("assistant", "tools")
386
- # And then return from tools back to the assistant
 
 
 
387
  builder.add_edge("tools", "assistant")
388
  graph = builder.compile()
389
 
@@ -402,10 +421,22 @@ class MyAgent:
402
  return self(question, file_paths)
403
 
404
  def _assistant_node(self, state: dict) -> dict:
405
- # Invoke the chat model with our BaseMessage list
406
- resp = self.llm(state["messages"])
407
- state["messages"].append(resp)
408
- return state
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
 
411
 
 
8
  import requests
9
  from urllib.parse import urlparse
10
  import shutil
11
+ import io
12
+ from typing import Dict, Any
13
 
14
  from langchain.tools import tool, StructuredTool
15
  from langchain_community.document_loaders import (
 
23
  from langgraph.graph.message import add_messages
24
  from langgraph.graph import START, StateGraph
25
  from langgraph.prebuilt import ToolNode, tools_condition
26
+ from langchain_core.documents import Document
27
 
28
  from youtube_transcript_api import YouTubeTranscriptApi
29
  from PIL import Image
 
331
  """
332
  if not self.docs:
333
  return
334
+
335
+ # Build retriever
336
+ try:
337
+ self.retriever = BM25Retriever.from_documents(self.docs)
338
+
339
+ # Define tool with proper error handling
340
+ @tool(name="rag_search")
341
+ def rag_search(query: str) -> str:
342
+ """Search loaded documents for relevant information."""
343
+ try:
344
+ if not self.retriever:
345
+ return "No documents have been loaded for search."
346
+
347
+ res = self.retriever.get_relevant_documents(query)
348
+ if res:
349
+ return "\n\n".join(f"Document {i+1}:\n{doc.page_content}"
350
+ for i, doc in enumerate(res[:3]))
351
+ return "No relevant information found in loaded documents."
352
+ except Exception as e:
353
+ return f"Error searching documents: {e}"
354
+
355
+ # Remove existing rag_search if present to prevent duplicates
356
+ self.tools = [t for t in self.tools if t.name != "rag_search"]
357
+ self.tools.append(rag_search)
358
+ except Exception as e:
359
+ print(f"Error building retriever: {e}")
360
 
361
  def __call__(
362
  self,
 
398
  builder.add_edge(START, "assistant")
399
 
400
  # Always allow the assistant to hand off to the tools node
401
+ builder.add_conditional_edges(
402
+ "assistant",
403
+ lambda s: any(t.name in s["messages"][-1].content for t in self.tools),
404
+ {True: "tools", False: "assistant"}
405
+ )
406
  builder.add_edge("tools", "assistant")
407
  graph = builder.compile()
408
 
 
421
  return self(question, file_paths)
422
 
423
  def _assistant_node(self, state: dict) -> dict:
424
+ """Process messages with the LLM."""
425
+ try:
426
+ # Check if messages exist
427
+ if not state["messages"]:
428
+ # Add a system message if empty
429
+ state["messages"].append(SystemMessage(content=SYSTEM_PROMPT))
430
+
431
+ # Invoke the chat model with our BaseMessage list
432
+ resp = self.llm(state["messages"])
433
+ state["messages"].append(resp)
434
+ return state
435
+ except Exception as e:
436
+ # Handle errors by adding an error message
437
+ error_msg = f"Error calling LLM: {str(e)}"
438
+ print(error_msg)
439
+ return state
440
 
441
 
442