Spaces:
Running
Running
File size: 2,590 Bytes
5539271 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | """Tests for the in-memory rate limiter middleware."""
from __future__ import annotations
from unittest.mock import patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from infra.rate_limiter import RateLimiterMiddleware, _ClientBucket
class TestClientBucket:
def test_count_recent_filters_old_entries(self):
bucket = _ClientBucket(timestamps=[1.0, 2.0, 3.0, 10.0])
count = bucket.count_recent(window=5.0, now=12.0)
assert count == 1 # only 10.0 is within [7.0, 12.0]
def test_count_recent_keeps_all_when_within_window(self):
bucket = _ClientBucket(timestamps=[10.0, 11.0, 12.0])
count = bucket.count_recent(window=60.0, now=15.0)
assert count == 3
def test_add(self):
bucket = _ClientBucket()
bucket.add(1.0)
bucket.add(2.0)
assert len(bucket.timestamps) == 2
@pytest.fixture
def limited_app():
"""FastAPI app with a very low rate limit for testing."""
app = FastAPI()
app.add_middleware(
RateLimiterMiddleware,
requests_per_window=3,
window_seconds=60,
exclude_paths=("/health",),
)
@app.get("/test")
def test_endpoint():
return {"ok": True}
@app.get("/health")
def health():
return {"status": "ok"}
return app
@pytest.fixture
def client(limited_app):
return TestClient(limited_app)
class TestRateLimiterMiddleware:
def test_allows_requests_under_limit(self, client):
for _ in range(3):
resp = client.get("/test")
assert resp.status_code == 200
def test_blocks_requests_over_limit(self, client):
for _ in range(3):
client.get("/test")
resp = client.get("/test")
assert resp.status_code == 429
assert resp.json()["detail"] == "Too many requests"
assert "Retry-After" in resp.headers
def test_health_excluded_from_limit(self, client):
# Exhaust the limit
for _ in range(3):
client.get("/test")
# Health should still work
resp = client.get("/health")
assert resp.status_code == 200
def test_window_resets(self, client):
"""After the window expires, requests should be allowed again."""
for _ in range(3):
client.get("/test")
assert client.get("/test").status_code == 429
# Simulate time passing beyond the window
with patch("time.monotonic", return_value=1e12):
resp = client.get("/test")
assert resp.status_code == 200
|