Spaces:
Sleeping
Sleeping
Commit ·
0b51467
1
Parent(s): 96229ca
update rag_search and build_retriever
Browse files
agent.py
CHANGED
|
@@ -17,7 +17,7 @@ from langchain_community.document_loaders import AssemblyAIAudioTranscriptLoader
|
|
| 17 |
from langchain.chat_models import init_chat_model
|
| 18 |
from langchain.agents import initialize_agent, AgentType
|
| 19 |
from langchain_community.retrievers import BM25Retriever
|
| 20 |
-
from
|
| 21 |
from langgraph.graph.message import add_messages
|
| 22 |
from langgraph.graph import START, StateGraph
|
| 23 |
from langgraph.prebuilt import ToolNode, tools_condition
|
|
@@ -260,8 +260,8 @@ tools: List[StructuredTool] = [
|
|
| 260 |
|
| 261 |
class AgentState(TypedDict):
|
| 262 |
# The document provided
|
| 263 |
-
input_file: Optional[str] # Contains file path (PDF/PNG)
|
| 264 |
-
messages: Annotated[
|
| 265 |
|
| 266 |
# === Agent Class ===
|
| 267 |
class MyAgent:
|
|
@@ -333,7 +333,7 @@ class MyAgent:
|
|
| 333 |
@tool(name="rag_search")
|
| 334 |
def rag_search(query: str) -> str:
|
| 335 |
"""Retrieve top-3 relevant document chunks via BM25."""
|
| 336 |
-
res = self.retriever.
|
| 337 |
if res:
|
| 338 |
return "\n\n".join([doc.page_content for doc in res[:3]])
|
| 339 |
return ""
|
|
@@ -347,63 +347,63 @@ class MyAgent:
|
|
| 347 |
question: str,
|
| 348 |
file_paths: Optional[List[str]] = None
|
| 349 |
) -> str:
|
| 350 |
-
|
| 351 |
-
|
|
|
|
| 352 |
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
return
|
| 399 |
-
return last_message.strip()
|
| 400 |
|
| 401 |
def run(self, question: str, file_paths: Optional[List[str]] = None) -> str:
|
| 402 |
return self(question, file_paths)
|
| 403 |
|
| 404 |
def _assistant_node(self, state: dict) -> dict:
|
| 405 |
-
# Invoke
|
| 406 |
-
resp = self.llm
|
| 407 |
state["messages"].append(resp)
|
| 408 |
return state
|
| 409 |
|
|
|
|
| 17 |
from langchain.chat_models import init_chat_model
|
| 18 |
from langchain.agents import initialize_agent, AgentType
|
| 19 |
from langchain_community.retrievers import BM25Retriever
|
| 20 |
+
from langchain.schema import BaseMessage, SystemMessage, HumanMessage
|
| 21 |
from langgraph.graph.message import add_messages
|
| 22 |
from langgraph.graph import START, StateGraph
|
| 23 |
from langgraph.prebuilt import ToolNode, tools_condition
|
|
|
|
| 260 |
|
| 261 |
class AgentState(TypedDict):
|
| 262 |
# The document provided
|
| 263 |
+
input_file: Optional[List[str]] # Contains file path (PDF/PNG)
|
| 264 |
+
messages: Annotated[List[BaseMessage], add_messages]
|
| 265 |
|
| 266 |
# === Agent Class ===
|
| 267 |
class MyAgent:
|
|
|
|
| 333 |
@tool(name="rag_search")
|
| 334 |
def rag_search(query: str) -> str:
|
| 335 |
"""Retrieve top-3 relevant document chunks via BM25."""
|
| 336 |
+
res = self.retriever.get_relevant_documents(query)
|
| 337 |
if res:
|
| 338 |
return "\n\n".join([doc.page_content for doc in res[:3]])
|
| 339 |
return ""
|
|
|
|
| 347 |
question: str,
|
| 348 |
file_paths: Optional[List[str]] = None
|
| 349 |
) -> str:
|
| 350 |
+
try:
|
| 351 |
+
# Prepare state graph
|
| 352 |
+
state: Dict[str, Any] = {"messages": [], "input_file": None}
|
| 353 |
|
| 354 |
+
# Use structured tool attributes
|
| 355 |
+
tool_desc = "\n".join(f"{t.name}: {t.description}" for t in self.tools)
|
| 356 |
+
|
| 357 |
+
# Enhanced system prompt with RAG guidance
|
| 358 |
+
rag_prompt = """
|
| 359 |
+
If the question seems to be about any loaded documents, ALWAYS:
|
| 360 |
+
1. Use the rag_search tool first to find relevant information
|
| 361 |
+
2. Base your answer on the retrieved content
|
| 362 |
+
3. If no relevant content is found, say so
|
| 363 |
+
"""
|
| 364 |
+
|
| 365 |
+
sys_msg = SystemMessage(content=f"{SYSTEM_PROMPT}\n\n{rag_prompt if file_paths else ''}\n\nTools:\n{tool_desc}")
|
| 366 |
+
state["messages"].append(sys_msg)
|
| 367 |
+
|
| 368 |
+
# Optionally load RAG docs
|
| 369 |
+
if file_paths:
|
| 370 |
+
self.add_files(file_paths)
|
| 371 |
+
self.build_retriever()
|
| 372 |
+
|
| 373 |
+
# Add user question
|
| 374 |
+
state["messages"].append(HumanMessage(content=question))
|
| 375 |
+
if file_paths:
|
| 376 |
+
state["input_file"] = file_paths
|
| 377 |
+
|
| 378 |
+
# Build graph
|
| 379 |
+
builder = StateGraph(dict)
|
| 380 |
+
builder.add_node("assistant", self._assistant_node)
|
| 381 |
+
builder.add_node("tools", ToolNode(self.tools))
|
| 382 |
+
builder.add_edge(START, "assistant")
|
| 383 |
+
|
| 384 |
+
# Always allow the assistant to hand off to the tools node
|
| 385 |
+
builder.add_edge("assistant", "tools")
|
| 386 |
+
# And then return from tools back to the assistant
|
| 387 |
+
builder.add_edge("tools", "assistant")
|
| 388 |
+
graph = builder.compile()
|
| 389 |
+
|
| 390 |
+
# Use invoke() instead of run()
|
| 391 |
+
out = graph.invoke(state)
|
| 392 |
+
last_message = out["messages"][-1].content
|
| 393 |
+
|
| 394 |
+
# Extract only the FINAL ANSWER part
|
| 395 |
+
if "FINAL ANSWER:" in last_message:
|
| 396 |
+
return last_message.split("FINAL ANSWER:")[-1].strip()
|
| 397 |
+
return last_message.strip()
|
| 398 |
+
except Exception as e:
|
| 399 |
+
return f"Error processing question: {e}"
|
|
|
|
| 400 |
|
| 401 |
def run(self, question: str, file_paths: Optional[List[str]] = None) -> str:
|
| 402 |
return self(question, file_paths)
|
| 403 |
|
| 404 |
def _assistant_node(self, state: dict) -> dict:
|
| 405 |
+
# Invoke the chat model with our BaseMessage list
|
| 406 |
+
resp = self.llm(state["messages"])
|
| 407 |
state["messages"].append(resp)
|
| 408 |
return state
|
| 409 |
|