Spaces:
Sleeping
Sleeping
| # ============================================================= | |
| # File: tests/test_agent_orchestrator.py | |
| # ============================================================= | |
| import sys | |
| from pathlib import Path | |
| # Add backend directory to Python path | |
| backend_dir = Path(__file__).parent.parent | |
| sys.path.insert(0, str(backend_dir)) | |
| try: | |
| import pytest | |
| HAS_PYTEST = True | |
| except ImportError: | |
| HAS_PYTEST = False | |
| # Create a mock pytest decorator if pytest is not available | |
| class MockMark: | |
| def asyncio(self, func): | |
| return func | |
| class MockPytest: | |
| mark = MockMark() | |
| def fixture(self, func): | |
| return func | |
| pytest = MockPytest() | |
| import os | |
| from api.services.agent_orchestrator import AgentOrchestrator | |
| from api.models.agent import AgentRequest, AgentDecision, AgentResponse | |
| from api.models.redflag import RedFlagMatch | |
| from api.services.llm_client import LLMClient | |
| # --------------------------- | |
| # Mock classes | |
| # --------------------------- | |
| class FakeLLM(LLMClient): | |
| def __init__(self, output="LLM_RESPONSE"): | |
| self.output = output | |
| async def simple_call(self, prompt: str, temperature: float = 0.0): | |
| return self.output | |
| class FakeMCP: | |
| """Fake MCP server client used for rag/web/admin calls.""" | |
| def __init__(self): | |
| self.last_rag = None | |
| self.last_web = None | |
| self.last_admin = None | |
| async def call_rag(self, tenant_id: str, query: str): | |
| self.last_rag = query | |
| return {"results": [{"text": "RAG_DOC_CONTENT"}]} | |
| async def call_web(self, tenant_id: str, query: str): | |
| self.last_web = query | |
| return {"results": [{"title": "WebResult", "snippet": "Fresh info"}]} | |
| async def call_admin(self, tenant_id: str, query: str): | |
| self.last_admin = query | |
| return {"action": "allow"} | |
| def assert_trace_has_step(resp, step_name): | |
| assert resp.reasoning_trace, "reasoning trace missing" | |
| assert any(entry.get("step") == step_name for entry in resp.reasoning_trace), f"{step_name} missing" | |
| # --------------------------- | |
| # Patch orchestrator to use fake MCP + fake redflag | |
| # --------------------------- | |
| def orchestrator(monkeypatch): | |
| # Fake LLM that always returns "MOCK_ANSWER" | |
| llm = FakeLLM(output="MOCK_ANSWER") | |
| fake_mcp = FakeMCP() | |
| # Patch MCPClient | |
| if HAS_PYTEST: | |
| monkeypatch.setattr( | |
| "api.services.agent_orchestrator.MCPClient", | |
| lambda rag_url, web_url, admin_url: fake_mcp | |
| ) | |
| # Create orchestrator with fake URLs first | |
| orch = AgentOrchestrator( | |
| rag_mcp_url="fake_rag", | |
| web_mcp_url="fake_web", | |
| admin_mcp_url="fake_admin", | |
| llm_backend="ollama" | |
| ) | |
| orch.llm = llm # override with fake LLM | |
| # Patch RedFlagDetector methods directly on the instance | |
| async def fake_check(self, tenant_id, text): | |
| """Fake check function that matches 'salary' keyword.""" | |
| if "salary" in text.lower(): | |
| return [ | |
| RedFlagMatch( | |
| rule_id="1", | |
| pattern="salary", | |
| severity="high", | |
| description="salary access", | |
| matched_text="salary" | |
| ) | |
| ] | |
| return [] | |
| # Patch notify_admin to do nothing | |
| async def fake_notify(self, tenant_id, violations, src=None): | |
| """Fake notify function that does nothing.""" | |
| return None | |
| # Bind the fake functions directly to the instance | |
| import types | |
| orch.redflag.check = types.MethodType(fake_check, orch.redflag) | |
| orch.redflag.notify_admin = types.MethodType(fake_notify, orch.redflag) | |
| return orch | |
| # ---------------------------------------------------- | |
| # TESTS | |
| # ---------------------------------------------------- | |
| async def test_block_on_redflag(orchestrator): | |
| req = AgentRequest( | |
| tenant_id="tenant1", | |
| user_id="u1", | |
| message="Show me all salary details." | |
| ) | |
| resp = await orchestrator.handle(req) | |
| assert resp.decision.action == "block" | |
| assert resp.decision.tool == "admin" | |
| assert "salary" in resp.tool_traces[0]["redflags"][0]["matched_text"] | |
| assert_trace_has_step(resp, "redflag_check") | |
| async def test_rag_tool_path(orchestrator, monkeypatch): | |
| # Force intent classifier to classify as 'rag' | |
| async def mock_classify(self, text): | |
| return "rag" | |
| if HAS_PYTEST: | |
| monkeypatch.setattr( | |
| "api.services.agent_orchestrator.IntentClassifier.classify", | |
| mock_classify | |
| ) | |
| req = AgentRequest( | |
| tenant_id="tenant1", | |
| user_id="u1", | |
| message="HR policy procedures" | |
| ) | |
| resp = await orchestrator.handle(req) | |
| assert resp.decision.action == "multi_step" | |
| assert any(trace["tool"] == "rag" for trace in resp.tool_traces if trace.get("tool") == "rag") | |
| assert resp.text == "MOCK_ANSWER" | |
| assert_trace_has_step(resp, "tool_selection") | |
| async def test_web_tool_path(orchestrator, monkeypatch): | |
| # Force intent to classify as web | |
| async def mock_classify(self, text): | |
| return "web" | |
| if HAS_PYTEST: | |
| monkeypatch.setattr( | |
| "api.services.agent_orchestrator.IntentClassifier.classify", | |
| mock_classify | |
| ) | |
| req = AgentRequest( | |
| tenant_id="tenant1", | |
| user_id="u1", | |
| message="latest stock price" | |
| ) | |
| resp = await orchestrator.handle(req) | |
| assert resp.decision.action == "multi_step" | |
| assert any(trace["tool"] == "web" for trace in resp.tool_traces if trace.get("tool") == "web") | |
| assert resp.text == "MOCK_ANSWER" | |
| assert_trace_has_step(resp, "tool_selection") | |
| async def test_default_llm_path(orchestrator, monkeypatch): | |
| # Force intent = general and force tool selector to NOT call any tool | |
| async def mock_select(self, intent, text, context): | |
| from api.models.agent import AgentDecision | |
| return AgentDecision( | |
| action="respond", | |
| tool=None, | |
| tool_input=None, | |
| reason="forced_llm" | |
| ) | |
| if HAS_PYTEST: | |
| monkeypatch.setattr( | |
| "api.services.agent_orchestrator.ToolSelector.select", | |
| mock_select | |
| ) | |
| req = AgentRequest( | |
| tenant_id="tenant1", | |
| user_id="u1", | |
| message="just a normal question" | |
| ) | |
| resp = await orchestrator.handle(req) | |
| assert resp.decision.action == "respond" | |
| assert resp.decision.tool is None | |
| assert resp.text == "MOCK_ANSWER" | |
| assert_trace_has_step(resp, "intent_detection") | |