Spaces:
Running
Running
| """Tests for the agent orchestrator.""" | |
| from __future__ import annotations | |
| import pytest | |
| from agent_bench.agents.orchestrator import AgentResponse, Orchestrator | |
| from agent_bench.core.provider import MockProvider | |
| from agent_bench.core.types import ( | |
| CompletionResponse, | |
| Role, | |
| TokenUsage, | |
| ToolCall, | |
| ) | |
| from agent_bench.tools.base import Tool, ToolOutput | |
| from agent_bench.tools.calculator import CalculatorTool | |
| from agent_bench.tools.registry import ToolRegistry | |
| # --- Helpers --- | |
| class FakeSearchTool(Tool): | |
| """Deterministic search tool for orchestrator tests.""" | |
| name = "search_documents" | |
| description = "Search docs" | |
| parameters = { | |
| "type": "object", | |
| "properties": {"query": {"type": "string"}}, | |
| "required": ["query"], | |
| } | |
| async def execute(self, **kwargs: object) -> ToolOutput: | |
| return ToolOutput( | |
| success=True, | |
| result="[1] (fastapi_path_params.md): Path parameters use curly braces.", | |
| metadata={ | |
| "sources": ["fastapi_path_params.md"], | |
| "ranked_sources": ["fastapi_path_params.md"], | |
| "source_chunks": ["Path parameters use curly braces."], | |
| "max_score": 0.85, | |
| "pre_rerank_count": 10, | |
| "chunks": [{"source": "fastapi_path_params.md", "score": 0.85, | |
| "preview": "Path parameters use curly braces."}], | |
| "pii_redactions_count": 0, | |
| }, | |
| ) | |
| class AlwaysToolCallProvider(MockProvider): | |
| """Provider that always returns tool_calls, never a final answer.""" | |
| async def complete(self, messages, tools=None, temperature=0.0, max_tokens=1024): | |
| self.call_count += 1 | |
| if tools is None: | |
| # Forced final call (no tools) — return text | |
| return CompletionResponse( | |
| content="Forced answer after max iterations.", | |
| tool_calls=[], | |
| usage=TokenUsage(input_tokens=100, output_tokens=20, estimated_cost_usd=0.0001), | |
| provider="mock", | |
| model="mock-1", | |
| latency_ms=1.0, | |
| ) | |
| return CompletionResponse( | |
| content="", | |
| tool_calls=[ | |
| ToolCall( | |
| id=f"call_{self.call_count}", | |
| name="search_documents", | |
| arguments={"query": "test"}, | |
| ) | |
| ], | |
| usage=TokenUsage(input_tokens=100, output_tokens=20, estimated_cost_usd=0.0001), | |
| provider="mock", | |
| model="mock-1", | |
| latency_ms=1.0, | |
| ) | |
| class MultiSearchProvider(MockProvider): | |
| """Provider that searches twice (different queries), then answers.""" | |
| async def complete(self, messages, tools=None, temperature=0.0, max_tokens=1024): | |
| self.call_count += 1 | |
| tool_results = [m for m in messages if m.role == Role.TOOL] | |
| if tools and len(tool_results) == 0: | |
| return CompletionResponse( | |
| content="", | |
| tool_calls=[ | |
| ToolCall(id="call_1", name="search_documents", arguments={"query": "first"}) | |
| ], | |
| usage=TokenUsage(input_tokens=100, output_tokens=20, estimated_cost_usd=0.0001), | |
| provider="mock", | |
| model="mock-1", | |
| latency_ms=1.0, | |
| ) | |
| elif tools and len(tool_results) == 1: | |
| return CompletionResponse( | |
| content="", | |
| tool_calls=[ | |
| ToolCall(id="call_2", name="search_documents", arguments={"query": "second"}) | |
| ], | |
| usage=TokenUsage(input_tokens=150, output_tokens=25, estimated_cost_usd=0.0002), | |
| provider="mock", | |
| model="mock-1", | |
| latency_ms=1.0, | |
| ) | |
| else: | |
| return CompletionResponse( | |
| content="Answer from two searches. [source: fastapi_path_params.md]", | |
| tool_calls=[], | |
| usage=TokenUsage(input_tokens=200, output_tokens=50, estimated_cost_usd=0.0003), | |
| provider="mock", | |
| model="mock-1", | |
| latency_ms=2.0, | |
| ) | |
| def _make_registry() -> ToolRegistry: | |
| registry = ToolRegistry() | |
| registry.register(FakeSearchTool()) | |
| registry.register(CalculatorTool()) | |
| return registry | |
| SYSTEM_PROMPT = "You are a helpful assistant." | |
| # --- Tests --- | |
| class TestOrchestrator: | |
| async def test_produces_agent_response_with_all_fields(self): | |
| """Orchestrator returns AgentResponse with all required fields.""" | |
| orchestrator = Orchestrator( | |
| provider=MockProvider(), registry=_make_registry(), max_iterations=3 | |
| ) | |
| response = await orchestrator.run("How do path params work?", SYSTEM_PROMPT) | |
| assert isinstance(response, AgentResponse) | |
| assert len(response.answer) > 0 | |
| assert response.iterations >= 1 | |
| assert response.usage.input_tokens > 0 | |
| assert response.usage.output_tokens > 0 | |
| assert response.latency_ms > 0 | |
| assert isinstance(response.sources, list) | |
| assert isinstance(response.tools_used, list) | |
| async def test_respects_max_iterations(self): | |
| """When provider always returns tool_calls, orchestrator stops at max_iterations.""" | |
| provider = AlwaysToolCallProvider() | |
| orchestrator = Orchestrator(provider=provider, registry=_make_registry(), max_iterations=2) | |
| response = await orchestrator.run("test question", SYSTEM_PROMPT) | |
| # 2 iterations of tool calls + 1 forced final call = 3 provider calls total | |
| assert provider.call_count == 3 | |
| assert response.iterations == 2 | |
| assert response.answer == "Forced answer after max iterations." | |
| async def test_accumulates_sources_from_multiple_searches(self): | |
| """Sources from multiple search calls are accumulated and deduplicated.""" | |
| orchestrator = Orchestrator( | |
| provider=MultiSearchProvider(), registry=_make_registry(), max_iterations=3 | |
| ) | |
| response = await orchestrator.run("multi search question", SYSTEM_PROMPT) | |
| # FakeSearchTool always returns fastapi_path_params.md | |
| assert len(response.sources) == 1 # deduplicated | |
| assert response.sources[0].source == "fastapi_path_params.md" | |
| assert response.tools_used.count("search_documents") == 2 | |
| # Token usage accumulated across 3 provider calls | |
| assert response.usage.input_tokens == 100 + 150 + 200 | |
| assert response.usage.output_tokens == 20 + 25 + 50 | |
| async def test_deterministic_output(self): | |
| """Fixed question + MockProvider → exact expected answer.""" | |
| orchestrator = Orchestrator( | |
| provider=MockProvider(), registry=_make_registry(), max_iterations=3 | |
| ) | |
| response = await orchestrator.run("How do path params work?", SYSTEM_PROMPT) | |
| # MockProvider: first call returns tool_calls for search_documents, | |
| # second call (with tool results) returns the canned answer | |
| assert "path parameters" in response.answer.lower() | |
| assert "[source: fastapi_path_params.md]" in response.answer | |
| assert "search_documents" in response.tools_used | |
| assert response.iterations == 2 | |
| class TestOrchestratorIntegration: | |
| """Integration test using real SearchTool + Retriever + HybridStore.""" | |
| async def test_real_rag_path(self, test_retriever): | |
| """Orchestrator with MockProvider + real SearchTool/Retriever returns RAG results.""" | |
| from agent_bench.tools.search import SearchTool | |
| registry = ToolRegistry() | |
| registry.register(SearchTool(retriever=test_retriever)) | |
| registry.register(CalculatorTool()) | |
| orchestrator = Orchestrator(provider=MockProvider(), registry=registry, max_iterations=3) | |
| response = await orchestrator.run( | |
| "How do path params work?", SYSTEM_PROMPT, top_k=3, strategy="hybrid" | |
| ) | |
| # MockProvider drives the loop, but the real SearchTool executes | |
| # against the real Retriever/HybridStore and returns real chunks | |
| assert isinstance(response, AgentResponse) | |
| assert len(response.answer) > 0 | |
| assert "search_documents" in response.tools_used | |
| # Sources come from the real store (sample_chunks in conftest) | |
| assert len(response.sources) > 0 | |