Spaces:
Running
Running
File size: 12,338 Bytes
36a9ab7 27c2e17 36a9ab7 27c2e17 36a9ab7 c5573d3 36a9ab7 3fc43ec c5573d3 36a9ab7 8766b82 36a9ab7 8766b82 36a9ab7 c410788 27c2e17 c410788 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 | """Tests for tool system: registry, calculator, search, and schema generation."""
from __future__ import annotations
import uuid
from dataclasses import dataclass, field
import pytest
from agent_bench.tools.calculator import CalculatorTool
from agent_bench.tools.registry import ToolRegistry
from agent_bench.tools.search import SearchTool
# --- Mock retriever for SearchTool tests ---
@dataclass
class MockChunk:
content: str
source: str
id: str = field(default_factory=lambda: uuid.uuid4().hex[:16])
@dataclass
class MockSearchResult:
chunk: MockChunk
score: float
class MockRetrievalResult:
"""Mimics RetrievalResult for test mocks."""
def __init__(self, results: list, pre_rerank_count: int = 0) -> None:
self.results = results
self.pre_rerank_count = pre_rerank_count
class MockRetriever:
"""Fake retriever that returns canned results."""
def __init__(self, results: list[MockSearchResult] | None = None) -> None:
self._results = results or []
async def search(
self, query: str, top_k: int = 5, strategy: str | None = None
) -> MockRetrievalResult:
return MockRetrievalResult(results=self._results[:top_k])
# --- Registry tests ---
class TestToolRegistry:
def test_register_and_retrieve(self):
registry = ToolRegistry()
tool = CalculatorTool()
registry.register(tool)
assert registry.get("calculator") is tool
def test_get_unknown_returns_none(self):
registry = ToolRegistry()
assert registry.get("nonexistent") is None
@pytest.mark.asyncio
async def test_execute_unknown_tool(self):
registry = ToolRegistry()
result = await registry.execute("nonexistent", query="test")
assert result.success is False
assert "Unknown tool: nonexistent" in result.result
@pytest.mark.asyncio
async def test_execute_registered_tool(self):
"""Registry dispatches to a registered tool and returns its output."""
registry = ToolRegistry()
registry.register(CalculatorTool())
result = await registry.execute("calculator", expression="1 + 2")
assert result.success is True
assert result.result == "3"
def test_get_definitions(self):
registry = ToolRegistry()
registry.register(CalculatorTool())
registry.register(SearchTool(retriever=MockRetriever()))
defs = registry.get_definitions()
assert len(defs) == 2
names = {d.name for d in defs}
assert names == {"calculator", "search_documents"}
# --- Calculator tests ---
class TestCalculatorTool:
@pytest.mark.asyncio
async def test_valid_expression(self):
calc = CalculatorTool()
result = await calc.execute(expression="2 + 3 * 4")
assert result.success is True
assert result.result == "14"
@pytest.mark.asyncio
async def test_float_expression(self):
calc = CalculatorTool()
result = await calc.execute(expression="10 / 3")
assert result.success is True
assert float(result.result) == pytest.approx(3.333333, rel=1e-4)
@pytest.mark.asyncio
async def test_rejects_import(self):
calc = CalculatorTool()
result = await calc.execute(expression="__import__('os').system('ls')")
assert result.success is False
assert "Could not evaluate" in result.result
@pytest.mark.asyncio
async def test_rejects_exec(self):
calc = CalculatorTool()
result = await calc.execute(expression="exec('print(1)')")
assert result.success is False
@pytest.mark.asyncio
async def test_empty_expression(self):
calc = CalculatorTool()
result = await calc.execute(expression="")
assert result.success is False
def test_definition_produces_valid_schema(self):
calc = CalculatorTool()
defn = calc.definition()
assert defn.name == "calculator"
assert defn.parameters["type"] == "object"
assert "expression" in defn.parameters["properties"]
assert "expression" in defn.parameters["required"]
# --- Search tool tests ---
class TestSearchTool:
@pytest.mark.asyncio
async def test_returns_formatted_results(self):
retriever = MockRetriever(
results=[
MockSearchResult(
chunk=MockChunk(
content="Path parameters are defined using curly braces.",
source="fastapi_path_params.md",
),
score=0.95,
),
MockSearchResult(
chunk=MockChunk(
content="Query parameters are automatically parsed.",
source="fastapi_query_params.md",
),
score=0.82,
),
]
)
tool = SearchTool(retriever=retriever)
result = await tool.execute(query="path parameters")
assert result.success is True
assert "[1] (fastapi_path_params.md):" in result.result
assert "[2] (fastapi_query_params.md):" in result.result
assert result.metadata["sources"] == [
"fastapi_path_params.md",
"fastapi_query_params.md",
]
@pytest.mark.asyncio
async def test_empty_results(self):
tool = SearchTool(retriever=MockRetriever(results=[]))
result = await tool.execute(query="nonexistent topic")
assert result.success is True
assert "No relevant documents found" in result.result
assert result.metadata["sources"] == []
@pytest.mark.asyncio
async def test_deduplicates_sources(self):
retriever = MockRetriever(
results=[
MockSearchResult(
chunk=MockChunk(content="Chunk 1", source="same_file.md"),
score=0.9,
),
MockSearchResult(
chunk=MockChunk(content="Chunk 2", source="same_file.md"),
score=0.8,
),
]
)
tool = SearchTool(retriever=retriever)
result = await tool.execute(query="test")
assert result.metadata["sources"] == ["same_file.md"]
@pytest.mark.asyncio
async def test_malformed_top_k_falls_back_to_default(self):
"""Model-generated bad top_k (e.g. 'five') should not crash."""
retriever = MockRetriever(
results=[
MockSearchResult(
chunk=MockChunk(content="Some content", source="test.md"),
score=0.9,
),
]
)
tool = SearchTool(retriever=retriever)
result = await tool.execute(query="test", top_k="five")
assert result.success is True
@pytest.mark.asyncio
async def test_empty_query(self):
tool = SearchTool(retriever=MockRetriever())
result = await tool.execute(query="")
assert result.success is False
def test_definition_produces_valid_schema(self):
tool = SearchTool(retriever=MockRetriever())
defn = tool.definition()
assert defn.name == "search_documents"
assert defn.parameters["type"] == "object"
assert "query" in defn.parameters["properties"]
assert "query" in defn.parameters["required"]
class TestSearchToolSpecSnapshot:
"""Frozen snapshot of SearchTool's LLM-facing contract.
Any silent change to the tool name, description, or parameter schema
that reaches the LLM invalidates invariants that callers rely on —
for example, an attempt to expose internal SearchTool state (such as
a Fix-2-style expansion flag) as an LLM-visible parameter would
break the "one LLM-facing tool call per execute() invocation"
iteration-budget guarantee. These assertions fail loudly if the
contract drifts.
"""
def test_tool_name(self):
assert SearchTool.name == "search_documents"
def test_tool_description(self):
assert SearchTool.description == (
"Search the technical documentation corpus for relevant passages. "
"Returns the most relevant document chunks with source attribution."
)
def test_tool_parameters_schema(self):
params = SearchTool.parameters
assert params["type"] == "object"
assert set(params["properties"].keys()) == {"query", "top_k"}
assert params["properties"]["query"]["type"] == "string"
assert params["properties"]["top_k"]["type"] == "integer"
assert params["required"] == ["query"]
# --- Refusal gate tests ---
class TestRefusalGate:
"""Tests for grounded refusal based on retrieval score threshold."""
def _make_results(self, scores: list[float]) -> list[MockSearchResult]:
return [
MockSearchResult(
chunk=MockChunk(content=f"Content {i}", source=f"doc_{i}.md"),
score=s,
)
for i, s in enumerate(scores)
]
@pytest.mark.asyncio
async def test_refusal_out_of_scope(self):
"""Low-scoring results below threshold trigger refusal."""
retriever = MockRetriever(results=self._make_results([0.005, 0.003]))
tool = SearchTool(retriever=retriever, refusal_threshold=0.02)
result = await tool.execute(query="how to cook pasta")
assert result.success is True
assert "No relevant documents found" in result.result
assert result.metadata["refused"] is True
assert result.metadata["sources"] == []
@pytest.mark.asyncio
async def test_no_refusal_in_scope(self):
"""High-scoring results above threshold proceed normally."""
retriever = MockRetriever(results=self._make_results([0.03, 0.025]))
tool = SearchTool(retriever=retriever, refusal_threshold=0.02)
result = await tool.execute(query="FastAPI authentication")
assert result.success is True
assert "No relevant documents found" not in result.result
assert result.metadata.get("refused") is None
assert len(result.metadata["sources"]) == 2
@pytest.mark.asyncio
async def test_refusal_metadata(self):
"""Refused response includes max_score and refused=True."""
retriever = MockRetriever(results=self._make_results([0.01, 0.008]))
tool = SearchTool(retriever=retriever, refusal_threshold=0.02)
result = await tool.execute(query="unrelated topic")
assert result.metadata["max_score"] == 0.01
assert result.metadata["refused"] is True
@pytest.mark.asyncio
async def test_threshold_zero_disables(self):
"""threshold=0.0 (default) never refuses, preserving V1 behavior."""
retriever = MockRetriever(results=self._make_results([0.001]))
tool = SearchTool(retriever=retriever, refusal_threshold=0.0)
result = await tool.execute(query="anything")
assert result.success is True
assert "No relevant documents found" not in result.result
assert result.metadata.get("refused") is None
@pytest.mark.asyncio
async def test_threshold_configurable(self):
"""Different threshold values change refusal behavior."""
results = self._make_results([0.015, 0.012])
retriever = MockRetriever(results=results)
# With threshold=0.01 -> passes (max_score 0.015 > 0.01)
tool_low = SearchTool(retriever=retriever, refusal_threshold=0.01)
result_low = await tool_low.execute(query="test")
assert result_low.metadata.get("refused") is None
# With threshold=0.02 -> refuses (max_score 0.015 < 0.02)
tool_high = SearchTool(retriever=retriever, refusal_threshold=0.02)
result_high = await tool_high.execute(query="test")
assert result_high.metadata["refused"] is True
@pytest.mark.asyncio
async def test_max_score_in_normal_metadata(self):
"""Non-refused responses also include max_score in metadata."""
retriever = MockRetriever(results=self._make_results([0.05, 0.03]))
tool = SearchTool(retriever=retriever, refusal_threshold=0.02)
result = await tool.execute(query="test")
assert result.metadata["max_score"] == 0.05
|