File size: 2,934 Bytes
ba5110e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import asyncio
import sys
import os
from unittest.mock import MagicMock, patch

# Add project root to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

from backend.agent.state import create_initial_state
from backend.agent.nodes import parallel_executor_node
from langchain_core.messages import AIMessage

# Colors
GREEN = "\033[92m"
BLUE = "\033[94m"
RED = "\033[91m"
RESET = "\033[0m"

async def test_code_smart_retry():
    print(f"{BLUE}📌 TEST: Code Tool Smart Retry (Self-Correction){RESET}")
    
    state = create_initial_state(session_id="test_retry")
    state["execution_plan"] = {
        "questions": [
            {"id": 1, "type": "code", "content": "Fix me", "tool_input": "Run bad code"}
        ]
    }
    
    with patch("backend.agent.nodes.CodeTool") as mock_code_tool_cls:
        with patch("backend.agent.nodes.get_model") as mock_get_model:
            
            # --- MOCK LLM RESPONSES ---
            mock_llm = MagicMock()
            
            # Response 1: Bad Code
            # Response 2: Fixed Code
            async def mock_llm_call(messages):
                content = messages[0].content
                if "LỖI GẶP PHẢI" in content: # Check if it's the FIX prompt
                    print(f"   [LLM Input]: Received Error Feedback -> Generating Fix...")
                    return AIMessage(content="```python\nprint('Fixed')\n```")
                else:
                    print(f"   [LLM Input]: First Attempt -> Generating Bad Code...")
                    return AIMessage(content="```python\nprint(1/0)\n```")
                    
            mock_llm.ainvoke.side_effect = mock_llm_call
            mock_get_model.return_value = mock_llm
            
            # --- MOCK CODE EXECUTOR ---
            mock_tool_instance = MagicMock()
            
            async def mock_exec(code):
                if "1/0" in code:
                    return {"success": False, "error": "ZeroDivisionError"}
                else:
                    return {"success": True, "output": "Fixed Output"}
            
            mock_tool_instance.execute.side_effect = mock_exec
            mock_code_tool_cls.return_value = mock_tool_instance
            
            # --- RUN EXECUTOR ---
            state = await parallel_executor_node(state)
            
    # --- ASSERTIONS ---
    results = state.get("question_results", [])
    if not results:
        print(f"{RED}❌ No results found{RESET}")
        return False
        
    res = results[0]
    result_text = str(res.get("result"))
    
    if "Fixed Output" in result_text:
        print(f"{GREEN}✅ Code succeeded after retry{RESET}")
        return True
    else:
        print(f"{RED}❌ Failed to self-correct. Result: {result_text}, Error: {res.get('error')}{RESET}")
        return False

if __name__ == "__main__":
    asyncio.run(test_code_smart_retry())