Viraj0112 commited on
Commit
b683ba5
·
verified ·
1 Parent(s): 1b49b1c

Update backend/utils/auth.py

Browse files
Files changed (1) hide show
  1. backend/utils/auth.py +65 -9
backend/utils/auth.py CHANGED
@@ -5,10 +5,12 @@ Authentication utilities for Supabase JWT verification.
5
  import os
6
  import jwt
7
  import logging
 
8
  from typing import Optional, Dict
9
  from fastapi import HTTPException, Security, Header
10
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
  from dotenv import load_dotenv
 
12
 
13
  load_dotenv()
14
 
@@ -22,29 +24,83 @@ SUPABASE_ANON_KEY = os.getenv("SUPABASE_ANON_KEY", "")
22
  # Security scheme
23
  security = HTTPBearer(auto_error=False)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def verify_supabase_token(token: str) -> Optional[Dict]:
27
  """
28
  Verify a Supabase JWT token and return the decoded payload.
29
 
 
 
30
  Args:
31
  token: JWT token string
32
 
33
  Returns:
34
  Decoded token payload if valid, None otherwise
35
  """
36
- if not SUPABASE_JWT_SECRET:
37
- logger.warning("⚠️ SUPABASE_JWT_SECRET not set. Auth verification disabled.")
38
  return None
39
 
40
  try:
41
- # Decode and verify the JWT token
42
- decoded = jwt.decode(
43
- token,
44
- SUPABASE_JWT_SECRET,
45
- algorithms=["ES256", "RS256", "HS256"],
46
- audience="authenticated"
47
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  logger.debug(f"✅ Token verified for user: {decoded.get('sub')}")
50
  return decoded
 
5
  import os
6
  import jwt
7
  import logging
8
+ import requests
9
  from typing import Optional, Dict
10
  from fastapi import HTTPException, Security, Header
11
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
12
  from dotenv import load_dotenv
13
+ from jwt import PyJWKClient
14
 
15
  load_dotenv()
16
 
 
24
  # Security scheme
25
  security = HTTPBearer(auto_error=False)
26
 
27
+ # JWKS client for ES256 verification (cached)
28
+ _jwks_client: Optional[PyJWKClient] = None
29
+
30
+
31
+ def get_jwks_client() -> Optional[PyJWKClient]:
32
+ """
33
+ Get or create a cached JWKS client for the Supabase project.
34
+ """
35
+ global _jwks_client
36
+
37
+ if _jwks_client is not None:
38
+ return _jwks_client
39
+
40
+ if not SUPABASE_URL:
41
+ logger.warning("⚠️ SUPABASE_URL not set. Cannot create JWKS client.")
42
+ return None
43
+
44
+ try:
45
+ # Supabase JWKS endpoint
46
+ jwks_url = f"{SUPABASE_URL}/auth/v1/.well-known/jwks.json"
47
+ _jwks_client = PyJWKClient(jwks_url, cache_keys=True)
48
+ logger.info(f"✅ JWKS client initialized for {jwks_url}")
49
+ return _jwks_client
50
+ except Exception as e:
51
+ logger.error(f"❌ Failed to create JWKS client: {e}")
52
+ return None
53
+
54
 
55
  def verify_supabase_token(token: str) -> Optional[Dict]:
56
  """
57
  Verify a Supabase JWT token and return the decoded payload.
58
 
59
+ Uses JWKS for ES256/RS256 verification, falls back to secret for HS256.
60
+
61
  Args:
62
  token: JWT token string
63
 
64
  Returns:
65
  Decoded token payload if valid, None otherwise
66
  """
67
+ if not SUPABASE_URL and not SUPABASE_JWT_SECRET:
68
+ logger.warning("⚠️ No Supabase auth configuration. Auth verification disabled.")
69
  return None
70
 
71
  try:
72
+ # First, peek at the token header to determine the algorithm
73
+ unverified_header = jwt.get_unverified_header(token)
74
+ algorithm = unverified_header.get("alg", "HS256")
75
+
76
+ if algorithm in ["ES256", "RS256"]:
77
+ # Use JWKS for asymmetric algorithms
78
+ jwks_client = get_jwks_client()
79
+ if not jwks_client:
80
+ logger.error("❌ JWKS client not available for asymmetric verification")
81
+ return None
82
+
83
+ # Get the signing key from JWKS
84
+ signing_key = jwks_client.get_signing_key_from_jwt(token)
85
+
86
+ decoded = jwt.decode(
87
+ token,
88
+ signing_key.key,
89
+ algorithms=[algorithm],
90
+ audience="authenticated"
91
+ )
92
+ else:
93
+ # Use secret for symmetric algorithms (HS256)
94
+ if not SUPABASE_JWT_SECRET:
95
+ logger.warning("⚠️ SUPABASE_JWT_SECRET not set for HS256 verification.")
96
+ return None
97
+
98
+ decoded = jwt.decode(
99
+ token,
100
+ SUPABASE_JWT_SECRET,
101
+ algorithms=["HS256"],
102
+ audience="authenticated"
103
+ )
104
 
105
  logger.debug(f"✅ Token verified for user: {decoded.get('sub')}")
106
  return decoded