File size: 7,574 Bytes
8816dfd 5dd4236 8816dfd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | from typing import Dict, Any
from langchain_core.messages import ToolMessage
import json
import logging
logger = logging.getLogger("ReAct Tool Execution")
def _get_tools_from_registry(workflow_id: int):
"""
Get tools from the global registry using workflow ID.
This avoids storing non-serializable tool objects in state.
"""
from ComputeAgent.graph.graph_ReAct import _TOOLS_REGISTRY
tools = _TOOLS_REGISTRY.get(workflow_id)
if tools is None:
raise ValueError(f"Tools not found in registry for workflow_id: {workflow_id}")
return tools
async def tool_execution_node(state: Dict[str, Any]) -> Dict[str, Any]:
"""
Node that executes approved tools and handles special researcher tool case.
Args:
state: Current ReAct state with approved tool calls
Returns:
Updated state with tool results and special handling for researcher
"""
approved_calls = state.get("approved_tool_calls", [])
if not approved_calls:
logger.info("βΉοΈ No approved tool calls to execute")
return state
# Get tools from registry using workflow_id (avoids serialization issues)
workflow_id = state.get("workflow_id")
if not workflow_id:
logger.error("β No workflow_id in state - cannot retrieve tools")
return state
try:
tools = _get_tools_from_registry(workflow_id)
tools_dict = {tool.name: tool for tool in tools}
logger.info(f"β
Retrieved {len(tools)} tools from registry")
except ValueError as e:
logger.error(f"β {e}")
return state
tool_results = []
researcher_executed = False
instance_created = False
logger.info(f"β‘ Executing {len(approved_calls)} approved tool call(s)")
for tool_call in approved_calls:
tool_name = tool_call['name']
try:
tool = tools_dict.get(tool_name)
if not tool:
error_msg = f"Error: Tool {tool_name} not found."
logger.error(error_msg)
tool_results.append(
ToolMessage(
content=error_msg,
tool_call_id=tool_call['id']
)
)
continue
logger.info(f"π Executing tool: {tool_name}")
result = await tool.ainvoke(tool_call['args'])
# Special handling for create_compute_instance tool
if tool_name == "create_compute_instance":
instance_created = True
logger.info("π create_compute_instance tool executed - storing instance details")
# Extract instance_id and status from result
# Result may be a string containing JSON or a dict
try:
logger.info(f"π Raw result type: {type(result)}, value: {result}")
if isinstance(result, str):
# Parse JSON string to dict
result_dict = json.loads(result)
elif isinstance(result, dict):
result_dict = result
else:
result_dict = {}
# The result may be nested in a 'result' key
if "result" in result_dict and isinstance(result_dict["result"], dict):
instance_data = result_dict["result"]
else:
instance_data = result_dict
instance_id = instance_data.get("id", "")
instance_status = str(instance_data.get("status", ""))
logger.info(f"π Extracted instance_id: '{instance_id}', status: '{instance_status}'")
# Store in state for generate node
state["instance_id"] = instance_id
state["instance_status"] = instance_status
state["instance_created"] = True
logger.info(f"β
Instance created and stored in state: {instance_id} (status: {instance_status})")
except (json.JSONDecodeError, AttributeError) as e:
logger.warning(f"β οΈ Could not parse result from create_compute_instance: {e}")
logger.warning(f"β οΈ Result: {result}")
state["instance_created"] = False
# Store the result for tool results
tool_results.append(
ToolMessage(
content=str(result),
tool_call_id=tool_call['id']
)
)
# Special handling for researcher tool
elif tool_name == "research":
researcher_executed = True
logger.info("π Researcher tool executed - storing results for generation")
# Set flag to indicate researcher was used
state["researcher_used"] = True
# Store the research result for the generate node to use
tool_results.append(
ToolMessage(
content=str(result),
tool_call_id=tool_call['id']
)
)
logger.info("β
Researcher tool completed - results stored for generation")
else:
# Regular tool execution
tool_results.append(
ToolMessage(
content=str(result),
tool_call_id=tool_call['id']
)
)
logger.info(f"β
Tool {tool_name} executed successfully")
except Exception as e:
error_msg = f"Error executing tool {tool_name}: {str(e)}"
logger.error(error_msg)
tool_results.append(
ToolMessage(
content=error_msg,
tool_call_id=tool_call['id']
)
)
# Update state with execution results
updated_state = state.copy()
# Append new tool results to existing ones for multi-tool scenarios
existing_results = updated_state.get("tool_results", [])
updated_state["tool_results"] = existing_results + tool_results
updated_state["messages"] = state["messages"] + tool_results
updated_state["approved_tool_calls"] = [] # Clear approved calls
updated_state["researcher_executed"] = researcher_executed
updated_state["skip_refinement"] = researcher_executed # Skip refinement if researcher executed
updated_state["current_step"] = "tool_execution_complete"
# Ensure instance creation flags are preserved in updated_state
if state.get("instance_created"):
updated_state["instance_created"] = state["instance_created"]
updated_state["instance_id"] = state.get("instance_id", "")
updated_state["instance_status"] = state.get("instance_status", "")
logger.info(f"β
Instance creation flags preserved in state: {updated_state['instance_id']}")
# Clear force_refinement flag after tool execution
if "force_refinement" in updated_state:
del updated_state["force_refinement"]
# NOTE: Don't remove tools here - agent_reasoning needs them next
# Tools are only removed in terminal nodes (generate, tool_rejection_exit)
logger.info(f"π Tool execution completed: {len(tool_results)} new results, {len(updated_state['tool_results'])} total results")
return updated_state |