| | from typing import Dict, Any |
| | from langchain_core.messages import ToolMessage |
| | import json |
| | import logging |
| |
|
| | logger = logging.getLogger("ReAct Tool Execution") |
| |
|
| |
|
| | def _get_tools_from_registry(workflow_id: int): |
| | """ |
| | Get tools from the global registry using workflow ID. |
| | This avoids storing non-serializable tool objects in state. |
| | """ |
| | from ComputeAgent.graph.graph_ReAct import _TOOLS_REGISTRY |
| | tools = _TOOLS_REGISTRY.get(workflow_id) |
| | if tools is None: |
| | raise ValueError(f"Tools not found in registry for workflow_id: {workflow_id}") |
| | return tools |
| |
|
| |
|
| | async def tool_execution_node(state: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Node that executes approved tools and handles special researcher tool case. |
| | |
| | Args: |
| | state: Current ReAct state with approved tool calls |
| | |
| | Returns: |
| | Updated state with tool results and special handling for researcher |
| | """ |
| | approved_calls = state.get("approved_tool_calls", []) |
| |
|
| | if not approved_calls: |
| | logger.info("βΉοΈ No approved tool calls to execute") |
| | return state |
| |
|
| | |
| | workflow_id = state.get("workflow_id") |
| | if not workflow_id: |
| | logger.error("β No workflow_id in state - cannot retrieve tools") |
| | return state |
| |
|
| | try: |
| | tools = _get_tools_from_registry(workflow_id) |
| | tools_dict = {tool.name: tool for tool in tools} |
| | logger.info(f"β
Retrieved {len(tools)} tools from registry") |
| | except ValueError as e: |
| | logger.error(f"β {e}") |
| | return state |
| |
|
| | tool_results = [] |
| | researcher_executed = False |
| | instance_created = False |
| |
|
| | logger.info(f"β‘ Executing {len(approved_calls)} approved tool call(s)") |
| |
|
| | for tool_call in approved_calls: |
| | tool_name = tool_call['name'] |
| |
|
| | try: |
| | tool = tools_dict.get(tool_name) |
| | if not tool: |
| | error_msg = f"Error: Tool {tool_name} not found." |
| | logger.error(error_msg) |
| | tool_results.append( |
| | ToolMessage( |
| | content=error_msg, |
| | tool_call_id=tool_call['id'] |
| | ) |
| | ) |
| | continue |
| |
|
| | logger.info(f"π Executing tool: {tool_name}") |
| | result = await tool.ainvoke(tool_call['args']) |
| |
|
| | |
| | if tool_name == "create_compute_instance": |
| | instance_created = True |
| | logger.info("π create_compute_instance tool executed - storing instance details") |
| |
|
| | |
| | |
| | try: |
| | logger.info(f"π Raw result type: {type(result)}, value: {result}") |
| |
|
| | if isinstance(result, str): |
| | |
| | result_dict = json.loads(result) |
| | elif isinstance(result, dict): |
| | result_dict = result |
| | else: |
| | result_dict = {} |
| |
|
| | |
| | if "result" in result_dict and isinstance(result_dict["result"], dict): |
| | instance_data = result_dict["result"] |
| | else: |
| | instance_data = result_dict |
| |
|
| | instance_id = instance_data.get("id", "") |
| | instance_status = str(instance_data.get("status", "")) |
| |
|
| | logger.info(f"π Extracted instance_id: '{instance_id}', status: '{instance_status}'") |
| |
|
| | |
| | state["instance_id"] = instance_id |
| | state["instance_status"] = instance_status |
| | state["instance_created"] = True |
| |
|
| | logger.info(f"β
Instance created and stored in state: {instance_id} (status: {instance_status})") |
| | except (json.JSONDecodeError, AttributeError) as e: |
| | logger.warning(f"β οΈ Could not parse result from create_compute_instance: {e}") |
| | logger.warning(f"β οΈ Result: {result}") |
| | state["instance_created"] = False |
| |
|
| | |
| | tool_results.append( |
| | ToolMessage( |
| | content=str(result), |
| | tool_call_id=tool_call['id'] |
| | ) |
| | ) |
| |
|
| | |
| | elif tool_name == "research": |
| | researcher_executed = True |
| | logger.info("π Researcher tool executed - storing results for generation") |
| |
|
| | |
| | state["researcher_used"] = True |
| |
|
| | |
| | tool_results.append( |
| | ToolMessage( |
| | content=str(result), |
| | tool_call_id=tool_call['id'] |
| | ) |
| | ) |
| |
|
| | logger.info("β
Researcher tool completed - results stored for generation") |
| | else: |
| | |
| | tool_results.append( |
| | ToolMessage( |
| | content=str(result), |
| | tool_call_id=tool_call['id'] |
| | ) |
| | ) |
| | logger.info(f"β
Tool {tool_name} executed successfully") |
| | |
| | except Exception as e: |
| | error_msg = f"Error executing tool {tool_name}: {str(e)}" |
| | logger.error(error_msg) |
| | tool_results.append( |
| | ToolMessage( |
| | content=error_msg, |
| | tool_call_id=tool_call['id'] |
| | ) |
| | ) |
| | |
| | |
| | updated_state = state.copy() |
| |
|
| | |
| | existing_results = updated_state.get("tool_results", []) |
| | updated_state["tool_results"] = existing_results + tool_results |
| |
|
| | updated_state["messages"] = state["messages"] + tool_results |
| | updated_state["approved_tool_calls"] = [] |
| | updated_state["researcher_executed"] = researcher_executed |
| | updated_state["skip_refinement"] = researcher_executed |
| | updated_state["current_step"] = "tool_execution_complete" |
| |
|
| | |
| | if state.get("instance_created"): |
| | updated_state["instance_created"] = state["instance_created"] |
| | updated_state["instance_id"] = state.get("instance_id", "") |
| | updated_state["instance_status"] = state.get("instance_status", "") |
| | logger.info(f"β
Instance creation flags preserved in state: {updated_state['instance_id']}") |
| |
|
| | |
| | if "force_refinement" in updated_state: |
| | del updated_state["force_refinement"] |
| |
|
| | |
| | |
| |
|
| | logger.info(f"π Tool execution completed: {len(tool_results)} new results, {len(updated_state['tool_results'])} total results") |
| |
|
| | return updated_state |