agentbench / tests /test_serving.py
Nomearod's picture
style: fix ruff lint β€” import sorting, line length
12a17f8
"""Tests for the serving layer: routes, schemas, middleware."""
from __future__ import annotations
import time
import pytest
from httpx import ASGITransport, AsyncClient
from agent_bench.agents.orchestrator import Orchestrator
from agent_bench.core.config import AppConfig, ProviderConfig
from agent_bench.core.provider import MockProvider, ProviderTimeoutError
from agent_bench.rag.store import HybridStore
from agent_bench.serving.middleware import MetricsCollector, RateLimitMiddleware, RequestMiddleware
from agent_bench.tools.calculator import CalculatorTool
from agent_bench.tools.registry import ToolRegistry
from .test_agent import FakeSearchTool
def _make_test_app():
"""Create a test app with MockProvider and no real store."""
from fastapi import FastAPI
app = FastAPI(title="agent-bench-test")
registry = ToolRegistry()
registry.register(FakeSearchTool())
registry.register(CalculatorTool())
provider = MockProvider()
orchestrator = Orchestrator(provider=provider, registry=registry, max_iterations=3)
app.state.orchestrator = orchestrator
app.state.store = HybridStore(dimension=384) # empty store for health check
app.state.config = AppConfig(provider=ProviderConfig(default="mock"))
app.state.system_prompt = "You are a test assistant."
app.state.start_time = time.time()
app.state.metrics = MetricsCollector()
app.add_middleware(RequestMiddleware)
from agent_bench.serving.routes import router
app.include_router(router)
return app
def _make_timeout_app():
"""Create a test app where the provider always times out."""
from fastapi import FastAPI
class TimeoutProvider(MockProvider):
async def complete(self, messages, tools=None, temperature=0.0, max_tokens=1024):
raise ProviderTimeoutError("test timeout")
app = FastAPI(title="agent-bench-timeout")
registry = ToolRegistry()
registry.register(FakeSearchTool())
provider = TimeoutProvider()
orchestrator = Orchestrator(provider=provider, registry=registry, max_iterations=1)
app.state.orchestrator = orchestrator
app.state.store = HybridStore(dimension=384)
app.state.config = AppConfig(provider=ProviderConfig(default="mock"))
app.state.system_prompt = "You are a test assistant."
app.state.start_time = time.time()
app.state.metrics = MetricsCollector()
app.add_middleware(RequestMiddleware)
from agent_bench.serving.routes import router
app.include_router(router)
return app
@pytest.fixture
def test_app():
return _make_test_app()
@pytest.fixture
def timeout_app():
return _make_timeout_app()
class TestAskEndpoint:
@pytest.mark.asyncio
async def test_valid_question_returns_200(self, test_app):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
response = await client.post("/ask", json={"question": "How do path parameters work?"})
assert response.status_code == 200
data = response.json()
assert "answer" in data
assert "sources" in data
assert "metadata" in data
assert len(data["answer"]) > 0
assert data["metadata"]["request_id"]
assert data["metadata"]["provider"] == "mock"
assert data["metadata"]["model"] == "mock-1"
@pytest.mark.asyncio
async def test_empty_question_returns_422(self, test_app):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
response = await client.post("/ask", json={"question": ""})
assert response.status_code == 422
@pytest.mark.asyncio
async def test_missing_question_returns_422(self, test_app):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
response = await client.post("/ask", json={})
assert response.status_code == 422
class TestHealthEndpoint:
@pytest.mark.asyncio
async def test_returns_health_response(self, test_app):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] in ("healthy", "degraded")
assert "vector_store_chunks" in data
assert "provider_available" in data
assert "uptime_seconds" in data
class TestMetricsEndpoint:
@pytest.mark.asyncio
async def test_returns_metrics_response(self, test_app):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
response = await client.get("/metrics")
assert response.status_code == 200
data = response.json()
assert "requests_total" in data
assert "latency_p50_ms" in data
assert "latency_p95_ms" in data
assert "errors_total" in data
assert "avg_cost_per_query_usd" in data
@pytest.mark.asyncio
async def test_prometheus_endpoint_returns_text_exposition(self, test_app):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
response = await client.get("/metrics/prometheus")
assert response.status_code == 200
assert "text/plain" in response.headers["content-type"]
body = response.text
assert "# TYPE agent_bench_requests_total counter" in body
assert "agent_bench_requests_total " in body
assert "# TYPE agent_bench_latency_p95_ms gauge" in body
assert "agent_bench_latency_p95_ms " in body
assert "# TYPE agent_bench_errors_total counter" in body
class TestHealthCheckProbesProvider:
@pytest.mark.asyncio
async def test_healthy_when_provider_health_check_passes(self, test_app):
"""MockProvider.health_check() returns True (default), so status=healthy."""
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert data["provider_available"] is True
@pytest.mark.asyncio
async def test_degraded_when_provider_health_check_fails(self):
"""Provider whose health_check() returns False -> status=degraded."""
from fastapi import FastAPI
class UnhealthyProvider(MockProvider):
async def health_check(self) -> bool:
return False
app = FastAPI()
registry = ToolRegistry()
registry.register(FakeSearchTool())
provider = UnhealthyProvider()
orchestrator = Orchestrator(provider=provider, registry=registry, max_iterations=1)
app.state.orchestrator = orchestrator
app.state.store = HybridStore(dimension=384)
app.state.config = AppConfig(provider=ProviderConfig(default="mock"))
app.state.system_prompt = "test"
app.state.start_time = time.time()
app.state.metrics = MetricsCollector()
app.add_middleware(RequestMiddleware)
from agent_bench.serving.routes import router
app.include_router(router)
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "degraded"
assert data["provider_available"] is False
@pytest.mark.asyncio
async def test_degraded_when_provider_health_check_raises(self):
"""Provider whose health_check() raises -> status=degraded."""
from fastapi import FastAPI
class CrashingProvider(MockProvider):
async def health_check(self) -> bool:
raise ConnectionError("upstream unreachable")
app = FastAPI()
registry = ToolRegistry()
registry.register(FakeSearchTool())
provider = CrashingProvider()
orchestrator = Orchestrator(provider=provider, registry=registry, max_iterations=1)
app.state.orchestrator = orchestrator
app.state.store = HybridStore(dimension=384)
app.state.config = AppConfig(provider=ProviderConfig(default="mock"))
app.state.system_prompt = "test"
app.state.start_time = time.time()
app.state.metrics = MetricsCollector()
app.add_middleware(RequestMiddleware)
from agent_bench.serving.routes import router
app.include_router(router)
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "degraded"
assert data["provider_available"] is False
class TestMiddleware:
@pytest.mark.asyncio
async def test_request_id_header(self, test_app):
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
response = await client.get("/health")
assert "x-request-id" in response.headers
# UUID format: 8-4-4-4-12 hex chars
request_id = response.headers["x-request-id"]
assert len(request_id) == 36
@pytest.mark.asyncio
async def test_provider_timeout_returns_504(self, timeout_app):
async with AsyncClient(
transport=ASGITransport(app=timeout_app), base_url="http://test"
) as client:
response = await client.post("/ask", json={"question": "This will timeout"})
assert response.status_code == 504
data = response.json()
assert "request_id" in data
assert "x-request-id" in response.headers
# --- Rate limiting tests ---
def _make_rate_limited_app(rpm: int = 3):
"""Create a test app with rate limiting enabled."""
from fastapi import FastAPI
app = FastAPI(title="agent-bench-ratelimit")
registry = ToolRegistry()
registry.register(FakeSearchTool())
registry.register(CalculatorTool())
provider = MockProvider()
orchestrator = Orchestrator(provider=provider, registry=registry, max_iterations=3)
app.state.orchestrator = orchestrator
app.state.store = HybridStore(dimension=384)
app.state.config = AppConfig(provider=ProviderConfig(default="mock"))
app.state.system_prompt = "You are a test assistant."
app.state.start_time = time.time()
app.state.metrics = MetricsCollector()
app.add_middleware(RequestMiddleware)
app.add_middleware(RateLimitMiddleware, requests_per_minute=rpm)
from agent_bench.serving.routes import router
app.include_router(router)
return app
@pytest.fixture
def rate_limited_app():
return _make_rate_limited_app(rpm=3)
class TestRateLimiting:
@pytest.mark.asyncio
async def test_allows_normal_traffic(self, rate_limited_app):
"""Requests within the limit on /ask all succeed."""
async with AsyncClient(
transport=ASGITransport(app=rate_limited_app), base_url="http://test"
) as client:
for _ in range(3):
response = await client.post("/ask", json={"question": "test"})
assert response.status_code == 200
@pytest.mark.asyncio
async def test_blocks_excess(self, rate_limited_app):
"""Request beyond the limit gets 429."""
async with AsyncClient(
transport=ASGITransport(app=rate_limited_app), base_url="http://test"
) as client:
# Use up the quota
for _ in range(3):
await client.post("/ask", json={"question": "test"})
# Next request should be blocked
response = await client.post("/ask", json={"question": "test"})
assert response.status_code == 429
@pytest.mark.asyncio
async def test_retry_after_header(self, rate_limited_app):
"""429 response includes Retry-After header."""
async with AsyncClient(
transport=ASGITransport(app=rate_limited_app), base_url="http://test"
) as client:
# Exhaust quota on non-exempt path
for _ in range(3):
await client.post("/ask", json={"question": "test"})
response = await client.post("/ask", json={"question": "test"})
assert response.status_code == 429
assert "retry-after" in response.headers
assert int(response.headers["retry-after"]) > 0
@pytest.mark.asyncio
async def test_health_exempt(self):
"""Health endpoint is never rate limited."""
app = _make_rate_limited_app(rpm=2)
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
# Exhaust quota on non-exempt path
for _ in range(2):
await client.post("/ask", json={"question": "test"})
# Health should still work
response = await client.get("/health")
assert response.status_code == 200
# But another ask should be blocked
response = await client.post("/ask", json={"question": "test"})
assert response.status_code == 429
@pytest.mark.asyncio
async def test_metrics_exempt(self):
"""Metrics endpoint is never rate limited."""
app = _make_rate_limited_app(rpm=2)
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
# Exhaust quota on non-exempt path
for _ in range(2):
await client.post("/ask", json={"question": "test"})
# Metrics should still work
response = await client.get("/metrics")
assert response.status_code == 200
def test_per_ip_isolation(self):
"""Two different client IPs get independent quotas.
ASGI test clients share one IP, so we test the middleware's
per-IP window tracking directly.
"""
middleware = RateLimitMiddleware(
app=None, requests_per_minute=2, # type: ignore[arg-type]
)
now = time.time()
# IP 1 exhausts quota
middleware.windows["10.0.0.1"] = [now, now]
# IP 2 is fresh
middleware.windows["10.0.0.2"] = [now]
# IP 1 should be at limit
assert len(middleware.windows["10.0.0.1"]) >= 2
# IP 2 should still have room
assert len(middleware.windows["10.0.0.2"]) < 2
# IPs don't share state
assert "10.0.0.1" in middleware.windows
assert "10.0.0.2" in middleware.windows
assert middleware.windows["10.0.0.1"] != middleware.windows["10.0.0.2"]
@pytest.mark.asyncio
async def test_window_resets_after_60s(self):
"""Sliding window prunes old timestamps, restoring quota after 60s."""
from unittest.mock import patch
app = _make_rate_limited_app(rpm=2)
fake_time = time.time()
with patch("agent_bench.serving.middleware.time") as mock_time:
mock_time.time.return_value = fake_time
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
# Exhaust quota at t=0
for _ in range(2):
resp = await client.post("/ask", json={"question": "test"})
assert resp.status_code == 200
# Blocked at t=0
resp = await client.post("/ask", json={"question": "test"})
assert resp.status_code == 429
# Advance time past the 60s window
mock_time.time.return_value = fake_time + 61
# Quota should be restored β€” old timestamps pruned
resp = await client.post("/ask", json={"question": "test"})
assert resp.status_code == 200
# --- Streaming tests ---
class TestStreaming:
@pytest.mark.asyncio
async def test_stream_endpoint_returns_sse(self, test_app):
"""Content-type is text/event-stream."""
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
response = await client.post(
"/ask/stream", json={"question": "How do path parameters work?"}
)
assert response.status_code == 200
assert "text/event-stream" in response.headers["content-type"]
@pytest.mark.asyncio
async def test_stream_events_ordered(self, test_app):
"""Legacy event sequence preserved: sources β†’ chunk* β†’ done."""
import json as json_mod
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
response = await client.post(
"/ask/stream", json={"question": "How do path parameters work?"}
)
all_events = []
for line in response.text.strip().split("\n"):
if line.startswith("data: "):
all_events.append(json_mod.loads(line[6:]))
# Filter to legacy event types only (stage events are additive)
legacy_types = ("sources", "chunk", "done", "_orchestrator_done")
legacy = [e for e in all_events if e["type"] in legacy_types]
assert len(legacy) >= 3 # at least sources + 1 chunk + done
assert legacy[0]["type"] == "sources"
assert legacy[-1]["type"] in ("done", "_orchestrator_done")
assert all(e["type"] == "chunk" for e in legacy[1:-1])
@pytest.mark.asyncio
async def test_stream_chunks_assemble(self, test_app):
"""Concatenating chunks produces coherent text."""
import json as json_mod
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
response = await client.post(
"/ask/stream", json={"question": "test"}
)
chunks = []
for line in response.text.strip().split("\n"):
if line.startswith("data: "):
event = json_mod.loads(line[6:])
if event["type"] == "chunk":
chunks.append(event["content"])
full_text = "".join(chunks)
assert len(full_text) > 0
assert "FastAPI" in full_text # MockProvider mentions FastAPI
@pytest.mark.asyncio
async def test_stream_emits_single_answer_chunk(self, test_app):
"""Stream emits the complete answer as a single chunk (no redundant LLM call)."""
import json as json_mod
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
response = await client.post(
"/ask/stream", json={"question": "test"}
)
chunks = [
json_mod.loads(line[6:])
for line in response.text.strip().split("\n")
if line.startswith("data: ")
and json_mod.loads(line[6:])["type"] == "chunk"
]
assert len(chunks) == 1
assert len(chunks[0]["content"]) > 0
@pytest.mark.asyncio
async def test_non_streaming_unchanged(self, test_app):
"""POST /ask still works identically."""
async with AsyncClient(
transport=ASGITransport(app=test_app), base_url="http://test"
) as client:
response = await client.post(
"/ask", json={"question": "How do path parameters work?"}
)
assert response.status_code == 200
data = response.json()
assert "answer" in data
assert "sources" in data