File size: 12,385 Bytes
d2d1903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
"""
Tests for Stage 84 β€” request correlation IDs + tracing log fields.

Coverage:
1. sanitize_upstream_id β€” accept valid, reject empty/whitespace/
   too-long/special-chars
2. generate_request_id β€” 16 hex chars, two calls produce
   different values
3. set/clear_request_context + contextvar reads
4. JsonLogFormatter auto-injects request_id + tenant_id when
   set in context; omits when None
5. Integration: X-Request-ID present in response (generated)
6. Integration: upstream X-Request-ID preserved (valid)
7. Integration: malicious upstream rejected β†’ fresh generated
8. Integration: different requests get different IDs
9. Integration: authenticated request sets tenant_id in context
   visible to handler logs
10. Context cleared after request returns (no leakage)
"""
import json
import logging

import pytest

pytest.importorskip("fastapi")
pytest.importorskip("httpx")
from fastapi.testclient import TestClient

from infra import OrgStateService
from infra.api import create_app
from infra.api.request_context import (
    clear_request_context,
    current_request_id,
    current_tenant_id,
    generate_request_id,
    sanitize_upstream_id,
    set_request_context,
    set_tenant_context,
)
from infra.deployment.observability import JsonLogFormatter

# =========================================================
# sanitize_upstream_id β€” unit
# =========================================================

def test_sanitize_accepts_uuid_hex():
    """32-char hex (uuid.hex format) β€” should pass."""
    assert sanitize_upstream_id(
        "aabbccdd11223344aabbccdd11223344"
    ) == "aabbccdd11223344aabbccdd11223344"


def test_sanitize_accepts_uuid_with_dashes():
    assert sanitize_upstream_id(
        "aabbccdd-1122-3344-aabb-ccdd11223344"
    ) == "aabbccdd-1122-3344-aabb-ccdd11223344"


def test_sanitize_accepts_cloudflare_ray():
    """CF-RAY format is letters+digits+dash; should pass."""
    assert sanitize_upstream_id("8a3f12cd6a17abcd-DUB") == \
        "8a3f12cd6a17abcd-DUB"


def test_sanitize_rejects_none_and_empty():
    assert sanitize_upstream_id(None) is None
    assert sanitize_upstream_id("") is None
    assert sanitize_upstream_id("   ") is None


def test_sanitize_rejects_too_long():
    """Defense against log injection β€” upstream sending 10KB
    request ID would balloon every log line."""
    assert sanitize_upstream_id("a" * 129) is None
    # exactly at the limit is OK
    assert sanitize_upstream_id("a" * 128) == "a" * 128


def test_sanitize_rejects_special_chars():
    """Whitespace + newline + control chars would break log
    grep + Sentry tag rules."""
    for bad in ("has space", "has\nnewline", "has;semi",
                 "has<lt", "has\"quote", "has\x00null"):
        assert sanitize_upstream_id(bad) is None, \
            f"should have rejected {bad!r}"


def test_sanitize_strips_whitespace():
    """Browser/proxy might pad with spaces β€” accept the trimmed
    value if it's valid."""
    assert sanitize_upstream_id("  abc123  ") == "abc123"


# =========================================================
# generate_request_id
# =========================================================

def test_generate_request_id_is_16_hex():
    rid = generate_request_id()
    assert len(rid) == 16
    assert all(c in "0123456789abcdef" for c in rid)


def test_generate_request_id_unique_per_call():
    """Birthday collision after ~2^32 calls β€” two calls should
    differ with overwhelming probability."""
    ids = {generate_request_id() for _ in range(100)}
    assert len(ids) == 100


# =========================================================
# contextvars β€” direct API
# =========================================================

def test_context_var_default_none():
    clear_request_context()
    assert current_request_id() is None
    assert current_tenant_id() is None


def test_set_and_clear_context():
    tokens = set_request_context("rid_abc", tenant_id="acme")
    try:
        assert current_request_id() == "rid_abc"
        assert current_tenant_id() == "acme"
    finally:
        clear_request_context(tokens)
    # after clear, back to None
    assert current_request_id() is None
    assert current_tenant_id() is None


def test_set_tenant_context_after_request_id():
    """Middleware sets request_id at the start; auth dependency
    later upgrades the context with tenant_id."""
    tokens = set_request_context("rid_xyz")
    try:
        assert current_tenant_id() is None
        set_tenant_context("globex")
        assert current_tenant_id() == "globex"
        assert current_request_id() == "rid_xyz"
    finally:
        clear_request_context(tokens)


# =========================================================
# JsonLogFormatter β€” auto-injection
# =========================================================

def test_log_omits_fields_when_context_empty():
    clear_request_context()
    fmt = JsonLogFormatter()
    record = logging.LogRecord(
        name="test", level=logging.INFO, pathname="x", lineno=1,
        msg="hello", args=(), exc_info=None,
    )
    payload = json.loads(fmt.format(record))
    assert "request_id" not in payload
    assert "tenant_id" not in payload


def test_log_includes_request_id_when_set():
    tokens = set_request_context("rid_logtest", tenant_id="acme")
    try:
        fmt = JsonLogFormatter()
        record = logging.LogRecord(
            name="test", level=logging.INFO, pathname="x", lineno=1,
            msg="hello", args=(), exc_info=None,
        )
        payload = json.loads(fmt.format(record))
        assert payload["request_id"] == "rid_logtest"
        assert payload["tenant_id"] == "acme"
    finally:
        clear_request_context(tokens)


def test_log_includes_request_id_only_when_no_tenant():
    """e.g. middleware-emitted log lines (before auth) carry
    request_id but no tenant_id yet."""
    tokens = set_request_context("rid_pre_auth", tenant_id=None)
    try:
        fmt = JsonLogFormatter()
        record = logging.LogRecord(
            name="test", level=logging.INFO, pathname="x", lineno=1,
            msg="pre-auth", args=(), exc_info=None,
        )
        payload = json.loads(fmt.format(record))
        assert payload["request_id"] == "rid_pre_auth"
        assert "tenant_id" not in payload
    finally:
        clear_request_context(tokens)


# =========================================================
# Integration β€” middleware in a real FastAPI app
# =========================================================

def _bootstrap(tmp_path):
    dbfile = str(tmp_path / "rctx.sqlite3")
    svc = OrgStateService(dbfile)
    try:
        svc.register_tenant("acme", "ACME")
        keys = {
            "acme_op": svc.create_api_key("acme", role="operator").raw,
        }
    finally:
        svc.close()
    return dbfile, keys


def _auth(k):
    return {"Authorization": f"Bearer {k}"}


def test_response_includes_generated_request_id(tmp_path):
    dbfile, _ = _bootstrap(tmp_path)
    client = TestClient(create_app(dbfile))
    r = client.get("/health")
    assert r.status_code == 200
    rid = r.headers.get("X-Request-ID")
    assert rid is not None
    assert len(rid) == 16   # generated format


def test_upstream_request_id_preserved(tmp_path):
    """If the caller (or a trace-aware proxy) supplied a valid
    X-Request-ID, we use it. Operator can trace across services
    by passing the same ID at every hop."""
    dbfile, _ = _bootstrap(tmp_path)
    client = TestClient(create_app(dbfile))
    r = client.get("/health", headers={"X-Request-ID": "trace-abc-123"})
    assert r.headers["X-Request-ID"] == "trace-abc-123"


def test_malicious_upstream_request_id_rejected(tmp_path):
    """A 10KB X-Request-ID would balloon every subsequent log
    line. We discard and generate fresh."""
    dbfile, _ = _bootstrap(tmp_path)
    client = TestClient(create_app(dbfile))
    r = client.get("/health", headers={"X-Request-ID": "a" * 5000})
    rid = r.headers["X-Request-ID"]
    assert rid != "a" * 5000
    assert len(rid) == 16


def test_different_requests_get_different_ids(tmp_path):
    dbfile, _ = _bootstrap(tmp_path)
    client = TestClient(create_app(dbfile))
    ids = set()
    for _ in range(10):
        r = client.get("/health")
        ids.add(r.headers["X-Request-ID"])
    # 10 calls, 10 distinct IDs
    assert len(ids) == 10


def test_context_visible_to_handler_via_request_state(tmp_path):
    """A handler can read request.state.request_id (set by the
    middleware) for explicit use cases β€” e.g. embedding the ID
    in an error response. FastAPI needs the Request type
    annotation to inject the request object."""
    from fastapi import Request
    dbfile, _ = _bootstrap(tmp_path)
    app = create_app(dbfile)

    @app.get("/__rid_echo")
    async def echo_rid(request: Request):
        return {"echoed": request.state.request_id}

    client = TestClient(app)
    r = client.get("/__rid_echo",
                    headers={"X-Request-ID": "abc-trace-xyz"})
    assert r.json()["echoed"] == "abc-trace-xyz"


def test_log_during_authenticated_request_carries_tenant_id(tmp_path, caplog):
    """End-to-end: an authenticated route's handler emits a log,
    that log MUST include both request_id (from middleware) and
    tenant_id (from auth dep). This is the headline value of
    Stage 84 β€” a customer says 'I saw 500 at 14:23, here's my
    request ID' and we can grep all log lines for that single
    request including which tenant was hit."""
    dbfile, keys = _bootstrap(tmp_path)
    app = create_app(dbfile)
    captured: list = []

    @app.get("/__ctx_log")
    async def emit_log(request):
        import logging
        logger = logging.getLogger("test.ctx")
        # We capture the formatter's OUTPUT β€” not the raw record
        fmt = JsonLogFormatter()
        record = logger.makeRecord(
            "test.ctx", logging.INFO, __file__, 0,
            "handler-fired", (), None,
        )
        captured.append(fmt.format(record))
        return {"ok": True}

    # need to access through the FastAPI router so auth dep runs;
    # but auth_dep is wired only to routes added BEFORE this. To
    # exercise the auth path, GET an existing auth'd route then
    # use the captured-formatter trick on a route that requires
    # auth via dep injection.
    from fastapi import Depends

    from infra.api.auth import make_dependency
    from infra.auth import ApiKey

    # rebuild the dep manually so it sets tenant_id contextvar
    svc = app.state.svc
    auth_dep = make_dependency(svc)

    @app.get("/__authed_log")
    async def emit_authed_log(key: ApiKey = Depends(auth_dep)):
        fmt = JsonLogFormatter()
        record = logging.LogRecord(
            name="test.authed", level=logging.INFO, pathname="x",
            lineno=1, msg="handler ran", args=(), exc_info=None,
        )
        captured.append(fmt.format(record))
        return {"ok": True}

    client = TestClient(app)
    r = client.get("/__authed_log", headers=_auth(keys["acme_op"]))
    assert r.status_code == 200
    assert len(captured) == 1
    payload = json.loads(captured[0])
    # both fields present because middleware set request_id +
    # auth_dep called set_tenant_context
    assert "request_id" in payload
    assert payload["tenant_id"] == "acme"


def test_context_cleared_between_requests(tmp_path):
    """The contextvar cleanup in the middleware's finally block
    must actually fire β€” otherwise request N+1 sees request N's
    tenant_id when called without auth."""
    from fastapi import Request
    dbfile, keys = _bootstrap(tmp_path)
    app = create_app(dbfile)

    @app.get("/__leak_check")
    async def show_ctx(request: Request):
        return {
            "rid": current_request_id(),
            "tid": current_tenant_id(),
        }

    client = TestClient(app)
    # request 1 with auth populates tenant context inside the
    # request via the existing auth dependency
    r1 = client.get("/tenants/acme", headers=_auth(keys["acme_op"]))
    assert r1.status_code == 200
    # request 2 (no auth) β€” context should be fresh
    r2 = client.get("/__leak_check")
    body = r2.json()
    # request_id is set fresh per request; tenant_id is None
    # (no auth on this route β†’ no set_tenant_context call)
    assert body["rid"] is not None
    assert body["tid"] is None