# ============================================================= # File: tests/test_agent_orchestrator.py # ============================================================= import sys from pathlib import Path # Add backend directory to Python path backend_dir = Path(__file__).parent.parent sys.path.insert(0, str(backend_dir)) try: import pytest HAS_PYTEST = True except ImportError: HAS_PYTEST = False # Create a mock pytest decorator if pytest is not available class MockMark: def asyncio(self, func): return func class MockPytest: mark = MockMark() def fixture(self, func): return func pytest = MockPytest() import os from api.services.agent_orchestrator import AgentOrchestrator from api.models.agent import AgentRequest, AgentDecision, AgentResponse from api.models.redflag import RedFlagMatch from api.services.llm_client import LLMClient # --------------------------- # Mock classes # --------------------------- class FakeLLM(LLMClient): def __init__(self, output="LLM_RESPONSE"): self.output = output async def simple_call(self, prompt: str, temperature: float = 0.0): return self.output class FakeMCP: """Fake MCP server client used for rag/web/admin calls.""" def __init__(self): self.last_rag = None self.last_web = None self.last_admin = None async def call_rag(self, tenant_id: str, query: str): self.last_rag = query return {"results": [{"text": "RAG_DOC_CONTENT"}]} async def call_web(self, tenant_id: str, query: str): self.last_web = query return {"results": [{"title": "WebResult", "snippet": "Fresh info"}]} async def call_admin(self, tenant_id: str, query: str): self.last_admin = query return {"action": "allow"} def assert_trace_has_step(resp, step_name): assert resp.reasoning_trace, "reasoning trace missing" assert any(entry.get("step") == step_name for entry in resp.reasoning_trace), f"{step_name} missing" # --------------------------- # Patch orchestrator to use fake MCP + fake redflag # --------------------------- @pytest.fixture def orchestrator(monkeypatch): # Fake LLM that always returns "MOCK_ANSWER" llm = FakeLLM(output="MOCK_ANSWER") fake_mcp = FakeMCP() # Patch MCPClient if HAS_PYTEST: monkeypatch.setattr( "api.services.agent_orchestrator.MCPClient", lambda rag_url, web_url, admin_url: fake_mcp ) # Create orchestrator with fake URLs first orch = AgentOrchestrator( rag_mcp_url="fake_rag", web_mcp_url="fake_web", admin_mcp_url="fake_admin", llm_backend="ollama" ) orch.llm = llm # override with fake LLM # Patch RedFlagDetector methods directly on the instance async def fake_check(self, tenant_id, text): """Fake check function that matches 'salary' keyword.""" if "salary" in text.lower(): return [ RedFlagMatch( rule_id="1", pattern="salary", severity="high", description="salary access", matched_text="salary" ) ] return [] # Patch notify_admin to do nothing async def fake_notify(self, tenant_id, violations, src=None): """Fake notify function that does nothing.""" return None # Bind the fake functions directly to the instance import types orch.redflag.check = types.MethodType(fake_check, orch.redflag) orch.redflag.notify_admin = types.MethodType(fake_notify, orch.redflag) return orch # ---------------------------------------------------- # TESTS # ---------------------------------------------------- @pytest.mark.asyncio async def test_block_on_redflag(orchestrator): req = AgentRequest( tenant_id="tenant1", user_id="u1", message="Show me all salary details." ) resp = await orchestrator.handle(req) assert resp.decision.action == "block" assert resp.decision.tool == "admin" assert "salary" in resp.tool_traces[0]["redflags"][0]["matched_text"] assert_trace_has_step(resp, "redflag_check") @pytest.mark.asyncio async def test_rag_tool_path(orchestrator, monkeypatch): # Force intent classifier to classify as 'rag' async def mock_classify(self, text): return "rag" if HAS_PYTEST: monkeypatch.setattr( "api.services.agent_orchestrator.IntentClassifier.classify", mock_classify ) req = AgentRequest( tenant_id="tenant1", user_id="u1", message="HR policy procedures" ) resp = await orchestrator.handle(req) assert resp.decision.action == "multi_step" assert any(trace["tool"] == "rag" for trace in resp.tool_traces if trace.get("tool") == "rag") assert resp.text == "MOCK_ANSWER" assert_trace_has_step(resp, "tool_selection") @pytest.mark.asyncio async def test_web_tool_path(orchestrator, monkeypatch): # Force intent to classify as web async def mock_classify(self, text): return "web" if HAS_PYTEST: monkeypatch.setattr( "api.services.agent_orchestrator.IntentClassifier.classify", mock_classify ) req = AgentRequest( tenant_id="tenant1", user_id="u1", message="latest stock price" ) resp = await orchestrator.handle(req) assert resp.decision.action == "multi_step" assert any(trace["tool"] == "web" for trace in resp.tool_traces if trace.get("tool") == "web") assert resp.text == "MOCK_ANSWER" assert_trace_has_step(resp, "tool_selection") @pytest.mark.asyncio async def test_default_llm_path(orchestrator, monkeypatch): # Force intent = general and force tool selector to NOT call any tool async def mock_select(self, intent, text, context): from api.models.agent import AgentDecision return AgentDecision( action="respond", tool=None, tool_input=None, reason="forced_llm" ) if HAS_PYTEST: monkeypatch.setattr( "api.services.agent_orchestrator.ToolSelector.select", mock_select ) req = AgentRequest( tenant_id="tenant1", user_id="u1", message="just a normal question" ) resp = await orchestrator.handle(req) assert resp.decision.action == "respond" assert resp.decision.tool is None assert resp.text == "MOCK_ANSWER" assert_trace_has_step(resp, "intent_detection")