Spaces:
Sleeping
Sleeping
| """ | |
| Single-file test suite for IntegraChat backend (unit + integration + simulation). | |
| This version aligns with the current backend API surface. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import List, Dict | |
| import pytest | |
| from fastapi.testclient import TestClient | |
| # --------------------------------------------------------------------------- | |
| # Ensure backend package is importable | |
| # --------------------------------------------------------------------------- | |
| PROJECT_ROOT = Path(__file__).resolve().parent | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| backend_path = PROJECT_ROOT / "backend" | |
| if str(backend_path) not in sys.path: | |
| sys.path.insert(0, str(backend_path)) | |
| # --------------------------------------------------------------------------- | |
| # Shared fixtures | |
| # --------------------------------------------------------------------------- | |
| def set_test_env(): | |
| os.environ.setdefault("RAG_MCP_URL", "http://mock-rag") | |
| os.environ.setdefault("WEB_MCP_URL", "http://mock-web") | |
| os.environ.setdefault("ADMIN_MCP_URL", "http://mock-admin") | |
| os.environ.setdefault("OLLAMA_URL", "http://localhost:11434") | |
| os.environ.setdefault("OLLAMA_MODEL", "llama3") | |
| os.environ.setdefault("LLM_BACKEND", "ollama") | |
| def mock_backend_dependencies(monkeypatch): | |
| print(">> applying backend dependency patches for tests") | |
| """Patch MCP client calls and red-flag detector for deterministic tests.""" | |
| from backend.api.models.redflag import RedFlagMatch | |
| from backend.api.services.tool_scoring import ToolScoringService | |
| import types | |
| async def fake_call_rag(self, tenant_id: str, query: str) -> Dict: | |
| return { | |
| "results": [ | |
| {"text": "HR policy includes onboarding, leave rules.", "relevance": 0.92}, | |
| {"text": "General company announcement", "relevance": 0.42} | |
| ], | |
| "metadata": {"total_retrieved": 2, "returned": 2, "threshold": 0.55} | |
| } | |
| async def fake_call_web(self, tenant_id: str, query: str) -> Dict: | |
| return { | |
| "results": [ | |
| {"title": "Latest inflation update", "snippet": "Inflation is 3.2%", "url": "https://example.com"}, | |
| {"title": "Global news", "snippet": "Market highlights", "url": "https://news.example.com"} | |
| ] | |
| } | |
| async def fake_call_admin(self, tenant_id: str, query: str) -> Dict: | |
| return {"status": "ok", "tenant_id": tenant_id, "query": query} | |
| monkeypatch.setattr("backend.api.mcp_clients.mcp_client.MCPClient.call_rag", fake_call_rag) | |
| monkeypatch.setattr("backend.api.mcp_clients.mcp_client.MCPClient.call_web", fake_call_web) | |
| monkeypatch.setattr("backend.api.mcp_clients.mcp_client.MCPClient.call_admin", fake_call_admin) | |
| async def fake_redflag_check(self, tenant_id: str, text: str) -> List[RedFlagMatch]: | |
| if "delete" in text.lower(): | |
| return [ | |
| RedFlagMatch( | |
| rule_id="1", | |
| pattern="delete", | |
| severity="high", | |
| description="Deletion request", | |
| matched_text="delete", | |
| confidence=0.9, | |
| explanation="Matched on keyword 'delete'" | |
| ) | |
| ] | |
| return [] | |
| async def fake_notify(self, tenant_id, violations, source_payload=None): | |
| return None | |
| monkeypatch.setattr("backend.api.services.redflag_detector.RedFlagDetector.check", fake_redflag_check) | |
| monkeypatch.setattr("backend.api.services.redflag_detector.RedFlagDetector.notify_admin", fake_notify) | |
| def fake_score(self, message: str, intent: str, rag_results: List[Dict]) -> Dict[str, float]: | |
| return {"rag_fitness": 0.82, "web_fitness": 0.78, "llm_only": 0.25} | |
| monkeypatch.setattr(ToolScoringService, "score", fake_score) | |
| # Ensure already-instantiated orchestrator uses the same patches | |
| from backend.api.routes import agent as agent_routes | |
| agent_routes.orchestrator.mcp.call_rag = types.MethodType(fake_call_rag, agent_routes.orchestrator.mcp) | |
| agent_routes.orchestrator.mcp.call_web = types.MethodType(fake_call_web, agent_routes.orchestrator.mcp) | |
| agent_routes.orchestrator.mcp.call_admin = types.MethodType(fake_call_admin, agent_routes.orchestrator.mcp) | |
| agent_routes.orchestrator.redflag.check = types.MethodType(fake_redflag_check, agent_routes.orchestrator.redflag) | |
| agent_routes.orchestrator.redflag.notify_admin = types.MethodType(fake_notify, agent_routes.orchestrator.redflag) | |
| def api_client(mock_backend_dependencies): | |
| from backend.api.main import app | |
| return TestClient(app) | |
| # --------------------------------------------------------------------------- | |
| # Unit tests | |
| # --------------------------------------------------------------------------- | |
| async def test_redflag_detector(): | |
| import time | |
| from backend.api.services.redflag_detector import RedFlagDetector | |
| from backend.api.models.redflag import RedFlagRule | |
| from backend.api.services.semantic_encoder import embed_text | |
| detector = RedFlagDetector(supabase_url="http://fake", supabase_key="fake") | |
| rule = RedFlagRule( | |
| id="rule-salary", | |
| pattern="salary", | |
| description="Salary access", | |
| severity="high", | |
| source="test", | |
| enabled=True, | |
| keywords=["salary"] | |
| ) | |
| detector._rules_cache["tenant-x"] = {"fetched_at": int(time.time()), "rules": [rule]} | |
| detector._rule_embeddings["tenant-x"] = {rule.id: embed_text("salary access")} | |
| matches = await detector.check("tenant-x", "Show me employee salary details") | |
| assert matches | |
| assert matches[0].matched_text.lower() == "salary" | |
| assert matches[0].confidence is not None | |
| def test_tool_scoring(): | |
| from backend.api.services.tool_scoring import ToolScoringService | |
| scorer = ToolScoringService() | |
| scores = scorer.score("What is inflation today?", intent="web", rag_results=[]) | |
| assert set(scores.keys()) == {"rag_fitness", "web_fitness", "llm_only"} | |
| assert scores["web_fitness"] > scores["rag_fitness"] | |
| async def test_tool_selector(): | |
| from backend.api.services.tool_selector import ToolSelector | |
| selector = ToolSelector() | |
| decision = await selector.select( | |
| intent="rag", | |
| text="Tell me HR policy and compare with external news", | |
| ctx={"rag_results": [{"text": "Policy"}], "tool_scores": {"rag_fitness": 0.9, "web_fitness": 0.8}} | |
| ) | |
| steps = decision.tool_input["steps"] | |
| assert steps[0]["tool"] == "rag" | |
| assert any(step["tool"] == "web" for step in steps) | |
| assert steps[-1]["tool"] == "llm" | |
| def test_reasoning_trace_via_response(api_client): | |
| payload = {"tenant_id": "tenant1", "message": "Summarize our HR policies"} | |
| res = api_client.post("/agent/message", json=payload) | |
| data = res.json() | |
| assert data["reasoning_trace"] | |
| step_names = [entry["step"] for entry in data["reasoning_trace"]] | |
| assert "intent_detection" in step_names | |
| # --------------------------------------------------------------------------- | |
| # Integration tests | |
| # --------------------------------------------------------------------------- | |
| def test_full_agent_pipeline(api_client): | |
| payload = {"tenant_id": "tenant123", "message": "What are our HR policies and latest updates?"} | |
| response = api_client.post("/agent/message", json=payload) | |
| data = response.json() | |
| assert data["text"] | |
| assert len(data["reasoning_trace"]) >= 3 | |
| rag_steps = [step for step in data["reasoning_trace"] if step.get("tool") == "rag"] | |
| assert rag_steps, "expected rag tool execution in reasoning trace" | |
| def test_parallel_execution_detected(api_client): | |
| payload = {"tenant_id": "t1", "message": "Summarize HR policies and latest news updates"} | |
| response = api_client.post("/agent/message", json=payload) | |
| data = response.json() | |
| tools_used = {trace.get("tool") for trace in data["tool_traces"] if trace.get("tool")} | |
| assert "rag" in tools_used and "web" in tools_used | |
| # --------------------------------------------------------------------------- | |
| # Simulation tests | |
| # --------------------------------------------------------------------------- | |
| SIM_QUERIES = [ | |
| "What is the inflation rate today?", | |
| "Summarize our HR policies", | |
| "Delete all records", | |
| "Explain our refund policy", | |
| "How many employees are in the company?" | |
| ] | |
| def test_agent_simulation(api_client, message): | |
| res = api_client.post("/agent/message", json={"tenant_id": "demo", "message": message}) | |
| data = res.json() | |
| assert data["text"] | |
| assert data["reasoning_trace"] | |
| if "delete" in message.lower(): | |
| assert data["decision"]["action"] in {"block", "multi_step"} | |
| reason = (data["decision"]["reason"] or "").lower() | |
| assert "admin" in reason or "redflag" in reason | |