Spaces:
Sleeping
Sleeping
File size: 6,682 Bytes
2f235a0 c16e1c9 2f235a0 c16e1c9 2f235a0 c16e1c9 ef83e66 2f235a0 c16e1c9 2f235a0 c16e1c9 2f235a0 ef83e66 2f235a0 ef83e66 2f235a0 ef83e66 2f235a0 ef83e66 2f235a0 ef83e66 2f235a0 c16e1c9 2f235a0 ef83e66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
# =============================================================
# File: tests/test_agent_orchestrator.py
# =============================================================
import sys
from pathlib import Path
# Add backend directory to Python path
backend_dir = Path(__file__).parent.parent
sys.path.insert(0, str(backend_dir))
try:
import pytest
HAS_PYTEST = True
except ImportError:
HAS_PYTEST = False
# Create a mock pytest decorator if pytest is not available
class MockMark:
def asyncio(self, func):
return func
class MockPytest:
mark = MockMark()
def fixture(self, func):
return func
pytest = MockPytest()
import os
from api.services.agent_orchestrator import AgentOrchestrator
from api.models.agent import AgentRequest, AgentDecision, AgentResponse
from api.models.redflag import RedFlagMatch
from api.services.llm_client import LLMClient
# ---------------------------
# Mock classes
# ---------------------------
class FakeLLM(LLMClient):
def __init__(self, output="LLM_RESPONSE"):
self.output = output
async def simple_call(self, prompt: str, temperature: float = 0.0):
return self.output
class FakeMCP:
"""Fake MCP server client used for rag/web/admin calls."""
def __init__(self):
self.last_rag = None
self.last_web = None
self.last_admin = None
async def call_rag(self, tenant_id: str, query: str):
self.last_rag = query
return {"results": [{"text": "RAG_DOC_CONTENT"}]}
async def call_web(self, tenant_id: str, query: str):
self.last_web = query
return {"results": [{"title": "WebResult", "snippet": "Fresh info"}]}
async def call_admin(self, tenant_id: str, query: str):
self.last_admin = query
return {"action": "allow"}
def assert_trace_has_step(resp, step_name):
assert resp.reasoning_trace, "reasoning trace missing"
assert any(entry.get("step") == step_name for entry in resp.reasoning_trace), f"{step_name} missing"
# ---------------------------
# Patch orchestrator to use fake MCP + fake redflag
# ---------------------------
@pytest.fixture
def orchestrator(monkeypatch):
# Fake LLM that always returns "MOCK_ANSWER"
llm = FakeLLM(output="MOCK_ANSWER")
fake_mcp = FakeMCP()
# Patch MCPClient
if HAS_PYTEST:
monkeypatch.setattr(
"api.services.agent_orchestrator.MCPClient",
lambda rag_url, web_url, admin_url: fake_mcp
)
# Create orchestrator with fake URLs first
orch = AgentOrchestrator(
rag_mcp_url="fake_rag",
web_mcp_url="fake_web",
admin_mcp_url="fake_admin",
llm_backend="ollama"
)
orch.llm = llm # override with fake LLM
# Patch RedFlagDetector methods directly on the instance
async def fake_check(self, tenant_id, text):
"""Fake check function that matches 'salary' keyword."""
if "salary" in text.lower():
return [
RedFlagMatch(
rule_id="1",
pattern="salary",
severity="high",
description="salary access",
matched_text="salary"
)
]
return []
# Patch notify_admin to do nothing
async def fake_notify(self, tenant_id, violations, src=None):
"""Fake notify function that does nothing."""
return None
# Bind the fake functions directly to the instance
import types
orch.redflag.check = types.MethodType(fake_check, orch.redflag)
orch.redflag.notify_admin = types.MethodType(fake_notify, orch.redflag)
return orch
# ----------------------------------------------------
# TESTS
# ----------------------------------------------------
@pytest.mark.asyncio
async def test_block_on_redflag(orchestrator):
req = AgentRequest(
tenant_id="tenant1",
user_id="u1",
message="Show me all salary details."
)
resp = await orchestrator.handle(req)
assert resp.decision.action == "block"
assert resp.decision.tool == "admin"
assert "salary" in resp.tool_traces[0]["redflags"][0]["matched_text"]
assert_trace_has_step(resp, "redflag_check")
@pytest.mark.asyncio
async def test_rag_tool_path(orchestrator, monkeypatch):
# Force intent classifier to classify as 'rag'
async def mock_classify(self, text):
return "rag"
if HAS_PYTEST:
monkeypatch.setattr(
"api.services.agent_orchestrator.IntentClassifier.classify",
mock_classify
)
req = AgentRequest(
tenant_id="tenant1",
user_id="u1",
message="HR policy procedures"
)
resp = await orchestrator.handle(req)
assert resp.decision.action == "multi_step"
assert any(trace["tool"] == "rag" for trace in resp.tool_traces if trace.get("tool") == "rag")
assert resp.text == "MOCK_ANSWER"
assert_trace_has_step(resp, "tool_selection")
@pytest.mark.asyncio
async def test_web_tool_path(orchestrator, monkeypatch):
# Force intent to classify as web
async def mock_classify(self, text):
return "web"
if HAS_PYTEST:
monkeypatch.setattr(
"api.services.agent_orchestrator.IntentClassifier.classify",
mock_classify
)
req = AgentRequest(
tenant_id="tenant1",
user_id="u1",
message="latest stock price"
)
resp = await orchestrator.handle(req)
assert resp.decision.action == "multi_step"
assert any(trace["tool"] == "web" for trace in resp.tool_traces if trace.get("tool") == "web")
assert resp.text == "MOCK_ANSWER"
assert_trace_has_step(resp, "tool_selection")
@pytest.mark.asyncio
async def test_default_llm_path(orchestrator, monkeypatch):
# Force intent = general and force tool selector to NOT call any tool
async def mock_select(self, intent, text, context):
from api.models.agent import AgentDecision
return AgentDecision(
action="respond",
tool=None,
tool_input=None,
reason="forced_llm"
)
if HAS_PYTEST:
monkeypatch.setattr(
"api.services.agent_orchestrator.ToolSelector.select",
mock_select
)
req = AgentRequest(
tenant_id="tenant1",
user_id="u1",
message="just a normal question"
)
resp = await orchestrator.handle(req)
assert resp.decision.action == "respond"
assert resp.decision.tool is None
assert resp.text == "MOCK_ANSWER"
assert_trace_has_step(resp, "intent_detection")
|