from pydantic import BaseModel, Field from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage, filter_messages from src.utils.prompts import execution_agent_prompt, compress_execution_system_prompt, compress_execution_human_message from src.utils.utils import think_tool, track_package, estimated_time_analysis, get_today_str import logging tools = [think_tool, track_package, estimated_time_analysis] tools_by_name = {tool.name: tool for tool in tools} class ExecutorNode: """ Executor node for handling tasks: 1. LLM reasoning 2. Tool invocation 3. Final compression """ def __init__(self, llm): self.llm = llm self.tools = tools self.tools_by_name = {tool.name: tool for tool in tools} self.model_with_tools = llm.bind_tools(tools) self.MAX_ITERATIONS = 6 # Increased to allow more tool calls (including think_tool) self.execution_agent_prompt_template = execution_agent_prompt self.compress_execution_system_prompt_template = compress_execution_system_prompt self.compress_execution_human_message = compress_execution_human_message # Debug tool binding print(f"Available tools: {list(self.tools_by_name.keys())}") def llm_call(self, state: dict) -> dict: """Calls the LLM with the executor message history and returns updated state.""" try: # Ensure we have the execution job in the messages execution_job = state.get("execution_job", "") existing_messages = state.get("executor_messages", []) print("EXECUTOR MESSAGES MESSAGES", existing_messages) print("EXECUTION JOB", execution_job) # If no existing messages, add the execution job as initial human message if not existing_messages and execution_job: existing_messages = [HumanMessage(content=execution_job)] # Format the prompt with current date formatted_prompt = self.execution_agent_prompt_template.format(date=get_today_str()) messages = [SystemMessage(content=formatted_prompt)] + existing_messages print(f"Calling LLM with {len(messages)} messages") print(f"Last message: {messages[-1] if messages else 'No messages'}") response = self.model_with_tools.invoke(messages) print(f"LLM Response type: {type(response)}") print(f"LLM Response content: {response.content[:100] if response.content else 'No content'}...") print(f"Tool calls in response: {getattr(response, 'tool_calls', 'No tool_calls attribute')}") return { **state, "executor_messages": existing_messages + [response] } except Exception as e: return { **state, "error": str(e), "executor_messages": state.get("executor_messages", []) } def tool_node(self, state: dict) -> dict: """Executes any tools requested by the LLM and appends ToolMessages.""" try: executor_messages = state.get("executor_messages", []) if not executor_messages: print("No executor messages found") return state last_message = executor_messages[-1] print(f"Last message type: {type(last_message)}") print(f"Last message attributes: {dir(last_message)}") # Get tool calls tool_calls = getattr(last_message, "tool_calls", []) print(f"Found {len(tool_calls)} tool calls: {tool_calls}") if not tool_calls: print("No tool calls found in last message") return state tool_outputs, new_data = [], [] for call in tool_calls: print(f"Processing tool call: {call}") tool_name = call.get("name") args = call.get("args", {}) tool_id = call.get("id") print(f"Tool: {tool_name}, Args: {args}, ID: {tool_id}") if tool_name in self.tools_by_name: try: print(f"Invoking tool {tool_name} with args {args}") result = self.tools_by_name[tool_name].invoke(args) print(f"Tool {tool_name} result: {result}") tool_message = ToolMessage( content=str(result), name=tool_name, tool_call_id=tool_id ) tool_outputs.append(tool_message) new_data.append(str(result)) except Exception as e: error_msg = f"Tool {tool_name} failed: {e}" print(f"Tool error: {error_msg}") tool_outputs.append( ToolMessage( content=error_msg, name=tool_name, tool_call_id=tool_id ) ) new_data.append(error_msg) else: error_msg = f"Tool {tool_name} not found. Available: {list(self.tools_by_name.keys())}" print(error_msg) tool_outputs.append( ToolMessage( content=error_msg, name=tool_name, tool_call_id=tool_id ) ) print(f"Returning {len(tool_outputs)} tool outputs") return { **state, "executor_messages": executor_messages + tool_outputs, "executor_data": state.get("executor_data", []) + new_data } except Exception as e: return { **state, "error": f"Tool execution failed: {str(e)}" } def compress_execution(self, state: dict) -> dict: """Summarizes the execution and returns final structured output.""" try: execution_job = state.get("execution_job", "Complete the assigned task") executor_messages = state.get("executor_messages", []) # Format the system prompt with current date formatted_system_prompt = self.compress_execution_system_prompt_template.format(date=get_today_str()) messages = [ SystemMessage(content=formatted_system_prompt), *executor_messages, HumanMessage(content=self.compress_execution_human_message.format( shipment_request=execution_job )) ] response = self.llm.invoke(messages) executor_data = [ str(m.content) for m in executor_messages if hasattr(m, 'content') and m.content ] return { "output": str(response.content), "executor_data": executor_data, "executor_messages": executor_messages } except Exception as e: return { "output": f"Execution completed with errors: {str(e)}", "executor_data": state.get("executor_data", []), "executor_messages": state.get("executor_messages", []) } def route_after_llm(self, state: dict) -> str: """Route: decide whether to call a tool or finalize.""" try: executor_messages = state.get("executor_messages", []) if not executor_messages: return "compress_execution" last_msg = executor_messages[-1] has_tool_calls = bool(getattr(last_msg, "tool_calls", None)) print(f"Routing decision - Has tool calls: {has_tool_calls}") return "tool_node" if has_tool_calls else "compress_execution" except Exception as e: return "compress_execution" def guard_llm(self, state: dict) -> str: """Prevent infinite loops by limiting iterations.""" iteration_count = state.get("iteration_count", 0) + 1 state["iteration_count"] = iteration_count print(f"Iteration count: {iteration_count}/{self.MAX_ITERATIONS}") if iteration_count > self.MAX_ITERATIONS: print("Max iterations reached, finalizing...") return "compress_execution" return self.route_after_llm(state)