File size: 9,166 Bytes
d2d1903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
"""
infra.api.auth_cache β€” Stage 134 β€” auth-header β†’ tenant policy
LRU cache.

Stage 128 added per-tenant rate-limit fail-closed override.
The rate-limit middleware looks up tenant + policy on every
authenticated request via ``svc.api_keys.verify`` (single
SHA256-indexed query) + ``svc.get_tenant_rate_limit_fail_
closed``. Both are sub-millisecond, but at 1k req/s the
constant load on the connection pool adds up β€” and a single
API key arrives THOUSANDS of times/second from the same
customer. Caching the resolution removes the wasted lookups.

Discipline:
* **Bounded size**: hard cap on entries (default 4096) with
  LRU eviction. Memory profile stays predictable even when
  a noisy client hammers with random bogus tokens (the
  bogus rows expire when evicted by real traffic).
* **TTL**: each entry has an expiry timestamp (default 60s).
  A tenant flipping the policy via ``infra tenant rate-
  limit-fail-closed-set`` propagates to all pods within
  the TTL window. Operators who need instant propagation
  can call ``cache.invalidate_tenant(tenant_id)`` from the
  service layer.
* **Negative caching**: unknown tokens cache as
  ``(None, None, expires_at)`` for a shorter TTL (default
  10s). Bogus traffic stops hammering the DB after the
  first lookup.
* **Thread-safe**: single coarse lock around the dict β€” same
  pattern as ``ratelimit.MemoryBackend``. Contention at
  realistic API rates is negligible (dict ops only, no I/O
  inside the lock).
* **Fail-soft**: if the underlying service raises, the
  caller catches and treats as cache miss + uses backend
  default. Caching ONLY adds to read performance; it must
  never reduce correctness.

Stdlib only β€” uses ``collections.OrderedDict`` for LRU
ordering and ``time.monotonic`` for TTL.
"""
from __future__ import annotations

import threading
import time
from collections import OrderedDict
from typing import Optional, Tuple

# Sentinel for "we looked up this token and it didn't resolve
# to a tenant" β€” distinct from "we haven't looked it up yet"
# (= dict miss). Caching negative results stops a stream of
# bogus tokens from re-hitting the DB.
_UNKNOWN = object()


class TenantAuthCache:
    """Maps SHA256(authorization) β†’ (tenant_id, fail_closed_policy,
    expires_at). LRU-bounded, TTL-expiring.

    Entry tuple shape:
        (tenant_id_str_OR_None, fail_closed_bool_OR_None,
         expires_at_monotonic)

    * (tid, policy, exp) β€” known tenant with optional policy
    * (None, None, exp)  β€” negative cache (token didn't resolve)
    """

    def __init__(self, *,
                  max_size: int = 4096,
                  positive_ttl_seconds: int = 60,
                  negative_ttl_seconds: int = 10):
        if max_size <= 0:
            raise ValueError("max_size must be > 0")
        if positive_ttl_seconds <= 0 or negative_ttl_seconds <= 0:
            raise ValueError("TTLs must be > 0 seconds")
        self._max_size = max_size
        self._positive_ttl = positive_ttl_seconds
        self._negative_ttl = negative_ttl_seconds
        self._store: OrderedDict[str, Tuple[
            Optional[str], Optional[bool], float,
        ]] = OrderedDict()
        self._lock = threading.Lock()
        # Counters for /metrics + ops debugging.
        self.hits = 0
        self.misses = 0
        self.evictions = 0
        self.negative_hits = 0

    @staticmethod
    def _clock() -> float:
        """Monotonic wall-time in seconds. Module-level so tests
        can monkeypatch it without touching ``time``."""
        return time.monotonic()

    @staticmethod
    def _key_for(authorization: str) -> str:
        """Hash the authorization header. We deliberately don't
        store the raw bearer token in process memory (defense in
        depth β€” process memory dumps stay scrubbed). Truncated
        sha256 is plenty: 64 bits of address space crushes any
        realistic collision risk for a 4k-entry cache."""
        import hashlib
        return hashlib.sha256(
            authorization.encode("utf-8"),
        ).hexdigest()[:32]

    def get(self, authorization: str,
              ) -> Tuple[bool, Optional[str], Optional[bool]]:
        """Returns ``(hit, tenant_id_or_None, policy_or_None)``.

        ``hit=True`` means the cache had a fresh entry; the
        tenant_id may still be None (negative cache). Caller's
        responsibility to decide what to do on negative cache β€”
        typical: treat as "no tenant resolvable" and let the
        backend default fire.

        ``hit=False`` means the caller must compute the answer
        and call ``put`` to seed the cache.
        """
        if not authorization:
            return False, None, None
        key = self._key_for(authorization)
        now = self._clock()
        with self._lock:
            entry = self._store.get(key)
            if entry is None:
                self.misses += 1
                return False, None, None
            tid, policy, exp = entry
            if now >= exp:
                # Expired β€” evict; treat as miss
                self._store.pop(key, None)
                self.misses += 1
                return False, None, None
            # Fresh β€” move-to-end for LRU
            self._store.move_to_end(key)
            self.hits += 1
            if tid is None:
                self.negative_hits += 1
            return True, tid, policy

    def put(self, authorization: str,
              tenant_id: Optional[str],
              policy: Optional[bool],
              ) -> None:
        """Seed a cache entry. ``tenant_id=None`` indicates a
        negative cache entry (token didn't resolve)."""
        if not authorization:
            return
        key = self._key_for(authorization)
        ttl = (self._positive_ttl
                if tenant_id is not None
                else self._negative_ttl)
        expires_at = self._clock() + ttl
        with self._lock:
            # If the key existed, refresh in place (move-to-end
            # via OrderedDict). Otherwise potentially evict.
            if key in self._store:
                self._store[key] = (tenant_id, policy, expires_at)
                self._store.move_to_end(key)
                return
            self._store[key] = (tenant_id, policy, expires_at)
            if len(self._store) > self._max_size:
                # Evict the OLDEST entry (popitem(last=False)).
                self._store.popitem(last=False)
                self.evictions += 1

    def invalidate_tenant(self, tenant_id: str) -> int:
        """Drop every entry pointing to ``tenant_id``. Called
        from the service layer after
        ``set_tenant_rate_limit_fail_closed`` so the change
        propagates faster than ``positive_ttl``.

        Returns the count of entries dropped."""
        if not tenant_id:
            return 0
        with self._lock:
            to_drop = [
                k for k, (tid, _, _) in self._store.items()
                if tid == tenant_id
            ]
            for k in to_drop:
                self._store.pop(k, None)
            return len(to_drop)

    def invalidate_all(self) -> int:
        """Drop every entry. Called on encryption key rotation
        or other broad config changes. Returns the count."""
        with self._lock:
            n = len(self._store)
            self._store.clear()
            return n

    def size(self) -> int:
        with self._lock:
            return len(self._store)


def resolve_tenant_fail_closed_cached(
        svc, authorization: Optional[str],
        cache: Optional[TenantAuthCache],
        ) -> Optional[bool]:
    """Stage 134 β€” replacement for ``app._resolve_tenant_fail_
    closed`` that consults ``cache`` first. Same return contract
    as Stage 128:

      * None β€” no override (anonymous request, unknown token,
                tenant has no override, or lookup raised)
      * True/False β€” explicit tenant override

    ``cache=None`` falls back to the uncached path (useful for
    tests that want to disable caching cleanly).
    """
    if not authorization:
        return None
    if cache is not None:
        hit, tid, policy = cache.get(authorization)
        if hit:
            # Cached: tid may be None (negative) β€” that means the
            # token didn't resolve last time. Return policy (None
            # for negative cache; the tenant's actual override
            # for positive cache).
            return policy
    # Miss / no cache β€” do the real lookup, store the result.
    try:
        raw = authorization.strip()
        if raw.lower().startswith("bearer "):
            raw = raw[7:].strip()
        api_key = svc.api_keys.verify(raw)
        if api_key is None:
            if cache is not None:
                cache.put(authorization, None, None)
            return None
        policy = svc.get_tenant_rate_limit_fail_closed(
            api_key.tenant_id,
        )
        if cache is not None:
            cache.put(authorization, api_key.tenant_id, policy)
        return policy
    except Exception:    # noqa: BLE001
        # Same fail-soft as Stage 128 β€” never propagate.
        return None