Spaces:
Running
Running
| """Authentication utilities for Supabase-backed JWT verification.""" | |
| from __future__ import annotations | |
| from functools import lru_cache | |
| from threading import Lock | |
| from typing import Any, Dict, Optional | |
| import jwt | |
| from jwt import PyJWKClient | |
| from ml_module.core.config import settings | |
| class SupabaseAuthError(Exception): | |
| """Raised when Supabase JWT verification fails.""" | |
| class SupabaseJWTVerifier: | |
| """Verifies Supabase-issued JWTs using JWKS with basic caching.""" | |
| def __init__(self) -> None: | |
| self._jwks_client: Optional[PyJWKClient] = None | |
| self._jwks_url: Optional[str] = None | |
| self._lock = Lock() | |
| def _resolve_jwks_url(self) -> str: | |
| if settings.SUPABASE_JWKS_URL: | |
| return settings.SUPABASE_JWKS_URL | |
| if settings.SUPABASE_PROJECT_ID: | |
| return f"https://{settings.SUPABASE_PROJECT_ID}.supabase.co/auth/v1/certs" | |
| raise SupabaseAuthError("Supabase JWKS URL is not configured") | |
| def _get_client(self) -> PyJWKClient: | |
| jwks_url = self._resolve_jwks_url() | |
| with self._lock: | |
| if not self._jwks_client or jwks_url != self._jwks_url: | |
| self._jwks_client = PyJWKClient( # type: ignore[call-arg] | |
| jwks_url, | |
| cache_keys=True, | |
| ) | |
| self._jwks_url = jwks_url | |
| return self._jwks_client # type: ignore[return-value] | |
| def decode(self, token: str) -> Dict[str, Any]: | |
| client = self._get_client() | |
| try: | |
| signing_key = client.get_signing_key_from_jwt(token) | |
| except Exception as exc: # pragma: no cover - library-specific errors | |
| raise SupabaseAuthError("Unable to resolve signing key for Supabase token") from exc | |
| options = {"verify_aud": bool(settings.SUPABASE_JWT_AUDIENCE)} | |
| audience = settings.SUPABASE_JWT_AUDIENCE or None | |
| issuer = settings.SUPABASE_JWT_ISSUER or None | |
| try: | |
| claims = jwt.decode( | |
| token, | |
| signing_key.key, | |
| algorithms=["RS256"], | |
| audience=audience, | |
| issuer=issuer, | |
| options=options, | |
| ) | |
| except jwt.ExpiredSignatureError as exc: # pragma: no cover - time dependent | |
| raise SupabaseAuthError("Supabase token has expired") from exc | |
| except jwt.InvalidTokenError as exc: # pragma: no cover - upstream detail | |
| raise SupabaseAuthError("Invalid Supabase token") from exc | |
| return claims | |
| def get_supabase_verifier() -> SupabaseJWTVerifier: | |
| return SupabaseJWTVerifier() | |