Nomearod Claude Opus 4.6 (1M context) commited on
Commit
871820a
·
1 Parent(s): 2fc13b5

feat: add provider retry with backoff and API rate limiting

Browse files

Part A: OpenAI 429 errors trigger exponential backoff (1s, 2s, 4s)
before raising ProviderRateLimitError. Retry wraps the raw openai call
inside the existing error translation — not outside it.

Part B: In-memory sliding window rate limiter (10 RPM per IP default).
/health and /metrics exempt. 429 response with Retry-After header.

- RetryConfig added to AppConfig (max_retries, base_delay, max_delay)
- rate_limit_rpm added to ServingConfig
- RateLimitMiddleware registered in app.py
- 8 new tests (117 total), lint + types clean
- DECISIONS.md entries for both

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

DECISIONS.md CHANGED
@@ -168,6 +168,32 @@ than a computed subset. The reranker's `top_k` handles truncation.
168
  This is simpler and more robust than computing an input size from
169
  per-system candidate counts.
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  ## Why ranked_sources separate from deduplicated sources?
172
 
173
  The deduplicated `sources` list in `AgentResponse` is for the API
 
168
  This is simpler and more robust than computing an input size from
169
  per-system candidate counts.
170
 
171
+ ## Why provider retry with exponential backoff
172
+
173
+ OpenAI returns 429 (rate limit) errors under load. Without retry logic,
174
+ a single 429 causes a user-visible failure. We add exponential backoff:
175
+ attempt after 1s, 2s, 4s. After 3 retries, raise `ProviderRateLimitError`
176
+ so the middleware returns a clear 503.
177
+
178
+ The retry wraps the raw `openai.RateLimitError` — it must fire BEFORE
179
+ the error gets translated to `ProviderRateLimitError`, otherwise retry
180
+ logic is dead code. Other errors (400, 401, timeout) fail immediately.
181
+
182
+ ## Why in-memory API rate limiting
183
+
184
+ A public-facing API needs abuse protection. We use a simple in-memory
185
+ sliding window limiter: 10 requests/minute per IP. Sufficient for a
186
+ demo deployment; a production system would use Redis.
187
+
188
+ Known limitation: the per-IP dict grows without bound across distinct
189
+ IPs. Acceptable for Fly.io with auto-stop (memory resets). If running
190
+ continuously under bot traffic, add a periodic sweep or switch to a
191
+ TTL-based structure.
192
+
193
+ Design choices:
194
+ - `/health` and `/metrics` exempt: monitoring should never be rate-limited.
195
+ - `Retry-After` header: follows HTTP 429 spec, lets clients back off.
196
+
197
  ## Why ranked_sources separate from deduplicated sources?
198
 
199
  The deduplicated `sources` list in `AgentResponse` is for the API
agent_bench/core/config.py CHANGED
@@ -53,6 +53,12 @@ class RAGConfig(BaseModel):
53
  refusal_threshold: float = 0.0 # 0.0 = disabled (V1 behavior)
54
 
55
 
 
 
 
 
 
 
56
  class EmbeddingConfig(BaseModel):
57
  model: str = "all-MiniLM-L6-v2"
58
  cache_dir: str = ".cache/embeddings"
@@ -62,6 +68,7 @@ class ServingConfig(BaseModel):
62
  host: str = "0.0.0.0"
63
  port: int = 8000
64
  request_timeout_seconds: int = 30
 
65
 
66
 
67
  class EvaluationConfig(BaseModel):
@@ -73,6 +80,7 @@ class AppConfig(BaseModel):
73
  agent: AgentConfig = AgentConfig()
74
  provider: ProviderConfig = ProviderConfig()
75
  rag: RAGConfig = RAGConfig()
 
76
  embedding: EmbeddingConfig = EmbeddingConfig()
77
  serving: ServingConfig = ServingConfig()
78
  evaluation: EvaluationConfig = EvaluationConfig()
 
53
  refusal_threshold: float = 0.0 # 0.0 = disabled (V1 behavior)
54
 
55
 
56
+ class RetryConfig(BaseModel):
57
+ max_retries: int = 3
58
+ base_delay: float = 1.0 # seconds
59
+ max_delay: float = 8.0 # cap for exponential backoff
60
+
61
+
62
  class EmbeddingConfig(BaseModel):
63
  model: str = "all-MiniLM-L6-v2"
64
  cache_dir: str = ".cache/embeddings"
 
68
  host: str = "0.0.0.0"
69
  port: int = 8000
70
  request_timeout_seconds: int = 30
71
+ rate_limit_rpm: int = 10 # requests per minute per IP
72
 
73
 
74
  class EvaluationConfig(BaseModel):
 
80
  agent: AgentConfig = AgentConfig()
81
  provider: ProviderConfig = ProviderConfig()
82
  rag: RAGConfig = RAGConfig()
83
+ retry: RetryConfig = RetryConfig()
84
  embedding: EmbeddingConfig = EmbeddingConfig()
85
  serving: ServingConfig = ServingConfig()
86
  evaluation: EvaluationConfig = EvaluationConfig()
agent_bench/core/provider.py CHANGED
@@ -2,10 +2,13 @@
2
 
3
  from __future__ import annotations
4
 
 
5
  import json
6
  import time
7
  from abc import ABC, abstractmethod
8
 
 
 
9
  from agent_bench.core.config import AppConfig, load_config
10
  from agent_bench.core.types import (
11
  CompletionResponse,
@@ -16,6 +19,8 @@ from agent_bench.core.types import (
16
  ToolDefinition,
17
  )
18
 
 
 
19
 
20
  class ProviderTimeoutError(Exception):
21
  """Raised when the LLM provider times out."""
@@ -173,7 +178,7 @@ class OpenAIProvider(LLMProvider):
173
  temperature: float = 0.0,
174
  max_tokens: int = 1024,
175
  ) -> CompletionResponse:
176
- from openai import APITimeoutError
177
 
178
  formatted_messages = format_messages_openai(messages)
179
  kwargs: dict = {
@@ -186,15 +191,30 @@ class OpenAIProvider(LLMProvider):
186
  kwargs["tools"] = self.format_tools(tools)
187
  kwargs["tool_choice"] = "auto"
188
 
 
189
  start = time.perf_counter()
190
- try:
191
- response = await self.client.chat.completions.create(**kwargs)
192
- except APITimeoutError as e:
193
- raise ProviderTimeoutError(f"OpenAI timed out: {e}") from e
194
- except Exception as e:
195
- if "insufficient_quota" in str(e) or "rate_limit" in str(e).lower():
196
- raise ProviderRateLimitError(f"OpenAI rate limit / quota: {e}") from e
197
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  latency_ms = (time.perf_counter() - start) * 1000
199
 
200
  choice = response.choices[0]
 
2
 
3
  from __future__ import annotations
4
 
5
+ import asyncio
6
  import json
7
  import time
8
  from abc import ABC, abstractmethod
9
 
10
+ import structlog
11
+
12
  from agent_bench.core.config import AppConfig, load_config
13
  from agent_bench.core.types import (
14
  CompletionResponse,
 
19
  ToolDefinition,
20
  )
21
 
22
+ log = structlog.get_logger()
23
+
24
 
25
  class ProviderTimeoutError(Exception):
26
  """Raised when the LLM provider times out."""
 
178
  temperature: float = 0.0,
179
  max_tokens: int = 1024,
180
  ) -> CompletionResponse:
181
+ from openai import APITimeoutError, RateLimitError
182
 
183
  formatted_messages = format_messages_openai(messages)
184
  kwargs: dict = {
 
191
  kwargs["tools"] = self.format_tools(tools)
192
  kwargs["tool_choice"] = "auto"
193
 
194
+ retry_cfg = self.config.retry
195
  start = time.perf_counter()
196
+
197
+ for attempt in range(retry_cfg.max_retries + 1):
198
+ try:
199
+ response = await self.client.chat.completions.create(**kwargs)
200
+ break # success
201
+ except RateLimitError as e:
202
+ if attempt == retry_cfg.max_retries:
203
+ log.error("provider_rate_limited",
204
+ attempts=attempt + 1, error=str(e))
205
+ raise ProviderRateLimitError(
206
+ f"Rate limited after {retry_cfg.max_retries} retries: {e}"
207
+ ) from e
208
+ wait = min(
209
+ retry_cfg.base_delay * (2 ** attempt),
210
+ retry_cfg.max_delay,
211
+ )
212
+ log.warning("provider_retry",
213
+ attempt=attempt + 1, wait_seconds=wait, error=str(e))
214
+ await asyncio.sleep(wait)
215
+ except APITimeoutError as e:
216
+ raise ProviderTimeoutError(f"OpenAI timed out: {e}") from e
217
+
218
  latency_ms = (time.perf_counter() - start) * 1000
219
 
220
  choice = response.choices[0]
agent_bench/serving/app.py CHANGED
@@ -13,7 +13,7 @@ from agent_bench.core.provider import create_provider
13
  from agent_bench.rag.embedder import Embedder
14
  from agent_bench.rag.retriever import Retriever
15
  from agent_bench.rag.store import HybridStore
16
- from agent_bench.serving.middleware import MetricsCollector, RequestMiddleware
17
  from agent_bench.serving.routes import router
18
  from agent_bench.tools.calculator import CalculatorTool
19
  from agent_bench.tools.registry import ToolRegistry
@@ -99,8 +99,9 @@ def create_app(config: AppConfig | None = None) -> FastAPI:
99
  app.state.start_time = time.time()
100
  app.state.metrics = metrics
101
 
102
- # Middleware and routes
103
  app.add_middleware(RequestMiddleware)
 
104
  app.include_router(router)
105
 
106
  return app
 
13
  from agent_bench.rag.embedder import Embedder
14
  from agent_bench.rag.retriever import Retriever
15
  from agent_bench.rag.store import HybridStore
16
+ from agent_bench.serving.middleware import MetricsCollector, RateLimitMiddleware, RequestMiddleware
17
  from agent_bench.serving.routes import router
18
  from agent_bench.tools.calculator import CalculatorTool
19
  from agent_bench.tools.registry import ToolRegistry
 
99
  app.state.start_time = time.time()
100
  app.state.metrics = metrics
101
 
102
+ # Middleware and routes (order matters: rate limit checked first)
103
  app.add_middleware(RequestMiddleware)
104
+ app.add_middleware(RateLimitMiddleware, requests_per_minute=config.serving.rate_limit_rpm)
105
  app.include_router(router)
106
 
107
  return app
agent_bench/serving/middleware.py CHANGED
@@ -1,10 +1,10 @@
1
- """Request middleware: ID generation, logging, error handling, metrics."""
2
 
3
  from __future__ import annotations
4
 
5
  import time
6
  import uuid
7
- from collections import deque
8
 
9
  import structlog
10
  from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
@@ -47,6 +47,44 @@ class MetricsCollector:
47
  return self.total_cost_usd / self.requests_total
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  class RequestMiddleware(BaseHTTPMiddleware):
51
  """Adds request ID, timing, structured logging, and error handling."""
52
 
 
1
+ """Request middleware: ID generation, logging, error handling, metrics, rate limiting."""
2
 
3
  from __future__ import annotations
4
 
5
  import time
6
  import uuid
7
+ from collections import defaultdict, deque
8
 
9
  import structlog
10
  from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
 
47
  return self.total_cost_usd / self.requests_total
48
 
49
 
50
+ class RateLimitMiddleware(BaseHTTPMiddleware):
51
+ """In-memory sliding window rate limiter, per client IP."""
52
+
53
+ EXEMPT_PATHS = {"/health", "/metrics"}
54
+
55
+ def __init__(self, app: object, requests_per_minute: int = 10) -> None:
56
+ super().__init__(app) # type: ignore[arg-type]
57
+ self.rpm = requests_per_minute
58
+ self.windows: dict[str, list[float]] = defaultdict(list)
59
+
60
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
61
+ if request.url.path in self.EXEMPT_PATHS:
62
+ return await call_next(request)
63
+
64
+ client_ip = request.client.host if request.client else "unknown"
65
+ now = time.time()
66
+ window_start = now - 60
67
+
68
+ # Prune timestamps outside the window
69
+ self.windows[client_ip] = [
70
+ t for t in self.windows[client_ip] if t > window_start
71
+ ]
72
+
73
+ if len(self.windows[client_ip]) >= self.rpm:
74
+ retry_after = max(1, int(60 - (now - self.windows[client_ip][0])))
75
+ logger.warning("rate_limited",
76
+ client_ip=client_ip,
77
+ requests_in_window=len(self.windows[client_ip]))
78
+ return JSONResponse(
79
+ status_code=429,
80
+ content={"error": "Rate limit exceeded", "retry_after": retry_after},
81
+ headers={"Retry-After": str(retry_after)},
82
+ )
83
+
84
+ self.windows[client_ip].append(now)
85
+ return await call_next(request)
86
+
87
+
88
  class RequestMiddleware(BaseHTTPMiddleware):
89
  """Adds request ID, timing, structured logging, and error handling."""
90
 
tests/test_provider.py CHANGED
@@ -1,11 +1,20 @@
1
  """Tests for core types, config, and provider abstraction."""
2
 
 
 
3
  import pytest
4
 
5
- from agent_bench.core.config import AppConfig, ProviderConfig, load_config, load_task_config
 
 
 
 
 
 
6
  from agent_bench.core.provider import (
7
  AnthropicProvider,
8
  MockProvider,
 
9
  create_provider,
10
  format_messages_openai,
11
  format_tools_openai,
@@ -395,3 +404,168 @@ class TestProviderFactory:
395
  config = AppConfig(provider=ProviderConfig(default="unknown"))
396
  with pytest.raises(ValueError, match="Unknown provider"):
397
  create_provider(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Tests for core types, config, and provider abstraction."""
2
 
3
+ from unittest.mock import patch
4
+
5
  import pytest
6
 
7
+ from agent_bench.core.config import (
8
+ AppConfig,
9
+ ProviderConfig,
10
+ RetryConfig,
11
+ load_config,
12
+ load_task_config,
13
+ )
14
  from agent_bench.core.provider import (
15
  AnthropicProvider,
16
  MockProvider,
17
+ ProviderRateLimitError,
18
  create_provider,
19
  format_messages_openai,
20
  format_tools_openai,
 
404
  config = AppConfig(provider=ProviderConfig(default="unknown"))
405
  with pytest.raises(ValueError, match="Unknown provider"):
406
  create_provider(config)
407
+
408
+
409
+ # --- Retry logic ---
410
+
411
+
412
+ class TestProviderRetry:
413
+ """Tests for OpenAI provider retry with exponential backoff."""
414
+
415
+ MOCK_SUCCESS_RESPONSE = {
416
+ "id": "chatcmpl-retry",
417
+ "object": "chat.completion",
418
+ "created": 1234567890,
419
+ "model": "gpt-4o-mini",
420
+ "choices": [
421
+ {
422
+ "index": 0,
423
+ "message": {
424
+ "role": "assistant",
425
+ "content": "Success after retry.",
426
+ "tool_calls": None,
427
+ },
428
+ "finish_reason": "stop",
429
+ }
430
+ ],
431
+ "usage": {"prompt_tokens": 50, "completion_tokens": 10, "total_tokens": 60},
432
+ }
433
+
434
+ @pytest.mark.asyncio
435
+ async def test_retry_on_rate_limit(self, monkeypatch):
436
+ """Two failures then success — returns answer."""
437
+ monkeypatch.setenv("OPENAI_API_KEY", "test-key-fake")
438
+
439
+ import httpx
440
+ import respx
441
+
442
+ from agent_bench.core.provider import OpenAIProvider
443
+
444
+ config = AppConfig(
445
+ provider=ProviderConfig(default="openai"),
446
+ retry=RetryConfig(max_retries=3, base_delay=0.01, max_delay=0.1),
447
+ )
448
+ provider = OpenAIProvider(config)
449
+
450
+ call_count = 0
451
+
452
+ def side_effect(request):
453
+ nonlocal call_count
454
+ call_count += 1
455
+ if call_count <= 2:
456
+ return httpx.Response(429, json={"error": {"message": "Rate limit exceeded"}})
457
+ return httpx.Response(200, json=self.MOCK_SUCCESS_RESPONSE)
458
+
459
+ with respx.mock:
460
+ respx.post("https://api.openai.com/v1/chat/completions").mock(
461
+ side_effect=side_effect
462
+ )
463
+ from agent_bench.core.types import Message, Role
464
+
465
+ response = await provider.complete(
466
+ [Message(role=Role.USER, content="test")]
467
+ )
468
+
469
+ assert response.content == "Success after retry."
470
+ assert call_count == 3
471
+
472
+ @pytest.mark.asyncio
473
+ async def test_retry_exhausted(self, monkeypatch):
474
+ """All retries fail — raises ProviderRateLimitError."""
475
+ monkeypatch.setenv("OPENAI_API_KEY", "test-key-fake")
476
+
477
+ import httpx
478
+ import respx
479
+
480
+ from agent_bench.core.provider import OpenAIProvider
481
+
482
+ config = AppConfig(
483
+ provider=ProviderConfig(default="openai"),
484
+ retry=RetryConfig(max_retries=2, base_delay=0.01, max_delay=0.1),
485
+ )
486
+ provider = OpenAIProvider(config)
487
+
488
+ with respx.mock:
489
+ respx.post("https://api.openai.com/v1/chat/completions").mock(
490
+ return_value=httpx.Response(429, json={"error": {"message": "Rate limit"}})
491
+ )
492
+ from agent_bench.core.types import Message, Role
493
+
494
+ with pytest.raises(ProviderRateLimitError, match="Rate limited after"):
495
+ await provider.complete(
496
+ [Message(role=Role.USER, content="test")]
497
+ )
498
+
499
+ @pytest.mark.asyncio
500
+ async def test_no_retry_on_other_errors(self, monkeypatch):
501
+ """Non-rate-limit errors fail immediately without retry."""
502
+ monkeypatch.setenv("OPENAI_API_KEY", "test-key-fake")
503
+
504
+ import httpx
505
+ import respx
506
+
507
+ from agent_bench.core.provider import OpenAIProvider
508
+
509
+ config = AppConfig(
510
+ provider=ProviderConfig(default="openai"),
511
+ retry=RetryConfig(max_retries=3, base_delay=0.01, max_delay=0.1),
512
+ )
513
+ provider = OpenAIProvider(config)
514
+
515
+ call_count = 0
516
+
517
+ def side_effect(request):
518
+ nonlocal call_count
519
+ call_count += 1
520
+ return httpx.Response(400, json={"error": {"message": "Bad request"}})
521
+
522
+ with respx.mock:
523
+ respx.post("https://api.openai.com/v1/chat/completions").mock(
524
+ side_effect=side_effect
525
+ )
526
+ from agent_bench.core.types import Message, Role
527
+
528
+ with pytest.raises(Exception):
529
+ await provider.complete(
530
+ [Message(role=Role.USER, content="test")]
531
+ )
532
+
533
+ assert call_count == 1 # no retry
534
+
535
+ @pytest.mark.asyncio
536
+ async def test_retry_backoff_timing(self, monkeypatch):
537
+ """Verify exponential backoff delays between retries."""
538
+ monkeypatch.setenv("OPENAI_API_KEY", "test-key-fake")
539
+
540
+ import httpx
541
+ import respx
542
+
543
+ from agent_bench.core.provider import OpenAIProvider
544
+
545
+ config = AppConfig(
546
+ provider=ProviderConfig(default="openai"),
547
+ retry=RetryConfig(max_retries=3, base_delay=1.0, max_delay=8.0),
548
+ )
549
+ provider = OpenAIProvider(config)
550
+
551
+ sleep_calls: list[float] = []
552
+
553
+ async def mock_sleep(seconds):
554
+ sleep_calls.append(seconds)
555
+
556
+ with respx.mock, patch("asyncio.sleep", side_effect=mock_sleep):
557
+ respx.post("https://api.openai.com/v1/chat/completions").mock(
558
+ return_value=httpx.Response(429, json={"error": {"message": "Rate limit"}})
559
+ )
560
+ from agent_bench.core.types import Message, Role
561
+
562
+ with pytest.raises(ProviderRateLimitError):
563
+ await provider.complete(
564
+ [Message(role=Role.USER, content="test")]
565
+ )
566
+
567
+ # 3 retries: delays should be 1.0, 2.0, 4.0
568
+ assert len(sleep_calls) == 3
569
+ assert sleep_calls[0] == pytest.approx(1.0)
570
+ assert sleep_calls[1] == pytest.approx(2.0)
571
+ assert sleep_calls[2] == pytest.approx(4.0)
tests/test_serving.py CHANGED
@@ -11,7 +11,7 @@ from agent_bench.agents.orchestrator import Orchestrator
11
  from agent_bench.core.config import AppConfig, ProviderConfig
12
  from agent_bench.core.provider import MockProvider, ProviderTimeoutError
13
  from agent_bench.rag.store import HybridStore
14
- from agent_bench.serving.middleware import MetricsCollector, RequestMiddleware
15
  from agent_bench.tools.calculator import CalculatorTool
16
  from agent_bench.tools.registry import ToolRegistry
17
 
@@ -174,3 +174,96 @@ class TestMiddleware:
174
  data = response.json()
175
  assert "request_id" in data
176
  assert "x-request-id" in response.headers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from agent_bench.core.config import AppConfig, ProviderConfig
12
  from agent_bench.core.provider import MockProvider, ProviderTimeoutError
13
  from agent_bench.rag.store import HybridStore
14
+ from agent_bench.serving.middleware import MetricsCollector, RateLimitMiddleware, RequestMiddleware
15
  from agent_bench.tools.calculator import CalculatorTool
16
  from agent_bench.tools.registry import ToolRegistry
17
 
 
174
  data = response.json()
175
  assert "request_id" in data
176
  assert "x-request-id" in response.headers
177
+
178
+
179
+ # --- Rate limiting tests ---
180
+
181
+
182
+ def _make_rate_limited_app(rpm: int = 3):
183
+ """Create a test app with rate limiting enabled."""
184
+ from fastapi import FastAPI
185
+
186
+ app = FastAPI(title="agent-bench-ratelimit")
187
+
188
+ registry = ToolRegistry()
189
+ registry.register(FakeSearchTool())
190
+ registry.register(CalculatorTool())
191
+
192
+ provider = MockProvider()
193
+ orchestrator = Orchestrator(provider=provider, registry=registry, max_iterations=3)
194
+
195
+ app.state.orchestrator = orchestrator
196
+ app.state.store = HybridStore(dimension=384)
197
+ app.state.config = AppConfig(provider=ProviderConfig(default="mock"))
198
+ app.state.system_prompt = "You are a test assistant."
199
+ app.state.start_time = time.time()
200
+ app.state.metrics = MetricsCollector()
201
+
202
+ app.add_middleware(RequestMiddleware)
203
+ app.add_middleware(RateLimitMiddleware, requests_per_minute=rpm)
204
+
205
+ from agent_bench.serving.routes import router
206
+
207
+ app.include_router(router)
208
+ return app
209
+
210
+
211
+ @pytest.fixture
212
+ def rate_limited_app():
213
+ return _make_rate_limited_app(rpm=3)
214
+
215
+
216
+ class TestRateLimiting:
217
+ @pytest.mark.asyncio
218
+ async def test_allows_normal_traffic(self, rate_limited_app):
219
+ """Requests within the limit all succeed."""
220
+ async with AsyncClient(
221
+ transport=ASGITransport(app=rate_limited_app), base_url="http://test"
222
+ ) as client:
223
+ for _ in range(3):
224
+ response = await client.get("/health")
225
+ assert response.status_code == 200
226
+
227
+ @pytest.mark.asyncio
228
+ async def test_blocks_excess(self, rate_limited_app):
229
+ """Request beyond the limit gets 429."""
230
+ async with AsyncClient(
231
+ transport=ASGITransport(app=rate_limited_app), base_url="http://test"
232
+ ) as client:
233
+ # Use up the quota
234
+ for _ in range(3):
235
+ await client.post("/ask", json={"question": "test"})
236
+ # Next request should be blocked
237
+ response = await client.post("/ask", json={"question": "test"})
238
+ assert response.status_code == 429
239
+
240
+ @pytest.mark.asyncio
241
+ async def test_retry_after_header(self, rate_limited_app):
242
+ """429 response includes Retry-After header."""
243
+ async with AsyncClient(
244
+ transport=ASGITransport(app=rate_limited_app), base_url="http://test"
245
+ ) as client:
246
+ # Exhaust quota on non-exempt path
247
+ for _ in range(3):
248
+ await client.post("/ask", json={"question": "test"})
249
+ response = await client.post("/ask", json={"question": "test"})
250
+ assert response.status_code == 429
251
+ assert "retry-after" in response.headers
252
+ assert int(response.headers["retry-after"]) > 0
253
+
254
+ @pytest.mark.asyncio
255
+ async def test_health_exempt(self):
256
+ """Health endpoint is never rate limited."""
257
+ app = _make_rate_limited_app(rpm=2)
258
+ async with AsyncClient(
259
+ transport=ASGITransport(app=app), base_url="http://test"
260
+ ) as client:
261
+ # Exhaust quota on non-exempt path
262
+ for _ in range(2):
263
+ await client.post("/ask", json={"question": "test"})
264
+ # Health should still work
265
+ response = await client.get("/health")
266
+ assert response.status_code == 200
267
+ # But another ask should be blocked
268
+ response = await client.post("/ask", json={"question": "test"})
269
+ assert response.status_code == 429