""" Comprehensive Unit Test Suite for Agent Workflow. Tests all possible question scenarios to ensure proper routing and memory tracking. Run with: python backend/tests/test_workflow_comprehensive.py """ import sys import os # Add parent directory to path for module imports sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) import pytest import asyncio import json from unittest.mock import AsyncMock, MagicMock, patch # Test utilities def create_mock_state(session_id="test-session", messages=None, image_data_list=None): """Create a mock AgentState for testing.""" from langchain_core.messages import HumanMessage return { "session_id": session_id, "messages": messages or [HumanMessage(content="Test question")], "image_data_list": image_data_list or [], "ocr_text": "", "ocr_results": [], "execution_plan": None, "question_results": [], "current_agent": "planner", "final_response": None, "tool_result": None, "tool_success": False, "agents_used": [], "tools_called": [], "model_calls": [], "context_status": "normal", "context_message": "", "session_token_count": 0, # Additional required fields "total_tokens": 0, "total_duration_ms": 0, "selected_tool": None, "should_use_tools": False, "wolfram_query": None, "wolfram_attempts": 0, "code_task": None, "generated_code": None, "error_message": None, "image_data": None, } class TestPlannerNode: """Tests for planner_node routing logic.""" @pytest.mark.asyncio async def test_all_direct_returns_text(self): """Test Case 1: All direct questions -> Planner returns text, current_agent='done'.""" from backend.agent.nodes import planner_node state = create_mock_state() # Mock LLM to return plain text (all direct answers) mock_response = MagicMock() mock_response.content = "## Bài 1:\nĐây là lời giải câu 1.\n\n## Bài 2:\nĐây là lời giải câu 2." with patch("backend.agent.nodes.get_model") as mock_get_model, \ patch("backend.agent.nodes.memory_tracker") as mock_memory: mock_llm = AsyncMock() mock_llm.ainvoke.return_value = mock_response mock_get_model.return_value = mock_llm mock_status = MagicMock() mock_status.status = "normal" mock_status.used_tokens = 100 mock_status.message = "" mock_memory.check_status.return_value = mock_status result = await planner_node(state) assert result["current_agent"] == "done", "All-direct should set current_agent to 'done'" assert result["final_response"] is not None, "Should have final_response set" assert "Bài 1" in result["final_response"], "Should contain direct answer" print("✅ Test Case 1 PASSED: All Direct -> Text -> Done") @pytest.mark.asyncio async def test_mixed_questions_returns_json(self): """Test Case 2: Mixed questions -> Planner returns JSON, current_agent='executor'.""" from backend.agent.nodes import planner_node state = create_mock_state() # Mock LLM to return JSON (mixed questions) mock_json = { "questions": [ {"id": 1, "content": "Câu hỏi 1", "type": "direct", "answer": "Đáp án 1"}, {"id": 2, "content": "Câu hỏi 2", "type": "code", "tool_input": "Viết code..."} ] } mock_response = MagicMock() mock_response.content = json.dumps(mock_json) with patch("backend.agent.nodes.get_model") as mock_get_model, \ patch("backend.agent.nodes.memory_tracker") as mock_memory: mock_llm = AsyncMock() mock_llm.ainvoke.return_value = mock_response mock_get_model.return_value = mock_llm mock_status = MagicMock() mock_status.status = "normal" mock_status.used_tokens = 100 mock_status.message = "" mock_memory.check_status.return_value = mock_status result = await planner_node(state) assert result["current_agent"] == "executor", "Mixed questions should route to executor" assert result["execution_plan"] is not None, "Should have execution_plan set" assert len(result["execution_plan"]["questions"]) == 2, "Plan should have 2 questions" print("✅ Test Case 2 PASSED: Mixed -> JSON -> Executor") @pytest.mark.asyncio async def test_memory_overflow_blocks_execution(self): """Test Case 5: Memory overflow should stop execution.""" from backend.agent.nodes import planner_node state = create_mock_state() mock_response = MagicMock() mock_response.content = json.dumps({"questions": [{"id": 1, "type": "code", "tool_input": "x"}]}) with patch("backend.agent.nodes.get_model") as mock_get_model, \ patch("backend.agent.nodes.memory_tracker") as mock_memory: mock_llm = AsyncMock() mock_llm.ainvoke.return_value = mock_response mock_get_model.return_value = mock_llm # Simulate memory overflow mock_status = MagicMock() mock_status.status = "blocked" mock_status.used_tokens = 100000 mock_status.message = "Bộ nhớ phiên đã đầy!" mock_memory.check_status.return_value = mock_status result = await planner_node(state) assert result["current_agent"] == "done", "Memory overflow should stop execution" assert "Bộ nhớ" in result["final_response"], "Should show memory warning" print("✅ Test Case 5 PASSED: Memory Overflow -> Blocked") @pytest.mark.asyncio async def test_json_repair_latex_backslashes(self): """Test Case 6: JSON with LaTeX backslashes should be repaired.""" from backend.agent.nodes import planner_node state = create_mock_state() # Mock LLM to return JSON with unescaped LaTeX raw_json = r'{"questions":[{"id":1,"type":"code","content":"\\iint_D \\frac{dx}{x}","tool_input":"calc"}]}' mock_response = MagicMock() mock_response.content = raw_json with patch("backend.agent.nodes.get_model") as mock_get_model, \ patch("backend.agent.nodes.memory_tracker") as mock_memory: mock_llm = AsyncMock() mock_llm.ainvoke.return_value = mock_response mock_get_model.return_value = mock_llm mock_status = MagicMock() mock_status.status = "normal" mock_status.used_tokens = 100 mock_status.message = "" mock_memory.check_status.return_value = mock_status result = await planner_node(state) # Should successfully parse (repair backslashes) assert result["execution_plan"] is not None or result["current_agent"] == "done", \ "Should either parse JSON or treat as direct answer" print("✅ Test Case 6 PASSED: JSON Repair (LaTeX)") class TestParallelExecutor: """Tests for parallel_executor_node.""" @pytest.mark.asyncio async def test_direct_uses_answer_field(self): """Test: Direct questions should use pre-generated answer, not call LLM.""" from backend.agent.nodes import parallel_executor_node state = create_mock_state() state["execution_plan"] = { "questions": [ {"id": 1, "type": "direct", "content": "Câu hỏi", "answer": "Đáp án sẵn có"} ] } with patch("backend.agent.nodes.get_model") as mock_get_model, \ patch("backend.agent.nodes.memory_tracker") as mock_memory: # LLM should NOT be called for direct type with answer mock_status = MagicMock() mock_status.status = "normal" mock_status.used_tokens = 100 mock_status.message = "" mock_memory.check_status.return_value = mock_status result = await parallel_executor_node(state) assert result["current_agent"] == "synthetic", "Should route to synthetic" assert len(result["question_results"]) == 1, "Should have 1 result" assert result["question_results"][0]["result"] == "Đáp án sẵn có", "Should use pre-generated answer" print("✅ Test: Direct with Answer Field -> Uses Cached Answer") class TestRouteAgent: """Tests for route_agent function.""" def test_route_done_returns_done(self): """Test: current_agent='done' should return 'done'.""" from backend.agent.nodes import route_agent state = {"current_agent": "done"} result = route_agent(state) assert result == "done", "Should return 'done' for done state" print("✅ Test: route_agent('done') -> 'done'") def test_route_executor_returns_executor(self): """Test: current_agent='executor' should return 'executor'.""" from backend.agent.nodes import route_agent state = {"current_agent": "executor"} result = route_agent(state) assert result == "executor", "Should return 'executor' for executor state" print("✅ Test: route_agent('executor') -> 'executor'") # Run tests if __name__ == "__main__": print("=" * 60) print("RUNNING COMPREHENSIVE WORKFLOW UNIT TESTS") print("=" * 60) async def run_all(): # Planner tests planner_tests = TestPlannerNode() await planner_tests.test_all_direct_returns_text() await planner_tests.test_mixed_questions_returns_json() await planner_tests.test_memory_overflow_blocks_execution() await planner_tests.test_json_repair_latex_backslashes() # Executor tests executor_tests = TestParallelExecutor() await executor_tests.test_direct_uses_answer_field() # Route tests route_tests = TestRouteAgent() route_tests.test_route_done_returns_done() route_tests.test_route_executor_returns_executor() print("\n" + "=" * 60) print("ALL TESTS PASSED ✅") print("=" * 60) asyncio.run(run_all())