""" Comprehensive tests for: 1. Per-Tool Latency Prediction 2. Context-Aware MCP Routing 3. Tool Output Schemas Tests all three new features for intelligent tool selection and output validation. """ import pytest from unittest.mock import Mock, patch, AsyncMock from backend.api.services.tool_metadata import ( get_tool_latency_estimate, estimate_path_latency, get_fastest_path, validate_tool_output, get_tool_schema, TOOL_LATENCY_METADATA, TOOL_OUTPUT_SCHEMAS ) from backend.api.services.tool_selector import ToolSelector from backend.api.services.agent_orchestrator import AgentOrchestrator class TestLatencyPrediction: """Test per-tool latency prediction""" def test_get_tool_latency_estimate_basic(self): """Test basic latency estimation without context""" rag_latency = get_tool_latency_estimate("rag") web_latency = get_tool_latency_estimate("web") admin_latency = get_tool_latency_estimate("admin") llm_latency = get_tool_latency_estimate("llm") # Check that latencies are within expected ranges assert 60 <= rag_latency <= 120 assert 400 <= web_latency <= 1800 assert 5 <= admin_latency <= 20 assert 500 <= llm_latency <= 5000 def test_get_tool_latency_estimate_with_context(self): """Test latency estimation with context""" # RAG with long query rag_long = get_tool_latency_estimate("rag", {"query_length": 200}) rag_short = get_tool_latency_estimate("rag", {"query_length": 10}) assert rag_long >= rag_short # Longer queries should take more time # Web with complexity web_complex = get_tool_latency_estimate("web", {"query_complexity": "high"}) web_simple = get_tool_latency_estimate("web", {"query_complexity": "low"}) assert web_complex >= web_simple # Complex queries should take more time def test_estimate_path_latency(self): """Test total latency estimation for tool sequences""" # Single tool single = estimate_path_latency(["admin"]) assert single > 0 assert single <= 20 # Multiple tools multi = estimate_path_latency(["rag", "web", "llm"]) assert multi > 0 # Should be sum of individual latencies assert multi >= get_tool_latency_estimate("rag") assert multi >= get_tool_latency_estimate("web") assert multi >= get_tool_latency_estimate("llm") def test_get_fastest_path(self): """Test fastest path optimization""" tools = ["llm", "admin", "rag", "web"] fastest = get_fastest_path(tools) # Should be sorted by latency (fastest first) assert len(fastest) == len(tools) assert "admin" in fastest # Fastest tool assert fastest[0] == "admin" # Should be first # Verify order is optimized latencies = [get_tool_latency_estimate(t) for t in fastest] assert latencies == sorted(latencies) # Should be in ascending order def test_latency_metadata_structure(self): """Test that latency metadata has correct structure""" for tool_name, metadata in TOOL_LATENCY_METADATA.items(): assert metadata.tool_name == tool_name assert metadata.min_ms > 0 assert metadata.max_ms >= metadata.min_ms assert metadata.avg_ms >= metadata.min_ms assert metadata.avg_ms <= metadata.max_ms assert len(metadata.description) > 0 class TestToolOutputSchemas: """Test tool output schema validation""" def test_get_tool_schema(self): """Test schema retrieval""" rag_schema = get_tool_schema("rag") web_schema = get_tool_schema("web") admin_schema = get_tool_schema("admin") llm_schema = get_tool_schema("llm") assert rag_schema is not None assert web_schema is not None assert admin_schema is not None assert llm_schema is not None assert rag_schema.tool_name == "rag" assert web_schema.tool_name == "web" assert admin_schema.tool_name == "admin" assert llm_schema.tool_name == "llm" def test_validate_rag_output_valid(self): """Test validation of valid RAG output""" valid_rag = { "results": [ { "text": "Document chunk", "similarity": 0.85, "metadata": {"title": "Test"}, "doc_id": "doc123" } ], "query": "test query", "tenant_id": "tenant1", "hits_count": 1, "avg_score": 0.85, "top_score": 0.85, "latency_ms": 90 } is_valid, error = validate_tool_output("rag", valid_rag) assert is_valid is True assert error is None def test_validate_rag_output_missing_field(self): """Test validation catches missing required fields""" invalid_rag = { "results": [], # Missing "query" and "tenant_id" "hits_count": 0 } is_valid, error = validate_tool_output("rag", invalid_rag) assert is_valid is False assert "Missing required field" in error def test_validate_web_output_valid(self): """Test validation of valid Web output""" valid_web = { "results": [ { "title": "Result Title", "snippet": "Result snippet", "link": "https://example.com", "displayLink": "example.com" } ], "query": "search query", "total_results": 10, "latency_ms": 800 } is_valid, error = validate_tool_output("web", valid_web) assert is_valid is True assert error is None def test_validate_admin_output_valid(self): """Test validation of valid Admin output""" valid_admin = { "violations": [ { "rule_id": "rule1", "rule_pattern": ".*password.*", "severity": "high", "matched_text": "password", "confidence": 0.95, "message_preview": "User asked for password" } ], "checked": True, "rules_count": 5, "latency_ms": 10 } is_valid, error = validate_tool_output("admin", valid_admin) assert is_valid is True assert error is None def test_validate_llm_output_valid(self): """Test validation of valid LLM output""" valid_llm = { "text": "Generated response", "tokens_used": 150, "latency_ms": 2000, "model": "llama3.1:latest", "temperature": 0.0 } is_valid, error = validate_tool_output("llm", valid_llm) assert is_valid is True assert error is None def test_validate_type_mismatch(self): """Test validation catches type mismatches""" invalid_rag = { "results": "not an array", # Should be array "query": "test", "tenant_id": "tenant1" } is_valid, error = validate_tool_output("rag", invalid_rag) assert is_valid is False assert "must be array" in error def test_schema_examples(self): """Test that all schemas have examples""" for tool_name, schema in TOOL_OUTPUT_SCHEMAS.items(): assert schema.example is not None assert isinstance(schema.example, dict) # Example should be valid is_valid, error = validate_tool_output(tool_name, schema.example) assert is_valid is True, f"Schema example for {tool_name} is invalid: {error}" class TestContextAwareRouting: """Test context-aware MCP routing""" @pytest.fixture def tool_selector(self): """Create a ToolSelector instance""" return ToolSelector(llm_client=None) def test_analyze_context_rag_high_score(self, tool_selector): """Test context analysis when RAG returns high score""" rag_results = [ {"similarity": 0.85, "text": "High quality result"}, {"similarity": 0.90, "text": "Another high quality result"} ] memory = [] admin_violations = [] tool_scores = {"rag_fitness": 0.8, "web_fitness": 0.5} hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores) assert hints.get("skip_web_if_rag_high") is True assert hints.get("rag_high_confidence") is True def test_analyze_context_rag_low_score(self, tool_selector): """Test context analysis when RAG returns low score""" rag_results = [ {"similarity": 0.3, "text": "Low quality result"} ] memory = [] admin_violations = [] tool_scores = {"rag_fitness": 0.3, "web_fitness": 0.7} hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores) # Should not skip web if RAG score is low assert hints.get("skip_web_if_rag_high") is not True def test_analyze_context_memory_relevant(self, tool_selector): """Test context analysis when relevant memory exists""" rag_results = [] memory = [ { "tool": "rag", "result": { "results": [ {"similarity": 0.80, "text": "Recent RAG result"} ] } } ] admin_violations = [] tool_scores = {} hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores) assert hints.get("has_relevant_memory") is True # Should suggest skipping RAG if memory is recent and high quality if memory[0]["result"]["results"][0]["similarity"] >= 0.75: assert hints.get("skip_rag_if_memory") is True def test_analyze_context_admin_critical(self, tool_selector): """Test context analysis when admin violation is critical""" rag_results = [] memory = [] admin_violations = [ { "severity": "critical", "rule_id": "rule1", "matched_text": "sensitive data" } ] tool_scores = {} hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores) assert hints.get("skip_agent_reasoning") is True assert hints.get("critical_violation") is True def test_analyze_context_admin_low_severity(self, tool_selector): """Test context analysis when admin violation is low severity""" rag_results = [] memory = [] admin_violations = [ { "severity": "low", "rule_id": "rule1", "matched_text": "minor issue" } ] tool_scores = {} hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores) # Low severity should not skip reasoning assert hints.get("skip_agent_reasoning") is not True @pytest.mark.asyncio async def test_tool_selection_with_context_hints(self, tool_selector): """Test tool selection uses context hints""" # Mock LLM client tool_selector.llm_client = AsyncMock() # Context with high RAG score ctx = { "tenant_id": "test_tenant", "rag_results": [ {"similarity": 0.85, "text": "High quality result"} ], "tool_scores": { "rag_fitness": 0.8, "web_fitness": 0.6, "llm_only": 0.3 }, "memory": [], "admin_violations": [] } decision = await tool_selector.select("general", "What is our company policy?", ctx) # Should include latency estimates in reason assert "latency" in decision.reason.lower() or "est." in decision.reason.lower() # Check that steps have latency estimates (for non-LLM tools) if decision.tool_input and "steps" in decision.tool_input: steps = decision.tool_input["steps"] for step in steps: if isinstance(step, dict) and "input" in step and step.get("tool") != "llm": # Non-LLM tools should have estimated latency (or be parallel) assert "_estimated_latency_ms" in step["input"] or "parallel" in step or step.get("tool") == "llm" @pytest.mark.asyncio async def test_tool_selection_skips_web_on_high_rag(self, tool_selector): """Test that tool selection skips web when RAG has high score""" tool_selector.llm_client = AsyncMock() ctx = { "tenant_id": "test_tenant", "rag_results": [ {"similarity": 0.90, "text": "Very high quality result"} ], "tool_scores": { "rag_fitness": 0.9, "web_fitness": 0.7, "llm_only": 0.2 }, "memory": [], "admin_violations": [] } decision = await tool_selector.select("general", "What is our internal policy?", ctx) # Check reason includes context hint assert "skip web" in decision.reason.lower() or "rag high" in decision.reason.lower() or "context" in decision.reason.lower() @pytest.mark.asyncio async def test_tool_selection_admin_critical_skip_reasoning(self, tool_selector): """Test that tool selection skips reasoning for critical admin violations""" tool_selector.llm_client = None # No LLM needed for admin-only path ctx = { "tenant_id": "test_tenant", "rag_results": [], "tool_scores": {}, "memory": [], "admin_violations": [ { "severity": "critical", "rule_id": "rule1", "matched_text": "critical violation" } ] } decision = await tool_selector.select("admin", "User trying to access sensitive data", ctx) # Should skip LLM reasoning for critical violations if decision.tool_input and "steps" in decision.tool_input: steps = decision.tool_input["steps"] # Should have admin step but may skip LLM has_admin = any(s.get("tool") == "admin" for s in steps if isinstance(s, dict)) assert has_admin class TestOrchestratorIntegration: """Test orchestrator integration with new features""" @pytest.fixture def orchestrator(self): """Create an AgentOrchestrator instance""" return AgentOrchestrator( rag_mcp_url="http://localhost:8900/rag", web_mcp_url="http://localhost:8900/web", admin_mcp_url="http://localhost:8900/admin", llm_backend="ollama" ) def test_format_rag_output(self, orchestrator): """Test RAG output formatting""" raw_output = { "results": [ {"text": "Chunk 1", "similarity": 0.85}, {"text": "Chunk 2", "similarity": 0.75} ], "query": "test query" } formatted = orchestrator._format_tool_output("rag", raw_output, 90) # Check schema compliance assert "results" in formatted assert "query" in formatted assert "tenant_id" in formatted assert "hits_count" in formatted assert "avg_score" in formatted assert "top_score" in formatted assert "latency_ms" in formatted # Validate against schema is_valid, error = validate_tool_output("rag", formatted) assert is_valid is True, f"Formatted RAG output invalid: {error}" def test_format_web_output(self, orchestrator): """Test Web output formatting""" raw_output = { "items": [ { "title": "Result Title", "snippet": "Result snippet", "link": "https://example.com" } ] } formatted = orchestrator._format_tool_output("web", raw_output, 800) # Check schema compliance assert "results" in formatted assert "query" in formatted assert "total_results" in formatted assert "latency_ms" in formatted # Validate against schema is_valid, error = validate_tool_output("web", formatted) assert is_valid is True, f"Formatted Web output invalid: {error}" def test_format_admin_output(self, orchestrator): """Test Admin output formatting""" raw_output = { "matches": [ { "rule_id": "rule1", "pattern": ".*password.*", "severity": "high", "text": "password", "confidence": 0.95 } ] } formatted = orchestrator._format_tool_output("admin", raw_output, 10) # Check schema compliance assert "violations" in formatted assert "checked" in formatted assert "rules_count" in formatted assert "latency_ms" in formatted # Validate against schema is_valid, error = validate_tool_output("admin", formatted) assert is_valid is True, f"Formatted Admin output invalid: {error}" def test_format_llm_output(self, orchestrator): """Test LLM output formatting""" raw_output = "This is a generated response from the LLM." formatted = orchestrator._format_tool_output("llm", raw_output, 2000) # Check schema compliance assert "text" in formatted assert "tokens_used" in formatted assert "latency_ms" in formatted assert "model" in formatted assert "temperature" in formatted # Validate against schema is_valid, error = validate_tool_output("llm", formatted) assert is_valid is True, f"Formatted LLM output invalid: {error}" def test_format_output_handles_missing_fields(self, orchestrator): """Test output formatting handles missing fields gracefully""" # Minimal RAG output minimal = {"results": []} formatted = orchestrator._format_tool_output("rag", minimal, 90) # Should have all required fields with defaults assert "query" in formatted assert "tenant_id" in formatted assert "hits_count" in formatted assert formatted["hits_count"] == 0 class TestEndToEndRouting: """End-to-end tests for context-aware routing""" @pytest.mark.asyncio async def test_routing_with_high_rag_score(self): """Test that high RAG score prevents web search""" selector = ToolSelector(llm_client=None) ctx = { "tenant_id": "test", "rag_results": [{"similarity": 0.92, "text": "Perfect match"}], "tool_scores": {"rag_fitness": 0.9, "web_fitness": 0.7}, "memory": [], "admin_violations": [] } decision = await selector.select("general", "What is our policy?", ctx) # Check that context hints are applied if decision.tool_input and "steps" in decision.tool_input: steps = decision.tool_input["steps"] tool_names = [s.get("tool") for s in steps if isinstance(s, dict) and "tool" in s] # Should have RAG but may skip web due to high score assert "rag" in tool_names or "llm" in tool_names @pytest.mark.asyncio async def test_routing_with_memory(self): """Test that relevant memory prevents redundant RAG call""" selector = ToolSelector(llm_client=None) ctx = { "tenant_id": "test", "rag_results": [], "tool_scores": {"rag_fitness": 0.6}, "memory": [ { "tool": "rag", "result": { "results": [{"similarity": 0.85, "text": "Recent result"}] } } ], "admin_violations": [] } decision = await selector.select("general", "Tell me about our policy", ctx) # Context should be analyzed # (Actual behavior depends on implementation, but should use memory) assert decision is not None if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"])