"""Firebase JWT verification middleware.""" from fastapi import Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from sqlalchemy import select from src.utils.database import AsyncSessionLocal from src.models.database import User from src.config import get_settings import httpx import jwt as pyjwt import os import time bearer_scheme = HTTPBearer(auto_error=False) JWKS_URL = f"https://www.googleapis.com/service_accounts/v1/jwk/securetoken@system.gserviceaccount.com" _PUBLIC_KEYS_CACHE = {"expires_at": 0.0, "keys": None} def _firebase_project_id() -> str: return os.getenv("FIREBASE_PROJECT_ID", "") or get_settings().firebase_project_id async def _get_firebase_public_keys() -> dict: now = time.time() if _PUBLIC_KEYS_CACHE["keys"] and _PUBLIC_KEYS_CACHE["expires_at"] > now: return _PUBLIC_KEYS_CACHE["keys"] # Try certifi first (production-safe). Fall back to verify=False if SSL fails # (Windows dev machines often have antivirus interfering with cert chain). import certifi try: async with httpx.AsyncClient(verify=certifi.where(), timeout=10) as client: r = await client.get(JWKS_URL) r.raise_for_status() data = r.json() except Exception as e: msg = str(e).lower() if "certificate" in msg or "ssl" in msg: from loguru import logger logger.warning(f"[AUTH] SSL verify failed locally, retrying without verify: {e}") async with httpx.AsyncClient(verify=False, timeout=10) as client: r = await client.get(JWKS_URL) r.raise_for_status() data = r.json() else: raise _PUBLIC_KEYS_CACHE["keys"] = data _PUBLIC_KEYS_CACHE["expires_at"] = now + 3600 return data async def verify_firebase_token_string(token: str) -> dict: """Verify a raw Firebase JWT and return its decoded payload.""" project_id = _firebase_project_id() if not project_id: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Firebase project is not configured", ) try: header = pyjwt.get_unverified_header(token) keys_data = await _get_firebase_public_keys() public_key = None for key in keys_data.get("keys", []): if key.get("kid") == header.get("kid"): public_key = pyjwt.algorithms.RSAAlgorithm.from_jwk(key) break if not public_key: print(f"[AUTH] No matching key for kid={header.get('kid')}; available={[k.get('kid') for k in keys_data.get('keys', [])]}") raise HTTPException(status_code=401, detail="Invalid token key") return pyjwt.decode( token, public_key, algorithms=["RS256"], audience=project_id, ) except pyjwt.ExpiredSignatureError: print("[AUTH] Token expired") raise HTTPException(status_code=401, detail="Token expired") except HTTPException: raise except Exception as e: print(f"[AUTH] Verify failed: {type(e).__name__}: {e}; project_id={project_id}") raise HTTPException(status_code=401, detail=f"Invalid token: {str(e)}") async def verify_firebase_token( credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme), ) -> dict: if not credentials: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") return await verify_firebase_token_string(credentials.credentials) async def get_current_user(payload: dict = Depends(verify_firebase_token)) -> "User": """Get or create user from Firebase token.""" firebase_uid = payload.get("uid") or payload.get("sub") email = payload.get("email", "") async with AsyncSessionLocal() as session: result = await session.execute(select(User).where(User.firebase_uid == firebase_uid)) user = result.scalar_one_or_none() if not user: user = User( firebase_uid=firebase_uid, email=email, display_name=payload.get("name", email.split("@")[0]), photo_url=payload.get("picture"), is_active=True, ) session.add(user) await session.commit() await session.refresh(user) return user