Spaces:
Running
Running
| """Tests for the serving layer: routes, schemas, middleware.""" | |
| from __future__ import annotations | |
| import time | |
| import pytest | |
| from httpx import ASGITransport, AsyncClient | |
| from agent_bench.agents.orchestrator import Orchestrator | |
| from agent_bench.core.config import AppConfig, ProviderConfig | |
| from agent_bench.core.provider import MockProvider, ProviderTimeoutError | |
| from agent_bench.rag.store import HybridStore | |
| from agent_bench.serving.middleware import MetricsCollector, RateLimitMiddleware, RequestMiddleware | |
| from agent_bench.tools.calculator import CalculatorTool | |
| from agent_bench.tools.registry import ToolRegistry | |
| from .test_agent import FakeSearchTool | |
| def _make_test_app(): | |
| """Create a test app with MockProvider and no real store.""" | |
| from fastapi import FastAPI | |
| app = FastAPI(title="agent-bench-test") | |
| registry = ToolRegistry() | |
| registry.register(FakeSearchTool()) | |
| registry.register(CalculatorTool()) | |
| provider = MockProvider() | |
| orchestrator = Orchestrator(provider=provider, registry=registry, max_iterations=3) | |
| app.state.orchestrator = orchestrator | |
| app.state.store = HybridStore(dimension=384) # empty store for health check | |
| app.state.config = AppConfig(provider=ProviderConfig(default="mock")) | |
| app.state.system_prompt = "You are a test assistant." | |
| app.state.start_time = time.time() | |
| app.state.metrics = MetricsCollector() | |
| app.add_middleware(RequestMiddleware) | |
| from agent_bench.serving.routes import router | |
| app.include_router(router) | |
| return app | |
| def _make_timeout_app(): | |
| """Create a test app where the provider always times out.""" | |
| from fastapi import FastAPI | |
| class TimeoutProvider(MockProvider): | |
| async def complete(self, messages, tools=None, temperature=0.0, max_tokens=1024): | |
| raise ProviderTimeoutError("test timeout") | |
| app = FastAPI(title="agent-bench-timeout") | |
| registry = ToolRegistry() | |
| registry.register(FakeSearchTool()) | |
| provider = TimeoutProvider() | |
| orchestrator = Orchestrator(provider=provider, registry=registry, max_iterations=1) | |
| app.state.orchestrator = orchestrator | |
| app.state.store = HybridStore(dimension=384) | |
| app.state.config = AppConfig(provider=ProviderConfig(default="mock")) | |
| app.state.system_prompt = "You are a test assistant." | |
| app.state.start_time = time.time() | |
| app.state.metrics = MetricsCollector() | |
| app.add_middleware(RequestMiddleware) | |
| from agent_bench.serving.routes import router | |
| app.include_router(router) | |
| return app | |
| def test_app(): | |
| return _make_test_app() | |
| def timeout_app(): | |
| return _make_timeout_app() | |
| class TestAskEndpoint: | |
| async def test_valid_question_returns_200(self, test_app): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=test_app), base_url="http://test" | |
| ) as client: | |
| response = await client.post("/ask", json={"question": "How do path parameters work?"}) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "answer" in data | |
| assert "sources" in data | |
| assert "metadata" in data | |
| assert len(data["answer"]) > 0 | |
| assert data["metadata"]["request_id"] | |
| assert data["metadata"]["provider"] == "mock" | |
| assert data["metadata"]["model"] == "mock-1" | |
| async def test_empty_question_returns_422(self, test_app): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=test_app), base_url="http://test" | |
| ) as client: | |
| response = await client.post("/ask", json={"question": ""}) | |
| assert response.status_code == 422 | |
| async def test_missing_question_returns_422(self, test_app): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=test_app), base_url="http://test" | |
| ) as client: | |
| response = await client.post("/ask", json={}) | |
| assert response.status_code == 422 | |
| class TestHealthEndpoint: | |
| async def test_returns_health_response(self, test_app): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=test_app), base_url="http://test" | |
| ) as client: | |
| response = await client.get("/health") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["status"] in ("healthy", "degraded") | |
| assert "vector_store_chunks" in data | |
| assert "provider_available" in data | |
| assert "uptime_seconds" in data | |
| class TestMetricsEndpoint: | |
| async def test_returns_metrics_response(self, test_app): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=test_app), base_url="http://test" | |
| ) as client: | |
| response = await client.get("/metrics") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "requests_total" in data | |
| assert "latency_p50_ms" in data | |
| assert "latency_p95_ms" in data | |
| assert "errors_total" in data | |
| assert "avg_cost_per_query_usd" in data | |
| async def test_prometheus_endpoint_returns_text_exposition(self, test_app): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=test_app), base_url="http://test" | |
| ) as client: | |
| response = await client.get("/metrics/prometheus") | |
| assert response.status_code == 200 | |
| assert "text/plain" in response.headers["content-type"] | |
| body = response.text | |
| assert "# TYPE agent_bench_requests_total counter" in body | |
| assert "agent_bench_requests_total " in body | |
| assert "# TYPE agent_bench_latency_p95_ms gauge" in body | |
| assert "agent_bench_latency_p95_ms " in body | |
| assert "# TYPE agent_bench_errors_total counter" in body | |
| class TestHealthCheckProbesProvider: | |
| async def test_healthy_when_provider_health_check_passes(self, test_app): | |
| """MockProvider.health_check() returns True (default), so status=healthy.""" | |
| async with AsyncClient( | |
| transport=ASGITransport(app=test_app), base_url="http://test" | |
| ) as client: | |
| response = await client.get("/health") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["status"] == "healthy" | |
| assert data["provider_available"] is True | |
| async def test_degraded_when_provider_health_check_fails(self): | |
| """Provider whose health_check() returns False -> status=degraded.""" | |
| from fastapi import FastAPI | |
| class UnhealthyProvider(MockProvider): | |
| async def health_check(self) -> bool: | |
| return False | |
| app = FastAPI() | |
| registry = ToolRegistry() | |
| registry.register(FakeSearchTool()) | |
| provider = UnhealthyProvider() | |
| orchestrator = Orchestrator(provider=provider, registry=registry, max_iterations=1) | |
| app.state.orchestrator = orchestrator | |
| app.state.store = HybridStore(dimension=384) | |
| app.state.config = AppConfig(provider=ProviderConfig(default="mock")) | |
| app.state.system_prompt = "test" | |
| app.state.start_time = time.time() | |
| app.state.metrics = MetricsCollector() | |
| app.add_middleware(RequestMiddleware) | |
| from agent_bench.serving.routes import router | |
| app.include_router(router) | |
| async with AsyncClient( | |
| transport=ASGITransport(app=app), base_url="http://test" | |
| ) as client: | |
| response = await client.get("/health") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["status"] == "degraded" | |
| assert data["provider_available"] is False | |
| async def test_degraded_when_provider_health_check_raises(self): | |
| """Provider whose health_check() raises -> status=degraded.""" | |
| from fastapi import FastAPI | |
| class CrashingProvider(MockProvider): | |
| async def health_check(self) -> bool: | |
| raise ConnectionError("upstream unreachable") | |
| app = FastAPI() | |
| registry = ToolRegistry() | |
| registry.register(FakeSearchTool()) | |
| provider = CrashingProvider() | |
| orchestrator = Orchestrator(provider=provider, registry=registry, max_iterations=1) | |
| app.state.orchestrator = orchestrator | |
| app.state.store = HybridStore(dimension=384) | |
| app.state.config = AppConfig(provider=ProviderConfig(default="mock")) | |
| app.state.system_prompt = "test" | |
| app.state.start_time = time.time() | |
| app.state.metrics = MetricsCollector() | |
| app.add_middleware(RequestMiddleware) | |
| from agent_bench.serving.routes import router | |
| app.include_router(router) | |
| async with AsyncClient( | |
| transport=ASGITransport(app=app), base_url="http://test" | |
| ) as client: | |
| response = await client.get("/health") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["status"] == "degraded" | |
| assert data["provider_available"] is False | |
| class TestMiddleware: | |
| async def test_request_id_header(self, test_app): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=test_app), base_url="http://test" | |
| ) as client: | |
| response = await client.get("/health") | |
| assert "x-request-id" in response.headers | |
| # UUID format: 8-4-4-4-12 hex chars | |
| request_id = response.headers["x-request-id"] | |
| assert len(request_id) == 36 | |
| async def test_provider_timeout_returns_504(self, timeout_app): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=timeout_app), base_url="http://test" | |
| ) as client: | |
| response = await client.post("/ask", json={"question": "This will timeout"}) | |
| assert response.status_code == 504 | |
| data = response.json() | |
| assert "request_id" in data | |
| assert "x-request-id" in response.headers | |
| # --- Rate limiting tests --- | |
| def _make_rate_limited_app(rpm: int = 3): | |
| """Create a test app with rate limiting enabled.""" | |
| from fastapi import FastAPI | |
| app = FastAPI(title="agent-bench-ratelimit") | |
| registry = ToolRegistry() | |
| registry.register(FakeSearchTool()) | |
| registry.register(CalculatorTool()) | |
| provider = MockProvider() | |
| orchestrator = Orchestrator(provider=provider, registry=registry, max_iterations=3) | |
| app.state.orchestrator = orchestrator | |
| app.state.store = HybridStore(dimension=384) | |
| app.state.config = AppConfig(provider=ProviderConfig(default="mock")) | |
| app.state.system_prompt = "You are a test assistant." | |
| app.state.start_time = time.time() | |
| app.state.metrics = MetricsCollector() | |
| app.add_middleware(RequestMiddleware) | |
| app.add_middleware(RateLimitMiddleware, requests_per_minute=rpm) | |
| from agent_bench.serving.routes import router | |
| app.include_router(router) | |
| return app | |
| def rate_limited_app(): | |
| return _make_rate_limited_app(rpm=3) | |
| class TestRateLimiting: | |
| async def test_allows_normal_traffic(self, rate_limited_app): | |
| """Requests within the limit on /ask all succeed.""" | |
| async with AsyncClient( | |
| transport=ASGITransport(app=rate_limited_app), base_url="http://test" | |
| ) as client: | |
| for _ in range(3): | |
| response = await client.post("/ask", json={"question": "test"}) | |
| assert response.status_code == 200 | |
| async def test_blocks_excess(self, rate_limited_app): | |
| """Request beyond the limit gets 429.""" | |
| async with AsyncClient( | |
| transport=ASGITransport(app=rate_limited_app), base_url="http://test" | |
| ) as client: | |
| # Use up the quota | |
| for _ in range(3): | |
| await client.post("/ask", json={"question": "test"}) | |
| # Next request should be blocked | |
| response = await client.post("/ask", json={"question": "test"}) | |
| assert response.status_code == 429 | |
| async def test_retry_after_header(self, rate_limited_app): | |
| """429 response includes Retry-After header.""" | |
| async with AsyncClient( | |
| transport=ASGITransport(app=rate_limited_app), base_url="http://test" | |
| ) as client: | |
| # Exhaust quota on non-exempt path | |
| for _ in range(3): | |
| await client.post("/ask", json={"question": "test"}) | |
| response = await client.post("/ask", json={"question": "test"}) | |
| assert response.status_code == 429 | |
| assert "retry-after" in response.headers | |
| assert int(response.headers["retry-after"]) > 0 | |
| async def test_health_exempt(self): | |
| """Health endpoint is never rate limited.""" | |
| app = _make_rate_limited_app(rpm=2) | |
| async with AsyncClient( | |
| transport=ASGITransport(app=app), base_url="http://test" | |
| ) as client: | |
| # Exhaust quota on non-exempt path | |
| for _ in range(2): | |
| await client.post("/ask", json={"question": "test"}) | |
| # Health should still work | |
| response = await client.get("/health") | |
| assert response.status_code == 200 | |
| # But another ask should be blocked | |
| response = await client.post("/ask", json={"question": "test"}) | |
| assert response.status_code == 429 | |
| async def test_metrics_exempt(self): | |
| """Metrics endpoint is never rate limited.""" | |
| app = _make_rate_limited_app(rpm=2) | |
| async with AsyncClient( | |
| transport=ASGITransport(app=app), base_url="http://test" | |
| ) as client: | |
| # Exhaust quota on non-exempt path | |
| for _ in range(2): | |
| await client.post("/ask", json={"question": "test"}) | |
| # Metrics should still work | |
| response = await client.get("/metrics") | |
| assert response.status_code == 200 | |
| def test_per_ip_isolation(self): | |
| """Two different client IPs get independent quotas. | |
| ASGI test clients share one IP, so we test the middleware's | |
| per-IP window tracking directly. | |
| """ | |
| middleware = RateLimitMiddleware( | |
| app=None, requests_per_minute=2, # type: ignore[arg-type] | |
| ) | |
| now = time.time() | |
| # IP 1 exhausts quota | |
| middleware.windows["10.0.0.1"] = [now, now] | |
| # IP 2 is fresh | |
| middleware.windows["10.0.0.2"] = [now] | |
| # IP 1 should be at limit | |
| assert len(middleware.windows["10.0.0.1"]) >= 2 | |
| # IP 2 should still have room | |
| assert len(middleware.windows["10.0.0.2"]) < 2 | |
| # IPs don't share state | |
| assert "10.0.0.1" in middleware.windows | |
| assert "10.0.0.2" in middleware.windows | |
| assert middleware.windows["10.0.0.1"] != middleware.windows["10.0.0.2"] | |
| async def test_window_resets_after_60s(self): | |
| """Sliding window prunes old timestamps, restoring quota after 60s.""" | |
| from unittest.mock import patch | |
| app = _make_rate_limited_app(rpm=2) | |
| fake_time = time.time() | |
| with patch("agent_bench.serving.middleware.time") as mock_time: | |
| mock_time.time.return_value = fake_time | |
| async with AsyncClient( | |
| transport=ASGITransport(app=app), base_url="http://test" | |
| ) as client: | |
| # Exhaust quota at t=0 | |
| for _ in range(2): | |
| resp = await client.post("/ask", json={"question": "test"}) | |
| assert resp.status_code == 200 | |
| # Blocked at t=0 | |
| resp = await client.post("/ask", json={"question": "test"}) | |
| assert resp.status_code == 429 | |
| # Advance time past the 60s window | |
| mock_time.time.return_value = fake_time + 61 | |
| # Quota should be restored β old timestamps pruned | |
| resp = await client.post("/ask", json={"question": "test"}) | |
| assert resp.status_code == 200 | |
| # --- Streaming tests --- | |
| class TestStreaming: | |
| async def test_stream_endpoint_returns_sse(self, test_app): | |
| """Content-type is text/event-stream.""" | |
| async with AsyncClient( | |
| transport=ASGITransport(app=test_app), base_url="http://test" | |
| ) as client: | |
| response = await client.post( | |
| "/ask/stream", json={"question": "How do path parameters work?"} | |
| ) | |
| assert response.status_code == 200 | |
| assert "text/event-stream" in response.headers["content-type"] | |
| async def test_stream_events_ordered(self, test_app): | |
| """Legacy event sequence preserved: sources β chunk* β done.""" | |
| import json as json_mod | |
| async with AsyncClient( | |
| transport=ASGITransport(app=test_app), base_url="http://test" | |
| ) as client: | |
| response = await client.post( | |
| "/ask/stream", json={"question": "How do path parameters work?"} | |
| ) | |
| all_events = [] | |
| for line in response.text.strip().split("\n"): | |
| if line.startswith("data: "): | |
| all_events.append(json_mod.loads(line[6:])) | |
| # Filter to legacy event types only (stage events are additive) | |
| legacy_types = ("sources", "chunk", "done", "_orchestrator_done") | |
| legacy = [e for e in all_events if e["type"] in legacy_types] | |
| assert len(legacy) >= 3 # at least sources + 1 chunk + done | |
| assert legacy[0]["type"] == "sources" | |
| assert legacy[-1]["type"] in ("done", "_orchestrator_done") | |
| assert all(e["type"] == "chunk" for e in legacy[1:-1]) | |
| async def test_stream_chunks_assemble(self, test_app): | |
| """Concatenating chunks produces coherent text.""" | |
| import json as json_mod | |
| async with AsyncClient( | |
| transport=ASGITransport(app=test_app), base_url="http://test" | |
| ) as client: | |
| response = await client.post( | |
| "/ask/stream", json={"question": "test"} | |
| ) | |
| chunks = [] | |
| for line in response.text.strip().split("\n"): | |
| if line.startswith("data: "): | |
| event = json_mod.loads(line[6:]) | |
| if event["type"] == "chunk": | |
| chunks.append(event["content"]) | |
| full_text = "".join(chunks) | |
| assert len(full_text) > 0 | |
| assert "FastAPI" in full_text # MockProvider mentions FastAPI | |
| async def test_stream_emits_single_answer_chunk(self, test_app): | |
| """Stream emits the complete answer as a single chunk (no redundant LLM call).""" | |
| import json as json_mod | |
| async with AsyncClient( | |
| transport=ASGITransport(app=test_app), base_url="http://test" | |
| ) as client: | |
| response = await client.post( | |
| "/ask/stream", json={"question": "test"} | |
| ) | |
| chunks = [ | |
| json_mod.loads(line[6:]) | |
| for line in response.text.strip().split("\n") | |
| if line.startswith("data: ") | |
| and json_mod.loads(line[6:])["type"] == "chunk" | |
| ] | |
| assert len(chunks) == 1 | |
| assert len(chunks[0]["content"]) > 0 | |
| async def test_non_streaming_unchanged(self, test_app): | |
| """POST /ask still works identically.""" | |
| async with AsyncClient( | |
| transport=ASGITransport(app=test_app), base_url="http://test" | |
| ) as client: | |
| response = await client.post( | |
| "/ask", json={"question": "How do path parameters work?"} | |
| ) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "answer" in data | |
| assert "sources" in data | |