sirus / backend /core /auth.py
ranilmukesh's picture
Deploy SiRUS SQL Agent backend
a8c9ee8
# In /backend/core/auth.py
import base64
import json
import os
import threading
import time
from typing import Dict, Optional, Tuple
from fastapi import Depends, HTTPException, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from supabase import create_client, Client # Updated import for modern supabase library
# Load environment variables from .env for local/dev/test
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
# Initialize Supabase Admin Client (lazy to avoid import-time crash on HF Spaces)
_supabase_client = None
AUTH_CACHE_MAX_ENTRIES = int(os.environ.get("AUTH_CACHE_MAX_ENTRIES", "2048"))
_auth_user_cache: Dict[str, Tuple["AuthUser", int]] = {}
_auth_user_cache_lock = threading.Lock()
def _decode_jwt_payload(token: str) -> dict:
try:
parts = token.split(".")
if len(parts) < 2:
return {}
payload = parts[1]
payload += "=" * (-len(payload) % 4)
decoded = base64.urlsafe_b64decode(payload.encode("utf-8")).decode("utf-8")
data = json.loads(decoded)
return data if isinstance(data, dict) else {}
except Exception:
return {}
def _extract_token_exp(token: str) -> int:
payload = _decode_jwt_payload(token)
exp = payload.get("exp")
if isinstance(exp, (int, float)):
return int(exp)
# If token exp can't be read, keep cache window short.
return int(time.time()) + 300
def _get_cached_user(token_value: str) -> Optional["AuthUser"]:
now = int(time.time())
with _auth_user_cache_lock:
cached = _auth_user_cache.get(token_value)
if not cached:
return None
user, exp = cached
if exp <= now:
_auth_user_cache.pop(token_value, None)
return None
return user
def _cache_user(token_value: str, user: "AuthUser") -> None:
now = int(time.time())
exp = _extract_token_exp(token_value)
with _auth_user_cache_lock:
_auth_user_cache[token_value] = (user, exp)
# Remove expired entries first.
expired_keys = [
key for key, (_cached_user, cached_exp) in _auth_user_cache.items() if cached_exp <= now
]
for key in expired_keys:
_auth_user_cache.pop(key, None)
# Keep cache bounded.
while len(_auth_user_cache) > AUTH_CACHE_MAX_ENTRIES:
_auth_user_cache.pop(next(iter(_auth_user_cache)), None)
def _get_supabase():
"""Lazy-initialize the Supabase client on first use."""
global _supabase_client
if _supabase_client is None:
url = os.environ.get("SUPABASE_URL")
key = os.environ.get("SUPABASE_SERVICE_ROLE_KEY")
if not url or not key:
raise RuntimeError("Missing SUPABASE_URL or SUPABASE_SERVICE_ROLE_KEY")
_supabase_client = create_client(url, key)
return _supabase_client
auth_scheme = HTTPBearer()
class AuthUser:
"""Pydantic-like model to hold user data from JWT."""
def __init__(self, user_data: dict):
self.id = user_data.get('id')
self.claims = user_data.get('user_metadata', {})
self.tenant_id = self.claims.get('tenant_id')
self.role = self.claims.get('role')
async def get_current_user(
token: HTTPAuthorizationCredentials = Depends(auth_scheme)
) -> AuthUser:
"""
FastAPI dependency to validate Supabase JWT and return user info.
This will be used by ALL user-facing API endpoints.
"""
token_value = token.credentials
cached_user = _get_cached_user(token_value)
if cached_user is not None:
return cached_user
try:
# The 'get_user' function validates the JWT (token.credentials)
user_data_response = _get_supabase().auth.get_user(jwt=token_value)
user_obj = user_data_response.user
# Try different ways to access user data
user_data = {}
try:
user_data = vars(user_obj)
except:
pass
# If vars() doesn't work, try direct attributes
if not user_data:
user_data = {
'id': getattr(user_obj, 'id', None),
'user_metadata': getattr(user_obj, 'user_metadata', {}),
'email': getattr(user_obj, 'email', None),
}
user = AuthUser(user_data)
# The custom claims hook ensures these fields exist
if not user.id or not user.tenant_id or not user.role:
raise HTTPException(status_code=401, detail="Invalid token claims")
_cache_user(token_value, user)
return user
except HTTPException:
raise
except Exception as e:
error_text = str(e).lower()
transient_network_error = any(
token in error_text
for token in (
"timed out",
"timeout",
"handshake",
"ssl",
"connection",
"network",
"temporarily unavailable",
)
)
if transient_network_error:
raise HTTPException(
status_code=503,
detail="Authentication service temporarily unavailable. Please retry.",
)
raise HTTPException(status_code=401, detail="Invalid token")
async def get_tenant_admin(
user: AuthUser = Depends(get_current_user)
) -> AuthUser:
"""
Dependency that *also* checks if the user is a TENANT_ADMIN.
This will be used by all self-service API endpoints.
"""
if user.role != 'TENANT_ADMIN':
raise HTTPException(status_code=403, detail="Forbidden: Admin access required")
return user