"""Tests for core types, config, and provider abstraction.""" from unittest.mock import patch import pytest from agent_bench.core.config import ( AppConfig, ProviderConfig, RetryConfig, load_config, load_task_config, ) from agent_bench.core.provider import ( AnthropicProvider, MockProvider, ProviderRateLimitError, create_provider, format_messages_anthropic, format_messages_openai, format_tools_anthropic, format_tools_openai, ) from agent_bench.core.types import ( CompletionResponse, Message, Role, TokenUsage, ToolCall, ToolDefinition, ) # --- Core types --- class TestCoreTypes: def test_message_creation(self): msg = Message(role=Role.USER, content="hello") assert msg.role == Role.USER assert msg.content == "hello" assert msg.tool_call_id is None assert msg.tool_calls is None def test_tool_call_creation(self): tc = ToolCall(id="call_123", name="search", arguments={"query": "test"}) assert tc.id == "call_123" assert tc.name == "search" assert tc.arguments == {"query": "test"} def test_token_usage_creation(self): usage = TokenUsage(input_tokens=100, output_tokens=50, estimated_cost_usd=0.001) assert usage.input_tokens == 100 assert usage.output_tokens == 50 assert usage.estimated_cost_usd == pytest.approx(0.001) def test_completion_response_defaults(self): resp = CompletionResponse( content="answer", usage=TokenUsage(input_tokens=10, output_tokens=5, estimated_cost_usd=0.0), provider="mock", model="mock-1", latency_ms=50.0, ) assert resp.tool_calls == [] assert resp.content == "answer" def test_tool_definition_schema(self): td = ToolDefinition( name="calculator", description="Evaluate math", parameters={ "type": "object", "properties": {"expression": {"type": "string"}}, "required": ["expression"], }, ) assert td.name == "calculator" assert "expression" in td.parameters["properties"] # --- Config --- class TestConfig: def test_load_default_config(self): config = load_config() assert config.provider.default == "openai" assert config.agent.max_iterations == 3 assert config.agent.temperature == 0.0 assert config.rag.chunking.strategy == "recursive" assert config.rag.chunking.chunk_size == 512 assert config.rag.retrieval.rrf_k == 60 assert config.rag.retrieval.top_k == 5 def test_model_pricing_available(self): config = load_config() models = config.provider.models assert "gpt-4o-mini" in models assert models["gpt-4o-mini"].input_cost_per_mtok == 0.15 assert models["gpt-4o-mini"].output_cost_per_mtok == 0.60 def test_cost_calculation(self): config = load_config() model_config = config.provider.models["gpt-4o-mini"] input_tokens = 1000 output_tokens = 500 expected_cost = (1000 * 0.15 + 500 * 0.60) / 1_000_000 cost = ( input_tokens * model_config.input_cost_per_mtok + output_tokens * model_config.output_cost_per_mtok ) / 1_000_000 assert cost == pytest.approx(expected_cost) def test_load_task_config(self): task = load_task_config("tech_docs") assert task.name == "tech_docs" assert "search_documents" in task.system_prompt assert "[source:" in task.system_prompt # --- MockProvider --- class TestMockProvider: @pytest.mark.asyncio async def test_returns_tool_calls_on_first_call(self, mock_provider): messages = [ Message(role=Role.SYSTEM, content="You are helpful."), Message(role=Role.USER, content="Search for FastAPI path params"), ] tools = [ ToolDefinition( name="search_documents", description="Search docs", parameters={"type": "object", "properties": {"query": {"type": "string"}}}, ) ] response = await mock_provider.complete(messages, tools=tools) assert len(response.tool_calls) > 0 assert response.tool_calls[0].name == "search_documents" assert response.provider == "mock" assert response.usage.input_tokens > 0 @pytest.mark.asyncio async def test_returns_final_answer_when_tool_results_present(self, mock_provider): messages = [ Message(role=Role.SYSTEM, content="You are helpful."), Message(role=Role.USER, content="Search for FastAPI path params"), Message( role=Role.ASSISTANT, content="", tool_calls=[ ToolCall( id="call_1", name="search_documents", arguments={"query": "path params"} ) ], ), Message(role=Role.TOOL, content="Path params use curly braces.", tool_call_id="call_1"), ] response = await mock_provider.complete(messages) assert response.tool_calls == [] assert len(response.content) > 0 assert response.usage.input_tokens > 0 @pytest.mark.asyncio async def test_returns_answer_without_tools(self, mock_provider): messages = [ Message(role=Role.SYSTEM, content="You are helpful."), Message(role=Role.USER, content="Hello"), ] response = await mock_provider.complete(messages, tools=None) assert response.tool_calls == [] assert len(response.content) > 0 def test_format_tools_returns_list(self, mock_provider): tools = [ ToolDefinition( name="calc", description="Calculate", parameters={"type": "object", "properties": {}}, ) ] formatted = mock_provider.format_tools(tools) assert isinstance(formatted, list) assert len(formatted) == 1 # --- OpenAI format functions (tested as pure functions, no API key needed) --- class TestOpenAIFormat: def test_format_tools_produces_openai_schema(self): tools = [ ToolDefinition( name="search_documents", description="Search the documentation corpus", parameters={ "type": "object", "properties": { "query": {"type": "string", "description": "Search query"}, "top_k": {"type": "integer", "description": "Number of results"}, }, "required": ["query"], }, ) ] formatted = format_tools_openai(tools) assert len(formatted) == 1 assert formatted[0]["type"] == "function" func = formatted[0]["function"] assert func["name"] == "search_documents" assert func["description"] == "Search the documentation corpus" assert func["parameters"]["required"] == ["query"] def test_format_messages_maps_roles(self): messages = [ Message(role=Role.SYSTEM, content="system prompt"), Message(role=Role.USER, content="user question"), Message( role=Role.ASSISTANT, content="", tool_calls=[ToolCall(id="call_1", name="search", arguments={"q": "test"})], ), Message(role=Role.TOOL, content="tool result", tool_call_id="call_1"), ] formatted = format_messages_openai(messages) assert formatted[0]["role"] == "system" assert formatted[1]["role"] == "user" assert formatted[2]["role"] == "assistant" assert formatted[2]["tool_calls"][0]["id"] == "call_1" assert formatted[2]["tool_calls"][0]["function"]["name"] == "search" assert formatted[3]["role"] == "tool" assert formatted[3]["tool_call_id"] == "call_1" # --- OpenAI provider (mocked HTTP) --- class TestOpenAIProvider: def test_factory_creates_openai_provider(self, monkeypatch): """Factory returns OpenAIProvider for 'openai' config.""" monkeypatch.setenv("OPENAI_API_KEY", "test-key-fake") from agent_bench.core.provider import OpenAIProvider config = AppConfig(provider=ProviderConfig(default="openai")) provider = create_provider(config) assert isinstance(provider, OpenAIProvider) def test_format_tools_via_instance(self, monkeypatch): """OpenAIProvider.format_tools delegates to format_tools_openai correctly.""" monkeypatch.setenv("OPENAI_API_KEY", "test-key-fake") from agent_bench.core.provider import OpenAIProvider config = AppConfig(provider=ProviderConfig(default="openai")) provider = OpenAIProvider(config) tools = [ ToolDefinition( name="search_documents", description="Search docs", parameters={"type": "object", "properties": {"query": {"type": "string"}}}, ) ] formatted = provider.format_tools(tools) assert formatted[0]["type"] == "function" assert formatted[0]["function"]["name"] == "search_documents" @pytest.mark.asyncio async def test_complete_with_mocked_response(self, monkeypatch): """OpenAI complete() parses a mocked API response correctly.""" monkeypatch.setenv("OPENAI_API_KEY", "test-key-fake") import httpx import respx from agent_bench.core.provider import OpenAIProvider config = AppConfig(provider=ProviderConfig(default="openai")) provider = OpenAIProvider(config) mock_response = { "id": "chatcmpl-test", "object": "chat.completion", "created": 1234567890, "model": "gpt-4o-mini", "choices": [ { "index": 0, "message": { "role": "assistant", "content": "FastAPI uses curly braces. [source: path_params.md]", "tool_calls": None, }, "finish_reason": "stop", } ], "usage": {"prompt_tokens": 100, "completion_tokens": 30, "total_tokens": 130}, } with respx.mock: respx.post("https://api.openai.com/v1/chat/completions").mock( return_value=httpx.Response(200, json=mock_response) ) response = await provider.complete( [Message(role=Role.USER, content="How do path params work?")] ) assert response.content == "FastAPI uses curly braces. [source: path_params.md]" assert response.tool_calls == [] assert response.provider == "openai" assert response.usage.input_tokens == 100 assert response.usage.output_tokens == 30 assert response.usage.estimated_cost_usd > 0 assert response.latency_ms > 0 @pytest.mark.asyncio async def test_complete_parses_tool_calls(self, monkeypatch): """OpenAI complete() correctly parses tool_calls from response.""" monkeypatch.setenv("OPENAI_API_KEY", "test-key-fake") import json import httpx import respx from agent_bench.core.provider import OpenAIProvider config = AppConfig(provider=ProviderConfig(default="openai")) provider = OpenAIProvider(config) mock_response = { "id": "chatcmpl-test2", "object": "chat.completion", "created": 1234567890, "model": "gpt-4o-mini", "choices": [ { "index": 0, "message": { "role": "assistant", "content": None, "tool_calls": [ { "id": "call_abc123", "type": "function", "function": { "name": "search_documents", "arguments": json.dumps({"query": "path parameters"}), }, } ], }, "finish_reason": "tool_calls", } ], "usage": {"prompt_tokens": 80, "completion_tokens": 20, "total_tokens": 100}, } tools = [ ToolDefinition( name="search_documents", description="Search docs", parameters={"type": "object", "properties": {"query": {"type": "string"}}}, ) ] with respx.mock: respx.post("https://api.openai.com/v1/chat/completions").mock( return_value=httpx.Response(200, json=mock_response) ) response = await provider.complete( [Message(role=Role.USER, content="search for path params")], tools=tools, ) assert len(response.tool_calls) == 1 assert response.tool_calls[0].id == "call_abc123" assert response.tool_calls[0].name == "search_documents" assert response.tool_calls[0].arguments == {"query": "path parameters"} # --- Anthropic stub --- class TestAnthropicFormat: def test_format_tools_produces_anthropic_schema(self): tools = [ ToolDefinition( name="search_documents", description="Search docs", parameters={ "type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"], }, ) ] formatted = format_tools_anthropic(tools) assert len(formatted) == 1 assert formatted[0]["name"] == "search_documents" assert "input_schema" in formatted[0] assert "parameters" not in formatted[0] assert formatted[0]["input_schema"]["required"] == ["query"] def test_format_messages_extracts_system(self): messages = [ Message(role=Role.SYSTEM, content="You are helpful."), Message(role=Role.USER, content="Hello"), ] system, formatted = format_messages_anthropic(messages) assert system == "You are helpful." assert len(formatted) == 1 assert formatted[0]["role"] == "user" def test_format_messages_tool_result(self): messages = [ Message(role=Role.USER, content="search for X"), Message( role=Role.ASSISTANT, content="", tool_calls=[ ToolCall( id="tc_1", name="search", arguments={"query": "X"}, ) ], ), Message( role=Role.TOOL, content="Result for X", tool_call_id="tc_1", ), ] _, formatted = format_messages_anthropic(messages) assert len(formatted) == 3 # Assistant with tool_use block assert formatted[1]["content"][0]["type"] == "tool_use" assert formatted[1]["content"][0]["id"] == "tc_1" # Tool result as user message with tool_result block assert formatted[2]["role"] == "user" assert formatted[2]["content"][0]["type"] == "tool_result" assert formatted[2]["content"][0]["tool_use_id"] == "tc_1" class TestAnthropicProvider: def test_factory_creates_anthropic_provider(self, monkeypatch): monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key-fake") config = AppConfig(provider=ProviderConfig(default="anthropic")) provider = create_provider(config) assert isinstance(provider, AnthropicProvider) def test_format_tools_via_instance(self, monkeypatch): monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key-fake") config = AppConfig(provider=ProviderConfig(default="anthropic")) provider = AnthropicProvider(config) tools = [ ToolDefinition( name="search_documents", description="Search docs", parameters={ "type": "object", "properties": {"query": {"type": "string"}}, }, ) ] formatted = provider.format_tools(tools) assert formatted[0]["name"] == "search_documents" assert "input_schema" in formatted[0] @pytest.mark.asyncio async def test_complete_with_mocked_response(self, monkeypatch): monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key-fake") import httpx import respx config = AppConfig(provider=ProviderConfig(default="anthropic")) provider = AnthropicProvider(config) mock_response = { "id": "msg_test", "type": "message", "role": "assistant", "model": "claude-haiku-4-5-20251001", "content": [ { "type": "text", "text": "FastAPI uses curly braces. [source: path_params.md]", } ], "stop_reason": "end_turn", "usage": { "input_tokens": 100, "output_tokens": 30, }, } with respx.mock: respx.post("https://api.anthropic.com/v1/messages").mock( return_value=httpx.Response(200, json=mock_response) ) response = await provider.complete( [ Message(role=Role.SYSTEM, content="Be helpful."), Message(role=Role.USER, content="How do path params work?"), ] ) assert "curly braces" in response.content assert response.tool_calls == [] assert response.provider == "anthropic" assert response.usage.input_tokens == 100 @pytest.mark.asyncio async def test_complete_parses_tool_calls(self, monkeypatch): monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key-fake") import httpx import respx config = AppConfig(provider=ProviderConfig(default="anthropic")) provider = AnthropicProvider(config) mock_response = { "id": "msg_test2", "type": "message", "role": "assistant", "model": "claude-haiku-4-5-20251001", "content": [ { "type": "tool_use", "id": "toolu_abc123", "name": "search_documents", "input": {"query": "path parameters"}, } ], "stop_reason": "tool_use", "usage": { "input_tokens": 80, "output_tokens": 20, }, } with respx.mock: respx.post("https://api.anthropic.com/v1/messages").mock( return_value=httpx.Response(200, json=mock_response) ) response = await provider.complete( [Message(role=Role.USER, content="search for path params")], tools=[ ToolDefinition( name="search_documents", description="Search docs", parameters={ "type": "object", "properties": { "query": {"type": "string"}, }, }, ) ], ) assert len(response.tool_calls) == 1 assert response.tool_calls[0].id == "toolu_abc123" assert response.tool_calls[0].name == "search_documents" assert response.tool_calls[0].arguments == {"query": "path parameters"} # --- Provider factory --- class TestProviderFactory: def test_create_mock_provider(self): config = AppConfig(provider=ProviderConfig(default="mock")) provider = create_provider(config) assert isinstance(provider, MockProvider) def test_create_unknown_provider_raises(self): config = AppConfig(provider=ProviderConfig(default="unknown")) with pytest.raises(ValueError, match="Unknown provider"): create_provider(config) # --- Retry logic --- class TestProviderRetry: """Tests for OpenAI provider retry with exponential backoff.""" MOCK_SUCCESS_RESPONSE = { "id": "chatcmpl-retry", "object": "chat.completion", "created": 1234567890, "model": "gpt-4o-mini", "choices": [ { "index": 0, "message": { "role": "assistant", "content": "Success after retry.", "tool_calls": None, }, "finish_reason": "stop", } ], "usage": {"prompt_tokens": 50, "completion_tokens": 10, "total_tokens": 60}, } @pytest.mark.asyncio async def test_retry_on_rate_limit(self, monkeypatch): """Two failures then success — returns answer.""" monkeypatch.setenv("OPENAI_API_KEY", "test-key-fake") import httpx import respx from agent_bench.core.provider import OpenAIProvider config = AppConfig( provider=ProviderConfig(default="openai"), retry=RetryConfig(max_retries=3, base_delay=0.01, max_delay=0.1), ) provider = OpenAIProvider(config) call_count = 0 def side_effect(request): nonlocal call_count call_count += 1 if call_count <= 2: return httpx.Response(429, json={"error": {"message": "Rate limit exceeded"}}) return httpx.Response(200, json=self.MOCK_SUCCESS_RESPONSE) with respx.mock: respx.post("https://api.openai.com/v1/chat/completions").mock( side_effect=side_effect ) from agent_bench.core.types import Message, Role response = await provider.complete( [Message(role=Role.USER, content="test")] ) assert response.content == "Success after retry." assert call_count == 3 @pytest.mark.asyncio async def test_retry_exhausted(self, monkeypatch): """All retries fail — raises ProviderRateLimitError.""" monkeypatch.setenv("OPENAI_API_KEY", "test-key-fake") import httpx import respx from agent_bench.core.provider import OpenAIProvider config = AppConfig( provider=ProviderConfig(default="openai"), retry=RetryConfig(max_retries=2, base_delay=0.01, max_delay=0.1), ) provider = OpenAIProvider(config) with respx.mock: respx.post("https://api.openai.com/v1/chat/completions").mock( return_value=httpx.Response(429, json={"error": {"message": "Rate limit"}}) ) from agent_bench.core.types import Message, Role with pytest.raises(ProviderRateLimitError, match="Rate limited after"): await provider.complete( [Message(role=Role.USER, content="test")] ) @pytest.mark.asyncio async def test_no_retry_on_other_errors(self, monkeypatch): """Non-rate-limit errors fail immediately without retry.""" monkeypatch.setenv("OPENAI_API_KEY", "test-key-fake") import httpx import respx from agent_bench.core.provider import OpenAIProvider config = AppConfig( provider=ProviderConfig(default="openai"), retry=RetryConfig(max_retries=3, base_delay=0.01, max_delay=0.1), ) provider = OpenAIProvider(config) call_count = 0 def side_effect(request): nonlocal call_count call_count += 1 return httpx.Response(400, json={"error": {"message": "Bad request"}}) with respx.mock: respx.post("https://api.openai.com/v1/chat/completions").mock( side_effect=side_effect ) from agent_bench.core.types import Message, Role with pytest.raises(Exception): await provider.complete( [Message(role=Role.USER, content="test")] ) assert call_count == 1 # no retry @pytest.mark.asyncio async def test_retry_backoff_timing(self, monkeypatch): """Verify exponential backoff delays between retries.""" monkeypatch.setenv("OPENAI_API_KEY", "test-key-fake") import httpx import respx from agent_bench.core.provider import OpenAIProvider config = AppConfig( provider=ProviderConfig(default="openai"), retry=RetryConfig(max_retries=3, base_delay=1.0, max_delay=8.0), ) provider = OpenAIProvider(config) sleep_calls: list[float] = [] async def mock_sleep(seconds): sleep_calls.append(seconds) with respx.mock, patch("asyncio.sleep", side_effect=mock_sleep): respx.post("https://api.openai.com/v1/chat/completions").mock( return_value=httpx.Response(429, json={"error": {"message": "Rate limit"}}) ) from agent_bench.core.types import Message, Role with pytest.raises(ProviderRateLimitError): await provider.complete( [Message(role=Role.USER, content="test")] ) # 3 retries: delays should be 1.0, 2.0, 4.0 assert len(sleep_calls) == 3 assert sleep_calls[0] == pytest.approx(1.0) assert sleep_calls[1] == pytest.approx(2.0) assert sleep_calls[2] == pytest.approx(4.0) class TestStreamingRetry: """Tests for stream_complete() retry/timeout parity with complete().""" @pytest.mark.asyncio async def test_stream_retry_on_rate_limit(self, monkeypatch): """stream_complete retries on 429 then succeeds.""" monkeypatch.setenv("OPENAI_API_KEY", "test-key-fake") import httpx import respx from agent_bench.core.provider import OpenAIProvider config = AppConfig( provider=ProviderConfig(default="openai"), retry=RetryConfig(max_retries=3, base_delay=0.01, max_delay=0.1), ) provider = OpenAIProvider(config) call_count = 0 # Streaming API: first 2 calls return 429, third returns SSE chunks def side_effect(request): nonlocal call_count call_count += 1 if call_count <= 2: return httpx.Response( 429, json={"error": {"message": "Rate limit"}} ) # Simulate streaming response with SSE format sse_body = ( 'data: {"id":"x","object":"chat.completion.chunk",' '"choices":[{"index":0,"delta":{"content":"hello"},' '"finish_reason":null}]}\n\n' 'data: [DONE]\n\n' ) return httpx.Response( 200, content=sse_body.encode(), headers={"content-type": "text/event-stream"}, ) with respx.mock: respx.post("https://api.openai.com/v1/chat/completions").mock( side_effect=side_effect ) from agent_bench.core.types import Message, Role chunks = [] async for chunk in provider.stream_complete( [Message(role=Role.USER, content="test")] ): chunks.append(chunk) assert call_count == 3 assert len(chunks) > 0 @pytest.mark.asyncio async def test_stream_retry_exhausted(self, monkeypatch): """stream_complete raises ProviderRateLimitError after retries.""" monkeypatch.setenv("OPENAI_API_KEY", "test-key-fake") import httpx import respx from agent_bench.core.provider import OpenAIProvider config = AppConfig( provider=ProviderConfig(default="openai"), retry=RetryConfig(max_retries=2, base_delay=0.01, max_delay=0.1), ) provider = OpenAIProvider(config) with respx.mock: respx.post("https://api.openai.com/v1/chat/completions").mock( return_value=httpx.Response( 429, json={"error": {"message": "Rate limit"}} ) ) from agent_bench.core.types import Message, Role with pytest.raises(ProviderRateLimitError, match="Rate limited"): async for _ in provider.stream_complete( [Message(role=Role.USER, content="test")] ): pass # pragma: no cover @pytest.mark.asyncio async def test_stream_timeout_raises(self, monkeypatch): """stream_complete translates APITimeoutError to ProviderTimeoutError.""" monkeypatch.setenv("OPENAI_API_KEY", "test-key-fake") from agent_bench.core.provider import OpenAIProvider, ProviderTimeoutError config = AppConfig( provider=ProviderConfig(default="openai"), retry=RetryConfig(max_retries=1, base_delay=0.01, max_delay=0.1), ) provider = OpenAIProvider(config) from openai import APITimeoutError async def mock_create(**kwargs): raise APITimeoutError(request=None) provider.client.chat.completions.create = mock_create # type: ignore[assignment] from agent_bench.core.types import Message, Role with pytest.raises(ProviderTimeoutError, match="timed out"): async for _ in provider.stream_complete( [Message(role=Role.USER, content="test")] ): pass # pragma: no cover class TestAnthropicStreamingRetry: """Tests for Anthropic stream_complete() retry/timeout parity.""" @pytest.mark.asyncio async def test_stream_retry_exhausted(self, monkeypatch): """stream_complete raises ProviderRateLimitError after retries.""" monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key-fake") from anthropic import RateLimitError as AnthropicRateLimitError config = AppConfig( provider=ProviderConfig(default="anthropic"), retry=RetryConfig(max_retries=2, base_delay=0.01, max_delay=0.1), ) provider = AnthropicProvider(config) call_count = 0 def mock_stream(**kwargs): nonlocal call_count call_count += 1 url = "https://api.anthropic.com/v1/messages" mock_req = type("Req", (), {"method": "POST", "url": url})() mock_resp = type( "Resp", (), {"status_code": 429, "headers": {}, "request": mock_req} )() raise AnthropicRateLimitError( message="rate limited", response=mock_resp, body=None, ) provider.client.messages.stream = mock_stream # type: ignore[assignment] from agent_bench.core.types import Message, Role with pytest.raises(ProviderRateLimitError, match="Rate limited"): async for _ in provider.stream_complete( [Message(role=Role.USER, content="test")] ): pass # pragma: no cover assert call_count == 3 # initial + 2 retries @pytest.mark.asyncio async def test_stream_timeout_raises(self, monkeypatch): """stream_complete translates APITimeoutError to ProviderTimeoutError.""" monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key-fake") from agent_bench.core.provider import ProviderTimeoutError config = AppConfig( provider=ProviderConfig(default="anthropic"), retry=RetryConfig(max_retries=1, base_delay=0.01, max_delay=0.1), ) provider = AnthropicProvider(config) from anthropic import APITimeoutError as AnthropicTimeoutError def mock_stream(**kwargs): raise AnthropicTimeoutError(request=None) provider.client.messages.stream = mock_stream # type: ignore[assignment] from agent_bench.core.types import Message, Role with pytest.raises(ProviderTimeoutError, match="timed out"): async for _ in provider.stream_complete( [Message(role=Role.USER, content="test")] ): pass # pragma: no cover