"""JWT issuance + verification for the FastAPI / MCP surfaces. Why --- Until now the only auth was a base64-encoded JSON ``UserContext`` carried as a bearer token. That proves nothing — any caller can craft any identity. This module replaces it with HS256-signed JWTs: - ``issue_token(user_id, org_id, roles, clearance_level)`` mints a token signed with ``settings.jwt_secret``. Suitable for the dev ``/token`` endpoint and for tests. - ``verify_token(token)`` validates signature, expiry, and (when configured) ``iss`` / ``aud`` claims, then returns a ``UserContext`` plus the raw claims (for audit logging the ``jti``). RS256 / JWKS mode ----------------- When ``settings.jwt_algorithm`` is ``"RS256"`` tokens are verified against a remote JWKS endpoint (Keycloak / Auth0 / etc.). The local ``/token`` endpoint returns 404 because we do not hold the IdP's private key. See ``utils/jwks_cache``. Fail-closed default ------------------- When ``settings.jwt_secret`` is unset the verifier rejects every token. The legacy unsigned base64 shape is accepted *only* when ``settings.allow_unsigned_tokens`` is explicitly turned on (dev/test). An unsigned token proves no identity, so it is never honoured silently in production. Deployments must set ``SAR_JWT_SECRET`` or ``SAR_JWKS_URL``. """ from __future__ import annotations import base64 import json import uuid from datetime import UTC, datetime, timedelta from typing import Any from config.settings import settings from ingestion.metadata import UserContext from utils.logging import get_logger logger = get_logger(__name__) class AuthError(Exception): """Raised when token verification fails for any reason. Carries a short machine-readable ``reason`` (``missing`` / ``expired`` / ``bad_signature`` / ``bad_claims`` / ``malformed``) so callers can map consistently to HTTP status codes. """ def __init__(self, reason: str, message: str = "") -> None: super().__init__(message or reason) self.reason = reason def _jose_available() -> bool: """Best-effort detection of python-jose. Cached on import would be nicer but a function makes the unit tests cleaner. """ try: import jose # noqa: F401 return True except ImportError: return False def _decode_header(token: str) -> dict[str, Any]: """Decode the JWT header without verification.""" parts = token.split(".") if len(parts) != 3: raise AuthError("malformed", "token must have 3 segments") padding = 4 - len(parts[0]) % 4 if padding != 4: parts[0] += "=" * padding try: return json.loads(base64.urlsafe_b64decode(parts[0]).decode("utf-8")) except Exception as exc: raise AuthError("malformed", f"invalid header: {exc}") from exc def issue_token( user_id: str, org_id: str, roles: list[str], clearance_level: int = 1, ttl_seconds: int | None = None, extra_claims: dict[str, Any] | None = None, ) -> str: """Mint a signed JWT for the given identity. Args: user_id: Stable principal identifier. org_id: Organization the principal belongs to. Drives multi-tenant collection routing. roles: List of role strings carried into the ``UserContext``. clearance_level: Numeric clearance (1=low / 3=high). ttl_seconds: Lifetime override; defaults to ``settings.jwt_ttl_seconds``. extra_claims: Optional extra claims to merge into the token payload. Returns: Compact JWT string. Raises: AuthError: If ``settings.jwt_secret`` is not configured or python-jose is missing. """ if settings.jwt_algorithm.upper() == "RS256": raise AuthError("missing", "Cannot issue RS256 tokens locally — use the external IdP") if not settings.jwt_secret: raise AuthError("missing", "SAR_JWT_SECRET is not configured") if not _jose_available(): raise AuthError("missing", "python-jose is not installed (install the [api] extra)") from jose import jwt # type: ignore[import-not-found] now = datetime.now(UTC) ttl = ttl_seconds if ttl_seconds is not None else settings.jwt_ttl_seconds payload: dict[str, Any] = { "sub": user_id, "user_id": user_id, "org_id": org_id, "roles": list(roles), "clearance_level": int(clearance_level), "iat": int(now.timestamp()), "exp": int((now + timedelta(seconds=ttl)).timestamp()), "jti": str(uuid.uuid4()), } if settings.jwt_issuer: payload["iss"] = settings.jwt_issuer if settings.jwt_audience: payload["aud"] = settings.jwt_audience if extra_claims: # Never let caller-supplied extra claims overwrite the security-relevant # registered claims (expiry, subject, token id, issuer, audience, etc.). # Otherwise a caller could mint a non-expiring or identity-spoofing token. _reserved = { "sub", "user_id", "org_id", "roles", "clearance_level", "iat", "exp", "jti", "iss", "aud", } safe_extra = {k: v for k, v in extra_claims.items() if k not in _reserved} dropped = set(extra_claims) - set(safe_extra) if dropped: logger.warning("jwt_extra_claims_dropped_reserved", dropped=sorted(dropped)) payload.update(safe_extra) token = jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm) logger.info( "jwt_issued", user_id=user_id, org_id=org_id, roles=roles, jti=payload["jti"], ttl_seconds=ttl, ) return token def _verify_legacy_base64(token: str) -> tuple[UserContext, dict[str, Any]]: """Decode the legacy unsigned base64(json(UserContext)) shape. Used only when ``settings.jwt_secret`` is unset. Always raises a runtime warning so deployments don't forget to flip on signed JWTs. """ logger.warning( "auth_unsigned_token", message=( "SAR_JWT_SECRET unset — accepting unsigned base64 tokens. Anyone " "with network access can impersonate any user. Configure " "SAR_JWT_SECRET in production." ), ) try: payload = json.loads(base64.b64decode(token).decode("utf-8")) except Exception as exc: raise AuthError("malformed", f"base64/json decode failed: {exc}") from exc try: ctx = UserContext(**payload) except Exception as exc: raise AuthError("bad_claims", f"UserContext build failed: {exc}") from exc return ctx, {"sub": ctx.user_id, "jti": "unsigned"} def _verify_jwt(token: str) -> tuple[UserContext, dict[str, Any]]: """Verify a signed JWT (HS256 or RS256).""" if not _jose_available(): raise AuthError("missing", "python-jose is not installed") from jose import JWTError, jwt # type: ignore[import-not-found] options: dict[str, Any] = {"require": ["exp", "iat", "sub"]} audience = settings.jwt_audience or None if not audience: options["verify_aud"] = False algorithm = settings.jwt_algorithm.upper() # Resolve the verification key. if algorithm == "RS256": try: header = _decode_header(token) kid = header.get("kid") if not kid: raise AuthError("bad_claims", "RS256 token missing 'kid' header") from utils.jwks_cache import get_signing_key key = get_signing_key(kid) except AuthError: raise except Exception as exc: raise AuthError("bad_signature", f"JWKS lookup failed: {exc}") from exc else: key = settings.jwt_secret if not key: raise AuthError("missing", "SAR_JWT_SECRET is not configured") try: claims = jwt.decode( token, key, algorithms=[algorithm], audience=audience, issuer=settings.jwt_issuer or None, options=options, ) except JWTError as exc: msg = str(exc).lower() if "expired" in msg: reason = "expired" elif "signature" in msg: reason = "bad_signature" elif "claim" in msg or "audience" in msg or "issuer" in msg: reason = "bad_claims" else: reason = "malformed" raise AuthError(reason, f"jwt decode failed: {exc}") from exc try: ctx = UserContext( user_id=claims.get("user_id") or claims["sub"], org_id=claims.get("org_id", ""), roles=list(claims.get("roles", [])), clearance_level=int(claims.get("clearance_level", 1)), ) except Exception as exc: raise AuthError("bad_claims", f"UserContext build failed: {exc}") from exc return ctx, claims def verify_token(token: str) -> tuple[UserContext, dict[str, Any]]: """Resolve a bearer token to a ``UserContext`` plus the raw claims. Args: token: Raw bearer token (no ``Bearer `` prefix). Returns: ``(user_context, claims)``. ``claims`` includes at minimum ``sub`` and ``jti``; the latter is used in audit-trail entries so a tampered or replayed token is traceable. Raises: AuthError: With ``.reason`` set to one of ``missing`` / ``malformed`` / ``expired`` / ``bad_signature`` / ``bad_claims``. """ if not token or not isinstance(token, str): raise AuthError("missing", "empty token") # When jwt_secret is set OR we're in RS256 mode, use JWT verification. if settings.jwt_secret or settings.jwt_algorithm.upper() == "RS256": return _verify_jwt(token) # No signing key configured. Fail closed unless the legacy unsigned shape # is explicitly opted into (dev/test only). if settings.allow_unsigned_tokens: return _verify_legacy_base64(token) raise AuthError( "missing", "no JWT secret configured and unsigned tokens are disabled " "(set SAR_JWT_SECRET, or SAR_ALLOW_UNSIGNED_TOKENS=true for dev only)", )