File size: 7,574 Bytes
8816dfd
 
 
 
 
 
 
 
 
 
 
 
 
5dd4236
8816dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
from typing import Dict, Any
from langchain_core.messages import ToolMessage
import json
import logging

logger = logging.getLogger("ReAct Tool Execution")


def _get_tools_from_registry(workflow_id: int):
    """
    Get tools from the global registry using workflow ID.
    This avoids storing non-serializable tool objects in state.
    """
    from ComputeAgent.graph.graph_ReAct import _TOOLS_REGISTRY
    tools = _TOOLS_REGISTRY.get(workflow_id)
    if tools is None:
        raise ValueError(f"Tools not found in registry for workflow_id: {workflow_id}")
    return tools


async def tool_execution_node(state: Dict[str, Any]) -> Dict[str, Any]:
    """
    Node that executes approved tools and handles special researcher tool case.
    
    Args:
        state: Current ReAct state with approved tool calls
        
    Returns:
        Updated state with tool results and special handling for researcher
    """
    approved_calls = state.get("approved_tool_calls", [])

    if not approved_calls:
        logger.info("ℹ️ No approved tool calls to execute")
        return state

    # Get tools from registry using workflow_id (avoids serialization issues)
    workflow_id = state.get("workflow_id")
    if not workflow_id:
        logger.error("❌ No workflow_id in state - cannot retrieve tools")
        return state

    try:
        tools = _get_tools_from_registry(workflow_id)
        tools_dict = {tool.name: tool for tool in tools}
        logger.info(f"βœ… Retrieved {len(tools)} tools from registry")
    except ValueError as e:
        logger.error(f"❌ {e}")
        return state

    tool_results = []
    researcher_executed = False
    instance_created = False

    logger.info(f"⚑ Executing {len(approved_calls)} approved tool call(s)")

    for tool_call in approved_calls:
        tool_name = tool_call['name']

        try:
            tool = tools_dict.get(tool_name)
            if not tool:
                error_msg = f"Error: Tool {tool_name} not found."
                logger.error(error_msg)
                tool_results.append(
                    ToolMessage(
                        content=error_msg,
                        tool_call_id=tool_call['id']
                    )
                )
                continue

            logger.info(f"πŸ”„ Executing tool: {tool_name}")
            result = await tool.ainvoke(tool_call['args'])

            # Special handling for create_compute_instance tool
            if tool_name == "create_compute_instance":
                instance_created = True
                logger.info("πŸš€ create_compute_instance tool executed - storing instance details")

                # Extract instance_id and status from result
                # Result may be a string containing JSON or a dict
                try:
                    logger.info(f"πŸ“‹ Raw result type: {type(result)}, value: {result}")

                    if isinstance(result, str):
                        # Parse JSON string to dict
                        result_dict = json.loads(result)
                    elif isinstance(result, dict):
                        result_dict = result
                    else:
                        result_dict = {}

                    # The result may be nested in a 'result' key
                    if "result" in result_dict and isinstance(result_dict["result"], dict):
                        instance_data = result_dict["result"]
                    else:
                        instance_data = result_dict

                    instance_id = instance_data.get("id", "")
                    instance_status = str(instance_data.get("status", ""))

                    logger.info(f"πŸ“‹ Extracted instance_id: '{instance_id}', status: '{instance_status}'")

                    # Store in state for generate node
                    state["instance_id"] = instance_id
                    state["instance_status"] = instance_status
                    state["instance_created"] = True

                    logger.info(f"βœ… Instance created and stored in state: {instance_id} (status: {instance_status})")
                except (json.JSONDecodeError, AttributeError) as e:
                    logger.warning(f"⚠️ Could not parse result from create_compute_instance: {e}")
                    logger.warning(f"⚠️ Result: {result}")
                    state["instance_created"] = False

                # Store the result for tool results
                tool_results.append(
                    ToolMessage(
                        content=str(result),
                        tool_call_id=tool_call['id']
                    )
                )

            # Special handling for researcher tool
            elif tool_name == "research":
                researcher_executed = True
                logger.info("🌐 Researcher tool executed - storing results for generation")

                # Set flag to indicate researcher was used
                state["researcher_used"] = True

                # Store the research result for the generate node to use
                tool_results.append(
                    ToolMessage(
                        content=str(result),
                        tool_call_id=tool_call['id']
                    )
                )

                logger.info("βœ… Researcher tool completed - results stored for generation")
            else:
                # Regular tool execution
                tool_results.append(
                    ToolMessage(
                        content=str(result),
                        tool_call_id=tool_call['id']
                    )
                )
                logger.info(f"βœ… Tool {tool_name} executed successfully")
                
        except Exception as e:
            error_msg = f"Error executing tool {tool_name}: {str(e)}"
            logger.error(error_msg)
            tool_results.append(
                ToolMessage(
                    content=error_msg,
                    tool_call_id=tool_call['id']
                )
            )
    
    # Update state with execution results
    updated_state = state.copy()

    # Append new tool results to existing ones for multi-tool scenarios
    existing_results = updated_state.get("tool_results", [])
    updated_state["tool_results"] = existing_results + tool_results

    updated_state["messages"] = state["messages"] + tool_results
    updated_state["approved_tool_calls"] = []  # Clear approved calls
    updated_state["researcher_executed"] = researcher_executed
    updated_state["skip_refinement"] = researcher_executed  # Skip refinement if researcher executed
    updated_state["current_step"] = "tool_execution_complete"

    # Ensure instance creation flags are preserved in updated_state
    if state.get("instance_created"):
        updated_state["instance_created"] = state["instance_created"]
        updated_state["instance_id"] = state.get("instance_id", "")
        updated_state["instance_status"] = state.get("instance_status", "")
        logger.info(f"βœ… Instance creation flags preserved in state: {updated_state['instance_id']}")

    # Clear force_refinement flag after tool execution
    if "force_refinement" in updated_state:
        del updated_state["force_refinement"]

    # NOTE: Don't remove tools here - agent_reasoning needs them next
    # Tools are only removed in terminal nodes (generate, tool_rejection_exit)

    logger.info(f"πŸ“ˆ Tool execution completed: {len(tool_results)} new results, {len(updated_state['tool_results'])} total results")

    return updated_state