Spaces:
Running
Running
| """Firebase JWT verification middleware.""" | |
| from fastapi import Depends, HTTPException, status | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from sqlalchemy import select | |
| from src.utils.database import AsyncSessionLocal | |
| from src.models.database import User | |
| from src.config import get_settings | |
| import httpx | |
| import jwt as pyjwt | |
| import os | |
| import time | |
| bearer_scheme = HTTPBearer(auto_error=False) | |
| JWKS_URL = f"https://www.googleapis.com/service_accounts/v1/jwk/securetoken@system.gserviceaccount.com" | |
| _PUBLIC_KEYS_CACHE = {"expires_at": 0.0, "keys": None} | |
| def _firebase_project_id() -> str: | |
| return os.getenv("FIREBASE_PROJECT_ID", "") or get_settings().firebase_project_id | |
| async def _get_firebase_public_keys() -> dict: | |
| now = time.time() | |
| if _PUBLIC_KEYS_CACHE["keys"] and _PUBLIC_KEYS_CACHE["expires_at"] > now: | |
| return _PUBLIC_KEYS_CACHE["keys"] | |
| # Try certifi first (production-safe). Fall back to verify=False if SSL fails | |
| # (Windows dev machines often have antivirus interfering with cert chain). | |
| import certifi | |
| try: | |
| async with httpx.AsyncClient(verify=certifi.where(), timeout=10) as client: | |
| r = await client.get(JWKS_URL) | |
| r.raise_for_status() | |
| data = r.json() | |
| except Exception as e: | |
| msg = str(e).lower() | |
| if "certificate" in msg or "ssl" in msg: | |
| from loguru import logger | |
| logger.warning(f"[AUTH] SSL verify failed locally, retrying without verify: {e}") | |
| async with httpx.AsyncClient(verify=False, timeout=10) as client: | |
| r = await client.get(JWKS_URL) | |
| r.raise_for_status() | |
| data = r.json() | |
| else: | |
| raise | |
| _PUBLIC_KEYS_CACHE["keys"] = data | |
| _PUBLIC_KEYS_CACHE["expires_at"] = now + 3600 | |
| return data | |
| async def verify_firebase_token_string(token: str) -> dict: | |
| """Verify a raw Firebase JWT and return its decoded payload.""" | |
| project_id = _firebase_project_id() | |
| if not project_id: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Firebase project is not configured", | |
| ) | |
| try: | |
| header = pyjwt.get_unverified_header(token) | |
| keys_data = await _get_firebase_public_keys() | |
| public_key = None | |
| for key in keys_data.get("keys", []): | |
| if key.get("kid") == header.get("kid"): | |
| public_key = pyjwt.algorithms.RSAAlgorithm.from_jwk(key) | |
| break | |
| if not public_key: | |
| print(f"[AUTH] No matching key for kid={header.get('kid')}; available={[k.get('kid') for k in keys_data.get('keys', [])]}") | |
| raise HTTPException(status_code=401, detail="Invalid token key") | |
| return pyjwt.decode( | |
| token, | |
| public_key, | |
| algorithms=["RS256"], | |
| audience=project_id, | |
| ) | |
| except pyjwt.ExpiredSignatureError: | |
| print("[AUTH] Token expired") | |
| raise HTTPException(status_code=401, detail="Token expired") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"[AUTH] Verify failed: {type(e).__name__}: {e}; project_id={project_id}") | |
| raise HTTPException(status_code=401, detail=f"Invalid token: {str(e)}") | |
| async def verify_firebase_token( | |
| credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme), | |
| ) -> dict: | |
| if not credentials: | |
| raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") | |
| return await verify_firebase_token_string(credentials.credentials) | |
| async def get_current_user(payload: dict = Depends(verify_firebase_token)) -> "User": | |
| """Get or create user from Firebase token.""" | |
| firebase_uid = payload.get("uid") or payload.get("sub") | |
| email = payload.get("email", "") | |
| async with AsyncSessionLocal() as session: | |
| result = await session.execute(select(User).where(User.firebase_uid == firebase_uid)) | |
| user = result.scalar_one_or_none() | |
| if not user: | |
| user = User( | |
| firebase_uid=firebase_uid, | |
| email=email, | |
| display_name=payload.get("name", email.split("@")[0]), | |
| photo_url=payload.get("picture"), | |
| is_active=True, | |
| ) | |
| session.add(user) | |
| await session.commit() | |
| await session.refresh(user) | |
| return user | |