File size: 15,314 Bytes
8816dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd4236
8816dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
"""
Agent Reasoning Node for ReAct Pattern - Enhanced Version

This module implements enhanced agent reasoning with support for:
1. Initial tool selection based on query
2. Re-reasoning after tool execution with results
3. Re-reasoning after user feedback/modifications
4. Memory context integration

Key Enhancements:
    - User feedback integration for re-reasoning
    - Modified tool context awareness
    - Conversation history preservation
    - Memory-enhanced reasoning

Author: ComputeAgent Team
"""

from typing import Dict, Any
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from constant import Constants
import logging

logger = logging.getLogger("ReAct Agent Reasoning")


def _get_llm_from_registry(workflow_id: int):
    """
    Get LLM from the global registry using workflow ID.
    This avoids storing non-serializable LLM objects in state.
    """
    from ComputeAgent.graph.graph_ReAct import _LLM_REGISTRY
    llm = _LLM_REGISTRY.get(workflow_id)
    if llm is None:
        raise ValueError(f"LLM not found in registry for workflow_id: {workflow_id}")
    return llm


async def agent_reasoning_node(state: Dict[str, Any]) -> Dict[str, Any]:
    """
    Enhanced agent reasoning node that handles initial reasoning and re-reasoning.

    Supports three reasoning scenarios:
    1. Initial reasoning: Fresh query, no prior tool executions
    2. Post-execution reasoning: After tools executed, decide if more tools needed
    3. Re-reasoning: After user feedback/modifications, reconsider approach

    Special handling for deployment workflow:
    - Detects when in deployment mode (capacity_approved=True)
    - Provides specific instructions for calling create_compute_instance
    - Passes deployment parameters from capacity estimation

    Args:
        state: Current ReAct state

    Returns:
        Updated state with tool calls or completion decision
    """
    # Extract state information
    query = state.get("query", "")
    messages = state.get("messages", [])
    tool_results = state.get("tool_results", [])
    user_id = state.get("user_id", "")
    session_id = state.get("session_id", "")
    needs_re_reasoning = state.get("needs_re_reasoning", False)
    re_reasoning_feedback = state.get("re_reasoning_feedback", "")
    modified_tool_calls = state.get("modified_tool_calls", [])

    # Extract deployment-specific information
    capacity_approved = state.get("capacity_approved", False)
    model_name = state.get("model_name", "")
    model_info = state.get("model_info", {})
    gpu_requirements = state.get("gpu_requirements", {})
    estimated_gpu_memory = state.get("estimated_gpu_memory", 0)

    # Get LLM from registry using workflow_id (avoids serialization issues)
    workflow_id = state.get("workflow_id")
    if not workflow_id:
        logger.error("❌ No workflow_id in state - cannot retrieve LLM")
        updated_state = state.copy()
        updated_state["pending_tool_calls"] = []
        updated_state["current_step"] = "agent_reasoning_error"
        updated_state["error"] = "Missing workflow_id"
        return updated_state

    try:
        llm = _get_llm_from_registry(workflow_id)
        logger.info(f"βœ… Retrieved LLM from registry")
    except ValueError as e:
        logger.error(f"❌ {e}")
        updated_state = state.copy()
        updated_state["pending_tool_calls"] = []
        updated_state["current_step"] = "agent_reasoning_error"
        updated_state["error"] = str(e)
        return updated_state
    
    # Determine reasoning scenario
    if needs_re_reasoning:
        logger.info("πŸ”„ Re-reasoning mode: User requested reconsideration")
        reasoning_mode = "re_reasoning"
    elif tool_results:
        logger.info("πŸ”„ Post-execution mode: Evaluating if more tools needed")
        reasoning_mode = "post_execution"
    else:
        logger.info("🎯 Initial reasoning mode: Processing fresh query")
        reasoning_mode = "initial"
    
    # Build memory context if available
    memory_context = ""
    if user_id and session_id:
        try:
            from helpers.memory import get_memory_manager
            memory_manager = get_memory_manager()
            memory_context = await memory_manager.build_context_for_node(
                user_id, 
                session_id, 
                "agent_reasoning"
            )
            if memory_context:
                logger.info("🧠 Using memory context for reasoning")
        except Exception as e:
            logger.warning(f"⚠️ Could not load memory context: {e}")
    
    # Build reasoning prompt based on scenario
    reasoning_prompt = _build_reasoning_prompt(
        query=query,
        reasoning_mode=reasoning_mode,
        memory_context=memory_context,
        tool_results=tool_results,
        re_reasoning_feedback=re_reasoning_feedback,
        modified_tool_calls=modified_tool_calls,
        # Pass deployment context
        capacity_approved=capacity_approved,
        model_name=model_name,
        model_info=model_info,
        gpu_requirements=gpu_requirements,
        estimated_gpu_memory=estimated_gpu_memory
    )
    
    # Prepare messages for LLM - ALWAYS include conversation history for context
    if messages:
        # Include conversation history so agent can reference previous responses
        llm_messages = messages + [HumanMessage(content=reasoning_prompt)]
        logger.info(f"πŸ“ Including {len(messages)} previous messages for context")
    else:
        # First message in conversation
        llm_messages = [HumanMessage(content=reasoning_prompt)]
        logger.info("πŸ“ Starting new conversation (no previous messages)")
    
    logger.info(f"πŸ€– Invoking LLM for {reasoning_mode} reasoning...")
    
    try:
        # Invoke LLM with tools bound
        response = await llm.ainvoke(llm_messages)
        
        # Extract tool calls if any
        tool_calls = []
        if hasattr(response, 'tool_calls') and response.tool_calls:
            tool_calls = [
                {
                    "id": tc.get("id", f"call_{i}"),
                    "name": tc.get("name"),
                    "args": tc.get("args", {})
                }
                for i, tc in enumerate(response.tool_calls)
            ]
            logger.info(f"πŸ”§ Agent selected {len(tool_calls)} tool(s)")
        else:
            logger.info("βœ… Agent decided no tools needed - ready to generate response")
        
        # Update state
        updated_state = state.copy()
        updated_state["messages"] = llm_messages + [response]
        updated_state["pending_tool_calls"] = tool_calls
        updated_state["current_step"] = "agent_reasoning_complete"
        
        # Clear re-reasoning flags after processing
        if needs_re_reasoning:
            updated_state["needs_re_reasoning"] = False
            updated_state["re_reasoning_feedback"] = ""
            logger.info("πŸ”„ Re-reasoning complete, flags cleared")
        
        # Clear modified tool calls after processing
        if modified_tool_calls:
            updated_state["modified_tool_calls"] = []

        # NOTE: Don't remove tools here - they may be needed for next node
        # Tools are only removed in terminal nodes (generate, tool_rejection_exit)

        return updated_state
        
    except Exception as e:
        logger.error(f"❌ Error in agent reasoning: {e}")

        # Fallback: set empty tool calls to proceed to generation
        updated_state = state.copy()
        updated_state["pending_tool_calls"] = []
        updated_state["current_step"] = "agent_reasoning_error"
        updated_state["error"] = str(e)

        # NOTE: Don't remove tools here - they may be needed for next node
        # Tools are only removed in terminal nodes (generate, tool_rejection_exit)

        return updated_state


def _build_reasoning_prompt(
    query: str,
    reasoning_mode: str,
    memory_context: str,
    tool_results: list,
    re_reasoning_feedback: str,
    modified_tool_calls: list,
    capacity_approved: bool = False,
    model_name: str = "",
    model_info: dict = None,
    gpu_requirements: dict = None,
    estimated_gpu_memory: float = 0
) -> str:
    """
    Build appropriate reasoning prompt based on the reasoning scenario.

    Args:
        query: Original user query
        reasoning_mode: "initial", "post_execution", or "re_reasoning"
        memory_context: Conversation memory context
        tool_results: Previous tool execution results
        re_reasoning_feedback: User's feedback for re-reasoning
        modified_tool_calls: Tools that were modified by user
        capacity_approved: Whether in deployment workflow with approved capacity
        model_name: Name of model to deploy
        model_info: Model information from capacity estimation
        gpu_requirements: GPU requirements from capacity estimation
        estimated_gpu_memory: Estimated GPU memory

    Returns:
        Formatted reasoning prompt
    """
    base_prompt = Constants.GENERAL_SYSTEM_PROMPT

    # Handle deployment workflow
    if capacity_approved and reasoning_mode == "initial":
        # Deployment-specific reasoning
        if model_info is None:
            model_info = {}
        if gpu_requirements is None:
            gpu_requirements = {}

        # Get deployment parameters
        location = model_info.get("location", "UAE-1")
        gpu_type = model_info.get("GPU_type", "RTX 4090")
        num_gpus = gpu_requirements.get(gpu_type, 1)
        config = f"{num_gpus}x {gpu_type}"

        deployment_instructions = f"""
πŸš€ **DEPLOYMENT MODE ACTIVATED** πŸš€

You are in a model deployment workflow. The capacity has been approved and you need to create a compute instance.

**Deployment Information:**
- Model to deploy: {model_name}
- Approved Location: {location}
- Required GPU Configuration: {config}
- GPU Memory Required: {estimated_gpu_memory:.2f} GB

**YOUR TASK:**
Call the `create_compute_instance` tool with appropriate arguments based on the deployment information above.

**IMPORTANT:**
1. Review the tool's specification to understand the valid parameter values
2. Use the deployment information provided to determine the correct arguments:
   - For the `name` parameter: Format the model name "{model_name}" following these rules:
     * Convert to lowercase
     * Replace forward slashes (/) with hyphens (-)
     * Replace dots (.) with hyphens (-)
     * Replace underscores (_) with hyphens (-)
     * Keep existing hyphens as-is
   - For the `location` parameter: Map the approved location to the tool's valid location format (see mapping below)
   - For the `config` parameter: Use the exact GPU configuration "{config}"
3. After the tool returns the instance_id and status, do NOT call any other tools
4. The generate node will handle creating the deployment instructions

**Location Mapping (map approved location to MCP tool format):**
- "UAE-1" or "uae-1" or "UAE" β†’ use "uae"
- "UAE-2" or "uae-2" β†’ use "uae-2"
- "France" or "FRANCE" β†’ use "france"
- "Texas" or "TEXAS" β†’ use "texas"

**Example name formatting:**
- "meta-llama/Llama-3.1-8B" β†’ "meta-llama-llama-3-1-8b"
- "Qwen/Qwen2.5-7B" β†’ "qwen-qwen2-5-7b"
- "google/gemma-2-9b" β†’ "google-gemma-2-9b"

Make sure your tool call arguments exactly match the MCP tool's specification format.
"""

        prompt = f"""{base_prompt}

{deployment_instructions}

User Query: {query}

{f"Conversation Context: {memory_context}" if memory_context else ""}"""

        return prompt

    if reasoning_mode == "initial":
        # Initial reasoning (non-deployment)
        # Include available model information for tool calls
        model_info_text = f"""
Available Models:
- For general queries: {Constants.DEFAULT_LLM_NAME}
- For function calling: {Constants.DEFAULT_LLM_FC}

When calling the research tool, use the model parameter: "{Constants.DEFAULT_LLM_NAME}"
"""
        prompt = f"""{base_prompt}

{model_info_text}

User Query: {query}

{f"Conversation Context: {memory_context}" if memory_context else ""}

IMPORTANT INSTRUCTIONS:
1. **Check conversation history first**: If this is a follow-up question, review previous messages to see if you already have the information.
2. **Avoid redundant tool calls**: Don't call tools to fetch information you've already provided in this conversation.
3. **Answer directly when possible**: If you can answer based on previous responses or your knowledge, respond without calling tools.
4. **Use tools only when necessary**: Only call tools if you genuinely need new information that isn't available in the conversation history.

When calling tools that require a "model" parameter (like the research tool),
use the model "{Constants.DEFAULT_LLM_NAME}" unless the user explicitly requests a different model."""
    
    elif reasoning_mode == "post_execution":
        # Post-execution reasoning
        tool_results_summary = "\n\n".join([
            f"Tool {i+1} ({getattr(r, 'name', 'unknown')}): {getattr(r, 'content', str(r))}"
            for i, r in enumerate(tool_results)
        ])

        prompt = f"""{base_prompt}

Original Query: {query}

{f"Conversation Context: {memory_context}" if memory_context else ""}

Tool Execution Results:
{tool_results_summary}

IMPORTANT: Evaluate if you have enough information to answer the user's query.

Decision Logic:
1. If the tool results provide sufficient information to answer the query β†’ DO NOT call any tools (respond without tool calls)
2. Only if critical information is still missing β†’ Select specific tools to gather that information

Remember:
- The generate node will format your final response, so you don't need to call tools just to format data
- Be efficient - don't call tools unless absolutely necessary
- If you respond without calling tools, the workflow will move to generate the final answer"""
    
    else:  # re_reasoning
        # Re-reasoning after user feedback
        model_info = f"""
Available Models:
- For general queries: {Constants.DEFAULT_LLM_NAME}
- For function calling: {Constants.DEFAULT_LLM_FC}

When calling the research tool, use the model parameter: "{Constants.DEFAULT_LLM_NAME}"
"""
        modified_summary = ""
        if modified_tool_calls:
            modified_summary = "\n\nUser Modified These Tools:\n" + "\n".join([
                f"- Tool {mod['index']}: {mod['modified']['name']} with args {mod['modified']['args']}"
                for mod in modified_tool_calls
            ])

        prompt = f"""{base_prompt}

{model_info}

Original Query: {query}

{f"Conversation Context: {memory_context}" if memory_context else ""}

User Feedback: {re_reasoning_feedback}

{modified_summary}

The user has provided feedback on your previous tool selection. Please reconsider your approach:
1. Review the user's feedback carefully
2. Consider the modified tool arguments if provided
3. Determine a new strategy that addresses the user's concerns

Select appropriate tools based on this feedback, or proceed without tools if you can now answer directly.

IMPORTANT: When calling tools that require a "model" parameter (like the research tool),
use the model "{Constants.DEFAULT_LLM_NAME}" unless the user explicitly requests a different model."""
    
    return prompt