Spaces:
Sleeping
Sleeping
| # ============================================================= | |
| # File: backend/tests/test_retry_system.py | |
| # ============================================================= | |
| """ | |
| Comprehensive tests for autonomous retry and self-correction system. | |
| Tests: | |
| 1. RAG retry with low scores (threshold adjustment + query expansion) | |
| 2. Web search retry with empty results (query rewriting) | |
| 3. Safe tool call retry mechanism | |
| 4. Rule safe message rewriting | |
| 5. Integration tests with reasoning traces | |
| 6. Analytics logging verification | |
| """ | |
| import sys | |
| from pathlib import Path | |
| import pytest | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| import asyncio | |
| # Add backend directory to Python path | |
| backend_dir = Path(__file__).parent.parent | |
| sys.path.insert(0, str(backend_dir)) | |
| try: | |
| HAS_PYTEST = True | |
| except ImportError: | |
| HAS_PYTEST = False | |
| class MockMark: | |
| def asyncio(self, func): | |
| return func | |
| class MockPytest: | |
| mark = MockMark() | |
| def fixture(self, func): | |
| return func | |
| pytest = MockPytest() | |
| from api.services.agent_orchestrator import AgentOrchestrator | |
| from api.models.agent import AgentRequest | |
| from api.models.redflag import RedFlagMatch | |
| # ============================================================= | |
| # FIXTURES | |
| # ============================================================= | |
| def mock_orchestrator(): | |
| """Create orchestrator with mocked dependencies.""" | |
| orch = AgentOrchestrator( | |
| rag_mcp_url="http://fake:8001", | |
| web_mcp_url="http://fake:8002", | |
| admin_mcp_url="http://fake:8003", | |
| llm_backend="ollama" | |
| ) | |
| # Mock MCP client | |
| orch.mcp = MagicMock() | |
| orch.analytics = MagicMock() | |
| orch.llm = MagicMock() | |
| orch.redflag = MagicMock() | |
| return orch | |
| # ============================================================= | |
| # RAG RETRY TESTS | |
| # ============================================================= | |
| async def test_rag_with_repair_high_score_no_retry(mock_orchestrator): | |
| """Test RAG repair doesn't retry when scores are good.""" | |
| # Mock high score result | |
| mock_orchestrator.mcp.call_rag = AsyncMock(return_value={ | |
| "results": [{"text": "relevant content", "score": 0.85}] | |
| }) | |
| reasoning_trace = [] | |
| result = await mock_orchestrator.rag_with_repair( | |
| query="test query", | |
| tenant_id="tenant1", | |
| reasoning_trace=reasoning_trace, | |
| user_id="user1" | |
| ) | |
| # Should only call once (no retry needed) | |
| assert mock_orchestrator.mcp.call_rag.call_count == 1 | |
| assert result["results"][0]["score"] == 0.85 | |
| async def test_rag_with_repair_low_score_retry_threshold(mock_orchestrator): | |
| """Test RAG repair retries with lower threshold when score < 0.30.""" | |
| # Mock first call - low score, second call - better score | |
| mock_orchestrator.mcp.call_rag = AsyncMock(side_effect=[ | |
| {"results": [{"text": "low relevance", "score": 0.25}]}, | |
| {"results": [{"text": "better match", "score": 0.45}]} | |
| ]) | |
| reasoning_trace = [] | |
| result = await mock_orchestrator.rag_with_repair( | |
| query="test query", | |
| tenant_id="tenant1", | |
| original_threshold=0.3, | |
| reasoning_trace=reasoning_trace, | |
| user_id="user1" | |
| ) | |
| # Should have retried with lower threshold (0.15) | |
| assert mock_orchestrator.mcp.call_rag.call_count == 2 | |
| # Check second call used threshold 0.15 | |
| second_call_kwargs = mock_orchestrator.mcp.call_rag.call_args_list[1].kwargs | |
| assert second_call_kwargs.get("threshold") == 0.15 | |
| # Verify reasoning trace has retry step | |
| retry_steps = [s for s in reasoning_trace if "retry" in str(s).lower()] | |
| assert len(retry_steps) > 0 | |
| async def test_rag_with_repair_expand_query(mock_orchestrator): | |
| """Test RAG repair expands query when score still low after threshold retry.""" | |
| # Mock: low score -> still low after threshold retry -> better after expansion | |
| mock_orchestrator.mcp.call_rag = AsyncMock(side_effect=[ | |
| {"results": [{"text": "low", "score": 0.12}]}, # Initial - very low | |
| {"results": [{"text": "still low", "score": 0.10}]}, # After threshold retry - still low | |
| {"results": [{"text": "better", "score": 0.35}]} # After query expansion - better | |
| ]) | |
| reasoning_trace = [] | |
| result = await mock_orchestrator.rag_with_repair( | |
| query="test", | |
| tenant_id="tenant1", | |
| original_threshold=0.3, | |
| reasoning_trace=reasoning_trace, | |
| user_id="user1" | |
| ) | |
| # Should have retried 3 times (initial + threshold + expanded query) | |
| assert mock_orchestrator.mcp.call_rag.call_count == 3 | |
| # Check reasoning trace has expanded query step | |
| expand_steps = [s for s in reasoning_trace if "expanded" in str(s).lower() or "expand" in str(s).lower()] | |
| assert len(expand_steps) > 0 | |
| # Verify analytics was called for retries | |
| assert mock_orchestrator.analytics.log_tool_usage.call_count > 1 | |
| async def test_rag_with_repair_no_results(mock_orchestrator): | |
| """Test RAG repair handles empty results gracefully.""" | |
| mock_orchestrator.mcp.call_rag = AsyncMock(return_value={ | |
| "results": [] | |
| }) | |
| reasoning_trace = [] | |
| result = await mock_orchestrator.rag_with_repair( | |
| query="test query", | |
| tenant_id="tenant1", | |
| reasoning_trace=reasoning_trace, | |
| user_id="user1" | |
| ) | |
| # Should handle gracefully (may retry or return empty) | |
| assert isinstance(result, dict) | |
| assert "results" in result | |
| # ============================================================= | |
| # WEB SEARCH RETRY TESTS | |
| # ============================================================= | |
| async def test_web_with_repair_has_results_no_retry(mock_orchestrator): | |
| """Test web repair doesn't retry when results are found.""" | |
| mock_orchestrator.mcp.call_web = AsyncMock(return_value={ | |
| "results": [ | |
| {"title": "Result 1", "snippet": "Content", "url": "http://example.com"} | |
| ] | |
| }) | |
| reasoning_trace = [] | |
| result = await mock_orchestrator.web_with_repair( | |
| query="normal query", | |
| tenant_id="tenant1", | |
| reasoning_trace=reasoning_trace, | |
| user_id="user1" | |
| ) | |
| # Should only call once (no retry needed) | |
| assert mock_orchestrator.mcp.call_web.call_count == 1 | |
| assert len(result["results"]) > 0 | |
| async def test_web_with_repair_empty_results_retry(mock_orchestrator): | |
| """Test web repair retries with rewritten query when results are empty.""" | |
| # Mock: empty -> empty -> success | |
| mock_orchestrator.mcp.call_web = AsyncMock(side_effect=[ | |
| {"results": []}, # Initial - empty | |
| {"results": []}, # First retry - still empty | |
| {"results": [{"title": "Found", "snippet": "Result", "url": "http://example.com"}]} # Second retry - success | |
| ]) | |
| reasoning_trace = [] | |
| result = await mock_orchestrator.web_with_repair( | |
| query="obscure query xyz", | |
| tenant_id="tenant1", | |
| reasoning_trace=reasoning_trace, | |
| user_id="user1" | |
| ) | |
| # Should have retried (up to 2 rewrites) | |
| assert mock_orchestrator.mcp.call_web.call_count >= 2 | |
| # Verify reasoning trace has retry steps | |
| retry_steps = [s for s in reasoning_trace if "retry" in str(s).lower()] | |
| assert len(retry_steps) > 0 | |
| # Check that rewritten queries were used | |
| # call_web takes positional args: (tenant_id, query) | |
| calls = mock_orchestrator.mcp.call_web.call_args_list | |
| rewritten_queries = [] | |
| for call in calls: | |
| # Extract query from positional args (args[1] after tenant_id) | |
| if len(call.args) > 1: | |
| rewritten_queries.append(call.args[1]) | |
| # Should have at least original + retry queries | |
| assert len(rewritten_queries) >= 2 | |
| # Check that at least one rewritten query contains our rewrite patterns | |
| assert any("best explanation" in str(q).lower() or "facts summary" in str(q).lower() | |
| for q in rewritten_queries if q) | |
| async def test_web_with_repair_analytics_logging(mock_orchestrator): | |
| """Test web repair logs retry attempts to analytics.""" | |
| mock_orchestrator.mcp.call_web = AsyncMock(side_effect=[ | |
| {"results": []}, | |
| {"results": [{"title": "Result", "snippet": "Content"}]} | |
| ]) | |
| await mock_orchestrator.web_with_repair( | |
| query="test", | |
| tenant_id="tenant1", | |
| user_id="user1" | |
| ) | |
| # Verify analytics was called | |
| assert mock_orchestrator.analytics.log_tool_usage.called | |
| # ============================================================= | |
| # SAFE TOOL CALL TESTS | |
| # ============================================================= | |
| async def test_safe_tool_call_success_first_attempt(mock_orchestrator): | |
| """Test safe_tool_call succeeds on first attempt.""" | |
| successful_tool = AsyncMock(return_value={"success": True, "data": "result"}) | |
| result = await mock_orchestrator.safe_tool_call( | |
| tool_fn=successful_tool, | |
| params={"param1": "value1"}, | |
| max_retries=2, | |
| tool_name="test_tool", | |
| tenant_id="tenant1", | |
| user_id="user1" | |
| ) | |
| # Should succeed on first try | |
| assert successful_tool.call_count == 1 | |
| assert result["success"] is True | |
| assert result["data"] == "result" | |
| async def test_safe_tool_call_retry_on_failure(mock_orchestrator): | |
| """Test safe_tool_call retries on failure.""" | |
| failing_tool = AsyncMock(side_effect=[ | |
| Exception("First failure"), | |
| {"success": True, "data": "recovered"} | |
| ]) | |
| reasoning_trace = [] | |
| result = await mock_orchestrator.safe_tool_call( | |
| tool_fn=failing_tool, | |
| params={}, | |
| max_retries=2, | |
| tool_name="test_tool", | |
| tenant_id="tenant1", | |
| user_id="user1", | |
| reasoning_trace=reasoning_trace | |
| ) | |
| # Should have retried | |
| assert failing_tool.call_count == 2 | |
| assert result["success"] is True | |
| # Verify reasoning trace has retry info | |
| retry_steps = [s for s in reasoning_trace if "retry" in str(s).lower()] | |
| assert len(retry_steps) > 0 | |
| async def test_safe_tool_call_exhausts_retries(mock_orchestrator): | |
| """Test safe_tool_call returns error after all retries exhausted.""" | |
| failing_tool = AsyncMock(side_effect=Exception("Always fails")) | |
| reasoning_trace = [] | |
| result = await mock_orchestrator.safe_tool_call( | |
| tool_fn=failing_tool, | |
| params={}, | |
| max_retries=2, | |
| tool_name="test_tool", | |
| tenant_id="tenant1", | |
| user_id="user1", | |
| reasoning_trace=reasoning_trace | |
| ) | |
| # Should have retried max_retries times | |
| assert failing_tool.call_count == 2 | |
| assert "error" in result | |
| # Verify analytics logged failures | |
| assert mock_orchestrator.analytics.log_tool_usage.called | |
| async def test_safe_tool_call_fallback_params(mock_orchestrator): | |
| """Test safe_tool_call uses fallback params on retry.""" | |
| tool_calls = [] | |
| async def mock_tool_async(**kwargs): | |
| tool_calls.append(kwargs.copy()) | |
| if len(tool_calls) == 1: | |
| raise Exception("First attempt failed") | |
| return {"success": True, "params": kwargs} | |
| result = await mock_orchestrator.safe_tool_call( | |
| tool_fn=mock_tool_async, | |
| params={"param1": "value1"}, | |
| max_retries=2, | |
| fallback_params={"param1": "fallback_value"}, | |
| tool_name="test_tool", | |
| tenant_id="tenant1" | |
| ) | |
| # Should have used fallback params on retry | |
| assert len(tool_calls) == 2 | |
| assert tool_calls[0]["param1"] == "value1" # Original params | |
| assert tool_calls[1]["param1"] == "fallback_value" # Fallback params on retry | |
| assert result["success"] is True | |
| # ============================================================= | |
| # RULE SAFE MESSAGE TESTS | |
| # ============================================================= | |
| async def test_rule_safe_message_no_violations(mock_orchestrator): | |
| """Test rule_safe_message returns original when no violations.""" | |
| mock_orchestrator.redflag.check = AsyncMock(return_value=[]) | |
| safe_msg = await mock_orchestrator.rule_safe_message( | |
| user_message="Normal message", | |
| tenant_id="tenant1" | |
| ) | |
| # Should return original message | |
| assert safe_msg == "Normal message" | |
| assert mock_orchestrator.redflag.check.call_count == 1 | |
| async def test_rule_safe_message_rewrites_violation(mock_orchestrator): | |
| """Test rule_safe_message rewrites violating messages.""" | |
| # Mock redflag check - first call violates, second (rewritten) passes | |
| violation = RedFlagMatch( | |
| rule_id="1", | |
| pattern="salary", | |
| severity="high", | |
| description="salary access", | |
| matched_text="salary" | |
| ) | |
| mock_orchestrator.redflag.check = AsyncMock(side_effect=[ | |
| [violation], # Original message violates | |
| [] # Rewritten message is safe | |
| ]) | |
| mock_orchestrator.llm.simple_call = AsyncMock( | |
| return_value="This is a compliant version of your request about compensation" | |
| ) | |
| reasoning_trace = [] | |
| safe_msg = await mock_orchestrator.rule_safe_message( | |
| user_message="I want to see salary info", | |
| tenant_id="tenant1", | |
| reasoning_trace=reasoning_trace | |
| ) | |
| # Should have checked rules twice (original + rewritten) | |
| assert mock_orchestrator.redflag.check.call_count == 2 | |
| # Should have called LLM to rewrite | |
| assert mock_orchestrator.llm.simple_call.called | |
| # Should return rewritten message | |
| assert "compliant" in safe_msg.lower() or safe_msg != "I want to see salary info" | |
| # Verify reasoning trace | |
| rewrite_steps = [s for s in reasoning_trace if "rewrite" in str(s).lower()] | |
| assert len(rewrite_steps) > 0 | |
| async def test_rule_safe_message_brief_rule_no_rewrite(mock_orchestrator): | |
| """Test rule_safe_message doesn't rewrite brief response rules.""" | |
| # Brief response rules are handled separately, so should return original | |
| brief_rule = RedFlagMatch( | |
| rule_id="1", | |
| pattern="greeting", | |
| severity="low", | |
| description="greeting", | |
| matched_text="hi" | |
| ) | |
| mock_orchestrator.redflag.check = AsyncMock(return_value=[brief_rule]) | |
| safe_msg = await mock_orchestrator.rule_safe_message( | |
| user_message="Hi there", | |
| tenant_id="tenant1" | |
| ) | |
| # Should return original (brief rules are handled elsewhere) | |
| assert safe_msg == "Hi there" | |
| async def test_rule_safe_message_llm_failure_fallback(mock_orchestrator): | |
| """Test rule_safe_message falls back to original if LLM rewrite fails.""" | |
| violation = RedFlagMatch( | |
| rule_id="1", | |
| pattern="blocked", | |
| severity="high", | |
| description="blocked", | |
| matched_text="blocked" | |
| ) | |
| mock_orchestrator.redflag.check = AsyncMock(return_value=[violation]) | |
| mock_orchestrator.llm.simple_call = AsyncMock(side_effect=Exception("LLM failed")) | |
| original_msg = "I want blocked content" | |
| safe_msg = await mock_orchestrator.rule_safe_message( | |
| user_message=original_msg, | |
| tenant_id="tenant1" | |
| ) | |
| # Should return original message if rewrite fails | |
| assert safe_msg == original_msg | |
| # ============================================================= | |
| # INTEGRATION TESTS | |
| # ============================================================= | |
| async def test_rag_integration_reasoning_trace(mock_orchestrator): | |
| """Test RAG retry steps appear in reasoning trace.""" | |
| mock_orchestrator.mcp.call_rag = AsyncMock(side_effect=[ | |
| {"results": [{"text": "low", "score": 0.20}]}, | |
| {"results": [{"text": "better", "score": 0.50}]} | |
| ]) | |
| reasoning_trace = [] | |
| await mock_orchestrator.rag_with_repair( | |
| query="test", | |
| tenant_id="tenant1", | |
| reasoning_trace=reasoning_trace, | |
| user_id="user1" | |
| ) | |
| # Check reasoning trace has retry information | |
| trace_str = str(reasoning_trace).lower() | |
| assert "retry" in trace_str or "threshold" in trace_str | |
| async def test_web_integration_reasoning_trace(mock_orchestrator): | |
| """Test web retry steps appear in reasoning trace.""" | |
| mock_orchestrator.mcp.call_web = AsyncMock(side_effect=[ | |
| {"results": []}, | |
| {"results": [{"title": "Result", "snippet": "Content"}]} | |
| ]) | |
| reasoning_trace = [] | |
| await mock_orchestrator.web_with_repair( | |
| query="test", | |
| tenant_id="tenant1", | |
| reasoning_trace=reasoning_trace, | |
| user_id="user1" | |
| ) | |
| # Check reasoning trace has retry information | |
| trace_str = str(reasoning_trace).lower() | |
| assert "retry" in trace_str or "rewritten" in trace_str | |
| async def test_analytics_logging_on_retries(mock_orchestrator): | |
| """Test that retry attempts are logged to analytics.""" | |
| mock_orchestrator.mcp.call_rag = AsyncMock(side_effect=[ | |
| {"results": [{"text": "low", "score": 0.25}]}, | |
| {"results": [{"text": "better", "score": 0.45}]} | |
| ]) | |
| await mock_orchestrator.rag_with_repair( | |
| query="test", | |
| tenant_id="tenant1", | |
| user_id="user1" | |
| ) | |
| # Verify analytics was called (for initial + retry) | |
| assert mock_orchestrator.analytics.log_tool_usage.call_count > 0 | |
| # Verify RAG search was logged | |
| assert mock_orchestrator.analytics.log_rag_search.called | |
| async def test_full_agent_flow_with_retry(mock_orchestrator): | |
| """Test full agent flow integrates retry system.""" | |
| # Setup mocks for a full agent request | |
| mock_orchestrator.intent = MagicMock() | |
| mock_orchestrator.intent.classify = AsyncMock(return_value="rag") | |
| mock_orchestrator.selector = MagicMock() | |
| from api.models.agent import AgentDecision | |
| mock_orchestrator.selector.select = AsyncMock(return_value=AgentDecision( | |
| action="call_tool", | |
| tool="rag", | |
| tool_input={"query": "test query"}, | |
| reason="test" | |
| )) | |
| mock_orchestrator.redflag.check = AsyncMock(return_value=[]) | |
| mock_orchestrator.mcp.call_rag = AsyncMock(side_effect=[ | |
| {"results": [{"text": "low relevance", "score": 0.25}]}, | |
| {"results": [{"text": "better match", "score": 0.50}]} | |
| ]) | |
| mock_orchestrator.llm.simple_call = AsyncMock(return_value="Final answer") | |
| # Create request | |
| req = AgentRequest( | |
| tenant_id="tenant1", | |
| user_id="user1", | |
| message="test query" | |
| ) | |
| # Handle request | |
| response = await mock_orchestrator.handle(req) | |
| # Verify retry happened (2 RAG calls) | |
| assert mock_orchestrator.mcp.call_rag.call_count == 2 | |
| # Verify response is generated | |
| assert response.text == "Final answer" | |
| # Verify reasoning trace contains retry info | |
| trace_str = str(response.reasoning_trace).lower() | |
| # Should have retry or repair related steps | |
| # ============================================================= | |
| # EDGE CASES | |
| # ============================================================= | |
| async def test_rag_repair_edge_case_exactly_threshold(mock_orchestrator): | |
| """Test RAG repair behavior at threshold boundary.""" | |
| # Score exactly at threshold - should not retry | |
| mock_orchestrator.mcp.call_rag = AsyncMock(return_value={ | |
| "results": [{"text": "content", "score": 0.30}]} # Exactly at threshold | |
| ) | |
| reasoning_trace = [] | |
| await mock_orchestrator.rag_with_repair( | |
| query="test", | |
| tenant_id="tenant1", | |
| original_threshold=0.3, | |
| reasoning_trace=reasoning_trace, | |
| user_id="user1" | |
| ) | |
| # Should not retry (score >= 0.30) | |
| assert mock_orchestrator.mcp.call_rag.call_count == 1 | |
| async def test_web_repair_all_retries_fail(mock_orchestrator): | |
| """Test web repair handles case where all retries return empty.""" | |
| mock_orchestrator.mcp.call_web = AsyncMock(return_value={"results": []}) | |
| reasoning_trace = [] | |
| result = await mock_orchestrator.web_with_repair( | |
| query="very obscure query", | |
| tenant_id="tenant1", | |
| reasoning_trace=reasoning_trace, | |
| user_id="user1" | |
| ) | |
| # Should have attempted retries | |
| assert mock_orchestrator.mcp.call_web.call_count >= 2 | |
| # Should still return result (even if empty) | |
| assert isinstance(result, dict) | |
| if __name__ == "__main__": | |
| # Allow running tests directly | |
| print("Running retry system tests...") | |
| pytest.main([__file__, "-v", "--tb=short"]) | |