File size: 8,572 Bytes
447e09c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d16fd9
 
 
 
 
 
 
 
 
 
447e09c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fc43ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""Tests for the agent orchestrator."""

from __future__ import annotations

import pytest

from agent_bench.agents.orchestrator import AgentResponse, Orchestrator
from agent_bench.core.provider import MockProvider
from agent_bench.core.types import (
    CompletionResponse,
    Role,
    TokenUsage,
    ToolCall,
)
from agent_bench.tools.base import Tool, ToolOutput
from agent_bench.tools.calculator import CalculatorTool
from agent_bench.tools.registry import ToolRegistry

# --- Helpers ---


class FakeSearchTool(Tool):
    """Deterministic search tool for orchestrator tests."""

    name = "search_documents"
    description = "Search docs"
    parameters = {
        "type": "object",
        "properties": {"query": {"type": "string"}},
        "required": ["query"],
    }

    async def execute(self, **kwargs: object) -> ToolOutput:
        return ToolOutput(
            success=True,
            result="[1] (fastapi_path_params.md): Path parameters use curly braces.",
            metadata={
                "sources": ["fastapi_path_params.md"],
                "ranked_sources": ["fastapi_path_params.md"],
                "source_chunks": ["Path parameters use curly braces."],
                "max_score": 0.85,
                "pre_rerank_count": 10,
                "chunks": [{"source": "fastapi_path_params.md", "score": 0.85,
                            "preview": "Path parameters use curly braces."}],
                "pii_redactions_count": 0,
            },
        )


class AlwaysToolCallProvider(MockProvider):
    """Provider that always returns tool_calls, never a final answer."""

    async def complete(self, messages, tools=None, temperature=0.0, max_tokens=1024):
        self.call_count += 1
        if tools is None:
            # Forced final call (no tools) — return text
            return CompletionResponse(
                content="Forced answer after max iterations.",
                tool_calls=[],
                usage=TokenUsage(input_tokens=100, output_tokens=20, estimated_cost_usd=0.0001),
                provider="mock",
                model="mock-1",
                latency_ms=1.0,
            )
        return CompletionResponse(
            content="",
            tool_calls=[
                ToolCall(
                    id=f"call_{self.call_count}",
                    name="search_documents",
                    arguments={"query": "test"},
                )
            ],
            usage=TokenUsage(input_tokens=100, output_tokens=20, estimated_cost_usd=0.0001),
            provider="mock",
            model="mock-1",
            latency_ms=1.0,
        )


class MultiSearchProvider(MockProvider):
    """Provider that searches twice (different queries), then answers."""

    async def complete(self, messages, tools=None, temperature=0.0, max_tokens=1024):
        self.call_count += 1
        tool_results = [m for m in messages if m.role == Role.TOOL]

        if tools and len(tool_results) == 0:
            return CompletionResponse(
                content="",
                tool_calls=[
                    ToolCall(id="call_1", name="search_documents", arguments={"query": "first"})
                ],
                usage=TokenUsage(input_tokens=100, output_tokens=20, estimated_cost_usd=0.0001),
                provider="mock",
                model="mock-1",
                latency_ms=1.0,
            )
        elif tools and len(tool_results) == 1:
            return CompletionResponse(
                content="",
                tool_calls=[
                    ToolCall(id="call_2", name="search_documents", arguments={"query": "second"})
                ],
                usage=TokenUsage(input_tokens=150, output_tokens=25, estimated_cost_usd=0.0002),
                provider="mock",
                model="mock-1",
                latency_ms=1.0,
            )
        else:
            return CompletionResponse(
                content="Answer from two searches. [source: fastapi_path_params.md]",
                tool_calls=[],
                usage=TokenUsage(input_tokens=200, output_tokens=50, estimated_cost_usd=0.0003),
                provider="mock",
                model="mock-1",
                latency_ms=2.0,
            )


def _make_registry() -> ToolRegistry:
    registry = ToolRegistry()
    registry.register(FakeSearchTool())
    registry.register(CalculatorTool())
    return registry


SYSTEM_PROMPT = "You are a helpful assistant."


# --- Tests ---


class TestOrchestrator:
    @pytest.mark.asyncio
    async def test_produces_agent_response_with_all_fields(self):
        """Orchestrator returns AgentResponse with all required fields."""
        orchestrator = Orchestrator(
            provider=MockProvider(), registry=_make_registry(), max_iterations=3
        )
        response = await orchestrator.run("How do path params work?", SYSTEM_PROMPT)

        assert isinstance(response, AgentResponse)
        assert len(response.answer) > 0
        assert response.iterations >= 1
        assert response.usage.input_tokens > 0
        assert response.usage.output_tokens > 0
        assert response.latency_ms > 0
        assert isinstance(response.sources, list)
        assert isinstance(response.tools_used, list)

    @pytest.mark.asyncio
    async def test_respects_max_iterations(self):
        """When provider always returns tool_calls, orchestrator stops at max_iterations."""
        provider = AlwaysToolCallProvider()
        orchestrator = Orchestrator(provider=provider, registry=_make_registry(), max_iterations=2)
        response = await orchestrator.run("test question", SYSTEM_PROMPT)

        # 2 iterations of tool calls + 1 forced final call = 3 provider calls total
        assert provider.call_count == 3
        assert response.iterations == 2
        assert response.answer == "Forced answer after max iterations."

    @pytest.mark.asyncio
    async def test_accumulates_sources_from_multiple_searches(self):
        """Sources from multiple search calls are accumulated and deduplicated."""
        orchestrator = Orchestrator(
            provider=MultiSearchProvider(), registry=_make_registry(), max_iterations=3
        )
        response = await orchestrator.run("multi search question", SYSTEM_PROMPT)

        # FakeSearchTool always returns fastapi_path_params.md
        assert len(response.sources) == 1  # deduplicated
        assert response.sources[0].source == "fastapi_path_params.md"
        assert response.tools_used.count("search_documents") == 2
        # Token usage accumulated across 3 provider calls
        assert response.usage.input_tokens == 100 + 150 + 200
        assert response.usage.output_tokens == 20 + 25 + 50

    @pytest.mark.asyncio
    async def test_deterministic_output(self):
        """Fixed question + MockProvider → exact expected answer."""
        orchestrator = Orchestrator(
            provider=MockProvider(), registry=_make_registry(), max_iterations=3
        )
        response = await orchestrator.run("How do path params work?", SYSTEM_PROMPT)

        # MockProvider: first call returns tool_calls for search_documents,
        # second call (with tool results) returns the canned answer
        assert "path parameters" in response.answer.lower()
        assert "[source: fastapi_path_params.md]" in response.answer
        assert "search_documents" in response.tools_used
        assert response.iterations == 2


class TestOrchestratorIntegration:
    """Integration test using real SearchTool + Retriever + HybridStore."""

    @pytest.mark.asyncio
    async def test_real_rag_path(self, test_retriever):
        """Orchestrator with MockProvider + real SearchTool/Retriever returns RAG results."""
        from agent_bench.tools.search import SearchTool

        registry = ToolRegistry()
        registry.register(SearchTool(retriever=test_retriever))
        registry.register(CalculatorTool())

        orchestrator = Orchestrator(provider=MockProvider(), registry=registry, max_iterations=3)
        response = await orchestrator.run(
            "How do path params work?", SYSTEM_PROMPT, top_k=3, strategy="hybrid"
        )

        # MockProvider drives the loop, but the real SearchTool executes
        # against the real Retriever/HybridStore and returns real chunks
        assert isinstance(response, AgentResponse)
        assert len(response.answer) > 0
        assert "search_documents" in response.tools_used
        # Sources come from the real store (sample_chunks in conftest)
        assert len(response.sources) > 0