IntegraChat / backend /tests /test_tool_metadata_and_routing.py
nothingworry's picture
feat: Add AI metadata extraction, latency prediction, context-aware routing, and tool output schemas
d1e5882
raw
history blame
21.4 kB
"""
Comprehensive tests for:
1. Per-Tool Latency Prediction
2. Context-Aware MCP Routing
3. Tool Output Schemas
Tests all three new features for intelligent tool selection and output validation.
"""
import pytest
from unittest.mock import Mock, patch, AsyncMock
from backend.api.services.tool_metadata import (
get_tool_latency_estimate,
estimate_path_latency,
get_fastest_path,
validate_tool_output,
get_tool_schema,
TOOL_LATENCY_METADATA,
TOOL_OUTPUT_SCHEMAS
)
from backend.api.services.tool_selector import ToolSelector
from backend.api.services.agent_orchestrator import AgentOrchestrator
class TestLatencyPrediction:
"""Test per-tool latency prediction"""
def test_get_tool_latency_estimate_basic(self):
"""Test basic latency estimation without context"""
rag_latency = get_tool_latency_estimate("rag")
web_latency = get_tool_latency_estimate("web")
admin_latency = get_tool_latency_estimate("admin")
llm_latency = get_tool_latency_estimate("llm")
# Check that latencies are within expected ranges
assert 60 <= rag_latency <= 120
assert 400 <= web_latency <= 1800
assert 5 <= admin_latency <= 20
assert 500 <= llm_latency <= 5000
def test_get_tool_latency_estimate_with_context(self):
"""Test latency estimation with context"""
# RAG with long query
rag_long = get_tool_latency_estimate("rag", {"query_length": 200})
rag_short = get_tool_latency_estimate("rag", {"query_length": 10})
assert rag_long >= rag_short # Longer queries should take more time
# Web with complexity
web_complex = get_tool_latency_estimate("web", {"query_complexity": "high"})
web_simple = get_tool_latency_estimate("web", {"query_complexity": "low"})
assert web_complex >= web_simple # Complex queries should take more time
def test_estimate_path_latency(self):
"""Test total latency estimation for tool sequences"""
# Single tool
single = estimate_path_latency(["admin"])
assert single > 0
assert single <= 20
# Multiple tools
multi = estimate_path_latency(["rag", "web", "llm"])
assert multi > 0
# Should be sum of individual latencies
assert multi >= get_tool_latency_estimate("rag")
assert multi >= get_tool_latency_estimate("web")
assert multi >= get_tool_latency_estimate("llm")
def test_get_fastest_path(self):
"""Test fastest path optimization"""
tools = ["llm", "admin", "rag", "web"]
fastest = get_fastest_path(tools)
# Should be sorted by latency (fastest first)
assert len(fastest) == len(tools)
assert "admin" in fastest # Fastest tool
assert fastest[0] == "admin" # Should be first
# Verify order is optimized
latencies = [get_tool_latency_estimate(t) for t in fastest]
assert latencies == sorted(latencies) # Should be in ascending order
def test_latency_metadata_structure(self):
"""Test that latency metadata has correct structure"""
for tool_name, metadata in TOOL_LATENCY_METADATA.items():
assert metadata.tool_name == tool_name
assert metadata.min_ms > 0
assert metadata.max_ms >= metadata.min_ms
assert metadata.avg_ms >= metadata.min_ms
assert metadata.avg_ms <= metadata.max_ms
assert len(metadata.description) > 0
class TestToolOutputSchemas:
"""Test tool output schema validation"""
def test_get_tool_schema(self):
"""Test schema retrieval"""
rag_schema = get_tool_schema("rag")
web_schema = get_tool_schema("web")
admin_schema = get_tool_schema("admin")
llm_schema = get_tool_schema("llm")
assert rag_schema is not None
assert web_schema is not None
assert admin_schema is not None
assert llm_schema is not None
assert rag_schema.tool_name == "rag"
assert web_schema.tool_name == "web"
assert admin_schema.tool_name == "admin"
assert llm_schema.tool_name == "llm"
def test_validate_rag_output_valid(self):
"""Test validation of valid RAG output"""
valid_rag = {
"results": [
{
"text": "Document chunk",
"similarity": 0.85,
"metadata": {"title": "Test"},
"doc_id": "doc123"
}
],
"query": "test query",
"tenant_id": "tenant1",
"hits_count": 1,
"avg_score": 0.85,
"top_score": 0.85,
"latency_ms": 90
}
is_valid, error = validate_tool_output("rag", valid_rag)
assert is_valid is True
assert error is None
def test_validate_rag_output_missing_field(self):
"""Test validation catches missing required fields"""
invalid_rag = {
"results": [],
# Missing "query" and "tenant_id"
"hits_count": 0
}
is_valid, error = validate_tool_output("rag", invalid_rag)
assert is_valid is False
assert "Missing required field" in error
def test_validate_web_output_valid(self):
"""Test validation of valid Web output"""
valid_web = {
"results": [
{
"title": "Result Title",
"snippet": "Result snippet",
"link": "https://example.com",
"displayLink": "example.com"
}
],
"query": "search query",
"total_results": 10,
"latency_ms": 800
}
is_valid, error = validate_tool_output("web", valid_web)
assert is_valid is True
assert error is None
def test_validate_admin_output_valid(self):
"""Test validation of valid Admin output"""
valid_admin = {
"violations": [
{
"rule_id": "rule1",
"rule_pattern": ".*password.*",
"severity": "high",
"matched_text": "password",
"confidence": 0.95,
"message_preview": "User asked for password"
}
],
"checked": True,
"rules_count": 5,
"latency_ms": 10
}
is_valid, error = validate_tool_output("admin", valid_admin)
assert is_valid is True
assert error is None
def test_validate_llm_output_valid(self):
"""Test validation of valid LLM output"""
valid_llm = {
"text": "Generated response",
"tokens_used": 150,
"latency_ms": 2000,
"model": "llama3.1:latest",
"temperature": 0.0
}
is_valid, error = validate_tool_output("llm", valid_llm)
assert is_valid is True
assert error is None
def test_validate_type_mismatch(self):
"""Test validation catches type mismatches"""
invalid_rag = {
"results": "not an array", # Should be array
"query": "test",
"tenant_id": "tenant1"
}
is_valid, error = validate_tool_output("rag", invalid_rag)
assert is_valid is False
assert "must be array" in error
def test_schema_examples(self):
"""Test that all schemas have examples"""
for tool_name, schema in TOOL_OUTPUT_SCHEMAS.items():
assert schema.example is not None
assert isinstance(schema.example, dict)
# Example should be valid
is_valid, error = validate_tool_output(tool_name, schema.example)
assert is_valid is True, f"Schema example for {tool_name} is invalid: {error}"
class TestContextAwareRouting:
"""Test context-aware MCP routing"""
@pytest.fixture
def tool_selector(self):
"""Create a ToolSelector instance"""
return ToolSelector(llm_client=None)
def test_analyze_context_rag_high_score(self, tool_selector):
"""Test context analysis when RAG returns high score"""
rag_results = [
{"similarity": 0.85, "text": "High quality result"},
{"similarity": 0.90, "text": "Another high quality result"}
]
memory = []
admin_violations = []
tool_scores = {"rag_fitness": 0.8, "web_fitness": 0.5}
hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores)
assert hints.get("skip_web_if_rag_high") is True
assert hints.get("rag_high_confidence") is True
def test_analyze_context_rag_low_score(self, tool_selector):
"""Test context analysis when RAG returns low score"""
rag_results = [
{"similarity": 0.3, "text": "Low quality result"}
]
memory = []
admin_violations = []
tool_scores = {"rag_fitness": 0.3, "web_fitness": 0.7}
hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores)
# Should not skip web if RAG score is low
assert hints.get("skip_web_if_rag_high") is not True
def test_analyze_context_memory_relevant(self, tool_selector):
"""Test context analysis when relevant memory exists"""
rag_results = []
memory = [
{
"tool": "rag",
"result": {
"results": [
{"similarity": 0.80, "text": "Recent RAG result"}
]
}
}
]
admin_violations = []
tool_scores = {}
hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores)
assert hints.get("has_relevant_memory") is True
# Should suggest skipping RAG if memory is recent and high quality
if memory[0]["result"]["results"][0]["similarity"] >= 0.75:
assert hints.get("skip_rag_if_memory") is True
def test_analyze_context_admin_critical(self, tool_selector):
"""Test context analysis when admin violation is critical"""
rag_results = []
memory = []
admin_violations = [
{
"severity": "critical",
"rule_id": "rule1",
"matched_text": "sensitive data"
}
]
tool_scores = {}
hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores)
assert hints.get("skip_agent_reasoning") is True
assert hints.get("critical_violation") is True
def test_analyze_context_admin_low_severity(self, tool_selector):
"""Test context analysis when admin violation is low severity"""
rag_results = []
memory = []
admin_violations = [
{
"severity": "low",
"rule_id": "rule1",
"matched_text": "minor issue"
}
]
tool_scores = {}
hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores)
# Low severity should not skip reasoning
assert hints.get("skip_agent_reasoning") is not True
@pytest.mark.asyncio
async def test_tool_selection_with_context_hints(self, tool_selector):
"""Test tool selection uses context hints"""
# Mock LLM client
tool_selector.llm_client = AsyncMock()
# Context with high RAG score
ctx = {
"tenant_id": "test_tenant",
"rag_results": [
{"similarity": 0.85, "text": "High quality result"}
],
"tool_scores": {
"rag_fitness": 0.8,
"web_fitness": 0.6,
"llm_only": 0.3
},
"memory": [],
"admin_violations": []
}
decision = await tool_selector.select("general", "What is our company policy?", ctx)
# Should include latency estimates in reason
assert "latency" in decision.reason.lower() or "est." in decision.reason.lower()
# Check that steps have latency estimates (for non-LLM tools)
if decision.tool_input and "steps" in decision.tool_input:
steps = decision.tool_input["steps"]
for step in steps:
if isinstance(step, dict) and "input" in step and step.get("tool") != "llm":
# Non-LLM tools should have estimated latency (or be parallel)
assert "_estimated_latency_ms" in step["input"] or "parallel" in step or step.get("tool") == "llm"
@pytest.mark.asyncio
async def test_tool_selection_skips_web_on_high_rag(self, tool_selector):
"""Test that tool selection skips web when RAG has high score"""
tool_selector.llm_client = AsyncMock()
ctx = {
"tenant_id": "test_tenant",
"rag_results": [
{"similarity": 0.90, "text": "Very high quality result"}
],
"tool_scores": {
"rag_fitness": 0.9,
"web_fitness": 0.7,
"llm_only": 0.2
},
"memory": [],
"admin_violations": []
}
decision = await tool_selector.select("general", "What is our internal policy?", ctx)
# Check reason includes context hint
assert "skip web" in decision.reason.lower() or "rag high" in decision.reason.lower() or "context" in decision.reason.lower()
@pytest.mark.asyncio
async def test_tool_selection_admin_critical_skip_reasoning(self, tool_selector):
"""Test that tool selection skips reasoning for critical admin violations"""
tool_selector.llm_client = None # No LLM needed for admin-only path
ctx = {
"tenant_id": "test_tenant",
"rag_results": [],
"tool_scores": {},
"memory": [],
"admin_violations": [
{
"severity": "critical",
"rule_id": "rule1",
"matched_text": "critical violation"
}
]
}
decision = await tool_selector.select("admin", "User trying to access sensitive data", ctx)
# Should skip LLM reasoning for critical violations
if decision.tool_input and "steps" in decision.tool_input:
steps = decision.tool_input["steps"]
# Should have admin step but may skip LLM
has_admin = any(s.get("tool") == "admin" for s in steps if isinstance(s, dict))
assert has_admin
class TestOrchestratorIntegration:
"""Test orchestrator integration with new features"""
@pytest.fixture
def orchestrator(self):
"""Create an AgentOrchestrator instance"""
return AgentOrchestrator(
rag_mcp_url="http://localhost:8900/rag",
web_mcp_url="http://localhost:8900/web",
admin_mcp_url="http://localhost:8900/admin",
llm_backend="ollama"
)
def test_format_rag_output(self, orchestrator):
"""Test RAG output formatting"""
raw_output = {
"results": [
{"text": "Chunk 1", "similarity": 0.85},
{"text": "Chunk 2", "similarity": 0.75}
],
"query": "test query"
}
formatted = orchestrator._format_tool_output("rag", raw_output, 90)
# Check schema compliance
assert "results" in formatted
assert "query" in formatted
assert "tenant_id" in formatted
assert "hits_count" in formatted
assert "avg_score" in formatted
assert "top_score" in formatted
assert "latency_ms" in formatted
# Validate against schema
is_valid, error = validate_tool_output("rag", formatted)
assert is_valid is True, f"Formatted RAG output invalid: {error}"
def test_format_web_output(self, orchestrator):
"""Test Web output formatting"""
raw_output = {
"items": [
{
"title": "Result Title",
"snippet": "Result snippet",
"link": "https://example.com"
}
]
}
formatted = orchestrator._format_tool_output("web", raw_output, 800)
# Check schema compliance
assert "results" in formatted
assert "query" in formatted
assert "total_results" in formatted
assert "latency_ms" in formatted
# Validate against schema
is_valid, error = validate_tool_output("web", formatted)
assert is_valid is True, f"Formatted Web output invalid: {error}"
def test_format_admin_output(self, orchestrator):
"""Test Admin output formatting"""
raw_output = {
"matches": [
{
"rule_id": "rule1",
"pattern": ".*password.*",
"severity": "high",
"text": "password",
"confidence": 0.95
}
]
}
formatted = orchestrator._format_tool_output("admin", raw_output, 10)
# Check schema compliance
assert "violations" in formatted
assert "checked" in formatted
assert "rules_count" in formatted
assert "latency_ms" in formatted
# Validate against schema
is_valid, error = validate_tool_output("admin", formatted)
assert is_valid is True, f"Formatted Admin output invalid: {error}"
def test_format_llm_output(self, orchestrator):
"""Test LLM output formatting"""
raw_output = "This is a generated response from the LLM."
formatted = orchestrator._format_tool_output("llm", raw_output, 2000)
# Check schema compliance
assert "text" in formatted
assert "tokens_used" in formatted
assert "latency_ms" in formatted
assert "model" in formatted
assert "temperature" in formatted
# Validate against schema
is_valid, error = validate_tool_output("llm", formatted)
assert is_valid is True, f"Formatted LLM output invalid: {error}"
def test_format_output_handles_missing_fields(self, orchestrator):
"""Test output formatting handles missing fields gracefully"""
# Minimal RAG output
minimal = {"results": []}
formatted = orchestrator._format_tool_output("rag", minimal, 90)
# Should have all required fields with defaults
assert "query" in formatted
assert "tenant_id" in formatted
assert "hits_count" in formatted
assert formatted["hits_count"] == 0
class TestEndToEndRouting:
"""End-to-end tests for context-aware routing"""
@pytest.mark.asyncio
async def test_routing_with_high_rag_score(self):
"""Test that high RAG score prevents web search"""
selector = ToolSelector(llm_client=None)
ctx = {
"tenant_id": "test",
"rag_results": [{"similarity": 0.92, "text": "Perfect match"}],
"tool_scores": {"rag_fitness": 0.9, "web_fitness": 0.7},
"memory": [],
"admin_violations": []
}
decision = await selector.select("general", "What is our policy?", ctx)
# Check that context hints are applied
if decision.tool_input and "steps" in decision.tool_input:
steps = decision.tool_input["steps"]
tool_names = [s.get("tool") for s in steps if isinstance(s, dict) and "tool" in s]
# Should have RAG but may skip web due to high score
assert "rag" in tool_names or "llm" in tool_names
@pytest.mark.asyncio
async def test_routing_with_memory(self):
"""Test that relevant memory prevents redundant RAG call"""
selector = ToolSelector(llm_client=None)
ctx = {
"tenant_id": "test",
"rag_results": [],
"tool_scores": {"rag_fitness": 0.6},
"memory": [
{
"tool": "rag",
"result": {
"results": [{"similarity": 0.85, "text": "Recent result"}]
}
}
],
"admin_violations": []
}
decision = await selector.select("general", "Tell me about our policy", ctx)
# Context should be analyzed
# (Actual behavior depends on implementation, but should use memory)
assert decision is not None
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])