Spaces:
Running
Running
| """Tests for tool system: registry, calculator, search, and schema generation.""" | |
| from __future__ import annotations | |
| import uuid | |
| from dataclasses import dataclass, field | |
| import pytest | |
| from agent_bench.tools.calculator import CalculatorTool | |
| from agent_bench.tools.registry import ToolRegistry | |
| from agent_bench.tools.search import SearchTool | |
| # --- Mock retriever for SearchTool tests --- | |
| class MockChunk: | |
| content: str | |
| source: str | |
| id: str = field(default_factory=lambda: uuid.uuid4().hex[:16]) | |
| class MockSearchResult: | |
| chunk: MockChunk | |
| score: float | |
| class MockRetrievalResult: | |
| """Mimics RetrievalResult for test mocks.""" | |
| def __init__(self, results: list, pre_rerank_count: int = 0) -> None: | |
| self.results = results | |
| self.pre_rerank_count = pre_rerank_count | |
| class MockRetriever: | |
| """Fake retriever that returns canned results.""" | |
| def __init__(self, results: list[MockSearchResult] | None = None) -> None: | |
| self._results = results or [] | |
| async def search( | |
| self, query: str, top_k: int = 5, strategy: str | None = None | |
| ) -> MockRetrievalResult: | |
| return MockRetrievalResult(results=self._results[:top_k]) | |
| # --- Registry tests --- | |
| class TestToolRegistry: | |
| def test_register_and_retrieve(self): | |
| registry = ToolRegistry() | |
| tool = CalculatorTool() | |
| registry.register(tool) | |
| assert registry.get("calculator") is tool | |
| def test_get_unknown_returns_none(self): | |
| registry = ToolRegistry() | |
| assert registry.get("nonexistent") is None | |
| async def test_execute_unknown_tool(self): | |
| registry = ToolRegistry() | |
| result = await registry.execute("nonexistent", query="test") | |
| assert result.success is False | |
| assert "Unknown tool: nonexistent" in result.result | |
| async def test_execute_registered_tool(self): | |
| """Registry dispatches to a registered tool and returns its output.""" | |
| registry = ToolRegistry() | |
| registry.register(CalculatorTool()) | |
| result = await registry.execute("calculator", expression="1 + 2") | |
| assert result.success is True | |
| assert result.result == "3" | |
| def test_get_definitions(self): | |
| registry = ToolRegistry() | |
| registry.register(CalculatorTool()) | |
| registry.register(SearchTool(retriever=MockRetriever())) | |
| defs = registry.get_definitions() | |
| assert len(defs) == 2 | |
| names = {d.name for d in defs} | |
| assert names == {"calculator", "search_documents"} | |
| # --- Calculator tests --- | |
| class TestCalculatorTool: | |
| async def test_valid_expression(self): | |
| calc = CalculatorTool() | |
| result = await calc.execute(expression="2 + 3 * 4") | |
| assert result.success is True | |
| assert result.result == "14" | |
| async def test_float_expression(self): | |
| calc = CalculatorTool() | |
| result = await calc.execute(expression="10 / 3") | |
| assert result.success is True | |
| assert float(result.result) == pytest.approx(3.333333, rel=1e-4) | |
| async def test_rejects_import(self): | |
| calc = CalculatorTool() | |
| result = await calc.execute(expression="__import__('os').system('ls')") | |
| assert result.success is False | |
| assert "Could not evaluate" in result.result | |
| async def test_rejects_exec(self): | |
| calc = CalculatorTool() | |
| result = await calc.execute(expression="exec('print(1)')") | |
| assert result.success is False | |
| async def test_empty_expression(self): | |
| calc = CalculatorTool() | |
| result = await calc.execute(expression="") | |
| assert result.success is False | |
| def test_definition_produces_valid_schema(self): | |
| calc = CalculatorTool() | |
| defn = calc.definition() | |
| assert defn.name == "calculator" | |
| assert defn.parameters["type"] == "object" | |
| assert "expression" in defn.parameters["properties"] | |
| assert "expression" in defn.parameters["required"] | |
| # --- Search tool tests --- | |
| class TestSearchTool: | |
| async def test_returns_formatted_results(self): | |
| retriever = MockRetriever( | |
| results=[ | |
| MockSearchResult( | |
| chunk=MockChunk( | |
| content="Path parameters are defined using curly braces.", | |
| source="fastapi_path_params.md", | |
| ), | |
| score=0.95, | |
| ), | |
| MockSearchResult( | |
| chunk=MockChunk( | |
| content="Query parameters are automatically parsed.", | |
| source="fastapi_query_params.md", | |
| ), | |
| score=0.82, | |
| ), | |
| ] | |
| ) | |
| tool = SearchTool(retriever=retriever) | |
| result = await tool.execute(query="path parameters") | |
| assert result.success is True | |
| assert "[1] (fastapi_path_params.md):" in result.result | |
| assert "[2] (fastapi_query_params.md):" in result.result | |
| assert result.metadata["sources"] == [ | |
| "fastapi_path_params.md", | |
| "fastapi_query_params.md", | |
| ] | |
| async def test_empty_results(self): | |
| tool = SearchTool(retriever=MockRetriever(results=[])) | |
| result = await tool.execute(query="nonexistent topic") | |
| assert result.success is True | |
| assert "No relevant documents found" in result.result | |
| assert result.metadata["sources"] == [] | |
| async def test_deduplicates_sources(self): | |
| retriever = MockRetriever( | |
| results=[ | |
| MockSearchResult( | |
| chunk=MockChunk(content="Chunk 1", source="same_file.md"), | |
| score=0.9, | |
| ), | |
| MockSearchResult( | |
| chunk=MockChunk(content="Chunk 2", source="same_file.md"), | |
| score=0.8, | |
| ), | |
| ] | |
| ) | |
| tool = SearchTool(retriever=retriever) | |
| result = await tool.execute(query="test") | |
| assert result.metadata["sources"] == ["same_file.md"] | |
| async def test_malformed_top_k_falls_back_to_default(self): | |
| """Model-generated bad top_k (e.g. 'five') should not crash.""" | |
| retriever = MockRetriever( | |
| results=[ | |
| MockSearchResult( | |
| chunk=MockChunk(content="Some content", source="test.md"), | |
| score=0.9, | |
| ), | |
| ] | |
| ) | |
| tool = SearchTool(retriever=retriever) | |
| result = await tool.execute(query="test", top_k="five") | |
| assert result.success is True | |
| async def test_empty_query(self): | |
| tool = SearchTool(retriever=MockRetriever()) | |
| result = await tool.execute(query="") | |
| assert result.success is False | |
| def test_definition_produces_valid_schema(self): | |
| tool = SearchTool(retriever=MockRetriever()) | |
| defn = tool.definition() | |
| assert defn.name == "search_documents" | |
| assert defn.parameters["type"] == "object" | |
| assert "query" in defn.parameters["properties"] | |
| assert "query" in defn.parameters["required"] | |
| class TestSearchToolSpecSnapshot: | |
| """Frozen snapshot of SearchTool's LLM-facing contract. | |
| Any silent change to the tool name, description, or parameter schema | |
| that reaches the LLM invalidates invariants that callers rely on — | |
| for example, an attempt to expose internal SearchTool state (such as | |
| a Fix-2-style expansion flag) as an LLM-visible parameter would | |
| break the "one LLM-facing tool call per execute() invocation" | |
| iteration-budget guarantee. These assertions fail loudly if the | |
| contract drifts. | |
| """ | |
| def test_tool_name(self): | |
| assert SearchTool.name == "search_documents" | |
| def test_tool_description(self): | |
| assert SearchTool.description == ( | |
| "Search the technical documentation corpus for relevant passages. " | |
| "Returns the most relevant document chunks with source attribution." | |
| ) | |
| def test_tool_parameters_schema(self): | |
| params = SearchTool.parameters | |
| assert params["type"] == "object" | |
| assert set(params["properties"].keys()) == {"query", "top_k"} | |
| assert params["properties"]["query"]["type"] == "string" | |
| assert params["properties"]["top_k"]["type"] == "integer" | |
| assert params["required"] == ["query"] | |
| # --- Refusal gate tests --- | |
| class TestRefusalGate: | |
| """Tests for grounded refusal based on retrieval score threshold.""" | |
| def _make_results(self, scores: list[float]) -> list[MockSearchResult]: | |
| return [ | |
| MockSearchResult( | |
| chunk=MockChunk(content=f"Content {i}", source=f"doc_{i}.md"), | |
| score=s, | |
| ) | |
| for i, s in enumerate(scores) | |
| ] | |
| async def test_refusal_out_of_scope(self): | |
| """Low-scoring results below threshold trigger refusal.""" | |
| retriever = MockRetriever(results=self._make_results([0.005, 0.003])) | |
| tool = SearchTool(retriever=retriever, refusal_threshold=0.02) | |
| result = await tool.execute(query="how to cook pasta") | |
| assert result.success is True | |
| assert "No relevant documents found" in result.result | |
| assert result.metadata["refused"] is True | |
| assert result.metadata["sources"] == [] | |
| async def test_no_refusal_in_scope(self): | |
| """High-scoring results above threshold proceed normally.""" | |
| retriever = MockRetriever(results=self._make_results([0.03, 0.025])) | |
| tool = SearchTool(retriever=retriever, refusal_threshold=0.02) | |
| result = await tool.execute(query="FastAPI authentication") | |
| assert result.success is True | |
| assert "No relevant documents found" not in result.result | |
| assert result.metadata.get("refused") is None | |
| assert len(result.metadata["sources"]) == 2 | |
| async def test_refusal_metadata(self): | |
| """Refused response includes max_score and refused=True.""" | |
| retriever = MockRetriever(results=self._make_results([0.01, 0.008])) | |
| tool = SearchTool(retriever=retriever, refusal_threshold=0.02) | |
| result = await tool.execute(query="unrelated topic") | |
| assert result.metadata["max_score"] == 0.01 | |
| assert result.metadata["refused"] is True | |
| async def test_threshold_zero_disables(self): | |
| """threshold=0.0 (default) never refuses, preserving V1 behavior.""" | |
| retriever = MockRetriever(results=self._make_results([0.001])) | |
| tool = SearchTool(retriever=retriever, refusal_threshold=0.0) | |
| result = await tool.execute(query="anything") | |
| assert result.success is True | |
| assert "No relevant documents found" not in result.result | |
| assert result.metadata.get("refused") is None | |
| async def test_threshold_configurable(self): | |
| """Different threshold values change refusal behavior.""" | |
| results = self._make_results([0.015, 0.012]) | |
| retriever = MockRetriever(results=results) | |
| # With threshold=0.01 -> passes (max_score 0.015 > 0.01) | |
| tool_low = SearchTool(retriever=retriever, refusal_threshold=0.01) | |
| result_low = await tool_low.execute(query="test") | |
| assert result_low.metadata.get("refused") is None | |
| # With threshold=0.02 -> refuses (max_score 0.015 < 0.02) | |
| tool_high = SearchTool(retriever=retriever, refusal_threshold=0.02) | |
| result_high = await tool_high.execute(query="test") | |
| assert result_high.metadata["refused"] is True | |
| async def test_max_score_in_normal_metadata(self): | |
| """Non-refused responses also include max_score in metadata.""" | |
| retriever = MockRetriever(results=self._make_results([0.05, 0.03])) | |
| tool = SearchTool(retriever=retriever, refusal_threshold=0.02) | |
| result = await tool.execute(query="test") | |
| assert result.metadata["max_score"] == 0.05 | |