| """ |
| Tests for Stage 84 β request correlation IDs + tracing log fields. |
| |
| Coverage: |
| 1. sanitize_upstream_id β accept valid, reject empty/whitespace/ |
| too-long/special-chars |
| 2. generate_request_id β 16 hex chars, two calls produce |
| different values |
| 3. set/clear_request_context + contextvar reads |
| 4. JsonLogFormatter auto-injects request_id + tenant_id when |
| set in context; omits when None |
| 5. Integration: X-Request-ID present in response (generated) |
| 6. Integration: upstream X-Request-ID preserved (valid) |
| 7. Integration: malicious upstream rejected β fresh generated |
| 8. Integration: different requests get different IDs |
| 9. Integration: authenticated request sets tenant_id in context |
| visible to handler logs |
| 10. Context cleared after request returns (no leakage) |
| """ |
| import json |
| import logging |
|
|
| import pytest |
|
|
| pytest.importorskip("fastapi") |
| pytest.importorskip("httpx") |
| from fastapi.testclient import TestClient |
|
|
| from infra import OrgStateService |
| from infra.api import create_app |
| from infra.api.request_context import ( |
| clear_request_context, |
| current_request_id, |
| current_tenant_id, |
| generate_request_id, |
| sanitize_upstream_id, |
| set_request_context, |
| set_tenant_context, |
| ) |
| from infra.deployment.observability import JsonLogFormatter |
|
|
| |
| |
| |
|
|
| def test_sanitize_accepts_uuid_hex(): |
| """32-char hex (uuid.hex format) β should pass.""" |
| assert sanitize_upstream_id( |
| "aabbccdd11223344aabbccdd11223344" |
| ) == "aabbccdd11223344aabbccdd11223344" |
|
|
|
|
| def test_sanitize_accepts_uuid_with_dashes(): |
| assert sanitize_upstream_id( |
| "aabbccdd-1122-3344-aabb-ccdd11223344" |
| ) == "aabbccdd-1122-3344-aabb-ccdd11223344" |
|
|
|
|
| def test_sanitize_accepts_cloudflare_ray(): |
| """CF-RAY format is letters+digits+dash; should pass.""" |
| assert sanitize_upstream_id("8a3f12cd6a17abcd-DUB") == \ |
| "8a3f12cd6a17abcd-DUB" |
|
|
|
|
| def test_sanitize_rejects_none_and_empty(): |
| assert sanitize_upstream_id(None) is None |
| assert sanitize_upstream_id("") is None |
| assert sanitize_upstream_id(" ") is None |
|
|
|
|
| def test_sanitize_rejects_too_long(): |
| """Defense against log injection β upstream sending 10KB |
| request ID would balloon every log line.""" |
| assert sanitize_upstream_id("a" * 129) is None |
| |
| assert sanitize_upstream_id("a" * 128) == "a" * 128 |
|
|
|
|
| def test_sanitize_rejects_special_chars(): |
| """Whitespace + newline + control chars would break log |
| grep + Sentry tag rules.""" |
| for bad in ("has space", "has\nnewline", "has;semi", |
| "has<lt", "has\"quote", "has\x00null"): |
| assert sanitize_upstream_id(bad) is None, \ |
| f"should have rejected {bad!r}" |
|
|
|
|
| def test_sanitize_strips_whitespace(): |
| """Browser/proxy might pad with spaces β accept the trimmed |
| value if it's valid.""" |
| assert sanitize_upstream_id(" abc123 ") == "abc123" |
|
|
|
|
| |
| |
| |
|
|
| def test_generate_request_id_is_16_hex(): |
| rid = generate_request_id() |
| assert len(rid) == 16 |
| assert all(c in "0123456789abcdef" for c in rid) |
|
|
|
|
| def test_generate_request_id_unique_per_call(): |
| """Birthday collision after ~2^32 calls β two calls should |
| differ with overwhelming probability.""" |
| ids = {generate_request_id() for _ in range(100)} |
| assert len(ids) == 100 |
|
|
|
|
| |
| |
| |
|
|
| def test_context_var_default_none(): |
| clear_request_context() |
| assert current_request_id() is None |
| assert current_tenant_id() is None |
|
|
|
|
| def test_set_and_clear_context(): |
| tokens = set_request_context("rid_abc", tenant_id="acme") |
| try: |
| assert current_request_id() == "rid_abc" |
| assert current_tenant_id() == "acme" |
| finally: |
| clear_request_context(tokens) |
| |
| assert current_request_id() is None |
| assert current_tenant_id() is None |
|
|
|
|
| def test_set_tenant_context_after_request_id(): |
| """Middleware sets request_id at the start; auth dependency |
| later upgrades the context with tenant_id.""" |
| tokens = set_request_context("rid_xyz") |
| try: |
| assert current_tenant_id() is None |
| set_tenant_context("globex") |
| assert current_tenant_id() == "globex" |
| assert current_request_id() == "rid_xyz" |
| finally: |
| clear_request_context(tokens) |
|
|
|
|
| |
| |
| |
|
|
| def test_log_omits_fields_when_context_empty(): |
| clear_request_context() |
| fmt = JsonLogFormatter() |
| record = logging.LogRecord( |
| name="test", level=logging.INFO, pathname="x", lineno=1, |
| msg="hello", args=(), exc_info=None, |
| ) |
| payload = json.loads(fmt.format(record)) |
| assert "request_id" not in payload |
| assert "tenant_id" not in payload |
|
|
|
|
| def test_log_includes_request_id_when_set(): |
| tokens = set_request_context("rid_logtest", tenant_id="acme") |
| try: |
| fmt = JsonLogFormatter() |
| record = logging.LogRecord( |
| name="test", level=logging.INFO, pathname="x", lineno=1, |
| msg="hello", args=(), exc_info=None, |
| ) |
| payload = json.loads(fmt.format(record)) |
| assert payload["request_id"] == "rid_logtest" |
| assert payload["tenant_id"] == "acme" |
| finally: |
| clear_request_context(tokens) |
|
|
|
|
| def test_log_includes_request_id_only_when_no_tenant(): |
| """e.g. middleware-emitted log lines (before auth) carry |
| request_id but no tenant_id yet.""" |
| tokens = set_request_context("rid_pre_auth", tenant_id=None) |
| try: |
| fmt = JsonLogFormatter() |
| record = logging.LogRecord( |
| name="test", level=logging.INFO, pathname="x", lineno=1, |
| msg="pre-auth", args=(), exc_info=None, |
| ) |
| payload = json.loads(fmt.format(record)) |
| assert payload["request_id"] == "rid_pre_auth" |
| assert "tenant_id" not in payload |
| finally: |
| clear_request_context(tokens) |
|
|
|
|
| |
| |
| |
|
|
| def _bootstrap(tmp_path): |
| dbfile = str(tmp_path / "rctx.sqlite3") |
| svc = OrgStateService(dbfile) |
| try: |
| svc.register_tenant("acme", "ACME") |
| keys = { |
| "acme_op": svc.create_api_key("acme", role="operator").raw, |
| } |
| finally: |
| svc.close() |
| return dbfile, keys |
|
|
|
|
| def _auth(k): |
| return {"Authorization": f"Bearer {k}"} |
|
|
|
|
| def test_response_includes_generated_request_id(tmp_path): |
| dbfile, _ = _bootstrap(tmp_path) |
| client = TestClient(create_app(dbfile)) |
| r = client.get("/health") |
| assert r.status_code == 200 |
| rid = r.headers.get("X-Request-ID") |
| assert rid is not None |
| assert len(rid) == 16 |
|
|
|
|
| def test_upstream_request_id_preserved(tmp_path): |
| """If the caller (or a trace-aware proxy) supplied a valid |
| X-Request-ID, we use it. Operator can trace across services |
| by passing the same ID at every hop.""" |
| dbfile, _ = _bootstrap(tmp_path) |
| client = TestClient(create_app(dbfile)) |
| r = client.get("/health", headers={"X-Request-ID": "trace-abc-123"}) |
| assert r.headers["X-Request-ID"] == "trace-abc-123" |
|
|
|
|
| def test_malicious_upstream_request_id_rejected(tmp_path): |
| """A 10KB X-Request-ID would balloon every subsequent log |
| line. We discard and generate fresh.""" |
| dbfile, _ = _bootstrap(tmp_path) |
| client = TestClient(create_app(dbfile)) |
| r = client.get("/health", headers={"X-Request-ID": "a" * 5000}) |
| rid = r.headers["X-Request-ID"] |
| assert rid != "a" * 5000 |
| assert len(rid) == 16 |
|
|
|
|
| def test_different_requests_get_different_ids(tmp_path): |
| dbfile, _ = _bootstrap(tmp_path) |
| client = TestClient(create_app(dbfile)) |
| ids = set() |
| for _ in range(10): |
| r = client.get("/health") |
| ids.add(r.headers["X-Request-ID"]) |
| |
| assert len(ids) == 10 |
|
|
|
|
| def test_context_visible_to_handler_via_request_state(tmp_path): |
| """A handler can read request.state.request_id (set by the |
| middleware) for explicit use cases β e.g. embedding the ID |
| in an error response. FastAPI needs the Request type |
| annotation to inject the request object.""" |
| from fastapi import Request |
| dbfile, _ = _bootstrap(tmp_path) |
| app = create_app(dbfile) |
|
|
| @app.get("/__rid_echo") |
| async def echo_rid(request: Request): |
| return {"echoed": request.state.request_id} |
|
|
| client = TestClient(app) |
| r = client.get("/__rid_echo", |
| headers={"X-Request-ID": "abc-trace-xyz"}) |
| assert r.json()["echoed"] == "abc-trace-xyz" |
|
|
|
|
| def test_log_during_authenticated_request_carries_tenant_id(tmp_path, caplog): |
| """End-to-end: an authenticated route's handler emits a log, |
| that log MUST include both request_id (from middleware) and |
| tenant_id (from auth dep). This is the headline value of |
| Stage 84 β a customer says 'I saw 500 at 14:23, here's my |
| request ID' and we can grep all log lines for that single |
| request including which tenant was hit.""" |
| dbfile, keys = _bootstrap(tmp_path) |
| app = create_app(dbfile) |
| captured: list = [] |
|
|
| @app.get("/__ctx_log") |
| async def emit_log(request): |
| import logging |
| logger = logging.getLogger("test.ctx") |
| |
| fmt = JsonLogFormatter() |
| record = logger.makeRecord( |
| "test.ctx", logging.INFO, __file__, 0, |
| "handler-fired", (), None, |
| ) |
| captured.append(fmt.format(record)) |
| return {"ok": True} |
|
|
| |
| |
| |
| |
| |
| from fastapi import Depends |
|
|
| from infra.api.auth import make_dependency |
| from infra.auth import ApiKey |
|
|
| |
| svc = app.state.svc |
| auth_dep = make_dependency(svc) |
|
|
| @app.get("/__authed_log") |
| async def emit_authed_log(key: ApiKey = Depends(auth_dep)): |
| fmt = JsonLogFormatter() |
| record = logging.LogRecord( |
| name="test.authed", level=logging.INFO, pathname="x", |
| lineno=1, msg="handler ran", args=(), exc_info=None, |
| ) |
| captured.append(fmt.format(record)) |
| return {"ok": True} |
|
|
| client = TestClient(app) |
| r = client.get("/__authed_log", headers=_auth(keys["acme_op"])) |
| assert r.status_code == 200 |
| assert len(captured) == 1 |
| payload = json.loads(captured[0]) |
| |
| |
| assert "request_id" in payload |
| assert payload["tenant_id"] == "acme" |
|
|
|
|
| def test_context_cleared_between_requests(tmp_path): |
| """The contextvar cleanup in the middleware's finally block |
| must actually fire β otherwise request N+1 sees request N's |
| tenant_id when called without auth.""" |
| from fastapi import Request |
| dbfile, keys = _bootstrap(tmp_path) |
| app = create_app(dbfile) |
|
|
| @app.get("/__leak_check") |
| async def show_ctx(request: Request): |
| return { |
| "rid": current_request_id(), |
| "tid": current_tenant_id(), |
| } |
|
|
| client = TestClient(app) |
| |
| |
| r1 = client.get("/tenants/acme", headers=_auth(keys["acme_op"])) |
| assert r1.status_code == 200 |
| |
| r2 = client.get("/__leak_check") |
| body = r2.json() |
| |
| |
| assert body["rid"] is not None |
| assert body["tid"] is None |
|
|