Spaces:
Running
Running
File size: 10,510 Bytes
c385f4b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 | """JWT issuance + verification for the FastAPI / MCP surfaces.
Why
---
Until now the only auth was a base64-encoded JSON ``UserContext`` carried as
a bearer token. That proves nothing — any caller can craft any identity. This
module replaces it with HS256-signed JWTs:
- ``issue_token(user_id, org_id, roles, clearance_level)`` mints a token
signed with ``settings.jwt_secret``. Suitable for the dev ``/token``
endpoint and for tests.
- ``verify_token(token)`` validates signature, expiry, and (when configured)
``iss`` / ``aud`` claims, then returns a ``UserContext`` plus the raw
claims (for audit logging the ``jti``).
RS256 / JWKS mode
-----------------
When ``settings.jwt_algorithm`` is ``"RS256"`` tokens are verified against a
remote JWKS endpoint (Keycloak / Auth0 / etc.). The local ``/token`` endpoint
returns 404 because we do not hold the IdP's private key. See ``utils/jwks_cache``.
Fail-closed default
-------------------
When ``settings.jwt_secret`` is unset the verifier rejects every token. The
legacy unsigned base64 shape is accepted *only* when
``settings.allow_unsigned_tokens`` is explicitly turned on (dev/test). An
unsigned token proves no identity, so it is never honoured silently in
production. Deployments must set ``SAR_JWT_SECRET`` or ``SAR_JWKS_URL``.
"""
from __future__ import annotations
import base64
import json
import uuid
from datetime import UTC, datetime, timedelta
from typing import Any
from config.settings import settings
from ingestion.metadata import UserContext
from utils.logging import get_logger
logger = get_logger(__name__)
class AuthError(Exception):
"""Raised when token verification fails for any reason.
Carries a short machine-readable ``reason`` (``missing`` / ``expired`` /
``bad_signature`` / ``bad_claims`` / ``malformed``) so callers can map
consistently to HTTP status codes.
"""
def __init__(self, reason: str, message: str = "") -> None:
super().__init__(message or reason)
self.reason = reason
def _jose_available() -> bool:
"""Best-effort detection of python-jose. Cached on import would be
nicer but a function makes the unit tests cleaner.
"""
try:
import jose # noqa: F401
return True
except ImportError:
return False
def _decode_header(token: str) -> dict[str, Any]:
"""Decode the JWT header without verification."""
parts = token.split(".")
if len(parts) != 3:
raise AuthError("malformed", "token must have 3 segments")
padding = 4 - len(parts[0]) % 4
if padding != 4:
parts[0] += "=" * padding
try:
return json.loads(base64.urlsafe_b64decode(parts[0]).decode("utf-8"))
except Exception as exc:
raise AuthError("malformed", f"invalid header: {exc}") from exc
def issue_token(
user_id: str,
org_id: str,
roles: list[str],
clearance_level: int = 1,
ttl_seconds: int | None = None,
extra_claims: dict[str, Any] | None = None,
) -> str:
"""Mint a signed JWT for the given identity.
Args:
user_id: Stable principal identifier.
org_id: Organization the principal belongs to. Drives multi-tenant
collection routing.
roles: List of role strings carried into the ``UserContext``.
clearance_level: Numeric clearance (1=low / 3=high).
ttl_seconds: Lifetime override; defaults to ``settings.jwt_ttl_seconds``.
extra_claims: Optional extra claims to merge into the token payload.
Returns:
Compact JWT string.
Raises:
AuthError: If ``settings.jwt_secret`` is not configured or python-jose
is missing.
"""
if settings.jwt_algorithm.upper() == "RS256":
raise AuthError("missing", "Cannot issue RS256 tokens locally — use the external IdP")
if not settings.jwt_secret:
raise AuthError("missing", "SAR_JWT_SECRET is not configured")
if not _jose_available():
raise AuthError("missing", "python-jose is not installed (install the [api] extra)")
from jose import jwt # type: ignore[import-not-found]
now = datetime.now(UTC)
ttl = ttl_seconds if ttl_seconds is not None else settings.jwt_ttl_seconds
payload: dict[str, Any] = {
"sub": user_id,
"user_id": user_id,
"org_id": org_id,
"roles": list(roles),
"clearance_level": int(clearance_level),
"iat": int(now.timestamp()),
"exp": int((now + timedelta(seconds=ttl)).timestamp()),
"jti": str(uuid.uuid4()),
}
if settings.jwt_issuer:
payload["iss"] = settings.jwt_issuer
if settings.jwt_audience:
payload["aud"] = settings.jwt_audience
if extra_claims:
# Never let caller-supplied extra claims overwrite the security-relevant
# registered claims (expiry, subject, token id, issuer, audience, etc.).
# Otherwise a caller could mint a non-expiring or identity-spoofing token.
_reserved = {
"sub",
"user_id",
"org_id",
"roles",
"clearance_level",
"iat",
"exp",
"jti",
"iss",
"aud",
}
safe_extra = {k: v for k, v in extra_claims.items() if k not in _reserved}
dropped = set(extra_claims) - set(safe_extra)
if dropped:
logger.warning("jwt_extra_claims_dropped_reserved", dropped=sorted(dropped))
payload.update(safe_extra)
token = jwt.encode(payload, settings.jwt_secret, algorithm=settings.jwt_algorithm)
logger.info(
"jwt_issued",
user_id=user_id,
org_id=org_id,
roles=roles,
jti=payload["jti"],
ttl_seconds=ttl,
)
return token
def _verify_legacy_base64(token: str) -> tuple[UserContext, dict[str, Any]]:
"""Decode the legacy unsigned base64(json(UserContext)) shape.
Used only when ``settings.jwt_secret`` is unset. Always raises a runtime
warning so deployments don't forget to flip on signed JWTs.
"""
logger.warning(
"auth_unsigned_token",
message=(
"SAR_JWT_SECRET unset — accepting unsigned base64 tokens. Anyone "
"with network access can impersonate any user. Configure "
"SAR_JWT_SECRET in production."
),
)
try:
payload = json.loads(base64.b64decode(token).decode("utf-8"))
except Exception as exc:
raise AuthError("malformed", f"base64/json decode failed: {exc}") from exc
try:
ctx = UserContext(**payload)
except Exception as exc:
raise AuthError("bad_claims", f"UserContext build failed: {exc}") from exc
return ctx, {"sub": ctx.user_id, "jti": "unsigned"}
def _verify_jwt(token: str) -> tuple[UserContext, dict[str, Any]]:
"""Verify a signed JWT (HS256 or RS256)."""
if not _jose_available():
raise AuthError("missing", "python-jose is not installed")
from jose import JWTError, jwt # type: ignore[import-not-found]
options: dict[str, Any] = {"require": ["exp", "iat", "sub"]}
audience = settings.jwt_audience or None
if not audience:
options["verify_aud"] = False
algorithm = settings.jwt_algorithm.upper()
# Resolve the verification key.
if algorithm == "RS256":
try:
header = _decode_header(token)
kid = header.get("kid")
if not kid:
raise AuthError("bad_claims", "RS256 token missing 'kid' header")
from utils.jwks_cache import get_signing_key
key = get_signing_key(kid)
except AuthError:
raise
except Exception as exc:
raise AuthError("bad_signature", f"JWKS lookup failed: {exc}") from exc
else:
key = settings.jwt_secret
if not key:
raise AuthError("missing", "SAR_JWT_SECRET is not configured")
try:
claims = jwt.decode(
token,
key,
algorithms=[algorithm],
audience=audience,
issuer=settings.jwt_issuer or None,
options=options,
)
except JWTError as exc:
msg = str(exc).lower()
if "expired" in msg:
reason = "expired"
elif "signature" in msg:
reason = "bad_signature"
elif "claim" in msg or "audience" in msg or "issuer" in msg:
reason = "bad_claims"
else:
reason = "malformed"
raise AuthError(reason, f"jwt decode failed: {exc}") from exc
try:
ctx = UserContext(
user_id=claims.get("user_id") or claims["sub"],
org_id=claims.get("org_id", ""),
roles=list(claims.get("roles", [])),
clearance_level=int(claims.get("clearance_level", 1)),
)
except Exception as exc:
raise AuthError("bad_claims", f"UserContext build failed: {exc}") from exc
return ctx, claims
def verify_token(token: str) -> tuple[UserContext, dict[str, Any]]:
"""Resolve a bearer token to a ``UserContext`` plus the raw claims.
Args:
token: Raw bearer token (no ``Bearer `` prefix).
Returns:
``(user_context, claims)``. ``claims`` includes at minimum ``sub`` and
``jti``; the latter is used in audit-trail entries so a tampered or
replayed token is traceable.
Raises:
AuthError: With ``.reason`` set to one of
``missing`` / ``malformed`` / ``expired`` / ``bad_signature`` /
``bad_claims``.
"""
if not token or not isinstance(token, str):
raise AuthError("missing", "empty token")
# When jwt_secret is set OR we're in RS256 mode, use JWT verification.
if settings.jwt_secret or settings.jwt_algorithm.upper() == "RS256":
return _verify_jwt(token)
# No signing key configured. Fail closed unless the legacy unsigned shape
# is explicitly opted into (dev/test only).
if settings.allow_unsigned_tokens:
return _verify_legacy_base64(token)
raise AuthError(
"missing",
"no JWT secret configured and unsigned tokens are disabled "
"(set SAR_JWT_SECRET, or SAR_ALLOW_UNSIGNED_TOKENS=true for dev only)",
)
|