"""Authentication utilities for Supabase-backed JWT verification.""" from __future__ import annotations from functools import lru_cache from threading import Lock from typing import Any, Dict, Optional import jwt from jwt import PyJWKClient from ml_module.core.config import settings class SupabaseAuthError(Exception): """Raised when Supabase JWT verification fails.""" class SupabaseJWTVerifier: """Verifies Supabase-issued JWTs using JWKS with basic caching.""" def __init__(self) -> None: self._jwks_client: Optional[PyJWKClient] = None self._jwks_url: Optional[str] = None self._lock = Lock() def _resolve_jwks_url(self) -> str: if settings.SUPABASE_JWKS_URL: return settings.SUPABASE_JWKS_URL if settings.SUPABASE_PROJECT_ID: return f"https://{settings.SUPABASE_PROJECT_ID}.supabase.co/auth/v1/certs" raise SupabaseAuthError("Supabase JWKS URL is not configured") def _get_client(self) -> PyJWKClient: jwks_url = self._resolve_jwks_url() with self._lock: if not self._jwks_client or jwks_url != self._jwks_url: self._jwks_client = PyJWKClient( # type: ignore[call-arg] jwks_url, cache_keys=True, ) self._jwks_url = jwks_url return self._jwks_client # type: ignore[return-value] def decode(self, token: str) -> Dict[str, Any]: client = self._get_client() try: signing_key = client.get_signing_key_from_jwt(token) except Exception as exc: # pragma: no cover - library-specific errors raise SupabaseAuthError("Unable to resolve signing key for Supabase token") from exc options = {"verify_aud": bool(settings.SUPABASE_JWT_AUDIENCE)} audience = settings.SUPABASE_JWT_AUDIENCE or None issuer = settings.SUPABASE_JWT_ISSUER or None try: claims = jwt.decode( token, signing_key.key, algorithms=["RS256"], audience=audience, issuer=issuer, options=options, ) except jwt.ExpiredSignatureError as exc: # pragma: no cover - time dependent raise SupabaseAuthError("Supabase token has expired") from exc except jwt.InvalidTokenError as exc: # pragma: no cover - upstream detail raise SupabaseAuthError("Invalid Supabase token") from exc return claims @lru_cache(maxsize=1) def get_supabase_verifier() -> SupabaseJWTVerifier: return SupabaseJWTVerifier()