secureagentrag-api / utils /jwks_cache.py
LeomordKaly's picture
deploy: phase 3 BYOK backend (Dockerfile.hf, FastAPI on 7860)
09ed8ca verified
"""In-memory JWKS cache with TTL for RS256 token verification.
Fetches the JSON Web Key Set from a remote endpoint and caches it for
``settings.jwks_cache_ttl_seconds`` (default 5 min). When a token arrives
with a ``kid`` not present in the cache, the cache is refreshed once
before giving up.
Usage::
from utils.jwks_cache import get_signing_key
key = get_signing_key(token_kid, jwks_url="https://idp/realm/protocol/openid-connect/certs")
"""
from __future__ import annotations
import json
import time
from typing import Any
from config.settings import settings
from utils.logging import get_logger
logger = get_logger(__name__)
# Module-level cache: {jwks_url: {"fetched_at": float, "keys": dict[str, str]}}
_cache: dict[str, dict[str, Any]] = {}
def _fetch_jwks(url: str) -> dict[str, str]:
"""Download JWKS and return a mapping kid -> PEM-encoded public key."""
import urllib.request
req = urllib.request.Request(url, headers={"Accept": "application/json"})
with urllib.request.urlopen(req, timeout=10) as resp:
data = json.loads(resp.read().decode("utf-8"))
keys: dict[str, str] = {}
for jwk in data.get("keys", []):
kid = jwk.get("kid")
if not kid:
continue
# Only RSA keys are supported for RS256.
if jwk.get("kty") != "RSA":
continue
pem = _jwk_to_pem(jwk)
if pem:
keys[kid] = pem
logger.info("jwks_fetched", url=url, key_count=len(keys))
return keys
def _jwk_to_pem(jwk: dict[str, Any]) -> str | None:
"""Convert a single RSA JWK to PEM using python-jose if available."""
try:
from jose.backends import RSAKey
from jose.utils import base64url_decode
except ImportError:
logger.warning("python_jose_missing_for_jwk_conversion")
return None
try:
n = int.from_bytes(base64url_decode(jwk["n"]), "big")
e = int.from_bytes(base64url_decode(jwk["e"]), "big")
rsa_key = RSAKey({"n": n, "e": e})
return rsa_key.to_pem().decode("utf-8")
except Exception as exc:
logger.warning("jwk_to_pem_failed", error=str(exc), kid=jwk.get("kid"))
return None
def get_signing_key(
kid: str,
jwks_url: str | None = None,
force_refresh: bool = False,
) -> str:
"""Return the PEM-encoded RSA public key for the given ``kid``.
Args:
kid: Key ID from the JWT header.
jwks_url: JWKS endpoint. Defaults to ``settings.jwks_url``.
force_refresh: If True, bypass the cache and re-fetch.
Returns:
PEM-encoded public key string.
Raises:
RuntimeError: If the key cannot be found after refresh.
"""
url = jwks_url or settings.jwks_url
if not url:
raise RuntimeError("jwks_url is not configured")
ttl = getattr(settings, "jwks_cache_ttl_seconds", 300)
now = time.monotonic()
entry = _cache.get(url)
if not force_refresh and entry and (now - entry["fetched_at"]) < ttl and kid in entry["keys"]:
return entry["keys"][kid]
# Cache miss or stale — fetch.
try:
keys = _fetch_jwks(url)
except Exception as exc:
# On fetch failure, try to serve stale if we have it.
if entry and kid in entry.get("keys", {}):
logger.warning("jwks_fetch_failed_serving_stale", error=str(exc), url=url)
return entry["keys"][kid]
raise RuntimeError(f"JWKS fetch failed: {exc}") from exc
_cache[url] = {"fetched_at": now, "keys": keys}
if kid not in keys:
raise RuntimeError(f"kid '{kid}' not found in JWKS")
return keys[kid]
def clear_cache(url: str | None = None) -> None:
"""Clear the JWKS cache. Used in tests."""
global _cache
if url is None:
_cache.clear()
else:
_cache.pop(url, None)