IntegraChat / backend /tests /test_agent_orchestrator.py
nothingworry's picture
Reasoning traces, smarter tools, deterministic backend tests.
ef83e66
raw
history blame
6.68 kB
# =============================================================
# File: tests/test_agent_orchestrator.py
# =============================================================
import sys
from pathlib import Path
# Add backend directory to Python path
backend_dir = Path(__file__).parent.parent
sys.path.insert(0, str(backend_dir))
try:
import pytest
HAS_PYTEST = True
except ImportError:
HAS_PYTEST = False
# Create a mock pytest decorator if pytest is not available
class MockMark:
def asyncio(self, func):
return func
class MockPytest:
mark = MockMark()
def fixture(self, func):
return func
pytest = MockPytest()
import os
from api.services.agent_orchestrator import AgentOrchestrator
from api.models.agent import AgentRequest, AgentDecision, AgentResponse
from api.models.redflag import RedFlagMatch
from api.services.llm_client import LLMClient
# ---------------------------
# Mock classes
# ---------------------------
class FakeLLM(LLMClient):
def __init__(self, output="LLM_RESPONSE"):
self.output = output
async def simple_call(self, prompt: str, temperature: float = 0.0):
return self.output
class FakeMCP:
"""Fake MCP server client used for rag/web/admin calls."""
def __init__(self):
self.last_rag = None
self.last_web = None
self.last_admin = None
async def call_rag(self, tenant_id: str, query: str):
self.last_rag = query
return {"results": [{"text": "RAG_DOC_CONTENT"}]}
async def call_web(self, tenant_id: str, query: str):
self.last_web = query
return {"results": [{"title": "WebResult", "snippet": "Fresh info"}]}
async def call_admin(self, tenant_id: str, query: str):
self.last_admin = query
return {"action": "allow"}
def assert_trace_has_step(resp, step_name):
assert resp.reasoning_trace, "reasoning trace missing"
assert any(entry.get("step") == step_name for entry in resp.reasoning_trace), f"{step_name} missing"
# ---------------------------
# Patch orchestrator to use fake MCP + fake redflag
# ---------------------------
@pytest.fixture
def orchestrator(monkeypatch):
# Fake LLM that always returns "MOCK_ANSWER"
llm = FakeLLM(output="MOCK_ANSWER")
fake_mcp = FakeMCP()
# Patch MCPClient
if HAS_PYTEST:
monkeypatch.setattr(
"api.services.agent_orchestrator.MCPClient",
lambda rag_url, web_url, admin_url: fake_mcp
)
# Create orchestrator with fake URLs first
orch = AgentOrchestrator(
rag_mcp_url="fake_rag",
web_mcp_url="fake_web",
admin_mcp_url="fake_admin",
llm_backend="ollama"
)
orch.llm = llm # override with fake LLM
# Patch RedFlagDetector methods directly on the instance
async def fake_check(self, tenant_id, text):
"""Fake check function that matches 'salary' keyword."""
if "salary" in text.lower():
return [
RedFlagMatch(
rule_id="1",
pattern="salary",
severity="high",
description="salary access",
matched_text="salary"
)
]
return []
# Patch notify_admin to do nothing
async def fake_notify(self, tenant_id, violations, src=None):
"""Fake notify function that does nothing."""
return None
# Bind the fake functions directly to the instance
import types
orch.redflag.check = types.MethodType(fake_check, orch.redflag)
orch.redflag.notify_admin = types.MethodType(fake_notify, orch.redflag)
return orch
# ----------------------------------------------------
# TESTS
# ----------------------------------------------------
@pytest.mark.asyncio
async def test_block_on_redflag(orchestrator):
req = AgentRequest(
tenant_id="tenant1",
user_id="u1",
message="Show me all salary details."
)
resp = await orchestrator.handle(req)
assert resp.decision.action == "block"
assert resp.decision.tool == "admin"
assert "salary" in resp.tool_traces[0]["redflags"][0]["matched_text"]
assert_trace_has_step(resp, "redflag_check")
@pytest.mark.asyncio
async def test_rag_tool_path(orchestrator, monkeypatch):
# Force intent classifier to classify as 'rag'
async def mock_classify(self, text):
return "rag"
if HAS_PYTEST:
monkeypatch.setattr(
"api.services.agent_orchestrator.IntentClassifier.classify",
mock_classify
)
req = AgentRequest(
tenant_id="tenant1",
user_id="u1",
message="HR policy procedures"
)
resp = await orchestrator.handle(req)
assert resp.decision.action == "multi_step"
assert any(trace["tool"] == "rag" for trace in resp.tool_traces if trace.get("tool") == "rag")
assert resp.text == "MOCK_ANSWER"
assert_trace_has_step(resp, "tool_selection")
@pytest.mark.asyncio
async def test_web_tool_path(orchestrator, monkeypatch):
# Force intent to classify as web
async def mock_classify(self, text):
return "web"
if HAS_PYTEST:
monkeypatch.setattr(
"api.services.agent_orchestrator.IntentClassifier.classify",
mock_classify
)
req = AgentRequest(
tenant_id="tenant1",
user_id="u1",
message="latest stock price"
)
resp = await orchestrator.handle(req)
assert resp.decision.action == "multi_step"
assert any(trace["tool"] == "web" for trace in resp.tool_traces if trace.get("tool") == "web")
assert resp.text == "MOCK_ANSWER"
assert_trace_has_step(resp, "tool_selection")
@pytest.mark.asyncio
async def test_default_llm_path(orchestrator, monkeypatch):
# Force intent = general and force tool selector to NOT call any tool
async def mock_select(self, intent, text, context):
from api.models.agent import AgentDecision
return AgentDecision(
action="respond",
tool=None,
tool_input=None,
reason="forced_llm"
)
if HAS_PYTEST:
monkeypatch.setattr(
"api.services.agent_orchestrator.ToolSelector.select",
mock_select
)
req = AgentRequest(
tenant_id="tenant1",
user_id="u1",
message="just a normal question"
)
resp = await orchestrator.handle(req)
assert resp.decision.action == "respond"
assert resp.decision.tool is None
assert resp.text == "MOCK_ANSWER"
assert_trace_has_step(resp, "intent_detection")