orgstate / infra /api /auth_cache.py
Legal-i's picture
Initial OrgState deploy via Stage 150 free-tier stack
d2d1903 verified
"""
infra.api.auth_cache β€” Stage 134 β€” auth-header β†’ tenant policy
LRU cache.
Stage 128 added per-tenant rate-limit fail-closed override.
The rate-limit middleware looks up tenant + policy on every
authenticated request via ``svc.api_keys.verify`` (single
SHA256-indexed query) + ``svc.get_tenant_rate_limit_fail_
closed``. Both are sub-millisecond, but at 1k req/s the
constant load on the connection pool adds up β€” and a single
API key arrives THOUSANDS of times/second from the same
customer. Caching the resolution removes the wasted lookups.
Discipline:
* **Bounded size**: hard cap on entries (default 4096) with
LRU eviction. Memory profile stays predictable even when
a noisy client hammers with random bogus tokens (the
bogus rows expire when evicted by real traffic).
* **TTL**: each entry has an expiry timestamp (default 60s).
A tenant flipping the policy via ``infra tenant rate-
limit-fail-closed-set`` propagates to all pods within
the TTL window. Operators who need instant propagation
can call ``cache.invalidate_tenant(tenant_id)`` from the
service layer.
* **Negative caching**: unknown tokens cache as
``(None, None, expires_at)`` for a shorter TTL (default
10s). Bogus traffic stops hammering the DB after the
first lookup.
* **Thread-safe**: single coarse lock around the dict β€” same
pattern as ``ratelimit.MemoryBackend``. Contention at
realistic API rates is negligible (dict ops only, no I/O
inside the lock).
* **Fail-soft**: if the underlying service raises, the
caller catches and treats as cache miss + uses backend
default. Caching ONLY adds to read performance; it must
never reduce correctness.
Stdlib only β€” uses ``collections.OrderedDict`` for LRU
ordering and ``time.monotonic`` for TTL.
"""
from __future__ import annotations
import threading
import time
from collections import OrderedDict
from typing import Optional, Tuple
# Sentinel for "we looked up this token and it didn't resolve
# to a tenant" β€” distinct from "we haven't looked it up yet"
# (= dict miss). Caching negative results stops a stream of
# bogus tokens from re-hitting the DB.
_UNKNOWN = object()
class TenantAuthCache:
"""Maps SHA256(authorization) β†’ (tenant_id, fail_closed_policy,
expires_at). LRU-bounded, TTL-expiring.
Entry tuple shape:
(tenant_id_str_OR_None, fail_closed_bool_OR_None,
expires_at_monotonic)
* (tid, policy, exp) β€” known tenant with optional policy
* (None, None, exp) β€” negative cache (token didn't resolve)
"""
def __init__(self, *,
max_size: int = 4096,
positive_ttl_seconds: int = 60,
negative_ttl_seconds: int = 10):
if max_size <= 0:
raise ValueError("max_size must be > 0")
if positive_ttl_seconds <= 0 or negative_ttl_seconds <= 0:
raise ValueError("TTLs must be > 0 seconds")
self._max_size = max_size
self._positive_ttl = positive_ttl_seconds
self._negative_ttl = negative_ttl_seconds
self._store: OrderedDict[str, Tuple[
Optional[str], Optional[bool], float,
]] = OrderedDict()
self._lock = threading.Lock()
# Counters for /metrics + ops debugging.
self.hits = 0
self.misses = 0
self.evictions = 0
self.negative_hits = 0
@staticmethod
def _clock() -> float:
"""Monotonic wall-time in seconds. Module-level so tests
can monkeypatch it without touching ``time``."""
return time.monotonic()
@staticmethod
def _key_for(authorization: str) -> str:
"""Hash the authorization header. We deliberately don't
store the raw bearer token in process memory (defense in
depth β€” process memory dumps stay scrubbed). Truncated
sha256 is plenty: 64 bits of address space crushes any
realistic collision risk for a 4k-entry cache."""
import hashlib
return hashlib.sha256(
authorization.encode("utf-8"),
).hexdigest()[:32]
def get(self, authorization: str,
) -> Tuple[bool, Optional[str], Optional[bool]]:
"""Returns ``(hit, tenant_id_or_None, policy_or_None)``.
``hit=True`` means the cache had a fresh entry; the
tenant_id may still be None (negative cache). Caller's
responsibility to decide what to do on negative cache β€”
typical: treat as "no tenant resolvable" and let the
backend default fire.
``hit=False`` means the caller must compute the answer
and call ``put`` to seed the cache.
"""
if not authorization:
return False, None, None
key = self._key_for(authorization)
now = self._clock()
with self._lock:
entry = self._store.get(key)
if entry is None:
self.misses += 1
return False, None, None
tid, policy, exp = entry
if now >= exp:
# Expired β€” evict; treat as miss
self._store.pop(key, None)
self.misses += 1
return False, None, None
# Fresh β€” move-to-end for LRU
self._store.move_to_end(key)
self.hits += 1
if tid is None:
self.negative_hits += 1
return True, tid, policy
def put(self, authorization: str,
tenant_id: Optional[str],
policy: Optional[bool],
) -> None:
"""Seed a cache entry. ``tenant_id=None`` indicates a
negative cache entry (token didn't resolve)."""
if not authorization:
return
key = self._key_for(authorization)
ttl = (self._positive_ttl
if tenant_id is not None
else self._negative_ttl)
expires_at = self._clock() + ttl
with self._lock:
# If the key existed, refresh in place (move-to-end
# via OrderedDict). Otherwise potentially evict.
if key in self._store:
self._store[key] = (tenant_id, policy, expires_at)
self._store.move_to_end(key)
return
self._store[key] = (tenant_id, policy, expires_at)
if len(self._store) > self._max_size:
# Evict the OLDEST entry (popitem(last=False)).
self._store.popitem(last=False)
self.evictions += 1
def invalidate_tenant(self, tenant_id: str) -> int:
"""Drop every entry pointing to ``tenant_id``. Called
from the service layer after
``set_tenant_rate_limit_fail_closed`` so the change
propagates faster than ``positive_ttl``.
Returns the count of entries dropped."""
if not tenant_id:
return 0
with self._lock:
to_drop = [
k for k, (tid, _, _) in self._store.items()
if tid == tenant_id
]
for k in to_drop:
self._store.pop(k, None)
return len(to_drop)
def invalidate_all(self) -> int:
"""Drop every entry. Called on encryption key rotation
or other broad config changes. Returns the count."""
with self._lock:
n = len(self._store)
self._store.clear()
return n
def size(self) -> int:
with self._lock:
return len(self._store)
def resolve_tenant_fail_closed_cached(
svc, authorization: Optional[str],
cache: Optional[TenantAuthCache],
) -> Optional[bool]:
"""Stage 134 β€” replacement for ``app._resolve_tenant_fail_
closed`` that consults ``cache`` first. Same return contract
as Stage 128:
* None β€” no override (anonymous request, unknown token,
tenant has no override, or lookup raised)
* True/False β€” explicit tenant override
``cache=None`` falls back to the uncached path (useful for
tests that want to disable caching cleanly).
"""
if not authorization:
return None
if cache is not None:
hit, tid, policy = cache.get(authorization)
if hit:
# Cached: tid may be None (negative) β€” that means the
# token didn't resolve last time. Return policy (None
# for negative cache; the tenant's actual override
# for positive cache).
return policy
# Miss / no cache β€” do the real lookup, store the result.
try:
raw = authorization.strip()
if raw.lower().startswith("bearer "):
raw = raw[7:].strip()
api_key = svc.api_keys.verify(raw)
if api_key is None:
if cache is not None:
cache.put(authorization, None, None)
return None
policy = svc.get_tenant_rate_limit_fail_closed(
api_key.tenant_id,
)
if cache is not None:
cache.put(authorization, api_key.tenant_id, policy)
return policy
except Exception: # noqa: BLE001
# Same fail-soft as Stage 128 β€” never propagate.
return None