File size: 12,338 Bytes
36a9ab7
 
 
 
27c2e17
 
36a9ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
27c2e17
36a9ab7
 
 
 
 
 
 
 
c5573d3
 
 
 
 
 
 
36a9ab7
 
 
 
 
 
3fc43ec
 
c5573d3
 
36a9ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8766b82
 
 
 
 
 
 
 
 
36a9ab7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8766b82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36a9ab7
 
 
 
 
 
 
 
 
 
 
 
 
c410788
 
27c2e17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c410788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
"""Tests for tool system: registry, calculator, search, and schema generation."""

from __future__ import annotations

import uuid
from dataclasses import dataclass, field

import pytest

from agent_bench.tools.calculator import CalculatorTool
from agent_bench.tools.registry import ToolRegistry
from agent_bench.tools.search import SearchTool

# --- Mock retriever for SearchTool tests ---


@dataclass
class MockChunk:
    content: str
    source: str
    id: str = field(default_factory=lambda: uuid.uuid4().hex[:16])


@dataclass
class MockSearchResult:
    chunk: MockChunk
    score: float


class MockRetrievalResult:
    """Mimics RetrievalResult for test mocks."""
    def __init__(self, results: list, pre_rerank_count: int = 0) -> None:
        self.results = results
        self.pre_rerank_count = pre_rerank_count


class MockRetriever:
    """Fake retriever that returns canned results."""

    def __init__(self, results: list[MockSearchResult] | None = None) -> None:
        self._results = results or []

    async def search(
        self, query: str, top_k: int = 5, strategy: str | None = None
    ) -> MockRetrievalResult:
        return MockRetrievalResult(results=self._results[:top_k])


# --- Registry tests ---


class TestToolRegistry:
    def test_register_and_retrieve(self):
        registry = ToolRegistry()
        tool = CalculatorTool()
        registry.register(tool)
        assert registry.get("calculator") is tool

    def test_get_unknown_returns_none(self):
        registry = ToolRegistry()
        assert registry.get("nonexistent") is None

    @pytest.mark.asyncio
    async def test_execute_unknown_tool(self):
        registry = ToolRegistry()
        result = await registry.execute("nonexistent", query="test")
        assert result.success is False
        assert "Unknown tool: nonexistent" in result.result

    @pytest.mark.asyncio
    async def test_execute_registered_tool(self):
        """Registry dispatches to a registered tool and returns its output."""
        registry = ToolRegistry()
        registry.register(CalculatorTool())
        result = await registry.execute("calculator", expression="1 + 2")
        assert result.success is True
        assert result.result == "3"

    def test_get_definitions(self):
        registry = ToolRegistry()
        registry.register(CalculatorTool())
        registry.register(SearchTool(retriever=MockRetriever()))
        defs = registry.get_definitions()
        assert len(defs) == 2
        names = {d.name for d in defs}
        assert names == {"calculator", "search_documents"}


# --- Calculator tests ---


class TestCalculatorTool:
    @pytest.mark.asyncio
    async def test_valid_expression(self):
        calc = CalculatorTool()
        result = await calc.execute(expression="2 + 3 * 4")
        assert result.success is True
        assert result.result == "14"

    @pytest.mark.asyncio
    async def test_float_expression(self):
        calc = CalculatorTool()
        result = await calc.execute(expression="10 / 3")
        assert result.success is True
        assert float(result.result) == pytest.approx(3.333333, rel=1e-4)

    @pytest.mark.asyncio
    async def test_rejects_import(self):
        calc = CalculatorTool()
        result = await calc.execute(expression="__import__('os').system('ls')")
        assert result.success is False
        assert "Could not evaluate" in result.result

    @pytest.mark.asyncio
    async def test_rejects_exec(self):
        calc = CalculatorTool()
        result = await calc.execute(expression="exec('print(1)')")
        assert result.success is False

    @pytest.mark.asyncio
    async def test_empty_expression(self):
        calc = CalculatorTool()
        result = await calc.execute(expression="")
        assert result.success is False

    def test_definition_produces_valid_schema(self):
        calc = CalculatorTool()
        defn = calc.definition()
        assert defn.name == "calculator"
        assert defn.parameters["type"] == "object"
        assert "expression" in defn.parameters["properties"]
        assert "expression" in defn.parameters["required"]


# --- Search tool tests ---


class TestSearchTool:
    @pytest.mark.asyncio
    async def test_returns_formatted_results(self):
        retriever = MockRetriever(
            results=[
                MockSearchResult(
                    chunk=MockChunk(
                        content="Path parameters are defined using curly braces.",
                        source="fastapi_path_params.md",
                    ),
                    score=0.95,
                ),
                MockSearchResult(
                    chunk=MockChunk(
                        content="Query parameters are automatically parsed.",
                        source="fastapi_query_params.md",
                    ),
                    score=0.82,
                ),
            ]
        )
        tool = SearchTool(retriever=retriever)
        result = await tool.execute(query="path parameters")

        assert result.success is True
        assert "[1] (fastapi_path_params.md):" in result.result
        assert "[2] (fastapi_query_params.md):" in result.result
        assert result.metadata["sources"] == [
            "fastapi_path_params.md",
            "fastapi_query_params.md",
        ]

    @pytest.mark.asyncio
    async def test_empty_results(self):
        tool = SearchTool(retriever=MockRetriever(results=[]))
        result = await tool.execute(query="nonexistent topic")
        assert result.success is True
        assert "No relevant documents found" in result.result
        assert result.metadata["sources"] == []

    @pytest.mark.asyncio
    async def test_deduplicates_sources(self):
        retriever = MockRetriever(
            results=[
                MockSearchResult(
                    chunk=MockChunk(content="Chunk 1", source="same_file.md"),
                    score=0.9,
                ),
                MockSearchResult(
                    chunk=MockChunk(content="Chunk 2", source="same_file.md"),
                    score=0.8,
                ),
            ]
        )
        tool = SearchTool(retriever=retriever)
        result = await tool.execute(query="test")
        assert result.metadata["sources"] == ["same_file.md"]

    @pytest.mark.asyncio
    async def test_malformed_top_k_falls_back_to_default(self):
        """Model-generated bad top_k (e.g. 'five') should not crash."""
        retriever = MockRetriever(
            results=[
                MockSearchResult(
                    chunk=MockChunk(content="Some content", source="test.md"),
                    score=0.9,
                ),
            ]
        )
        tool = SearchTool(retriever=retriever)
        result = await tool.execute(query="test", top_k="five")
        assert result.success is True

    @pytest.mark.asyncio
    async def test_empty_query(self):
        tool = SearchTool(retriever=MockRetriever())
        result = await tool.execute(query="")
        assert result.success is False

    def test_definition_produces_valid_schema(self):
        tool = SearchTool(retriever=MockRetriever())
        defn = tool.definition()
        assert defn.name == "search_documents"
        assert defn.parameters["type"] == "object"
        assert "query" in defn.parameters["properties"]
        assert "query" in defn.parameters["required"]


class TestSearchToolSpecSnapshot:
    """Frozen snapshot of SearchTool's LLM-facing contract.

    Any silent change to the tool name, description, or parameter schema
    that reaches the LLM invalidates invariants that callers rely on —
    for example, an attempt to expose internal SearchTool state (such as
    a Fix-2-style expansion flag) as an LLM-visible parameter would
    break the "one LLM-facing tool call per execute() invocation"
    iteration-budget guarantee. These assertions fail loudly if the
    contract drifts.
    """

    def test_tool_name(self):
        assert SearchTool.name == "search_documents"

    def test_tool_description(self):
        assert SearchTool.description == (
            "Search the technical documentation corpus for relevant passages. "
            "Returns the most relevant document chunks with source attribution."
        )

    def test_tool_parameters_schema(self):
        params = SearchTool.parameters
        assert params["type"] == "object"
        assert set(params["properties"].keys()) == {"query", "top_k"}
        assert params["properties"]["query"]["type"] == "string"
        assert params["properties"]["top_k"]["type"] == "integer"
        assert params["required"] == ["query"]


# --- Refusal gate tests ---


class TestRefusalGate:
    """Tests for grounded refusal based on retrieval score threshold."""

    def _make_results(self, scores: list[float]) -> list[MockSearchResult]:
        return [
            MockSearchResult(
                chunk=MockChunk(content=f"Content {i}", source=f"doc_{i}.md"),
                score=s,
            )
            for i, s in enumerate(scores)
        ]

    @pytest.mark.asyncio
    async def test_refusal_out_of_scope(self):
        """Low-scoring results below threshold trigger refusal."""
        retriever = MockRetriever(results=self._make_results([0.005, 0.003]))
        tool = SearchTool(retriever=retriever, refusal_threshold=0.02)
        result = await tool.execute(query="how to cook pasta")

        assert result.success is True
        assert "No relevant documents found" in result.result
        assert result.metadata["refused"] is True
        assert result.metadata["sources"] == []

    @pytest.mark.asyncio
    async def test_no_refusal_in_scope(self):
        """High-scoring results above threshold proceed normally."""
        retriever = MockRetriever(results=self._make_results([0.03, 0.025]))
        tool = SearchTool(retriever=retriever, refusal_threshold=0.02)
        result = await tool.execute(query="FastAPI authentication")

        assert result.success is True
        assert "No relevant documents found" not in result.result
        assert result.metadata.get("refused") is None
        assert len(result.metadata["sources"]) == 2

    @pytest.mark.asyncio
    async def test_refusal_metadata(self):
        """Refused response includes max_score and refused=True."""
        retriever = MockRetriever(results=self._make_results([0.01, 0.008]))
        tool = SearchTool(retriever=retriever, refusal_threshold=0.02)
        result = await tool.execute(query="unrelated topic")

        assert result.metadata["max_score"] == 0.01
        assert result.metadata["refused"] is True

    @pytest.mark.asyncio
    async def test_threshold_zero_disables(self):
        """threshold=0.0 (default) never refuses, preserving V1 behavior."""
        retriever = MockRetriever(results=self._make_results([0.001]))
        tool = SearchTool(retriever=retriever, refusal_threshold=0.0)
        result = await tool.execute(query="anything")

        assert result.success is True
        assert "No relevant documents found" not in result.result
        assert result.metadata.get("refused") is None

    @pytest.mark.asyncio
    async def test_threshold_configurable(self):
        """Different threshold values change refusal behavior."""
        results = self._make_results([0.015, 0.012])
        retriever = MockRetriever(results=results)

        # With threshold=0.01 -> passes (max_score 0.015 > 0.01)
        tool_low = SearchTool(retriever=retriever, refusal_threshold=0.01)
        result_low = await tool_low.execute(query="test")
        assert result_low.metadata.get("refused") is None

        # With threshold=0.02 -> refuses (max_score 0.015 < 0.02)
        tool_high = SearchTool(retriever=retriever, refusal_threshold=0.02)
        result_high = await tool_high.execute(query="test")
        assert result_high.metadata["refused"] is True

    @pytest.mark.asyncio
    async def test_max_score_in_normal_metadata(self):
        """Non-refused responses also include max_score in metadata."""
        retriever = MockRetriever(results=self._make_results([0.05, 0.03]))
        tool = SearchTool(retriever=retriever, refusal_threshold=0.02)
        result = await tool.execute(query="test")

        assert result.metadata["max_score"] == 0.05