"""In-memory JWKS cache with TTL for RS256 token verification. Fetches the JSON Web Key Set from a remote endpoint and caches it for ``settings.jwks_cache_ttl_seconds`` (default 5 min). When a token arrives with a ``kid`` not present in the cache, the cache is refreshed once before giving up. Usage:: from utils.jwks_cache import get_signing_key key = get_signing_key(token_kid, jwks_url="https://idp/realm/protocol/openid-connect/certs") """ from __future__ import annotations import json import time from typing import Any from config.settings import settings from utils.logging import get_logger logger = get_logger(__name__) # Module-level cache: {jwks_url: {"fetched_at": float, "keys": dict[str, str]}} _cache: dict[str, dict[str, Any]] = {} def _fetch_jwks(url: str) -> dict[str, str]: """Download JWKS and return a mapping kid -> PEM-encoded public key.""" import urllib.request req = urllib.request.Request(url, headers={"Accept": "application/json"}) with urllib.request.urlopen(req, timeout=10) as resp: data = json.loads(resp.read().decode("utf-8")) keys: dict[str, str] = {} for jwk in data.get("keys", []): kid = jwk.get("kid") if not kid: continue # Only RSA keys are supported for RS256. if jwk.get("kty") != "RSA": continue pem = _jwk_to_pem(jwk) if pem: keys[kid] = pem logger.info("jwks_fetched", url=url, key_count=len(keys)) return keys def _jwk_to_pem(jwk: dict[str, Any]) -> str | None: """Convert a single RSA JWK to PEM using python-jose if available.""" try: from jose.backends import RSAKey from jose.utils import base64url_decode except ImportError: logger.warning("python_jose_missing_for_jwk_conversion") return None try: n = int.from_bytes(base64url_decode(jwk["n"]), "big") e = int.from_bytes(base64url_decode(jwk["e"]), "big") rsa_key = RSAKey({"n": n, "e": e}) return rsa_key.to_pem().decode("utf-8") except Exception as exc: logger.warning("jwk_to_pem_failed", error=str(exc), kid=jwk.get("kid")) return None def get_signing_key( kid: str, jwks_url: str | None = None, force_refresh: bool = False, ) -> str: """Return the PEM-encoded RSA public key for the given ``kid``. Args: kid: Key ID from the JWT header. jwks_url: JWKS endpoint. Defaults to ``settings.jwks_url``. force_refresh: If True, bypass the cache and re-fetch. Returns: PEM-encoded public key string. Raises: RuntimeError: If the key cannot be found after refresh. """ url = jwks_url or settings.jwks_url if not url: raise RuntimeError("jwks_url is not configured") ttl = getattr(settings, "jwks_cache_ttl_seconds", 300) now = time.monotonic() entry = _cache.get(url) if not force_refresh and entry and (now - entry["fetched_at"]) < ttl and kid in entry["keys"]: return entry["keys"][kid] # Cache miss or stale — fetch. try: keys = _fetch_jwks(url) except Exception as exc: # On fetch failure, try to serve stale if we have it. if entry and kid in entry.get("keys", {}): logger.warning("jwks_fetch_failed_serving_stale", error=str(exc), url=url) return entry["keys"][kid] raise RuntimeError(f"JWKS fetch failed: {exc}") from exc _cache[url] = {"fetched_at": now, "keys": keys} if kid not in keys: raise RuntimeError(f"kid '{kid}' not found in JWKS") return keys[kid] def clear_cache(url: str | None = None) -> None: """Clear the JWKS cache. Used in tests.""" global _cache if url is None: _cache.clear() else: _cache.pop(url, None)