File size: 2,640 Bytes
783a952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Authentication utilities for Supabase-backed JWT verification."""
from __future__ import annotations

from functools import lru_cache
from threading import Lock
from typing import Any, Dict, Optional

import jwt
from jwt import PyJWKClient

from ml_module.core.config import settings


class SupabaseAuthError(Exception):
    """Raised when Supabase JWT verification fails."""


class SupabaseJWTVerifier:
    """Verifies Supabase-issued JWTs using JWKS with basic caching."""

    def __init__(self) -> None:
        self._jwks_client: Optional[PyJWKClient] = None
        self._jwks_url: Optional[str] = None
        self._lock = Lock()

    def _resolve_jwks_url(self) -> str:
        if settings.SUPABASE_JWKS_URL:
            return settings.SUPABASE_JWKS_URL
        if settings.SUPABASE_PROJECT_ID:
            return f"https://{settings.SUPABASE_PROJECT_ID}.supabase.co/auth/v1/certs"
        raise SupabaseAuthError("Supabase JWKS URL is not configured")

    def _get_client(self) -> PyJWKClient:
        jwks_url = self._resolve_jwks_url()
        with self._lock:
            if not self._jwks_client or jwks_url != self._jwks_url:
                self._jwks_client = PyJWKClient(  # type: ignore[call-arg]
                    jwks_url,
                    cache_keys=True,
                )
                self._jwks_url = jwks_url
        return self._jwks_client  # type: ignore[return-value]

    def decode(self, token: str) -> Dict[str, Any]:
        client = self._get_client()
        try:
            signing_key = client.get_signing_key_from_jwt(token)
        except Exception as exc:  # pragma: no cover - library-specific errors
            raise SupabaseAuthError("Unable to resolve signing key for Supabase token") from exc

        options = {"verify_aud": bool(settings.SUPABASE_JWT_AUDIENCE)}
        audience = settings.SUPABASE_JWT_AUDIENCE or None
        issuer = settings.SUPABASE_JWT_ISSUER or None

        try:
            claims = jwt.decode(
                token,
                signing_key.key,
                algorithms=["RS256"],
                audience=audience,
                issuer=issuer,
                options=options,
            )
        except jwt.ExpiredSignatureError as exc:  # pragma: no cover - time dependent
            raise SupabaseAuthError("Supabase token has expired") from exc
        except jwt.InvalidTokenError as exc:  # pragma: no cover - upstream detail
            raise SupabaseAuthError("Invalid Supabase token") from exc

        return claims


@lru_cache(maxsize=1)
def get_supabase_verifier() -> SupabaseJWTVerifier:
    return SupabaseJWTVerifier()