mabelwang21 commited on
Commit
0b51467
·
1 Parent(s): 96229ca

update rag_search and build_retriever

Browse files
Files changed (1) hide show
  1. agent.py +55 -55
agent.py CHANGED
@@ -17,7 +17,7 @@ from langchain_community.document_loaders import AssemblyAIAudioTranscriptLoader
17
  from langchain.chat_models import init_chat_model
18
  from langchain.agents import initialize_agent, AgentType
19
  from langchain_community.retrievers import BM25Retriever
20
- from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
21
  from langgraph.graph.message import add_messages
22
  from langgraph.graph import START, StateGraph
23
  from langgraph.prebuilt import ToolNode, tools_condition
@@ -260,8 +260,8 @@ tools: List[StructuredTool] = [
260
 
261
  class AgentState(TypedDict):
262
  # The document provided
263
- input_file: Optional[str] # Contains file path (PDF/PNG)
264
- messages: Annotated[list[AnyMessage], add_messages]
265
 
266
  # === Agent Class ===
267
  class MyAgent:
@@ -333,7 +333,7 @@ class MyAgent:
333
  @tool(name="rag_search")
334
  def rag_search(query: str) -> str:
335
  """Retrieve top-3 relevant document chunks via BM25."""
336
- res = self.retriever.invoke(query)
337
  if res:
338
  return "\n\n".join([doc.page_content for doc in res[:3]])
339
  return ""
@@ -347,63 +347,63 @@ class MyAgent:
347
  question: str,
348
  file_paths: Optional[List[str]] = None
349
  ) -> str:
350
- # Prepare state graph
351
- state: Dict[str, Any] = {"messages": [], "input_file": None}
 
352
 
353
- # Use structured tool attributes
354
- tool_desc = "\n".join(f"{t.name}: {t.description}" for t in self.tools)
355
-
356
- # Enhanced system prompt with RAG guidance
357
- rag_prompt = """
358
- If the question seems to be about any loaded documents, ALWAYS:
359
- 1. Use the rag_search tool first to find relevant information
360
- 2. Base your answer on the retrieved content
361
- 3. If no relevant content is found, say so
362
- """
363
-
364
- sys_msg = SystemMessage(content=f"{SYSTEM_PROMPT}\n\n{rag_prompt if file_paths else ''}\n\nTools:\n{tool_desc}")
365
- state["messages"].append(sys_msg)
366
-
367
- # Optionally load RAG docs
368
- if file_paths:
369
- self.add_files(file_paths)
370
- self.build_retriever()
371
-
372
- # Add user question
373
- state["messages"].append(HumanMessage(content=question))
374
- if file_paths:
375
- state["input_file"] = file_paths
376
-
377
- # Build graph
378
- builder = StateGraph(dict)
379
- builder.add_node("assistant", self._assistant_node)
380
- builder.add_node("tools", ToolNode(self.tools))
381
- builder.add_edge(START, "assistant")
382
-
383
- # Updated tool detection logic
384
- builder.add_conditional_edges(
385
- "assistant",
386
- lambda s: any(t.name in s["messages"][-1].content for t in self.tools),
387
- "tools"
388
- )
389
- builder.add_edge("tools", "assistant")
390
- graph = builder.compile()
391
-
392
- # Use invoke() instead of run()
393
- out = graph.invoke(state)
394
- last_message = out["messages"][-1].content
395
-
396
- # Extract only the FINAL ANSWER part
397
- if "FINAL ANSWER:" in last_message:
398
- return last_message.split("FINAL ANSWER:")[-1].strip()
399
- return last_message.strip()
400
 
401
  def run(self, question: str, file_paths: Optional[List[str]] = None) -> str:
402
  return self(question, file_paths)
403
 
404
  def _assistant_node(self, state: dict) -> dict:
405
- # Invoke LLM on current messages
406
- resp = self.llm.invoke(state["messages"])
407
  state["messages"].append(resp)
408
  return state
409
 
 
17
  from langchain.chat_models import init_chat_model
18
  from langchain.agents import initialize_agent, AgentType
19
  from langchain_community.retrievers import BM25Retriever
20
+ from langchain.schema import BaseMessage, SystemMessage, HumanMessage
21
  from langgraph.graph.message import add_messages
22
  from langgraph.graph import START, StateGraph
23
  from langgraph.prebuilt import ToolNode, tools_condition
 
260
 
261
  class AgentState(TypedDict):
262
  # The document provided
263
+ input_file: Optional[List[str]] # Contains file path (PDF/PNG)
264
+ messages: Annotated[List[BaseMessage], add_messages]
265
 
266
  # === Agent Class ===
267
  class MyAgent:
 
333
  @tool(name="rag_search")
334
  def rag_search(query: str) -> str:
335
  """Retrieve top-3 relevant document chunks via BM25."""
336
+ res = self.retriever.get_relevant_documents(query)
337
  if res:
338
  return "\n\n".join([doc.page_content for doc in res[:3]])
339
  return ""
 
347
  question: str,
348
  file_paths: Optional[List[str]] = None
349
  ) -> str:
350
+ try:
351
+ # Prepare state graph
352
+ state: Dict[str, Any] = {"messages": [], "input_file": None}
353
 
354
+ # Use structured tool attributes
355
+ tool_desc = "\n".join(f"{t.name}: {t.description}" for t in self.tools)
356
+
357
+ # Enhanced system prompt with RAG guidance
358
+ rag_prompt = """
359
+ If the question seems to be about any loaded documents, ALWAYS:
360
+ 1. Use the rag_search tool first to find relevant information
361
+ 2. Base your answer on the retrieved content
362
+ 3. If no relevant content is found, say so
363
+ """
364
+
365
+ sys_msg = SystemMessage(content=f"{SYSTEM_PROMPT}\n\n{rag_prompt if file_paths else ''}\n\nTools:\n{tool_desc}")
366
+ state["messages"].append(sys_msg)
367
+
368
+ # Optionally load RAG docs
369
+ if file_paths:
370
+ self.add_files(file_paths)
371
+ self.build_retriever()
372
+
373
+ # Add user question
374
+ state["messages"].append(HumanMessage(content=question))
375
+ if file_paths:
376
+ state["input_file"] = file_paths
377
+
378
+ # Build graph
379
+ builder = StateGraph(dict)
380
+ builder.add_node("assistant", self._assistant_node)
381
+ builder.add_node("tools", ToolNode(self.tools))
382
+ builder.add_edge(START, "assistant")
383
+
384
+ # Always allow the assistant to hand off to the tools node
385
+ builder.add_edge("assistant", "tools")
386
+ # And then return from tools back to the assistant
387
+ builder.add_edge("tools", "assistant")
388
+ graph = builder.compile()
389
+
390
+ # Use invoke() instead of run()
391
+ out = graph.invoke(state)
392
+ last_message = out["messages"][-1].content
393
+
394
+ # Extract only the FINAL ANSWER part
395
+ if "FINAL ANSWER:" in last_message:
396
+ return last_message.split("FINAL ANSWER:")[-1].strip()
397
+ return last_message.strip()
398
+ except Exception as e:
399
+ return f"Error processing question: {e}"
 
400
 
401
  def run(self, question: str, file_paths: Optional[List[str]] = None) -> str:
402
  return self(question, file_paths)
403
 
404
  def _assistant_node(self, state: dict) -> dict:
405
+ # Invoke the chat model with our BaseMessage list
406
+ resp = self.llm(state["messages"])
407
  state["messages"].append(resp)
408
  return state
409