Spaces:
Sleeping
Sleeping
Commit ·
88b5dd6
1
Parent(s): 0b51467
update rag_search in build_retrievier
Browse files
agent.py
CHANGED
|
@@ -8,6 +8,8 @@ from typing import List, TypedDict, Annotated, Optional
|
|
| 8 |
import requests
|
| 9 |
from urllib.parse import urlparse
|
| 10 |
import shutil
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from langchain.tools import tool, StructuredTool
|
| 13 |
from langchain_community.document_loaders import (
|
|
@@ -21,6 +23,7 @@ from langchain.schema import BaseMessage, SystemMessage, HumanMessage
|
|
| 21 |
from langgraph.graph.message import add_messages
|
| 22 |
from langgraph.graph import START, StateGraph
|
| 23 |
from langgraph.prebuilt import ToolNode, tools_condition
|
|
|
|
| 24 |
|
| 25 |
from youtube_transcript_api import YouTubeTranscriptApi
|
| 26 |
from PIL import Image
|
|
@@ -328,19 +331,32 @@ class MyAgent:
|
|
| 328 |
"""
|
| 329 |
if not self.docs:
|
| 330 |
return
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
def __call__(
|
| 346 |
self,
|
|
@@ -382,8 +398,11 @@ class MyAgent:
|
|
| 382 |
builder.add_edge(START, "assistant")
|
| 383 |
|
| 384 |
# Always allow the assistant to hand off to the tools node
|
| 385 |
-
builder.
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
| 387 |
builder.add_edge("tools", "assistant")
|
| 388 |
graph = builder.compile()
|
| 389 |
|
|
@@ -402,10 +421,22 @@ class MyAgent:
|
|
| 402 |
return self(question, file_paths)
|
| 403 |
|
| 404 |
def _assistant_node(self, state: dict) -> dict:
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
|
| 411 |
|
|
|
|
| 8 |
import requests
|
| 9 |
from urllib.parse import urlparse
|
| 10 |
import shutil
|
| 11 |
+
import io
|
| 12 |
+
from typing import Dict, Any
|
| 13 |
|
| 14 |
from langchain.tools import tool, StructuredTool
|
| 15 |
from langchain_community.document_loaders import (
|
|
|
|
| 23 |
from langgraph.graph.message import add_messages
|
| 24 |
from langgraph.graph import START, StateGraph
|
| 25 |
from langgraph.prebuilt import ToolNode, tools_condition
|
| 26 |
+
from langchain_core.documents import Document
|
| 27 |
|
| 28 |
from youtube_transcript_api import YouTubeTranscriptApi
|
| 29 |
from PIL import Image
|
|
|
|
| 331 |
"""
|
| 332 |
if not self.docs:
|
| 333 |
return
|
| 334 |
+
|
| 335 |
+
# Build retriever
|
| 336 |
+
try:
|
| 337 |
+
self.retriever = BM25Retriever.from_documents(self.docs)
|
| 338 |
+
|
| 339 |
+
# Define tool with proper error handling
|
| 340 |
+
@tool(name="rag_search")
|
| 341 |
+
def rag_search(query: str) -> str:
|
| 342 |
+
"""Search loaded documents for relevant information."""
|
| 343 |
+
try:
|
| 344 |
+
if not self.retriever:
|
| 345 |
+
return "No documents have been loaded for search."
|
| 346 |
+
|
| 347 |
+
res = self.retriever.get_relevant_documents(query)
|
| 348 |
+
if res:
|
| 349 |
+
return "\n\n".join(f"Document {i+1}:\n{doc.page_content}"
|
| 350 |
+
for i, doc in enumerate(res[:3]))
|
| 351 |
+
return "No relevant information found in loaded documents."
|
| 352 |
+
except Exception as e:
|
| 353 |
+
return f"Error searching documents: {e}"
|
| 354 |
+
|
| 355 |
+
# Remove existing rag_search if present to prevent duplicates
|
| 356 |
+
self.tools = [t for t in self.tools if t.name != "rag_search"]
|
| 357 |
+
self.tools.append(rag_search)
|
| 358 |
+
except Exception as e:
|
| 359 |
+
print(f"Error building retriever: {e}")
|
| 360 |
|
| 361 |
def __call__(
|
| 362 |
self,
|
|
|
|
| 398 |
builder.add_edge(START, "assistant")
|
| 399 |
|
| 400 |
# Always allow the assistant to hand off to the tools node
|
| 401 |
+
builder.add_conditional_edges(
|
| 402 |
+
"assistant",
|
| 403 |
+
lambda s: any(t.name in s["messages"][-1].content for t in self.tools),
|
| 404 |
+
{True: "tools", False: "assistant"}
|
| 405 |
+
)
|
| 406 |
builder.add_edge("tools", "assistant")
|
| 407 |
graph = builder.compile()
|
| 408 |
|
|
|
|
| 421 |
return self(question, file_paths)
|
| 422 |
|
| 423 |
def _assistant_node(self, state: dict) -> dict:
|
| 424 |
+
"""Process messages with the LLM."""
|
| 425 |
+
try:
|
| 426 |
+
# Check if messages exist
|
| 427 |
+
if not state["messages"]:
|
| 428 |
+
# Add a system message if empty
|
| 429 |
+
state["messages"].append(SystemMessage(content=SYSTEM_PROMPT))
|
| 430 |
+
|
| 431 |
+
# Invoke the chat model with our BaseMessage list
|
| 432 |
+
resp = self.llm(state["messages"])
|
| 433 |
+
state["messages"].append(resp)
|
| 434 |
+
return state
|
| 435 |
+
except Exception as e:
|
| 436 |
+
# Handle errors by adding an error message
|
| 437 |
+
error_msg = f"Error calling LLM: {str(e)}"
|
| 438 |
+
print(error_msg)
|
| 439 |
+
return state
|
| 440 |
|
| 441 |
|
| 442 |
|