sirus / backend /ml_module /core /auth.py
ranilmukesh's picture
Deploy SiRUS SQL Agent backend
783a952
"""Authentication utilities for Supabase-backed JWT verification."""
from __future__ import annotations
from functools import lru_cache
from threading import Lock
from typing import Any, Dict, Optional
import jwt
from jwt import PyJWKClient
from ml_module.core.config import settings
class SupabaseAuthError(Exception):
"""Raised when Supabase JWT verification fails."""
class SupabaseJWTVerifier:
"""Verifies Supabase-issued JWTs using JWKS with basic caching."""
def __init__(self) -> None:
self._jwks_client: Optional[PyJWKClient] = None
self._jwks_url: Optional[str] = None
self._lock = Lock()
def _resolve_jwks_url(self) -> str:
if settings.SUPABASE_JWKS_URL:
return settings.SUPABASE_JWKS_URL
if settings.SUPABASE_PROJECT_ID:
return f"https://{settings.SUPABASE_PROJECT_ID}.supabase.co/auth/v1/certs"
raise SupabaseAuthError("Supabase JWKS URL is not configured")
def _get_client(self) -> PyJWKClient:
jwks_url = self._resolve_jwks_url()
with self._lock:
if not self._jwks_client or jwks_url != self._jwks_url:
self._jwks_client = PyJWKClient( # type: ignore[call-arg]
jwks_url,
cache_keys=True,
)
self._jwks_url = jwks_url
return self._jwks_client # type: ignore[return-value]
def decode(self, token: str) -> Dict[str, Any]:
client = self._get_client()
try:
signing_key = client.get_signing_key_from_jwt(token)
except Exception as exc: # pragma: no cover - library-specific errors
raise SupabaseAuthError("Unable to resolve signing key for Supabase token") from exc
options = {"verify_aud": bool(settings.SUPABASE_JWT_AUDIENCE)}
audience = settings.SUPABASE_JWT_AUDIENCE or None
issuer = settings.SUPABASE_JWT_ISSUER or None
try:
claims = jwt.decode(
token,
signing_key.key,
algorithms=["RS256"],
audience=audience,
issuer=issuer,
options=options,
)
except jwt.ExpiredSignatureError as exc: # pragma: no cover - time dependent
raise SupabaseAuthError("Supabase token has expired") from exc
except jwt.InvalidTokenError as exc: # pragma: no cover - upstream detail
raise SupabaseAuthError("Invalid Supabase token") from exc
return claims
@lru_cache(maxsize=1)
def get_supabase_verifier() -> SupabaseJWTVerifier:
return SupabaseJWTVerifier()