Spaces:
Sleeping
Sleeping
Commit ·
4ea88b5
1
Parent(s): 88b5dd6
update _assistant_node
Browse files
agent.py
CHANGED
|
@@ -266,20 +266,27 @@ class AgentState(TypedDict):
|
|
| 266 |
input_file: Optional[List[str]] # Contains file path (PDF/PNG)
|
| 267 |
messages: Annotated[List[BaseMessage], add_messages]
|
| 268 |
|
|
|
|
| 269 |
# === Agent Class ===
|
| 270 |
class MyAgent:
|
| 271 |
def __init__(
|
| 272 |
self,
|
| 273 |
-
model_name: str = "anthropic:claude-3-5-sonnet-latest",
|
| 274 |
temperature: float = 0.0
|
| 275 |
):
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
def add_files(self, file_paths: List[str]):
|
| 285 |
"""
|
|
@@ -391,23 +398,25 @@ class MyAgent:
|
|
| 391 |
if file_paths:
|
| 392 |
state["input_file"] = file_paths
|
| 393 |
|
| 394 |
-
# Build graph
|
| 395 |
builder = StateGraph(dict)
|
| 396 |
builder.add_node("assistant", self._assistant_node)
|
| 397 |
builder.add_node("tools", ToolNode(self.tools))
|
| 398 |
builder.add_edge(START, "assistant")
|
| 399 |
|
| 400 |
-
#
|
| 401 |
builder.add_conditional_edges(
|
| 402 |
"assistant",
|
| 403 |
-
|
| 404 |
-
|
| 405 |
)
|
| 406 |
builder.add_edge("tools", "assistant")
|
|
|
|
|
|
|
| 407 |
graph = builder.compile()
|
| 408 |
-
|
| 409 |
-
# Use invoke()
|
| 410 |
-
out = graph.invoke(state)
|
| 411 |
last_message = out["messages"][-1].content
|
| 412 |
|
| 413 |
# Extract only the FINAL ANSWER part
|
|
@@ -423,19 +432,31 @@ class MyAgent:
|
|
| 423 |
def _assistant_node(self, state: dict) -> dict:
|
| 424 |
"""Process messages with the LLM."""
|
| 425 |
try:
|
| 426 |
-
# Check if messages exist
|
| 427 |
-
if not state["messages"]:
|
| 428 |
# Add a system message if empty
|
| 429 |
-
state["messages"]
|
| 430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
# Invoke the chat model with our BaseMessage list
|
| 432 |
-
resp = self.llm(state["messages"])
|
| 433 |
state["messages"].append(resp)
|
| 434 |
return state
|
| 435 |
except Exception as e:
|
| 436 |
-
# Handle errors by adding an error message
|
| 437 |
error_msg = f"Error calling LLM: {str(e)}"
|
| 438 |
print(error_msg)
|
|
|
|
|
|
|
|
|
|
| 439 |
return state
|
| 440 |
|
| 441 |
|
|
|
|
| 266 |
input_file: Optional[List[str]] # Contains file path (PDF/PNG)
|
| 267 |
messages: Annotated[List[BaseMessage], add_messages]
|
| 268 |
|
| 269 |
+
|
| 270 |
# === Agent Class ===
|
| 271 |
class MyAgent:
|
| 272 |
def __init__(
|
| 273 |
self,
|
| 274 |
+
model_name: str = "anthropic:claude-3-5-sonnet-latest", # <-- Use a valid model name
|
| 275 |
temperature: float = 0.0
|
| 276 |
):
|
| 277 |
+
try:
|
| 278 |
+
self.llm = init_chat_model(
|
| 279 |
+
model_name,
|
| 280 |
+
temperature=temperature
|
| 281 |
+
)
|
| 282 |
+
# Base tools
|
| 283 |
+
self.tools = tools
|
| 284 |
+
# RAG components
|
| 285 |
+
self.docs: List[Any] = []
|
| 286 |
+
self.retriever: Optional[BM25Retriever] = None
|
| 287 |
+
except Exception as e:
|
| 288 |
+
print(f"Error initializing LLM: {e}")
|
| 289 |
+
raise
|
| 290 |
|
| 291 |
def add_files(self, file_paths: List[str]):
|
| 292 |
"""
|
|
|
|
| 398 |
if file_paths:
|
| 399 |
state["input_file"] = file_paths
|
| 400 |
|
| 401 |
+
# Build graph with proper conditional edge to prevent loops
|
| 402 |
builder = StateGraph(dict)
|
| 403 |
builder.add_node("assistant", self._assistant_node)
|
| 404 |
builder.add_node("tools", ToolNode(self.tools))
|
| 405 |
builder.add_edge(START, "assistant")
|
| 406 |
|
| 407 |
+
# Fix conditional edges with better check
|
| 408 |
builder.add_conditional_edges(
|
| 409 |
"assistant",
|
| 410 |
+
tools_condition, # Use built-in tools_condition
|
| 411 |
+
"tools"
|
| 412 |
)
|
| 413 |
builder.add_edge("tools", "assistant")
|
| 414 |
+
|
| 415 |
+
# Add recursion_limit to prevent infinite loops
|
| 416 |
graph = builder.compile()
|
| 417 |
+
|
| 418 |
+
# Use invoke() with higher recursion limit
|
| 419 |
+
out = graph.invoke(state, {"recursion_limit": 10}) # Lower limit
|
| 420 |
last_message = out["messages"][-1].content
|
| 421 |
|
| 422 |
# Extract only the FINAL ANSWER part
|
|
|
|
| 432 |
def _assistant_node(self, state: dict) -> dict:
|
| 433 |
"""Process messages with the LLM."""
|
| 434 |
try:
|
| 435 |
+
# Check if messages exist and ensure proper format
|
| 436 |
+
if not state.get("messages") or len(state["messages"]) == 0:
|
| 437 |
# Add a system message if empty
|
| 438 |
+
state["messages"] = [SystemMessage(content=SYSTEM_PROMPT)]
|
| 439 |
+
|
| 440 |
+
# Ensure we have at least a system and user message
|
| 441 |
+
has_system = any(isinstance(m, SystemMessage) for m in state["messages"])
|
| 442 |
+
has_human = any(isinstance(m, HumanMessage) for m in state["messages"])
|
| 443 |
+
|
| 444 |
+
if not has_system:
|
| 445 |
+
state["messages"].insert(0, SystemMessage(content=SYSTEM_PROMPT))
|
| 446 |
+
|
| 447 |
+
if not has_human:
|
| 448 |
+
state["messages"].append(HumanMessage(content="Hello"))
|
| 449 |
+
|
| 450 |
# Invoke the chat model with our BaseMessage list
|
| 451 |
+
resp = self.llm.invoke(state["messages"])
|
| 452 |
state["messages"].append(resp)
|
| 453 |
return state
|
| 454 |
except Exception as e:
|
|
|
|
| 455 |
error_msg = f"Error calling LLM: {str(e)}"
|
| 456 |
print(error_msg)
|
| 457 |
+
print(f"Message count: {len(state.get('messages', []))}")
|
| 458 |
+
if state.get("messages"):
|
| 459 |
+
print(f"Message types: {[type(m).__name__ for m in state['messages']]}")
|
| 460 |
return state
|
| 461 |
|
| 462 |
|