""" Single-file test suite for IntegraChat backend (unit + integration + simulation). This version aligns with the current backend API surface. """ from __future__ import annotations import os import sys from pathlib import Path from typing import List, Dict import pytest from fastapi.testclient import TestClient # --------------------------------------------------------------------------- # Ensure backend package is importable # --------------------------------------------------------------------------- PROJECT_ROOT = Path(__file__).resolve().parent if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) backend_path = PROJECT_ROOT / "backend" if str(backend_path) not in sys.path: sys.path.insert(0, str(backend_path)) # --------------------------------------------------------------------------- # Shared fixtures # --------------------------------------------------------------------------- @pytest.fixture(autouse=True, scope="session") def set_test_env(): os.environ.setdefault("RAG_MCP_URL", "http://mock-rag") os.environ.setdefault("WEB_MCP_URL", "http://mock-web") os.environ.setdefault("ADMIN_MCP_URL", "http://mock-admin") os.environ.setdefault("OLLAMA_URL", "http://localhost:11434") os.environ.setdefault("OLLAMA_MODEL", "llama3") os.environ.setdefault("LLM_BACKEND", "ollama") @pytest.fixture def mock_backend_dependencies(monkeypatch): print(">> applying backend dependency patches for tests") """Patch MCP client calls and red-flag detector for deterministic tests.""" from backend.api.models.redflag import RedFlagMatch from backend.api.services.tool_scoring import ToolScoringService import types async def fake_call_rag(self, tenant_id: str, query: str) -> Dict: return { "results": [ {"text": "HR policy includes onboarding, leave rules.", "relevance": 0.92}, {"text": "General company announcement", "relevance": 0.42} ], "metadata": {"total_retrieved": 2, "returned": 2, "threshold": 0.55} } async def fake_call_web(self, tenant_id: str, query: str) -> Dict: return { "results": [ {"title": "Latest inflation update", "snippet": "Inflation is 3.2%", "url": "https://example.com"}, {"title": "Global news", "snippet": "Market highlights", "url": "https://news.example.com"} ] } async def fake_call_admin(self, tenant_id: str, query: str) -> Dict: return {"status": "ok", "tenant_id": tenant_id, "query": query} monkeypatch.setattr("backend.api.mcp_clients.mcp_client.MCPClient.call_rag", fake_call_rag) monkeypatch.setattr("backend.api.mcp_clients.mcp_client.MCPClient.call_web", fake_call_web) monkeypatch.setattr("backend.api.mcp_clients.mcp_client.MCPClient.call_admin", fake_call_admin) async def fake_redflag_check(self, tenant_id: str, text: str) -> List[RedFlagMatch]: if "delete" in text.lower(): return [ RedFlagMatch( rule_id="1", pattern="delete", severity="high", description="Deletion request", matched_text="delete", confidence=0.9, explanation="Matched on keyword 'delete'" ) ] return [] async def fake_notify(self, tenant_id, violations, source_payload=None): return None monkeypatch.setattr("backend.api.services.redflag_detector.RedFlagDetector.check", fake_redflag_check) monkeypatch.setattr("backend.api.services.redflag_detector.RedFlagDetector.notify_admin", fake_notify) def fake_score(self, message: str, intent: str, rag_results: List[Dict]) -> Dict[str, float]: return {"rag_fitness": 0.82, "web_fitness": 0.78, "llm_only": 0.25} monkeypatch.setattr(ToolScoringService, "score", fake_score) # Ensure already-instantiated orchestrator uses the same patches from backend.api.routes import agent as agent_routes agent_routes.orchestrator.mcp.call_rag = types.MethodType(fake_call_rag, agent_routes.orchestrator.mcp) agent_routes.orchestrator.mcp.call_web = types.MethodType(fake_call_web, agent_routes.orchestrator.mcp) agent_routes.orchestrator.mcp.call_admin = types.MethodType(fake_call_admin, agent_routes.orchestrator.mcp) agent_routes.orchestrator.redflag.check = types.MethodType(fake_redflag_check, agent_routes.orchestrator.redflag) agent_routes.orchestrator.redflag.notify_admin = types.MethodType(fake_notify, agent_routes.orchestrator.redflag) @pytest.fixture def api_client(mock_backend_dependencies): from backend.api.main import app return TestClient(app) # --------------------------------------------------------------------------- # Unit tests # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_redflag_detector(): import time from backend.api.services.redflag_detector import RedFlagDetector from backend.api.models.redflag import RedFlagRule from backend.api.services.semantic_encoder import embed_text detector = RedFlagDetector(supabase_url="http://fake", supabase_key="fake") rule = RedFlagRule( id="rule-salary", pattern="salary", description="Salary access", severity="high", source="test", enabled=True, keywords=["salary"] ) detector._rules_cache["tenant-x"] = {"fetched_at": int(time.time()), "rules": [rule]} detector._rule_embeddings["tenant-x"] = {rule.id: embed_text("salary access")} matches = await detector.check("tenant-x", "Show me employee salary details") assert matches assert matches[0].matched_text.lower() == "salary" assert matches[0].confidence is not None def test_tool_scoring(): from backend.api.services.tool_scoring import ToolScoringService scorer = ToolScoringService() scores = scorer.score("What is inflation today?", intent="web", rag_results=[]) assert set(scores.keys()) == {"rag_fitness", "web_fitness", "llm_only"} assert scores["web_fitness"] > scores["rag_fitness"] @pytest.mark.asyncio async def test_tool_selector(): from backend.api.services.tool_selector import ToolSelector selector = ToolSelector() decision = await selector.select( intent="rag", text="Tell me HR policy and compare with external news", ctx={"rag_results": [{"text": "Policy"}], "tool_scores": {"rag_fitness": 0.9, "web_fitness": 0.8}} ) steps = decision.tool_input["steps"] assert steps[0]["tool"] == "rag" assert any(step["tool"] == "web" for step in steps) assert steps[-1]["tool"] == "llm" def test_reasoning_trace_via_response(api_client): payload = {"tenant_id": "tenant1", "message": "Summarize our HR policies"} res = api_client.post("/agent/message", json=payload) data = res.json() assert data["reasoning_trace"] step_names = [entry["step"] for entry in data["reasoning_trace"]] assert "intent_detection" in step_names # --------------------------------------------------------------------------- # Integration tests # --------------------------------------------------------------------------- def test_full_agent_pipeline(api_client): payload = {"tenant_id": "tenant123", "message": "What are our HR policies and latest updates?"} response = api_client.post("/agent/message", json=payload) data = response.json() assert data["text"] assert len(data["reasoning_trace"]) >= 3 rag_steps = [step for step in data["reasoning_trace"] if step.get("tool") == "rag"] assert rag_steps, "expected rag tool execution in reasoning trace" def test_parallel_execution_detected(api_client): payload = {"tenant_id": "t1", "message": "Summarize HR policies and latest news updates"} response = api_client.post("/agent/message", json=payload) data = response.json() tools_used = {trace.get("tool") for trace in data["tool_traces"] if trace.get("tool")} assert "rag" in tools_used and "web" in tools_used # --------------------------------------------------------------------------- # Simulation tests # --------------------------------------------------------------------------- SIM_QUERIES = [ "What is the inflation rate today?", "Summarize our HR policies", "Delete all records", "Explain our refund policy", "How many employees are in the company?" ] @pytest.mark.parametrize("message", SIM_QUERIES) def test_agent_simulation(api_client, message): res = api_client.post("/agent/message", json={"tenant_id": "demo", "message": message}) data = res.json() assert data["text"] assert data["reasoning_trace"] if "delete" in message.lower(): assert data["decision"]["action"] in {"block", "multi_step"} reason = (data["decision"]["reason"] or "").lower() assert "admin" in reason or "redflag" in reason