Nomearod Claude Opus 4.6 (1M context) commited on
Commit
36a9ab7
·
1 Parent(s): 9d976db

feat: Day 2 — tool system with registry, calculator, and search

Browse files

- Tool ABC with definition() → ToolDefinition for provider format_tools()
- ToolRegistry: dict-based register/get/execute/get_definitions
- CalculatorTool: simpleeval with try/except safety wrapper
- SearchTool: numbered passage format [1] (filename.md): content
with source deduplication in metadata
- SearchTool uses Protocol for retriever dependency (decoupled from rag/)
- 15 new tests: registry CRUD, unknown tool handling, calculator
valid/invalid/dangerous expressions, search formatting/empty/dedup

38 total tests, lint + mypy clean.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

agent_bench/tools/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Tool system: base interface, registry, and built-in tools."""
agent_bench/tools/base.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base tool interface and output model."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+ from agent_bench.core.types import ToolDefinition
10
+
11
+
12
+ class ToolOutput(BaseModel):
13
+ success: bool
14
+ result: str
15
+ metadata: dict = Field(default_factory=dict)
16
+
17
+
18
+ class Tool(ABC):
19
+ """Abstract base for all tools the agent can invoke."""
20
+
21
+ name: str
22
+ description: str
23
+ parameters: dict # JSON Schema for the tool's arguments
24
+
25
+ @abstractmethod
26
+ async def execute(self, **kwargs: object) -> ToolOutput:
27
+ """Execute the tool with the given arguments."""
28
+ ...
29
+
30
+ def definition(self) -> ToolDefinition:
31
+ """Return a ToolDefinition for provider format_tools()."""
32
+ return ToolDefinition(
33
+ name=self.name,
34
+ description=self.description,
35
+ parameters=self.parameters,
36
+ )
agent_bench/tools/calculator.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calculator tool: safe math expression evaluation via simpleeval."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from simpleeval import simple_eval
6
+
7
+ from agent_bench.tools.base import Tool, ToolOutput
8
+
9
+
10
+ class CalculatorTool(Tool):
11
+ """Evaluate mathematical expressions safely."""
12
+
13
+ name = "calculator"
14
+ description = "Evaluate a mathematical expression. Supports +, -, *, /, **, %, and parentheses."
15
+ parameters = {
16
+ "type": "object",
17
+ "properties": {
18
+ "expression": {
19
+ "type": "string",
20
+ "description": "The mathematical expression to evaluate, e.g. '2 + 3 * 4'",
21
+ },
22
+ },
23
+ "required": ["expression"],
24
+ }
25
+
26
+ async def execute(self, **kwargs: object) -> ToolOutput:
27
+ expression = str(kwargs.get("expression", ""))
28
+ if not expression:
29
+ return ToolOutput(success=False, result="No expression provided")
30
+ try:
31
+ result = simple_eval(expression)
32
+ return ToolOutput(success=True, result=str(result))
33
+ except Exception:
34
+ return ToolOutput(success=False, result=f"Could not evaluate: {expression}")
agent_bench/tools/registry.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tool registry: register, retrieve, and dispatch tools by name."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from agent_bench.core.types import ToolDefinition
6
+ from agent_bench.tools.base import Tool, ToolOutput
7
+
8
+
9
+ class ToolRegistry:
10
+ """Dict-based tool registry."""
11
+
12
+ def __init__(self) -> None:
13
+ self._tools: dict[str, Tool] = {}
14
+
15
+ def register(self, tool: Tool) -> None:
16
+ """Register a tool by its name."""
17
+ self._tools[tool.name] = tool
18
+
19
+ def get(self, name: str) -> Tool | None:
20
+ """Retrieve a tool by name, or None if not found."""
21
+ return self._tools.get(name)
22
+
23
+ def get_definitions(self) -> list[ToolDefinition]:
24
+ """Return ToolDefinitions for all registered tools."""
25
+ return [tool.definition() for tool in self._tools.values()]
26
+
27
+ async def execute(self, name: str, **kwargs: object) -> ToolOutput:
28
+ """Execute a tool by name. Returns failure output for unknown tools."""
29
+ tool = self._tools.get(name)
30
+ if tool is None:
31
+ return ToolOutput(success=False, result=f"Unknown tool: {name}")
32
+ return await tool.execute(**kwargs)
agent_bench/tools/search.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Search tool: RAG retrieval over the document corpus."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Protocol
6
+
7
+ from agent_bench.tools.base import Tool, ToolOutput
8
+
9
+
10
+ class SearchResult(Protocol):
11
+ """Protocol for retriever search results (defined fully in rag.store)."""
12
+
13
+ @property
14
+ def chunk(self) -> object: ...
15
+
16
+ @property
17
+ def score(self) -> float: ...
18
+
19
+
20
+ class Retriever(Protocol):
21
+ """Protocol for the retriever dependency (defined fully in rag.retriever)."""
22
+
23
+ async def search(self, query: str, top_k: int = 5) -> list: ...
24
+
25
+
26
+ class SearchTool(Tool):
27
+ """Search the document corpus and return relevant passages."""
28
+
29
+ name = "search_documents"
30
+ description = (
31
+ "Search the technical documentation corpus for relevant passages. "
32
+ "Returns the most relevant document chunks with source attribution."
33
+ )
34
+ parameters = {
35
+ "type": "object",
36
+ "properties": {
37
+ "query": {
38
+ "type": "string",
39
+ "description": "The search query to find relevant documentation",
40
+ },
41
+ "top_k": {
42
+ "type": "integer",
43
+ "description": "Number of results to return (default 5)",
44
+ },
45
+ },
46
+ "required": ["query"],
47
+ }
48
+
49
+ def __init__(self, retriever: Retriever) -> None:
50
+ self._retriever = retriever
51
+
52
+ async def execute(self, **kwargs: object) -> ToolOutput:
53
+ query = str(kwargs.get("query", ""))
54
+ top_k_val = kwargs.get("top_k", 5)
55
+ top_k: int = top_k_val if isinstance(top_k_val, int) else int(str(top_k_val))
56
+
57
+ if not query:
58
+ return ToolOutput(success=False, result="No query provided")
59
+
60
+ results = await self._retriever.search(query, top_k=top_k)
61
+
62
+ if not results:
63
+ return ToolOutput(
64
+ success=True,
65
+ result="No relevant documents found.",
66
+ metadata={"sources": []},
67
+ )
68
+
69
+ # Format as numbered passages with filename attribution
70
+ lines = []
71
+ sources = []
72
+ for i, r in enumerate(results, 1):
73
+ source = r.chunk.source
74
+ content = r.chunk.content
75
+ lines.append(f"[{i}] ({source}): {content}")
76
+ if source not in sources:
77
+ sources.append(source)
78
+
79
+ return ToolOutput(
80
+ success=True,
81
+ result="\n\n".join(lines),
82
+ metadata={"sources": sources},
83
+ )
tests/test_tools.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for tool system: registry, calculator, search, and schema generation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ import pytest
8
+
9
+ from agent_bench.tools.calculator import CalculatorTool
10
+ from agent_bench.tools.registry import ToolRegistry
11
+ from agent_bench.tools.search import SearchTool
12
+
13
+ # --- Mock retriever for SearchTool tests ---
14
+
15
+
16
+ @dataclass
17
+ class MockChunk:
18
+ content: str
19
+ source: str
20
+
21
+
22
+ @dataclass
23
+ class MockSearchResult:
24
+ chunk: MockChunk
25
+ score: float
26
+
27
+
28
+ class MockRetriever:
29
+ """Fake retriever that returns canned results."""
30
+
31
+ def __init__(self, results: list[MockSearchResult] | None = None) -> None:
32
+ self._results = results or []
33
+
34
+ async def search(self, query: str, top_k: int = 5) -> list[MockSearchResult]:
35
+ return self._results[:top_k]
36
+
37
+
38
+ # --- Registry tests ---
39
+
40
+
41
+ class TestToolRegistry:
42
+ def test_register_and_retrieve(self):
43
+ registry = ToolRegistry()
44
+ tool = CalculatorTool()
45
+ registry.register(tool)
46
+ assert registry.get("calculator") is tool
47
+
48
+ def test_get_unknown_returns_none(self):
49
+ registry = ToolRegistry()
50
+ assert registry.get("nonexistent") is None
51
+
52
+ @pytest.mark.asyncio
53
+ async def test_execute_unknown_tool(self):
54
+ registry = ToolRegistry()
55
+ result = await registry.execute("nonexistent", query="test")
56
+ assert result.success is False
57
+ assert "Unknown tool: nonexistent" in result.result
58
+
59
+ def test_get_definitions(self):
60
+ registry = ToolRegistry()
61
+ registry.register(CalculatorTool())
62
+ registry.register(SearchTool(retriever=MockRetriever()))
63
+ defs = registry.get_definitions()
64
+ assert len(defs) == 2
65
+ names = {d.name for d in defs}
66
+ assert names == {"calculator", "search_documents"}
67
+
68
+
69
+ # --- Calculator tests ---
70
+
71
+
72
+ class TestCalculatorTool:
73
+ @pytest.mark.asyncio
74
+ async def test_valid_expression(self):
75
+ calc = CalculatorTool()
76
+ result = await calc.execute(expression="2 + 3 * 4")
77
+ assert result.success is True
78
+ assert result.result == "14"
79
+
80
+ @pytest.mark.asyncio
81
+ async def test_float_expression(self):
82
+ calc = CalculatorTool()
83
+ result = await calc.execute(expression="10 / 3")
84
+ assert result.success is True
85
+ assert float(result.result) == pytest.approx(3.333333, rel=1e-4)
86
+
87
+ @pytest.mark.asyncio
88
+ async def test_rejects_import(self):
89
+ calc = CalculatorTool()
90
+ result = await calc.execute(expression="__import__('os').system('ls')")
91
+ assert result.success is False
92
+ assert "Could not evaluate" in result.result
93
+
94
+ @pytest.mark.asyncio
95
+ async def test_rejects_exec(self):
96
+ calc = CalculatorTool()
97
+ result = await calc.execute(expression="exec('print(1)')")
98
+ assert result.success is False
99
+
100
+ @pytest.mark.asyncio
101
+ async def test_empty_expression(self):
102
+ calc = CalculatorTool()
103
+ result = await calc.execute(expression="")
104
+ assert result.success is False
105
+
106
+ def test_definition_produces_valid_schema(self):
107
+ calc = CalculatorTool()
108
+ defn = calc.definition()
109
+ assert defn.name == "calculator"
110
+ assert defn.parameters["type"] == "object"
111
+ assert "expression" in defn.parameters["properties"]
112
+ assert "expression" in defn.parameters["required"]
113
+
114
+
115
+ # --- Search tool tests ---
116
+
117
+
118
+ class TestSearchTool:
119
+ @pytest.mark.asyncio
120
+ async def test_returns_formatted_results(self):
121
+ retriever = MockRetriever(
122
+ results=[
123
+ MockSearchResult(
124
+ chunk=MockChunk(
125
+ content="Path parameters are defined using curly braces.",
126
+ source="fastapi_path_params.md",
127
+ ),
128
+ score=0.95,
129
+ ),
130
+ MockSearchResult(
131
+ chunk=MockChunk(
132
+ content="Query parameters are automatically parsed.",
133
+ source="fastapi_query_params.md",
134
+ ),
135
+ score=0.82,
136
+ ),
137
+ ]
138
+ )
139
+ tool = SearchTool(retriever=retriever)
140
+ result = await tool.execute(query="path parameters")
141
+
142
+ assert result.success is True
143
+ assert "[1] (fastapi_path_params.md):" in result.result
144
+ assert "[2] (fastapi_query_params.md):" in result.result
145
+ assert result.metadata["sources"] == [
146
+ "fastapi_path_params.md",
147
+ "fastapi_query_params.md",
148
+ ]
149
+
150
+ @pytest.mark.asyncio
151
+ async def test_empty_results(self):
152
+ tool = SearchTool(retriever=MockRetriever(results=[]))
153
+ result = await tool.execute(query="nonexistent topic")
154
+ assert result.success is True
155
+ assert "No relevant documents found" in result.result
156
+ assert result.metadata["sources"] == []
157
+
158
+ @pytest.mark.asyncio
159
+ async def test_deduplicates_sources(self):
160
+ retriever = MockRetriever(
161
+ results=[
162
+ MockSearchResult(
163
+ chunk=MockChunk(content="Chunk 1", source="same_file.md"),
164
+ score=0.9,
165
+ ),
166
+ MockSearchResult(
167
+ chunk=MockChunk(content="Chunk 2", source="same_file.md"),
168
+ score=0.8,
169
+ ),
170
+ ]
171
+ )
172
+ tool = SearchTool(retriever=retriever)
173
+ result = await tool.execute(query="test")
174
+ assert result.metadata["sources"] == ["same_file.md"]
175
+
176
+ @pytest.mark.asyncio
177
+ async def test_empty_query(self):
178
+ tool = SearchTool(retriever=MockRetriever())
179
+ result = await tool.execute(query="")
180
+ assert result.success is False
181
+
182
+ def test_definition_produces_valid_schema(self):
183
+ tool = SearchTool(retriever=MockRetriever())
184
+ defn = tool.definition()
185
+ assert defn.name == "search_documents"
186
+ assert defn.parameters["type"] == "object"
187
+ assert "query" in defn.parameters["properties"]
188
+ assert "query" in defn.parameters["required"]