carraraig's picture
finish (#8)
5dd4236 verified
from typing import Dict, Any
from langchain_core.messages import ToolMessage
import json
import logging
logger = logging.getLogger("ReAct Tool Execution")
def _get_tools_from_registry(workflow_id: int):
"""
Get tools from the global registry using workflow ID.
This avoids storing non-serializable tool objects in state.
"""
from ComputeAgent.graph.graph_ReAct import _TOOLS_REGISTRY
tools = _TOOLS_REGISTRY.get(workflow_id)
if tools is None:
raise ValueError(f"Tools not found in registry for workflow_id: {workflow_id}")
return tools
async def tool_execution_node(state: Dict[str, Any]) -> Dict[str, Any]:
"""
Node that executes approved tools and handles special researcher tool case.
Args:
state: Current ReAct state with approved tool calls
Returns:
Updated state with tool results and special handling for researcher
"""
approved_calls = state.get("approved_tool_calls", [])
if not approved_calls:
logger.info("ℹ️ No approved tool calls to execute")
return state
# Get tools from registry using workflow_id (avoids serialization issues)
workflow_id = state.get("workflow_id")
if not workflow_id:
logger.error("❌ No workflow_id in state - cannot retrieve tools")
return state
try:
tools = _get_tools_from_registry(workflow_id)
tools_dict = {tool.name: tool for tool in tools}
logger.info(f"βœ… Retrieved {len(tools)} tools from registry")
except ValueError as e:
logger.error(f"❌ {e}")
return state
tool_results = []
researcher_executed = False
instance_created = False
logger.info(f"⚑ Executing {len(approved_calls)} approved tool call(s)")
for tool_call in approved_calls:
tool_name = tool_call['name']
try:
tool = tools_dict.get(tool_name)
if not tool:
error_msg = f"Error: Tool {tool_name} not found."
logger.error(error_msg)
tool_results.append(
ToolMessage(
content=error_msg,
tool_call_id=tool_call['id']
)
)
continue
logger.info(f"πŸ”„ Executing tool: {tool_name}")
result = await tool.ainvoke(tool_call['args'])
# Special handling for create_compute_instance tool
if tool_name == "create_compute_instance":
instance_created = True
logger.info("πŸš€ create_compute_instance tool executed - storing instance details")
# Extract instance_id and status from result
# Result may be a string containing JSON or a dict
try:
logger.info(f"πŸ“‹ Raw result type: {type(result)}, value: {result}")
if isinstance(result, str):
# Parse JSON string to dict
result_dict = json.loads(result)
elif isinstance(result, dict):
result_dict = result
else:
result_dict = {}
# The result may be nested in a 'result' key
if "result" in result_dict and isinstance(result_dict["result"], dict):
instance_data = result_dict["result"]
else:
instance_data = result_dict
instance_id = instance_data.get("id", "")
instance_status = str(instance_data.get("status", ""))
logger.info(f"πŸ“‹ Extracted instance_id: '{instance_id}', status: '{instance_status}'")
# Store in state for generate node
state["instance_id"] = instance_id
state["instance_status"] = instance_status
state["instance_created"] = True
logger.info(f"βœ… Instance created and stored in state: {instance_id} (status: {instance_status})")
except (json.JSONDecodeError, AttributeError) as e:
logger.warning(f"⚠️ Could not parse result from create_compute_instance: {e}")
logger.warning(f"⚠️ Result: {result}")
state["instance_created"] = False
# Store the result for tool results
tool_results.append(
ToolMessage(
content=str(result),
tool_call_id=tool_call['id']
)
)
# Special handling for researcher tool
elif tool_name == "research":
researcher_executed = True
logger.info("🌐 Researcher tool executed - storing results for generation")
# Set flag to indicate researcher was used
state["researcher_used"] = True
# Store the research result for the generate node to use
tool_results.append(
ToolMessage(
content=str(result),
tool_call_id=tool_call['id']
)
)
logger.info("βœ… Researcher tool completed - results stored for generation")
else:
# Regular tool execution
tool_results.append(
ToolMessage(
content=str(result),
tool_call_id=tool_call['id']
)
)
logger.info(f"βœ… Tool {tool_name} executed successfully")
except Exception as e:
error_msg = f"Error executing tool {tool_name}: {str(e)}"
logger.error(error_msg)
tool_results.append(
ToolMessage(
content=error_msg,
tool_call_id=tool_call['id']
)
)
# Update state with execution results
updated_state = state.copy()
# Append new tool results to existing ones for multi-tool scenarios
existing_results = updated_state.get("tool_results", [])
updated_state["tool_results"] = existing_results + tool_results
updated_state["messages"] = state["messages"] + tool_results
updated_state["approved_tool_calls"] = [] # Clear approved calls
updated_state["researcher_executed"] = researcher_executed
updated_state["skip_refinement"] = researcher_executed # Skip refinement if researcher executed
updated_state["current_step"] = "tool_execution_complete"
# Ensure instance creation flags are preserved in updated_state
if state.get("instance_created"):
updated_state["instance_created"] = state["instance_created"]
updated_state["instance_id"] = state.get("instance_id", "")
updated_state["instance_status"] = state.get("instance_status", "")
logger.info(f"βœ… Instance creation flags preserved in state: {updated_state['instance_id']}")
# Clear force_refinement flag after tool execution
if "force_refinement" in updated_state:
del updated_state["force_refinement"]
# NOTE: Don't remove tools here - agent_reasoning needs them next
# Tools are only removed in terminal nodes (generate, tool_rejection_exit)
logger.info(f"πŸ“ˆ Tool execution completed: {len(tool_results)} new results, {len(updated_state['tool_results'])} total results")
return updated_state