IntegraChat / test_all.py
nothingworry's picture
Reasoning traces, smarter tools, deterministic backend tests.
ef83e66
raw
history blame
9.07 kB
"""
Single-file test suite for IntegraChat backend (unit + integration + simulation).
This version aligns with the current backend API surface.
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
from typing import List, Dict
import pytest
from fastapi.testclient import TestClient
# ---------------------------------------------------------------------------
# Ensure backend package is importable
# ---------------------------------------------------------------------------
PROJECT_ROOT = Path(__file__).resolve().parent
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
backend_path = PROJECT_ROOT / "backend"
if str(backend_path) not in sys.path:
sys.path.insert(0, str(backend_path))
# ---------------------------------------------------------------------------
# Shared fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True, scope="session")
def set_test_env():
os.environ.setdefault("RAG_MCP_URL", "http://mock-rag")
os.environ.setdefault("WEB_MCP_URL", "http://mock-web")
os.environ.setdefault("ADMIN_MCP_URL", "http://mock-admin")
os.environ.setdefault("OLLAMA_URL", "http://localhost:11434")
os.environ.setdefault("OLLAMA_MODEL", "llama3")
os.environ.setdefault("LLM_BACKEND", "ollama")
@pytest.fixture
def mock_backend_dependencies(monkeypatch):
print(">> applying backend dependency patches for tests")
"""Patch MCP client calls and red-flag detector for deterministic tests."""
from backend.api.models.redflag import RedFlagMatch
from backend.api.services.tool_scoring import ToolScoringService
import types
async def fake_call_rag(self, tenant_id: str, query: str) -> Dict:
return {
"results": [
{"text": "HR policy includes onboarding, leave rules.", "relevance": 0.92},
{"text": "General company announcement", "relevance": 0.42}
],
"metadata": {"total_retrieved": 2, "returned": 2, "threshold": 0.55}
}
async def fake_call_web(self, tenant_id: str, query: str) -> Dict:
return {
"results": [
{"title": "Latest inflation update", "snippet": "Inflation is 3.2%", "url": "https://example.com"},
{"title": "Global news", "snippet": "Market highlights", "url": "https://news.example.com"}
]
}
async def fake_call_admin(self, tenant_id: str, query: str) -> Dict:
return {"status": "ok", "tenant_id": tenant_id, "query": query}
monkeypatch.setattr("backend.api.mcp_clients.mcp_client.MCPClient.call_rag", fake_call_rag)
monkeypatch.setattr("backend.api.mcp_clients.mcp_client.MCPClient.call_web", fake_call_web)
monkeypatch.setattr("backend.api.mcp_clients.mcp_client.MCPClient.call_admin", fake_call_admin)
async def fake_redflag_check(self, tenant_id: str, text: str) -> List[RedFlagMatch]:
if "delete" in text.lower():
return [
RedFlagMatch(
rule_id="1",
pattern="delete",
severity="high",
description="Deletion request",
matched_text="delete",
confidence=0.9,
explanation="Matched on keyword 'delete'"
)
]
return []
async def fake_notify(self, tenant_id, violations, source_payload=None):
return None
monkeypatch.setattr("backend.api.services.redflag_detector.RedFlagDetector.check", fake_redflag_check)
monkeypatch.setattr("backend.api.services.redflag_detector.RedFlagDetector.notify_admin", fake_notify)
def fake_score(self, message: str, intent: str, rag_results: List[Dict]) -> Dict[str, float]:
return {"rag_fitness": 0.82, "web_fitness": 0.78, "llm_only": 0.25}
monkeypatch.setattr(ToolScoringService, "score", fake_score)
# Ensure already-instantiated orchestrator uses the same patches
from backend.api.routes import agent as agent_routes
agent_routes.orchestrator.mcp.call_rag = types.MethodType(fake_call_rag, agent_routes.orchestrator.mcp)
agent_routes.orchestrator.mcp.call_web = types.MethodType(fake_call_web, agent_routes.orchestrator.mcp)
agent_routes.orchestrator.mcp.call_admin = types.MethodType(fake_call_admin, agent_routes.orchestrator.mcp)
agent_routes.orchestrator.redflag.check = types.MethodType(fake_redflag_check, agent_routes.orchestrator.redflag)
agent_routes.orchestrator.redflag.notify_admin = types.MethodType(fake_notify, agent_routes.orchestrator.redflag)
@pytest.fixture
def api_client(mock_backend_dependencies):
from backend.api.main import app
return TestClient(app)
# ---------------------------------------------------------------------------
# Unit tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_redflag_detector():
import time
from backend.api.services.redflag_detector import RedFlagDetector
from backend.api.models.redflag import RedFlagRule
from backend.api.services.semantic_encoder import embed_text
detector = RedFlagDetector(supabase_url="http://fake", supabase_key="fake")
rule = RedFlagRule(
id="rule-salary",
pattern="salary",
description="Salary access",
severity="high",
source="test",
enabled=True,
keywords=["salary"]
)
detector._rules_cache["tenant-x"] = {"fetched_at": int(time.time()), "rules": [rule]}
detector._rule_embeddings["tenant-x"] = {rule.id: embed_text("salary access")}
matches = await detector.check("tenant-x", "Show me employee salary details")
assert matches
assert matches[0].matched_text.lower() == "salary"
assert matches[0].confidence is not None
def test_tool_scoring():
from backend.api.services.tool_scoring import ToolScoringService
scorer = ToolScoringService()
scores = scorer.score("What is inflation today?", intent="web", rag_results=[])
assert set(scores.keys()) == {"rag_fitness", "web_fitness", "llm_only"}
assert scores["web_fitness"] > scores["rag_fitness"]
@pytest.mark.asyncio
async def test_tool_selector():
from backend.api.services.tool_selector import ToolSelector
selector = ToolSelector()
decision = await selector.select(
intent="rag",
text="Tell me HR policy and compare with external news",
ctx={"rag_results": [{"text": "Policy"}], "tool_scores": {"rag_fitness": 0.9, "web_fitness": 0.8}}
)
steps = decision.tool_input["steps"]
assert steps[0]["tool"] == "rag"
assert any(step["tool"] == "web" for step in steps)
assert steps[-1]["tool"] == "llm"
def test_reasoning_trace_via_response(api_client):
payload = {"tenant_id": "tenant1", "message": "Summarize our HR policies"}
res = api_client.post("/agent/message", json=payload)
data = res.json()
assert data["reasoning_trace"]
step_names = [entry["step"] for entry in data["reasoning_trace"]]
assert "intent_detection" in step_names
# ---------------------------------------------------------------------------
# Integration tests
# ---------------------------------------------------------------------------
def test_full_agent_pipeline(api_client):
payload = {"tenant_id": "tenant123", "message": "What are our HR policies and latest updates?"}
response = api_client.post("/agent/message", json=payload)
data = response.json()
assert data["text"]
assert len(data["reasoning_trace"]) >= 3
rag_steps = [step for step in data["reasoning_trace"] if step.get("tool") == "rag"]
assert rag_steps, "expected rag tool execution in reasoning trace"
def test_parallel_execution_detected(api_client):
payload = {"tenant_id": "t1", "message": "Summarize HR policies and latest news updates"}
response = api_client.post("/agent/message", json=payload)
data = response.json()
tools_used = {trace.get("tool") for trace in data["tool_traces"] if trace.get("tool")}
assert "rag" in tools_used and "web" in tools_used
# ---------------------------------------------------------------------------
# Simulation tests
# ---------------------------------------------------------------------------
SIM_QUERIES = [
"What is the inflation rate today?",
"Summarize our HR policies",
"Delete all records",
"Explain our refund policy",
"How many employees are in the company?"
]
@pytest.mark.parametrize("message", SIM_QUERIES)
def test_agent_simulation(api_client, message):
res = api_client.post("/agent/message", json={"tenant_id": "demo", "message": message})
data = res.json()
assert data["text"]
assert data["reasoning_trace"]
if "delete" in message.lower():
assert data["decision"]["action"] in {"block", "multi_step"}
reason = (data["decision"]["reason"] or "").lower()
assert "admin" in reason or "redflag" in reason