File size: 5,242 Bytes
3998131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import json
from typing import Dict, Optional
from datetime import datetime, timedelta
from jwt import PyJWKClient, decode, InvalidTokenError, get_unverified_header
from fastapi import HTTPException, status
from api.core.config import settings


class SupabaseJWT:
    """Handle Supabase JWT token verification and user extraction."""

    def __init__(self, supabase_url: str):
        if not supabase_url:
            raise ValueError("SUPABASE_URL is not configured")
        
        self.supabase_url = supabase_url.rstrip('/')
        self.jwks_url = f"{self.supabase_url}/auth/v1/.well-known/jwks.json"
        self._jwk_client: Optional[PyJWKClient] = None
        print(f"[DEBUG] Initialized SupabaseJWT with URL: {self.supabase_url}")
        print(f"[DEBUG] JWKS URL: {self.jwks_url}")

    @property
    def jwk_client(self) -> PyJWKClient:
        """Lazily initialize and cache the JWK client."""
        if self._jwk_client is None:
            try:
                self._jwk_client = PyJWKClient(self.jwks_url)
                print("[DEBUG] PyJWKClient initialized successfully")
            except Exception as e:
                print(f"[ERROR] Failed to initialize PyJWKClient: {e}")
                raise
        return self._jwk_client

    def verify_token(self, token: str) -> Dict:
        """
        Verify a Supabase JWT token and return the payload.
        
        Args:
            token: JWT token string
            
        Returns:
            Decoded token payload
            
        Raises:
            HTTPException: If token is invalid or expired
        """
        try:
            # First, decode without verification to see the header and payload
            unverified_header = get_unverified_header(token)
            print(f"[DEBUG] Token algorithm: {unverified_header.get('alg')}")
            print(f"[DEBUG] Token kid: {unverified_header.get('kid')}")
            
            # Get signing key from JWKS
            try:
                signing_key = self.jwk_client.get_signing_key_from_jwt(token)
                print(f"[DEBUG] Successfully retrieved signing key")
            except Exception as e:
                print(f"[ERROR] Failed to get signing key: {e}")
                raise HTTPException(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    detail=f"Could not retrieve signing key: {str(e)}",
                    headers={"WWW-Authenticate": "Bearer"},
                )
            
            # Decode and verify the token
            payload = decode(
                token,
                signing_key.key,
                algorithms=["RS256", "HS256", "ES256"],  # Support RS256, HS256, and ES256
                options={"verify_aud": False},
            )
            print("[DEBUG] Token verified successfully")
            return payload
            
        except InvalidTokenError as e:
            print(f"[ERROR] Invalid token: {e}")
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail=f"Invalid authentication credentials: {str(e)}",
                headers={"WWW-Authenticate": "Bearer"},
            )
        except HTTPException:
            raise
        except Exception as e:
            print(f"[ERROR] Token verification failed: {e}")
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Could not validate credentials",
                headers={"WWW-Authenticate": "Bearer"},
            )

    def extract_user(self, payload: Dict) -> Dict:
        """
        Extract user information from token payload.
        
        Args:
            payload: Decoded JWT payload
            
        Returns:
            User object with id, email, role
        """
        return {
            "id": payload.get("sub"),
            "email": payload.get("email"),
            "role": payload.get("role") or payload.get("app_metadata", {}).get("role", "user"),
            "phone": payload.get("phone"),
            "user_metadata": payload.get("user_metadata", {}),
        }


# Initialize Supabase JWT handler
supabase_jwt = None
try:
    if settings.supabase_url:
        supabase_jwt = SupabaseJWT(settings.supabase_url)
    else:
        print("[WARNING] SUPABASE_URL not configured")
except Exception as e:
    print(f"[ERROR] Failed to initialize Supabase JWT: {e}")


def verify_supabase_token(token: str) -> Dict:
    """Verify and decode a Supabase JWT token."""
    if not supabase_jwt:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="Supabase not configured",
        )
    return supabase_jwt.verify_token(token)


def extract_user_from_token(payload: Dict) -> Dict:
    """Extract user info from token payload."""
    if not supabase_jwt:
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail="Supabase not configured",
        )
    user = supabase_jwt.extract_user(payload)
    if not user.get("id"):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid token payload",
        )
    return user