agentbench / tests /test_agent.py
Nomearod's picture
feat: enrich SearchTool metadata with scores, previews, PII count
1d16fd9
"""Tests for the agent orchestrator."""
from __future__ import annotations
import pytest
from agent_bench.agents.orchestrator import AgentResponse, Orchestrator
from agent_bench.core.provider import MockProvider
from agent_bench.core.types import (
CompletionResponse,
Role,
TokenUsage,
ToolCall,
)
from agent_bench.tools.base import Tool, ToolOutput
from agent_bench.tools.calculator import CalculatorTool
from agent_bench.tools.registry import ToolRegistry
# --- Helpers ---
class FakeSearchTool(Tool):
"""Deterministic search tool for orchestrator tests."""
name = "search_documents"
description = "Search docs"
parameters = {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
}
async def execute(self, **kwargs: object) -> ToolOutput:
return ToolOutput(
success=True,
result="[1] (fastapi_path_params.md): Path parameters use curly braces.",
metadata={
"sources": ["fastapi_path_params.md"],
"ranked_sources": ["fastapi_path_params.md"],
"source_chunks": ["Path parameters use curly braces."],
"max_score": 0.85,
"pre_rerank_count": 10,
"chunks": [{"source": "fastapi_path_params.md", "score": 0.85,
"preview": "Path parameters use curly braces."}],
"pii_redactions_count": 0,
},
)
class AlwaysToolCallProvider(MockProvider):
"""Provider that always returns tool_calls, never a final answer."""
async def complete(self, messages, tools=None, temperature=0.0, max_tokens=1024):
self.call_count += 1
if tools is None:
# Forced final call (no tools) — return text
return CompletionResponse(
content="Forced answer after max iterations.",
tool_calls=[],
usage=TokenUsage(input_tokens=100, output_tokens=20, estimated_cost_usd=0.0001),
provider="mock",
model="mock-1",
latency_ms=1.0,
)
return CompletionResponse(
content="",
tool_calls=[
ToolCall(
id=f"call_{self.call_count}",
name="search_documents",
arguments={"query": "test"},
)
],
usage=TokenUsage(input_tokens=100, output_tokens=20, estimated_cost_usd=0.0001),
provider="mock",
model="mock-1",
latency_ms=1.0,
)
class MultiSearchProvider(MockProvider):
"""Provider that searches twice (different queries), then answers."""
async def complete(self, messages, tools=None, temperature=0.0, max_tokens=1024):
self.call_count += 1
tool_results = [m for m in messages if m.role == Role.TOOL]
if tools and len(tool_results) == 0:
return CompletionResponse(
content="",
tool_calls=[
ToolCall(id="call_1", name="search_documents", arguments={"query": "first"})
],
usage=TokenUsage(input_tokens=100, output_tokens=20, estimated_cost_usd=0.0001),
provider="mock",
model="mock-1",
latency_ms=1.0,
)
elif tools and len(tool_results) == 1:
return CompletionResponse(
content="",
tool_calls=[
ToolCall(id="call_2", name="search_documents", arguments={"query": "second"})
],
usage=TokenUsage(input_tokens=150, output_tokens=25, estimated_cost_usd=0.0002),
provider="mock",
model="mock-1",
latency_ms=1.0,
)
else:
return CompletionResponse(
content="Answer from two searches. [source: fastapi_path_params.md]",
tool_calls=[],
usage=TokenUsage(input_tokens=200, output_tokens=50, estimated_cost_usd=0.0003),
provider="mock",
model="mock-1",
latency_ms=2.0,
)
def _make_registry() -> ToolRegistry:
registry = ToolRegistry()
registry.register(FakeSearchTool())
registry.register(CalculatorTool())
return registry
SYSTEM_PROMPT = "You are a helpful assistant."
# --- Tests ---
class TestOrchestrator:
@pytest.mark.asyncio
async def test_produces_agent_response_with_all_fields(self):
"""Orchestrator returns AgentResponse with all required fields."""
orchestrator = Orchestrator(
provider=MockProvider(), registry=_make_registry(), max_iterations=3
)
response = await orchestrator.run("How do path params work?", SYSTEM_PROMPT)
assert isinstance(response, AgentResponse)
assert len(response.answer) > 0
assert response.iterations >= 1
assert response.usage.input_tokens > 0
assert response.usage.output_tokens > 0
assert response.latency_ms > 0
assert isinstance(response.sources, list)
assert isinstance(response.tools_used, list)
@pytest.mark.asyncio
async def test_respects_max_iterations(self):
"""When provider always returns tool_calls, orchestrator stops at max_iterations."""
provider = AlwaysToolCallProvider()
orchestrator = Orchestrator(provider=provider, registry=_make_registry(), max_iterations=2)
response = await orchestrator.run("test question", SYSTEM_PROMPT)
# 2 iterations of tool calls + 1 forced final call = 3 provider calls total
assert provider.call_count == 3
assert response.iterations == 2
assert response.answer == "Forced answer after max iterations."
@pytest.mark.asyncio
async def test_accumulates_sources_from_multiple_searches(self):
"""Sources from multiple search calls are accumulated and deduplicated."""
orchestrator = Orchestrator(
provider=MultiSearchProvider(), registry=_make_registry(), max_iterations=3
)
response = await orchestrator.run("multi search question", SYSTEM_PROMPT)
# FakeSearchTool always returns fastapi_path_params.md
assert len(response.sources) == 1 # deduplicated
assert response.sources[0].source == "fastapi_path_params.md"
assert response.tools_used.count("search_documents") == 2
# Token usage accumulated across 3 provider calls
assert response.usage.input_tokens == 100 + 150 + 200
assert response.usage.output_tokens == 20 + 25 + 50
@pytest.mark.asyncio
async def test_deterministic_output(self):
"""Fixed question + MockProvider → exact expected answer."""
orchestrator = Orchestrator(
provider=MockProvider(), registry=_make_registry(), max_iterations=3
)
response = await orchestrator.run("How do path params work?", SYSTEM_PROMPT)
# MockProvider: first call returns tool_calls for search_documents,
# second call (with tool results) returns the canned answer
assert "path parameters" in response.answer.lower()
assert "[source: fastapi_path_params.md]" in response.answer
assert "search_documents" in response.tools_used
assert response.iterations == 2
class TestOrchestratorIntegration:
"""Integration test using real SearchTool + Retriever + HybridStore."""
@pytest.mark.asyncio
async def test_real_rag_path(self, test_retriever):
"""Orchestrator with MockProvider + real SearchTool/Retriever returns RAG results."""
from agent_bench.tools.search import SearchTool
registry = ToolRegistry()
registry.register(SearchTool(retriever=test_retriever))
registry.register(CalculatorTool())
orchestrator = Orchestrator(provider=MockProvider(), registry=registry, max_iterations=3)
response = await orchestrator.run(
"How do path params work?", SYSTEM_PROMPT, top_k=3, strategy="hybrid"
)
# MockProvider drives the loop, but the real SearchTool executes
# against the real Retriever/HybridStore and returns real chunks
assert isinstance(response, AgentResponse)
assert len(response.answer) > 0
assert "search_documents" in response.tools_used
# Sources come from the real store (sample_chunks in conftest)
assert len(response.sources) > 0