Spaces:
Running
Running
| # In /backend/core/auth.py | |
| import base64 | |
| import json | |
| import os | |
| import threading | |
| import time | |
| from typing import Dict, Optional, Tuple | |
| from fastapi import Depends, HTTPException, Request | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from supabase import create_client, Client # Updated import for modern supabase library | |
| # Load environment variables from .env for local/dev/test | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| # Initialize Supabase Admin Client (lazy to avoid import-time crash on HF Spaces) | |
| _supabase_client = None | |
| AUTH_CACHE_MAX_ENTRIES = int(os.environ.get("AUTH_CACHE_MAX_ENTRIES", "2048")) | |
| _auth_user_cache: Dict[str, Tuple["AuthUser", int]] = {} | |
| _auth_user_cache_lock = threading.Lock() | |
| def _decode_jwt_payload(token: str) -> dict: | |
| try: | |
| parts = token.split(".") | |
| if len(parts) < 2: | |
| return {} | |
| payload = parts[1] | |
| payload += "=" * (-len(payload) % 4) | |
| decoded = base64.urlsafe_b64decode(payload.encode("utf-8")).decode("utf-8") | |
| data = json.loads(decoded) | |
| return data if isinstance(data, dict) else {} | |
| except Exception: | |
| return {} | |
| def _extract_token_exp(token: str) -> int: | |
| payload = _decode_jwt_payload(token) | |
| exp = payload.get("exp") | |
| if isinstance(exp, (int, float)): | |
| return int(exp) | |
| # If token exp can't be read, keep cache window short. | |
| return int(time.time()) + 300 | |
| def _get_cached_user(token_value: str) -> Optional["AuthUser"]: | |
| now = int(time.time()) | |
| with _auth_user_cache_lock: | |
| cached = _auth_user_cache.get(token_value) | |
| if not cached: | |
| return None | |
| user, exp = cached | |
| if exp <= now: | |
| _auth_user_cache.pop(token_value, None) | |
| return None | |
| return user | |
| def _cache_user(token_value: str, user: "AuthUser") -> None: | |
| now = int(time.time()) | |
| exp = _extract_token_exp(token_value) | |
| with _auth_user_cache_lock: | |
| _auth_user_cache[token_value] = (user, exp) | |
| # Remove expired entries first. | |
| expired_keys = [ | |
| key for key, (_cached_user, cached_exp) in _auth_user_cache.items() if cached_exp <= now | |
| ] | |
| for key in expired_keys: | |
| _auth_user_cache.pop(key, None) | |
| # Keep cache bounded. | |
| while len(_auth_user_cache) > AUTH_CACHE_MAX_ENTRIES: | |
| _auth_user_cache.pop(next(iter(_auth_user_cache)), None) | |
| def _get_supabase(): | |
| """Lazy-initialize the Supabase client on first use.""" | |
| global _supabase_client | |
| if _supabase_client is None: | |
| url = os.environ.get("SUPABASE_URL") | |
| key = os.environ.get("SUPABASE_SERVICE_ROLE_KEY") | |
| if not url or not key: | |
| raise RuntimeError("Missing SUPABASE_URL or SUPABASE_SERVICE_ROLE_KEY") | |
| _supabase_client = create_client(url, key) | |
| return _supabase_client | |
| auth_scheme = HTTPBearer() | |
| class AuthUser: | |
| """Pydantic-like model to hold user data from JWT.""" | |
| def __init__(self, user_data: dict): | |
| self.id = user_data.get('id') | |
| self.claims = user_data.get('user_metadata', {}) | |
| self.tenant_id = self.claims.get('tenant_id') | |
| self.role = self.claims.get('role') | |
| async def get_current_user( | |
| token: HTTPAuthorizationCredentials = Depends(auth_scheme) | |
| ) -> AuthUser: | |
| """ | |
| FastAPI dependency to validate Supabase JWT and return user info. | |
| This will be used by ALL user-facing API endpoints. | |
| """ | |
| token_value = token.credentials | |
| cached_user = _get_cached_user(token_value) | |
| if cached_user is not None: | |
| return cached_user | |
| try: | |
| # The 'get_user' function validates the JWT (token.credentials) | |
| user_data_response = _get_supabase().auth.get_user(jwt=token_value) | |
| user_obj = user_data_response.user | |
| # Try different ways to access user data | |
| user_data = {} | |
| try: | |
| user_data = vars(user_obj) | |
| except: | |
| pass | |
| # If vars() doesn't work, try direct attributes | |
| if not user_data: | |
| user_data = { | |
| 'id': getattr(user_obj, 'id', None), | |
| 'user_metadata': getattr(user_obj, 'user_metadata', {}), | |
| 'email': getattr(user_obj, 'email', None), | |
| } | |
| user = AuthUser(user_data) | |
| # The custom claims hook ensures these fields exist | |
| if not user.id or not user.tenant_id or not user.role: | |
| raise HTTPException(status_code=401, detail="Invalid token claims") | |
| _cache_user(token_value, user) | |
| return user | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| error_text = str(e).lower() | |
| transient_network_error = any( | |
| token in error_text | |
| for token in ( | |
| "timed out", | |
| "timeout", | |
| "handshake", | |
| "ssl", | |
| "connection", | |
| "network", | |
| "temporarily unavailable", | |
| ) | |
| ) | |
| if transient_network_error: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Authentication service temporarily unavailable. Please retry.", | |
| ) | |
| raise HTTPException(status_code=401, detail="Invalid token") | |
| async def get_tenant_admin( | |
| user: AuthUser = Depends(get_current_user) | |
| ) -> AuthUser: | |
| """ | |
| Dependency that *also* checks if the user is a TENANT_ADMIN. | |
| This will be used by all self-service API endpoints. | |
| """ | |
| if user.role != 'TENANT_ADMIN': | |
| raise HTTPException(status_code=403, detail="Forbidden: Admin access required") | |
| return user |