Spaces:
Sleeping
Sleeping
feat: Add AI metadata extraction, latency prediction, context-aware routing, and tool output schemas
d1e5882
| """ | |
| Comprehensive tests for: | |
| 1. Per-Tool Latency Prediction | |
| 2. Context-Aware MCP Routing | |
| 3. Tool Output Schemas | |
| Tests all three new features for intelligent tool selection and output validation. | |
| """ | |
| import pytest | |
| from unittest.mock import Mock, patch, AsyncMock | |
| from backend.api.services.tool_metadata import ( | |
| get_tool_latency_estimate, | |
| estimate_path_latency, | |
| get_fastest_path, | |
| validate_tool_output, | |
| get_tool_schema, | |
| TOOL_LATENCY_METADATA, | |
| TOOL_OUTPUT_SCHEMAS | |
| ) | |
| from backend.api.services.tool_selector import ToolSelector | |
| from backend.api.services.agent_orchestrator import AgentOrchestrator | |
| class TestLatencyPrediction: | |
| """Test per-tool latency prediction""" | |
| def test_get_tool_latency_estimate_basic(self): | |
| """Test basic latency estimation without context""" | |
| rag_latency = get_tool_latency_estimate("rag") | |
| web_latency = get_tool_latency_estimate("web") | |
| admin_latency = get_tool_latency_estimate("admin") | |
| llm_latency = get_tool_latency_estimate("llm") | |
| # Check that latencies are within expected ranges | |
| assert 60 <= rag_latency <= 120 | |
| assert 400 <= web_latency <= 1800 | |
| assert 5 <= admin_latency <= 20 | |
| assert 500 <= llm_latency <= 5000 | |
| def test_get_tool_latency_estimate_with_context(self): | |
| """Test latency estimation with context""" | |
| # RAG with long query | |
| rag_long = get_tool_latency_estimate("rag", {"query_length": 200}) | |
| rag_short = get_tool_latency_estimate("rag", {"query_length": 10}) | |
| assert rag_long >= rag_short # Longer queries should take more time | |
| # Web with complexity | |
| web_complex = get_tool_latency_estimate("web", {"query_complexity": "high"}) | |
| web_simple = get_tool_latency_estimate("web", {"query_complexity": "low"}) | |
| assert web_complex >= web_simple # Complex queries should take more time | |
| def test_estimate_path_latency(self): | |
| """Test total latency estimation for tool sequences""" | |
| # Single tool | |
| single = estimate_path_latency(["admin"]) | |
| assert single > 0 | |
| assert single <= 20 | |
| # Multiple tools | |
| multi = estimate_path_latency(["rag", "web", "llm"]) | |
| assert multi > 0 | |
| # Should be sum of individual latencies | |
| assert multi >= get_tool_latency_estimate("rag") | |
| assert multi >= get_tool_latency_estimate("web") | |
| assert multi >= get_tool_latency_estimate("llm") | |
| def test_get_fastest_path(self): | |
| """Test fastest path optimization""" | |
| tools = ["llm", "admin", "rag", "web"] | |
| fastest = get_fastest_path(tools) | |
| # Should be sorted by latency (fastest first) | |
| assert len(fastest) == len(tools) | |
| assert "admin" in fastest # Fastest tool | |
| assert fastest[0] == "admin" # Should be first | |
| # Verify order is optimized | |
| latencies = [get_tool_latency_estimate(t) for t in fastest] | |
| assert latencies == sorted(latencies) # Should be in ascending order | |
| def test_latency_metadata_structure(self): | |
| """Test that latency metadata has correct structure""" | |
| for tool_name, metadata in TOOL_LATENCY_METADATA.items(): | |
| assert metadata.tool_name == tool_name | |
| assert metadata.min_ms > 0 | |
| assert metadata.max_ms >= metadata.min_ms | |
| assert metadata.avg_ms >= metadata.min_ms | |
| assert metadata.avg_ms <= metadata.max_ms | |
| assert len(metadata.description) > 0 | |
| class TestToolOutputSchemas: | |
| """Test tool output schema validation""" | |
| def test_get_tool_schema(self): | |
| """Test schema retrieval""" | |
| rag_schema = get_tool_schema("rag") | |
| web_schema = get_tool_schema("web") | |
| admin_schema = get_tool_schema("admin") | |
| llm_schema = get_tool_schema("llm") | |
| assert rag_schema is not None | |
| assert web_schema is not None | |
| assert admin_schema is not None | |
| assert llm_schema is not None | |
| assert rag_schema.tool_name == "rag" | |
| assert web_schema.tool_name == "web" | |
| assert admin_schema.tool_name == "admin" | |
| assert llm_schema.tool_name == "llm" | |
| def test_validate_rag_output_valid(self): | |
| """Test validation of valid RAG output""" | |
| valid_rag = { | |
| "results": [ | |
| { | |
| "text": "Document chunk", | |
| "similarity": 0.85, | |
| "metadata": {"title": "Test"}, | |
| "doc_id": "doc123" | |
| } | |
| ], | |
| "query": "test query", | |
| "tenant_id": "tenant1", | |
| "hits_count": 1, | |
| "avg_score": 0.85, | |
| "top_score": 0.85, | |
| "latency_ms": 90 | |
| } | |
| is_valid, error = validate_tool_output("rag", valid_rag) | |
| assert is_valid is True | |
| assert error is None | |
| def test_validate_rag_output_missing_field(self): | |
| """Test validation catches missing required fields""" | |
| invalid_rag = { | |
| "results": [], | |
| # Missing "query" and "tenant_id" | |
| "hits_count": 0 | |
| } | |
| is_valid, error = validate_tool_output("rag", invalid_rag) | |
| assert is_valid is False | |
| assert "Missing required field" in error | |
| def test_validate_web_output_valid(self): | |
| """Test validation of valid Web output""" | |
| valid_web = { | |
| "results": [ | |
| { | |
| "title": "Result Title", | |
| "snippet": "Result snippet", | |
| "link": "https://example.com", | |
| "displayLink": "example.com" | |
| } | |
| ], | |
| "query": "search query", | |
| "total_results": 10, | |
| "latency_ms": 800 | |
| } | |
| is_valid, error = validate_tool_output("web", valid_web) | |
| assert is_valid is True | |
| assert error is None | |
| def test_validate_admin_output_valid(self): | |
| """Test validation of valid Admin output""" | |
| valid_admin = { | |
| "violations": [ | |
| { | |
| "rule_id": "rule1", | |
| "rule_pattern": ".*password.*", | |
| "severity": "high", | |
| "matched_text": "password", | |
| "confidence": 0.95, | |
| "message_preview": "User asked for password" | |
| } | |
| ], | |
| "checked": True, | |
| "rules_count": 5, | |
| "latency_ms": 10 | |
| } | |
| is_valid, error = validate_tool_output("admin", valid_admin) | |
| assert is_valid is True | |
| assert error is None | |
| def test_validate_llm_output_valid(self): | |
| """Test validation of valid LLM output""" | |
| valid_llm = { | |
| "text": "Generated response", | |
| "tokens_used": 150, | |
| "latency_ms": 2000, | |
| "model": "llama3.1:latest", | |
| "temperature": 0.0 | |
| } | |
| is_valid, error = validate_tool_output("llm", valid_llm) | |
| assert is_valid is True | |
| assert error is None | |
| def test_validate_type_mismatch(self): | |
| """Test validation catches type mismatches""" | |
| invalid_rag = { | |
| "results": "not an array", # Should be array | |
| "query": "test", | |
| "tenant_id": "tenant1" | |
| } | |
| is_valid, error = validate_tool_output("rag", invalid_rag) | |
| assert is_valid is False | |
| assert "must be array" in error | |
| def test_schema_examples(self): | |
| """Test that all schemas have examples""" | |
| for tool_name, schema in TOOL_OUTPUT_SCHEMAS.items(): | |
| assert schema.example is not None | |
| assert isinstance(schema.example, dict) | |
| # Example should be valid | |
| is_valid, error = validate_tool_output(tool_name, schema.example) | |
| assert is_valid is True, f"Schema example for {tool_name} is invalid: {error}" | |
| class TestContextAwareRouting: | |
| """Test context-aware MCP routing""" | |
| def tool_selector(self): | |
| """Create a ToolSelector instance""" | |
| return ToolSelector(llm_client=None) | |
| def test_analyze_context_rag_high_score(self, tool_selector): | |
| """Test context analysis when RAG returns high score""" | |
| rag_results = [ | |
| {"similarity": 0.85, "text": "High quality result"}, | |
| {"similarity": 0.90, "text": "Another high quality result"} | |
| ] | |
| memory = [] | |
| admin_violations = [] | |
| tool_scores = {"rag_fitness": 0.8, "web_fitness": 0.5} | |
| hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores) | |
| assert hints.get("skip_web_if_rag_high") is True | |
| assert hints.get("rag_high_confidence") is True | |
| def test_analyze_context_rag_low_score(self, tool_selector): | |
| """Test context analysis when RAG returns low score""" | |
| rag_results = [ | |
| {"similarity": 0.3, "text": "Low quality result"} | |
| ] | |
| memory = [] | |
| admin_violations = [] | |
| tool_scores = {"rag_fitness": 0.3, "web_fitness": 0.7} | |
| hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores) | |
| # Should not skip web if RAG score is low | |
| assert hints.get("skip_web_if_rag_high") is not True | |
| def test_analyze_context_memory_relevant(self, tool_selector): | |
| """Test context analysis when relevant memory exists""" | |
| rag_results = [] | |
| memory = [ | |
| { | |
| "tool": "rag", | |
| "result": { | |
| "results": [ | |
| {"similarity": 0.80, "text": "Recent RAG result"} | |
| ] | |
| } | |
| } | |
| ] | |
| admin_violations = [] | |
| tool_scores = {} | |
| hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores) | |
| assert hints.get("has_relevant_memory") is True | |
| # Should suggest skipping RAG if memory is recent and high quality | |
| if memory[0]["result"]["results"][0]["similarity"] >= 0.75: | |
| assert hints.get("skip_rag_if_memory") is True | |
| def test_analyze_context_admin_critical(self, tool_selector): | |
| """Test context analysis when admin violation is critical""" | |
| rag_results = [] | |
| memory = [] | |
| admin_violations = [ | |
| { | |
| "severity": "critical", | |
| "rule_id": "rule1", | |
| "matched_text": "sensitive data" | |
| } | |
| ] | |
| tool_scores = {} | |
| hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores) | |
| assert hints.get("skip_agent_reasoning") is True | |
| assert hints.get("critical_violation") is True | |
| def test_analyze_context_admin_low_severity(self, tool_selector): | |
| """Test context analysis when admin violation is low severity""" | |
| rag_results = [] | |
| memory = [] | |
| admin_violations = [ | |
| { | |
| "severity": "low", | |
| "rule_id": "rule1", | |
| "matched_text": "minor issue" | |
| } | |
| ] | |
| tool_scores = {} | |
| hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores) | |
| # Low severity should not skip reasoning | |
| assert hints.get("skip_agent_reasoning") is not True | |
| async def test_tool_selection_with_context_hints(self, tool_selector): | |
| """Test tool selection uses context hints""" | |
| # Mock LLM client | |
| tool_selector.llm_client = AsyncMock() | |
| # Context with high RAG score | |
| ctx = { | |
| "tenant_id": "test_tenant", | |
| "rag_results": [ | |
| {"similarity": 0.85, "text": "High quality result"} | |
| ], | |
| "tool_scores": { | |
| "rag_fitness": 0.8, | |
| "web_fitness": 0.6, | |
| "llm_only": 0.3 | |
| }, | |
| "memory": [], | |
| "admin_violations": [] | |
| } | |
| decision = await tool_selector.select("general", "What is our company policy?", ctx) | |
| # Should include latency estimates in reason | |
| assert "latency" in decision.reason.lower() or "est." in decision.reason.lower() | |
| # Check that steps have latency estimates (for non-LLM tools) | |
| if decision.tool_input and "steps" in decision.tool_input: | |
| steps = decision.tool_input["steps"] | |
| for step in steps: | |
| if isinstance(step, dict) and "input" in step and step.get("tool") != "llm": | |
| # Non-LLM tools should have estimated latency (or be parallel) | |
| assert "_estimated_latency_ms" in step["input"] or "parallel" in step or step.get("tool") == "llm" | |
| async def test_tool_selection_skips_web_on_high_rag(self, tool_selector): | |
| """Test that tool selection skips web when RAG has high score""" | |
| tool_selector.llm_client = AsyncMock() | |
| ctx = { | |
| "tenant_id": "test_tenant", | |
| "rag_results": [ | |
| {"similarity": 0.90, "text": "Very high quality result"} | |
| ], | |
| "tool_scores": { | |
| "rag_fitness": 0.9, | |
| "web_fitness": 0.7, | |
| "llm_only": 0.2 | |
| }, | |
| "memory": [], | |
| "admin_violations": [] | |
| } | |
| decision = await tool_selector.select("general", "What is our internal policy?", ctx) | |
| # Check reason includes context hint | |
| assert "skip web" in decision.reason.lower() or "rag high" in decision.reason.lower() or "context" in decision.reason.lower() | |
| async def test_tool_selection_admin_critical_skip_reasoning(self, tool_selector): | |
| """Test that tool selection skips reasoning for critical admin violations""" | |
| tool_selector.llm_client = None # No LLM needed for admin-only path | |
| ctx = { | |
| "tenant_id": "test_tenant", | |
| "rag_results": [], | |
| "tool_scores": {}, | |
| "memory": [], | |
| "admin_violations": [ | |
| { | |
| "severity": "critical", | |
| "rule_id": "rule1", | |
| "matched_text": "critical violation" | |
| } | |
| ] | |
| } | |
| decision = await tool_selector.select("admin", "User trying to access sensitive data", ctx) | |
| # Should skip LLM reasoning for critical violations | |
| if decision.tool_input and "steps" in decision.tool_input: | |
| steps = decision.tool_input["steps"] | |
| # Should have admin step but may skip LLM | |
| has_admin = any(s.get("tool") == "admin" for s in steps if isinstance(s, dict)) | |
| assert has_admin | |
| class TestOrchestratorIntegration: | |
| """Test orchestrator integration with new features""" | |
| def orchestrator(self): | |
| """Create an AgentOrchestrator instance""" | |
| return AgentOrchestrator( | |
| rag_mcp_url="http://localhost:8900/rag", | |
| web_mcp_url="http://localhost:8900/web", | |
| admin_mcp_url="http://localhost:8900/admin", | |
| llm_backend="ollama" | |
| ) | |
| def test_format_rag_output(self, orchestrator): | |
| """Test RAG output formatting""" | |
| raw_output = { | |
| "results": [ | |
| {"text": "Chunk 1", "similarity": 0.85}, | |
| {"text": "Chunk 2", "similarity": 0.75} | |
| ], | |
| "query": "test query" | |
| } | |
| formatted = orchestrator._format_tool_output("rag", raw_output, 90) | |
| # Check schema compliance | |
| assert "results" in formatted | |
| assert "query" in formatted | |
| assert "tenant_id" in formatted | |
| assert "hits_count" in formatted | |
| assert "avg_score" in formatted | |
| assert "top_score" in formatted | |
| assert "latency_ms" in formatted | |
| # Validate against schema | |
| is_valid, error = validate_tool_output("rag", formatted) | |
| assert is_valid is True, f"Formatted RAG output invalid: {error}" | |
| def test_format_web_output(self, orchestrator): | |
| """Test Web output formatting""" | |
| raw_output = { | |
| "items": [ | |
| { | |
| "title": "Result Title", | |
| "snippet": "Result snippet", | |
| "link": "https://example.com" | |
| } | |
| ] | |
| } | |
| formatted = orchestrator._format_tool_output("web", raw_output, 800) | |
| # Check schema compliance | |
| assert "results" in formatted | |
| assert "query" in formatted | |
| assert "total_results" in formatted | |
| assert "latency_ms" in formatted | |
| # Validate against schema | |
| is_valid, error = validate_tool_output("web", formatted) | |
| assert is_valid is True, f"Formatted Web output invalid: {error}" | |
| def test_format_admin_output(self, orchestrator): | |
| """Test Admin output formatting""" | |
| raw_output = { | |
| "matches": [ | |
| { | |
| "rule_id": "rule1", | |
| "pattern": ".*password.*", | |
| "severity": "high", | |
| "text": "password", | |
| "confidence": 0.95 | |
| } | |
| ] | |
| } | |
| formatted = orchestrator._format_tool_output("admin", raw_output, 10) | |
| # Check schema compliance | |
| assert "violations" in formatted | |
| assert "checked" in formatted | |
| assert "rules_count" in formatted | |
| assert "latency_ms" in formatted | |
| # Validate against schema | |
| is_valid, error = validate_tool_output("admin", formatted) | |
| assert is_valid is True, f"Formatted Admin output invalid: {error}" | |
| def test_format_llm_output(self, orchestrator): | |
| """Test LLM output formatting""" | |
| raw_output = "This is a generated response from the LLM." | |
| formatted = orchestrator._format_tool_output("llm", raw_output, 2000) | |
| # Check schema compliance | |
| assert "text" in formatted | |
| assert "tokens_used" in formatted | |
| assert "latency_ms" in formatted | |
| assert "model" in formatted | |
| assert "temperature" in formatted | |
| # Validate against schema | |
| is_valid, error = validate_tool_output("llm", formatted) | |
| assert is_valid is True, f"Formatted LLM output invalid: {error}" | |
| def test_format_output_handles_missing_fields(self, orchestrator): | |
| """Test output formatting handles missing fields gracefully""" | |
| # Minimal RAG output | |
| minimal = {"results": []} | |
| formatted = orchestrator._format_tool_output("rag", minimal, 90) | |
| # Should have all required fields with defaults | |
| assert "query" in formatted | |
| assert "tenant_id" in formatted | |
| assert "hits_count" in formatted | |
| assert formatted["hits_count"] == 0 | |
| class TestEndToEndRouting: | |
| """End-to-end tests for context-aware routing""" | |
| async def test_routing_with_high_rag_score(self): | |
| """Test that high RAG score prevents web search""" | |
| selector = ToolSelector(llm_client=None) | |
| ctx = { | |
| "tenant_id": "test", | |
| "rag_results": [{"similarity": 0.92, "text": "Perfect match"}], | |
| "tool_scores": {"rag_fitness": 0.9, "web_fitness": 0.7}, | |
| "memory": [], | |
| "admin_violations": [] | |
| } | |
| decision = await selector.select("general", "What is our policy?", ctx) | |
| # Check that context hints are applied | |
| if decision.tool_input and "steps" in decision.tool_input: | |
| steps = decision.tool_input["steps"] | |
| tool_names = [s.get("tool") for s in steps if isinstance(s, dict) and "tool" in s] | |
| # Should have RAG but may skip web due to high score | |
| assert "rag" in tool_names or "llm" in tool_names | |
| async def test_routing_with_memory(self): | |
| """Test that relevant memory prevents redundant RAG call""" | |
| selector = ToolSelector(llm_client=None) | |
| ctx = { | |
| "tenant_id": "test", | |
| "rag_results": [], | |
| "tool_scores": {"rag_fitness": 0.6}, | |
| "memory": [ | |
| { | |
| "tool": "rag", | |
| "result": { | |
| "results": [{"similarity": 0.85, "text": "Recent result"}] | |
| } | |
| } | |
| ], | |
| "admin_violations": [] | |
| } | |
| decision = await selector.select("general", "Tell me about our policy", ctx) | |
| # Context should be analyzed | |
| # (Actual behavior depends on implementation, but should use memory) | |
| assert decision is not None | |
| if __name__ == "__main__": | |
| pytest.main([__file__, "-v", "--tb=short"]) | |