File size: 3,827 Bytes
09ed8ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""In-memory JWKS cache with TTL for RS256 token verification.

Fetches the JSON Web Key Set from a remote endpoint and caches it for
``settings.jwks_cache_ttl_seconds`` (default 5 min). When a token arrives
with a ``kid`` not present in the cache, the cache is refreshed once
before giving up.

Usage::

    from utils.jwks_cache import get_signing_key
    key = get_signing_key(token_kid, jwks_url="https://idp/realm/protocol/openid-connect/certs")
"""

from __future__ import annotations

import json
import time
from typing import Any

from config.settings import settings
from utils.logging import get_logger

logger = get_logger(__name__)

# Module-level cache: {jwks_url: {"fetched_at": float, "keys": dict[str, str]}}
_cache: dict[str, dict[str, Any]] = {}


def _fetch_jwks(url: str) -> dict[str, str]:
    """Download JWKS and return a mapping kid -> PEM-encoded public key."""
    import urllib.request

    req = urllib.request.Request(url, headers={"Accept": "application/json"})
    with urllib.request.urlopen(req, timeout=10) as resp:
        data = json.loads(resp.read().decode("utf-8"))

    keys: dict[str, str] = {}
    for jwk in data.get("keys", []):
        kid = jwk.get("kid")
        if not kid:
            continue
        # Only RSA keys are supported for RS256.
        if jwk.get("kty") != "RSA":
            continue
        pem = _jwk_to_pem(jwk)
        if pem:
            keys[kid] = pem

    logger.info("jwks_fetched", url=url, key_count=len(keys))
    return keys


def _jwk_to_pem(jwk: dict[str, Any]) -> str | None:
    """Convert a single RSA JWK to PEM using python-jose if available."""
    try:
        from jose.backends import RSAKey
        from jose.utils import base64url_decode
    except ImportError:
        logger.warning("python_jose_missing_for_jwk_conversion")
        return None

    try:
        n = int.from_bytes(base64url_decode(jwk["n"]), "big")
        e = int.from_bytes(base64url_decode(jwk["e"]), "big")
        rsa_key = RSAKey({"n": n, "e": e})
        return rsa_key.to_pem().decode("utf-8")
    except Exception as exc:
        logger.warning("jwk_to_pem_failed", error=str(exc), kid=jwk.get("kid"))
        return None


def get_signing_key(
    kid: str,
    jwks_url: str | None = None,
    force_refresh: bool = False,
) -> str:
    """Return the PEM-encoded RSA public key for the given ``kid``.

    Args:
        kid: Key ID from the JWT header.
        jwks_url: JWKS endpoint. Defaults to ``settings.jwks_url``.
        force_refresh: If True, bypass the cache and re-fetch.

    Returns:
        PEM-encoded public key string.

    Raises:
        RuntimeError: If the key cannot be found after refresh.
    """
    url = jwks_url or settings.jwks_url
    if not url:
        raise RuntimeError("jwks_url is not configured")

    ttl = getattr(settings, "jwks_cache_ttl_seconds", 300)
    now = time.monotonic()

    entry = _cache.get(url)
    if not force_refresh and entry and (now - entry["fetched_at"]) < ttl and kid in entry["keys"]:
        return entry["keys"][kid]

    # Cache miss or stale — fetch.
    try:
        keys = _fetch_jwks(url)
    except Exception as exc:
        # On fetch failure, try to serve stale if we have it.
        if entry and kid in entry.get("keys", {}):
            logger.warning("jwks_fetch_failed_serving_stale", error=str(exc), url=url)
            return entry["keys"][kid]
        raise RuntimeError(f"JWKS fetch failed: {exc}") from exc

    _cache[url] = {"fetched_at": now, "keys": keys}

    if kid not in keys:
        raise RuntimeError(f"kid '{kid}' not found in JWKS")
    return keys[kid]


def clear_cache(url: str | None = None) -> None:
    """Clear the JWKS cache. Used in tests."""
    global _cache
    if url is None:
        _cache.clear()
    else:
        _cache.pop(url, None)